Created
February 3, 2023 21:14
-
-
Save ezyang/b22fa7b72b7349137211d8dc7041f758 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import operator | |
import itertools | |
import sys | |
from typing import Tuple | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch.fx.experimental.symbolic_shapes import ShapeEnv | |
from torch._refs import _maybe_broadcast | |
from torch._prims_common import is_same_shape, make_contiguous_strides_for | |
""" | |
How to model check two meta function implementations? | |
Limit ourselves ONLY to sizes (dtype/device can be checked through exhaustive | |
enumeration / is quite a bit easier to solve for.) | |
General recipe: things that are symbolically represented, use Z3. Things | |
that are specialized (including guards), test combinatorially. | |
Specialized things: | |
- Number of dimensions | |
- 0/1 size dims | |
- ~~Duck sizing~~ turned off | |
Iterate through all specialized things (in particular, pick a size [0, 1, 2]). | |
For each configuration, run with Z3. The guard configuration says "what we've | |
tested". Invert it and ask Z3 for another example that doesn't match this. | |
Keep going (OR'ing together) until nothing less. Move onto next configuration. | |
""" | |
def gen_size(): | |
dim = MAX_DIM | |
for size in itertools.product([1, 2], repeat=dim): | |
yield size | |
class TensorSpec: | |
__slots__ = ['_size', '_stride'] | |
def __init__(self, size, stride): | |
self._size = size | |
self._stride = stride | |
def numel(self): | |
r = 1 | |
for s in self._size: | |
r *= s | |
return r | |
@property | |
def ndim(self): | |
return len(self._size) | |
@property | |
def shape(self): | |
return self._size | |
def size(self): | |
return self._size | |
def stride(self): | |
return self._stride | |
def check_same_shape(*args): | |
shape = None | |
for arg in args: | |
if shape is None: | |
shape = arg.shape | |
assert is_same_shape(shape, arg.shape) | |
# NOTE: Based on the implementation in TensorIterator.cpp, but note that | |
# the note [Computing output strides] is incorrect, because it | |
# says that strides will be preserved even if they are not | |
# "non overlapping and dense", but this is incorrect. The | |
# output of elementwise operations are always given | |
# non overlapping and dense strides. | |
# This is also INCORRECT because it does not model TensorIterator's | |
# short-circuit, which can cause different strides. | |
def compute_elementwise_output_permutation(*tensors) -> Tuple[int, ...]: | |
check_same_shape(*tensors) | |
# Short-circuits for shapes with zero or one dimensions | |
# TODO: are these necessary? | |
ndim = tensors[0].ndim | |
if ndim == 0: | |
return () | |
if ndim == 1: | |
return (1,) | |
shape = tensors[0].shape | |
def should_swap(idx_a, idx_b): | |
for tensor in tensors: | |
stride_a = tensor.stride()[idx_a] | |
stride_b = tensor.stride()[idx_b] | |
if stride_a == 0 or stride_b == 0: | |
continue | |
if stride_a < stride_b: | |
return -1 | |
if stride_a > stride_b: | |
return 1 | |
# stride_a == stride_b | |
if shape[idx_a] > shape[idx_b]: | |
return 1 | |
# Note: this case is hit if all strides are zero, | |
# or all strides are equal and all dimensions have the same length | |
return 0 | |
perm = list(reversed(range(ndim))) | |
# insertion sort with support for ambiguous comparisons | |
for i in range(1, ndim): | |
dim1 = i | |
for dim0 in reversed(range(i)): | |
comparison = should_swap(perm[dim0], perm[dim1]) | |
if comparison > 0: | |
perm[dim0], perm[dim1] = perm[dim1], perm[dim0] | |
dim1 = dim0 | |
elif comparison < 0: | |
break | |
# Identity permutation is [2, 1, 0] | |
return perm | |
def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]: | |
perm = compute_elementwise_output_permutation(*tensors) | |
shape = tensors[0].shape | |
return get_permuted_strides_for(perm, shape) | |
def get_permuted_strides_for(perm, shape): | |
ndim = len(shape) | |
permuted_shape = [-1] * ndim | |
for idx, x in enumerate(reversed(perm)): | |
permuted_shape[idx] = shape[x] | |
new_strides = make_contiguous_strides_for(permuted_shape) | |
permuted_strides = [-1] * ndim | |
for idx, x in enumerate(reversed(perm)): | |
permuted_strides[x] = new_strides[idx] | |
return tuple(permuted_strides) | |
# This function is equivalent to compute_contiguous() from TensorImpl.cpp | |
def is_contiguous(a) -> bool: | |
""" | |
Tests whether a tensor is contiguous or not. | |
Tensors are contiguous when they have no elements, | |
one element, or when they have "nested" strides. | |
""" | |
if a.numel() < 2: | |
return True | |
expected_stride = 1 | |
for x, y in reversed(tuple(zip(a.shape, a.stride()))): | |
# Skips checking strides when a dimension has length 1 | |
if x == 1: | |
continue | |
if y != expected_stride: | |
return False | |
expected_stride = expected_stride * x | |
return True | |
# This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp | |
def is_channels_last_contiguous_2d(a) -> bool: | |
# NHWC or not channels last 2D contiguous | |
if a.ndim != 4: | |
return False | |
expected_stride = 1 | |
for idx in (1, 3, 2, 0): | |
length = a.shape[idx] | |
if length == 1: | |
continue | |
stride = a.stride()[idx] | |
if stride != expected_stride: | |
return False | |
expected_stride *= length | |
return True | |
def fast_path(*operands): | |
ndim = len(operands[0].shape) | |
is_contiguous_ = True | |
is_channels_last = True | |
# TODO: is_non-overlapping_and_dense (not bound from Python | |
# no inplace, no out, everything defined | |
for op in operands: | |
is_contiguous_ = is_contiguous_ and is_contiguous(op) | |
is_channels_last = is_channels_last and is_channels_last_contiguous_2d(op) | |
if is_contiguous_: | |
return get_permuted_strides_for(list(reversed(range(ndim))), operands[0].shape) | |
# if is_channels_last | |
return None | |
from torch._dynamo.source import LocalSource | |
MAX_DIM = 3 | |
for size in gen_size(): | |
for astride in gen_size(): | |
for bstride in gen_size(): | |
shape_env = ShapeEnv() | |
def inflate(prefix, xs): | |
return tuple( | |
shape_env.create_symintnode( | |
shape_env.create_symbol(x, LocalSource(f"{prefix}{i}")) | |
) | |
for i, x in enumerate(xs) | |
) | |
i_size = inflate("s", size) | |
a = TensorSpec(i_size, inflate("a", astride)) | |
b = TensorSpec(i_size, inflate("b", bstride)) | |
r1 = compute_elementwise_output_strides(a, b) | |
r2 = fast_path(a, b) | |
def deflate(xs): | |
return tuple( | |
shape_env.size_hint(x.node.expr) | |
if isinstance(x, torch.SymInt) else x | |
for x in xs | |
) | |
def check_significant_strides(size, astride, bstride): | |
for idx in range(a.ndim): | |
if astride[idx] != bstride[idx] and size[idx] > 1: | |
return False | |
return True | |
if r2 is not None: | |
matches = check_significant_strides(size, r1, r2) | |
if not matches: | |
print((torch.empty_strided(size, astride) + torch.empty_strided(size, bstride)).stride()) | |
raise RuntimeError(f"{deflate(r1)} != {deflate(r2)} for {size} {astride} {bstride}") | |
#print([g.expr for g in shape_env.guards]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment