Created
April 14, 2016 00:24
-
-
Save igul222/06ef43954cd73e27b160ca2b4f547b15 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
_baseslice = slice | |
class Slice(object): | |
def __init__(self, _tensor): | |
self.tensor = _tensor | |
def __getitem__(self, slice_spec): | |
if not isinstance(slice_spec, (list, tuple)): | |
slice_spec = [slice_spec] | |
indices = [] | |
sizes = [] | |
squeeze_dims = [] | |
for dim, s in enumerate(slice_spec): | |
if isinstance(s, int) or isinstance(s, tf.Tensor): | |
indices.append(s) | |
sizes.append(1) | |
squeeze_dims.append(dim) | |
elif isinstance(s, _baseslice): | |
start = s.start if s.start is not None else 0 | |
indices.append(start) | |
# NOTE(mrry): If the stop is not specified, Python substitutes | |
# sys.maxsize, which is typically (2 ** 63) - 1. Since Slice currently | |
# supports signed DT_INT32 arguments, we use -1 to specify that all | |
# elements should be captured. | |
if s.stop is None or (not isinstance(s.stop, tf.Tensor) and s.stop == sys.maxsize): | |
sizes.append(-1) | |
else: | |
sizes.append(s.stop - start) | |
elif s is Ellipsis: | |
raise NotImplementedError("Ellipsis is not currently supported") | |
else: | |
raise TypeError("Bad slice index %s of type %s" % (s, type(s))) | |
sliced = tf.slice(self.tensor, tf.pack(indices), tf.pack(sizes)) | |
if squeeze_dims: | |
return t.squeeze(sliced, squeeze_dims=squeeze_dims) | |
else: | |
return sliced |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment