Skip to content

Instantly share code, notes, and snippets.

@igul222
Created April 14, 2016 00:24
Show Gist options
  • Save igul222/06ef43954cd73e27b160ca2b4f547b15 to your computer and use it in GitHub Desktop.
Save igul222/06ef43954cd73e27b160ca2b4f547b15 to your computer and use it in GitHub Desktop.
_baseslice = slice
class Slice(object):
def __init__(self, _tensor):
self.tensor = _tensor
def __getitem__(self, slice_spec):
if not isinstance(slice_spec, (list, tuple)):
slice_spec = [slice_spec]
indices = []
sizes = []
squeeze_dims = []
for dim, s in enumerate(slice_spec):
if isinstance(s, int) or isinstance(s, tf.Tensor):
indices.append(s)
sizes.append(1)
squeeze_dims.append(dim)
elif isinstance(s, _baseslice):
start = s.start if s.start is not None else 0
indices.append(start)
# NOTE(mrry): If the stop is not specified, Python substitutes
# sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently
# supports signed DT_INT32 arguments, we use -1 to specify that all
# elements should be captured.
if s.stop is None or (not isinstance(s.stop, tf.Tensor) and s.stop == sys.maxsize):
sizes.append(-1)
else:
sizes.append(s.stop - start)
elif s is Ellipsis:
raise NotImplementedError("Ellipsis is not currently supported")
else:
raise TypeError("Bad slice index %s of type %s" % (s, type(s)))
sliced = tf.slice(self.tensor, tf.pack(indices), tf.pack(sizes))
if squeeze_dims:
return t.squeeze(sliced, squeeze_dims=squeeze_dims)
else:
return sliced
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment