-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Refactor patch and pos embed resampling based on feedback from https://github.com/stas-sl #2518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
timm/models/naflexvit.py
Outdated
padding_mode='border', | ||
).to(dtype=x.dtype) # (B, C, H_out, W_out) | ||
|
||
# NOTE if we bring in patch_valid, can explicitly mask padding tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I’m fully getting this part of the code, but if the goal is to explicitly mask padded tokens, can’t you just write like this:
x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] * patch_valid[..., None]
Though, IMHO it shouldn’t really make a difference, as these tokens will be masked in the attention layers anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, could be done that way, I was also going to explore whether the different indexing approach had any impact on throughput (in combo with masking having any impact). The masking shouldn't make a difference due to the attention and pooling masks...
doubt there is much point in keeping the more complicated indexing or the masking, validation didn't show any noteworthy throughput difference, but thought i might check at least one train run
Would you also be interested in having a similar implementation using It might look like this: def _apply_factorized_naflex_pos_embed_grid_sample(self, x, patch_coord):
B, _, C = x.shape[0]
shapes = patch_coord.amax(dim=1) + 1
pe_x = rearrange('1 w c -> b c 1 w', self.pos_embed_x, b=B)
pe_y = rearrange('1 h c -> b c 1 h', self.pos_embed_y, b=B)
grid_size = shapes.amax(0)
theta_x = torch.zeros(B, 2, 3, device=x.device)
theta_x[:, 0, 0] = grid_size[1] / shapes[:, 1]
theta_x[:, 0, 2] = theta_x[:, 0, 0] - 1
theta_x[:, 1, 1] = 1
theta_x[:, 1, 2] = 0
theta_y = torch.zeros(B, 2, 3, device=x.device)
theta_y[:, 0, 0] = grid_size[0] / shapes[:, 0]
theta_y[:, 0, 2] = theta_y[:, 0, 0] - 1
theta_y[:, 1, 1] = 1
theta_y[:, 1, 2] = 0
grid_x = F.affine_grid(theta_x, (B, C, 1, grid_size[1]), align_corners=False)
grid_y = F.affine_grid(theta_y, (B, C, 1, grid_size[0]), align_corners=False)
pe_x = F.grid_sample(pe_x, grid_x, mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border')
pe_y = F.grid_sample(pe_y, grid_y, mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border')
bi = torch.arange(B, device=x.device)[:, None]
x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]] |
@stas-sl yeah, thanks! I was going to look at the 1d factorized impl as figured it wouldn't much extra work once the 2d is working well. And yeah, should be less difference. Though worth noting, flipping between grid_sample and interpolate wasn't very different (especially with 'border' for padding) for the models I had pretrained in pytorch with the 2d learned code while testing naflexvit impl. I need to revisit the siglip-2 weights. |
…nsistent, add factorized grid_sample and fixed grid size methods.
@stas-sl okay, had some time to go through and clean this up a bit. Added the 1d factorized + grid_sample, and also the missing fixed grid for factorized. Cleaned up consistency of the interface, no more calculation of grid size arrays when using grid_sample (removes a graph break in compiled mode). |
@rwightman, awesome! Thanks for taking the time - everything looks good to me now 👍 |
Refactoring patch embed resampling, and adding grid sampling pos embed alternative...