```py import thunder import torch def fn(a, idx): return a[idx] a = torch.randn(5, 5, 7, requires_grad=True) idx = (None, [1, 2]) print(fn(a, idx).shape) # torch.Size([1, 2, 5, 7]) jfn = thunder.jit(fn) print(jfn(a, idx).shape) # torch.Size([2, 1, 5, 7]) ```