Skip to content

Instantly share code, notes, and snippets.

@dmelcer9
Created December 19, 2024 20:03
Show Gist options
  • Save dmelcer9/286d5fd674111380107b321998313f34 to your computer and use it in GitHub Desktop.
Save dmelcer9/286d5fd674111380107b321998313f34 to your computer and use it in GitHub Desktop.
A nice multi-level indexing function on tensors
def nice_index(tens: Tensor, index: Tensor) -> Tensor:
"""
A multi-level indexing function.
The last dimension of the index is a multi-level specifier
All other dimensions of the index are batch
For example, if the index is [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
the result will be [[tens[1, 2], tens[3, 4]], [tens[5, 6], tens[7, 8]]]
See tests for more details
"""
num_levels_of_index = index.shape[-1]
dims_remaining_of_tens = len(tens.shape) - num_levels_of_index
if dims_remaining_of_tens < 0:
raise IndexError(
f"Index too long: {num_levels_of_index} levels on a {len(tens)}-dimensional tensor"
)
output_shape = index.shape[:-1] + tens.shape[num_levels_of_index:]
batch_shape = index.shape[:-1]
batch_flat_size = functools.reduce(mul, batch_shape, 1)
flat_batch_index = index.reshape(batch_flat_size, num_levels_of_index)
if flat_batch_index.numel() > 0:
result = tens[tuple(map(tuple, flat_batch_index.T.long().tolist()))]
else:
result = tens.unsqueeze(0).tile(batch_flat_size)
return result.reshape(output_shape)
class TestNiceIndex(unittest.TestCase):
def expect(self, tens, index, output):
o = nice_index(torch.as_tensor(tens), torch.as_tensor(index))
self.assertEqual(o.tolist(), output)
def test_1d_1d(self):
self.expect([1, 2, 3, 4], [3], 4)
def test_1d_2d(self):
self.expect([1, 2, 3, 4], [[2], [1]], [3, 2])
def test_1d_3d(self):
self.expect([0, 1, 2, 3], [[[1], [2]], [[3], [0]]], [[1, 2], [3, 0]])
def test_1d_0d(self):
self.expect([0, 1, 2, 3], [], [0, 1, 2, 3])
self.expect([0, 1, 2, 3], [[], []], [[0, 1, 2, 3], [0, 1, 2, 3]])
def test_2d_1d(self):
self.expect([[0, 1], [2, 3]], [1], [2, 3])
self.expect([[0, 1], [2, 3]], [1, 1], 3)
def test_2d_2d(self):
self.expect([[0, 1], [2, 3]], [[1], [0]], [[2, 3], [0, 1]])
self.expect([[0, 1], [2, 3]], [[1, 0], [0, 1]], [2, 1])
def test_2d_3d(self):
self.expect(
[[0, 1], [2, 3]], [[[0, 1], [1, 0]], [[1, 1], [0, 0]]], [[1, 2], [3, 0]]
)
@dmelcer9
Copy link
Author

The core indexing comes from https://stackoverflow.com/a/52092603, the rest of the function is handling more complex batch shapes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment