Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created August 22, 2023 08:12
Show Gist options
  • Save gau-nernst/a8da6f1c2ffeea0daf76ec16eb8bf5fb to your computer and use it in GitHub Desktop.
Save gau-nernst/a8da6f1c2ffeea0daf76ec16eb8bf5fb to your computer and use it in GitHub Desktop.
PyTorch serialized list
import torch
# Modified from https://github.com/ppwwyyxx/RAM-multiprocess-dataloader
class PyTorchStrList:
def __init__(self, items: list[str]):
data = [torch.frombuffer(x.encode(), dtype=torch.uint8) for x in items]
lengths = [0] + [x.shape[0] for x in data]
self.data = torch.cat(data, 0)
self.index = torch.tensor(lengths).cumsum_(0)
def __getitem__(self, i: int) -> str:
assert i >= 0
start, end = self.index[i : i + 2].tolist()
return memoryview(self.data[start:end].numpy()).tobytes().decode()
def __len__(self) -> int:
return self.index.shape[0] - 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment