-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[Gemma3n] Fix audio batching #24052
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Gemma3n] Fix audio batching #24052
Conversation
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
input_features_mask=MultiModalFieldConfig.batched("audio")) | ||
return dict( | ||
pixel_values=MultiModalFieldConfig.batched("image"), | ||
input_features=MultiModalFieldConfig.batched("audio"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need input_features
in that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I definitely want to review that once I enable that processor test that required a hf transformer bump.
For now there's no big overhead at runtime 'cause the unpadded it's just a view.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request fixes an issue with batched audio inference for Gemma3n models by padding audio sequences. The core logic involves introducing a padded version of input_features
for batched processing by the audio tower, while keeping an unpadded version for caching. The changes are generally good, but I've identified a critical issue with a .squeeze(1)
call that will likely cause a crash, and a high-severity issue with an incorrect type hint.
assert self.audio_tower is not None | ||
input_features = audio_input["input_features"].squeeze(1) | ||
# Run on padded features to enable batching | ||
input_features = audio_input["input_features_padded"].squeeze(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of .squeeze(1)
here is likely incorrect and will cause a runtime error. input_features_padded
is expected to have a shape of (batch_size, seq_length, num_features)
. Calling .squeeze(1)
will only succeed if seq_length
is 1, which is not generally the case for audio features. This seems to be a pre-existing issue, but since this line is modified, it's important to address it. The .squeeze(1)
should probably be removed.
input_features = audio_input["input_features_padded"].squeeze(1) | |
input_features = audio_input["input_features_padded"] |
Thanks for the fix @NickLucche ! |
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as long as tests still pass
@DarkLight1337 looks green |
Signed-off-by: NickLucche <[email protected]> Signed-off-by: 子悬 <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Signed-off-by: NickLucche <[email protected]>
Fix #24006 by enabling proper batched audio_tower inference.
This is done by simply padding the sequences to the max seq len.
Thanks to @pratapyash for reporting the bug!
Test with