Forked from jingwangsg/gist:365421c1118e2968f3e1f570edf39932
Created
September 24, 2024 22:32
-
-
Save bioothod/c995f33ed119cd618adb8d8aec6589c7 to your computer and use it in GitHub Desktop.
Generalized torch.gather
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import re | |
from einops import rearrange, repeat | |
from functools import lru_cache | |
import copy | |
import numpy as np | |
def isolate_patterns(p): | |
snippets = re.split(r"\[|\]", p) | |
inside_brackets = snippets[1].strip() | |
outside_brackets = " ".join([snippets[0].strip(), snippets[2].strip()]) | |
n_match = re.search(r"\s(\w+)\s*\[", p) | |
select_axis = n_match.group(1) if n_match else "" | |
return inside_brackets, outside_brackets, select_axis | |
def reshape_patterns(tensor_pattern, indices_pattern, select_axis): | |
parsed_tensor_pattern = tensor_pattern.split() | |
parsed_indices_pattern = indices_pattern.split() | |
common_axis = list(set(parsed_tensor_pattern) & set(parsed_indices_pattern)) | |
new_tensor_pattern = copy.copy(common_axis) | |
new_tensor_pattern.append(select_axis) | |
for ax in parsed_tensor_pattern: | |
if ax not in common_axis and ax != select_axis: | |
new_tensor_pattern.append(ax) | |
# new_select_pattern = copy.copy(common_axis) | |
# for ax in parsed_select_pattern: | |
# if ax not in common_axis: | |
# new_select_pattern.append(ax) | |
# new_select_pattern = " ".join(new_select_pattern) | |
new_tensor_pattern = " ".join(new_tensor_pattern) | |
return new_tensor_pattern, common_axis | |
def get_axis_dim(tensor, pattern): | |
ret_dict = {} | |
for i, ax in enumerate(pattern.split()): | |
ret_dict[ax] = tensor.shape[i] | |
return ret_dict | |
def gather_general(tensor, indices, pattern): | |
""" | |
This is an extended version of torch.gather, | |
For original torch.gather, it only supports the case where "tensor.ndim == indices.ndim" | |
Then depending on `dim`, it gathers the value from dim-th axis | |
However, suppose we have a case below | |
* tensor (M, N, D) | |
* indices (M, K, L) | |
* outputs (M, K, L, D) | |
where tensor[m, indices[m, k, l], :] = outputs[m, k, l, :] | |
It differs for original cases as follows | |
* tensor.ndim != indices.ndim | |
* position of axis are flexible at both sides | |
(e.g. for original, dimensions are fixed, when dim=2 and both ndim=3, output[i][j][k] = input[i][j][indices[i][j][k]]) | |
* allow for extra axis in indices (e.g. K in this case) | |
these extra axis will act similar as batch axis for the output tensor | |
* allow for indices to ignore some axis from tensor (e.g. D in this case) | |
these axes will expand to a hyperplane, which means we are gathering "hyperplane" instead of "value" | |
The output dimension must be union of (1) common axis from both tensor and indices (2) extra axis from indices | |
We define the notation in light of einops | |
``` | |
m n[m k l] d -> m k l d | |
``` | |
Position of [] stands for gathered axis, common axis are similar to original gather, and extra axis are appended at the end | |
""" | |
groups = re.match(r"(.*)->(.*)", pattern) | |
_tensor_pattern = groups[1].strip() | |
output_pattern = groups[2].strip() | |
device = tensor.device | |
indices_flat = indices.reshape(-1) | |
indices_pattern, tensor_pattern, select_axis = isolate_patterns(_tensor_pattern) | |
# dim_per_axis_tensor = get_axis_dim(tensor, tensor_pattern) | |
dim_per_axis_indices = get_axis_dim(indices, indices_pattern) | |
new_tensor_pattern, common_axis = reshape_patterns( | |
tensor_pattern=tensor_pattern, | |
indices_pattern=indices_pattern, | |
select_axis=select_axis, | |
) | |
reshaped_tensor = rearrange(tensor, f"{tensor_pattern} -> {new_tensor_pattern}") | |
def get_pad_indices(ax): | |
pad_ind = repeat( | |
torch.arange(dim_per_axis_indices[ax], device=device), | |
f"{ax} -> {indices_pattern}", | |
**dim_per_axis_indices, | |
).reshape(-1) | |
return pad_ind | |
select_indices = [get_pad_indices(ax) for ax in common_axis] | |
select_indices += [indices_flat] | |
selected_tensor = reshaped_tensor[select_indices] | |
parsed_indices_pattern = indices_pattern.split() | |
selected_pattern = f"({indices_pattern}) " \ | |
+ " ".join([ax for ax in new_tensor_pattern.split() \ | |
if ax not in parsed_indices_pattern and ax != select_axis]) | |
selected_tensor = rearrange( | |
selected_tensor, | |
f"{selected_pattern} -> {output_pattern}", | |
**dim_per_axis_indices, | |
) | |
return selected_tensor | |
def test_gather_general(): | |
M = 300 | |
K = 100 | |
L = 200 | |
N = 100 | |
D = 400 | |
tensor = torch.arange(M * N * D).reshape(M, N, D).cuda() | |
indices = torch.randint(0, N, (M, K, L)).cuda() | |
# m = 2, n = 3, k = 2, l = 3, d = 4 | |
# select X[m, I[k, l], d] given k, l, m | |
pattern = "m n[m k l] d -> k l m d" | |
ret = gather_general(tensor, indices, pattern=pattern) | |
def check_by_index(k, l, m): | |
return (ret[k, l, m, :] == tensor[m, indices[m, k, l], :]).all().item() | |
for t in range(100): | |
k = np.random.randint(0, K) | |
l = np.random.randint(0, L) | |
m = np.random.randint(0, M) | |
check_res = check_by_index(k, l, m) | |
if not check_res: | |
print(f"Fail at {k, l, m}") | |
break | |
print(f"Pass at {k, l, m}") | |
if __name__ == "__main__": | |
test_gather_general() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment