Skip to content

Conversation

rwightman
Copy link
Collaborator

Refactoring patch embed resampling, and adding grid sampling pos embed alternative...

@rwightman rwightman mentioned this pull request Jun 13, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

padding_mode='border',
).to(dtype=x.dtype) # (B, C, H_out, W_out)

# NOTE if we bring in patch_valid, can explicitly mask padding tokens
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I’m fully getting this part of the code, but if the goal is to explicitly mask padded tokens, can’t you just write like this:

x += pos_embed[bi, :, patch_coord[..., 0], patch_coord[..., 1]] * patch_valid[..., None]

Though, IMHO it shouldn’t really make a difference, as these tokens will be masked in the attention layers anyway.

Copy link
Collaborator Author

@rwightman rwightman Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, could be done that way, I was also going to explore whether the different indexing approach had any impact on throughput (in combo with masking having any impact). The masking shouldn't make a difference due to the attention and pooling masks...

doubt there is much point in keeping the more complicated indexing or the masking, validation didn't show any noteworthy throughput difference, but thought i might check at least one train run

@stas-sl
Copy link

stas-sl commented Jun 18, 2025

Would you also be interested in having a similar implementation using grid_sample with factorized embeddings? No pressure at all - just thought it might be nice for completeness, since you already have this approach with the learned embeddings. In fact, for 1D interpolation, as there is no antialias=True the results for F.interpolate and F.grid_sample are very close numerically.

It might look like this:

def _apply_factorized_naflex_pos_embed_grid_sample(self, x, patch_coord):
    B, _, C = x.shape[0]
    shapes = patch_coord.amax(dim=1) + 1
    pe_x = rearrange('1 w c -> b c 1 w', self.pos_embed_x, b=B)
    pe_y = rearrange('1 h c -> b c 1 h', self.pos_embed_y, b=B)
    grid_size = shapes.amax(0)
    theta_x = torch.zeros(B, 2, 3, device=x.device)
    theta_x[:, 0, 0] = grid_size[1] / shapes[:, 1]
    theta_x[:, 0, 2] = theta_x[:, 0, 0] - 1
    theta_x[:, 1, 1] = 1
    theta_x[:, 1, 2] = 0
    theta_y = torch.zeros(B, 2, 3, device=x.device)
    theta_y[:, 0, 0] = grid_size[0] / shapes[:, 0]
    theta_y[:, 0, 2] = theta_y[:, 0, 0] - 1
    theta_y[:, 1, 1] = 1
    theta_y[:, 1, 2] = 0
    grid_x = F.affine_grid(theta_x, (B, C, 1, grid_size[1]), align_corners=False)
    grid_y = F.affine_grid(theta_y, (B, C, 1, grid_size[0]), align_corners=False)
    pe_x = F.grid_sample(pe_x, grid_x, mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border')
    pe_y = F.grid_sample(pe_y, grid_y, mode=self.pos_embed_interp_mode, align_corners=False, padding_mode='border')
    bi = torch.arange(B, device=x.device)[:, None]
    x += pe_x[bi, :, 0, patch_coord[..., 1]] + pe_y[bi, :, 0, patch_coord[..., 0]]

@rwightman
Copy link
Collaborator Author

rwightman commented Jun 19, 2025

@stas-sl yeah, thanks! I was going to look at the 1d factorized impl as figured it wouldn't much extra work once the 2d is working well.

And yeah, should be less difference. Though worth noting, flipping between grid_sample and interpolate wasn't very different (especially with 'border' for padding) for the models I had pretrained in pytorch with the 2d learned code while testing naflexvit impl. I need to revisit the siglip-2 weights.

…nsistent, add factorized grid_sample and fixed grid size methods.
@rwightman
Copy link
Collaborator Author

rwightman commented Jun 19, 2025

@stas-sl okay, had some time to go through and clean this up a bit. Added the 1d factorized + grid_sample, and also the missing fixed grid for factorized. Cleaned up consistency of the interface, no more calculation of grid size arrays when using grid_sample (removes a graph break in compiled mode).

@stas-sl
Copy link

stas-sl commented Jun 19, 2025

@rwightman, awesome! Thanks for taking the time - everything looks good to me now 👍

@rwightman rwightman merged commit 996c149 into main Jun 19, 2025
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants