This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
print(torch.__version__) | |
group_size = 32 | |
w = torch.randn(512, 1024) | |
w_groups = w.unflatten(1, (-1, group_size)) | |
min_val = w_groups.amin(2, keepdim=True) | |
max_val = w_groups.amax(2, keepdim=True) | |
scale = (max_val - min_val) / 15 # scale (max-min) to 15 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import Tensor, nn | |
from tqdm import tqdm | |
class PerLayerOffloadWithBackwardGradient: | |
"This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer." | |
def __init__( | |
self, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from diffusers import FluxPipeline | |
from torch import nn | |
class ModelOffloaderV2: | |
def __init__(self, model: nn.Module, record_stream: bool = False): | |
# move model to pinned memory. keep a model copy in CPU pinned memory. | |
for p in model.parameters(): | |
p.data = p.data.cpu().pin_memory() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
from torch import Tensor | |
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) | |
configs = [ | |
(128, 256, 64, 3, 8), | |
(64, 256, 32, 4, 4), |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import subprocess | |
import torch | |
def load_audio(path: str, sample_rate: int) -> torch.Tensor: | |
cmd = f"{FFMPEG_PATH} -i {path} -ar {sample_rate} -ac 1 -f s32le -" | |
proc = subprocess.run(shlex.split(cmd), capture_output=True) | |
if proc.returncode: | |
raise RuntimeError(proc.stderr.decode()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
# Modified from https://github.com/ppwwyyxx/RAM-multiprocess-dataloader | |
class PyTorchStrList: | |
def __init__(self, items: list[str]): | |
data = [torch.frombuffer(x.encode(), dtype=torch.uint8) for x in items] | |
lengths = [0] + [x.shape[0] for x in data] | |
self.data = torch.cat(data, 0) | |
self.index = torch.tensor(lengths).cumsum_(0) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from playwright.sync_api import sync_playwright | |
import requests | |
import re | |
import json | |
import os | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from tqdm import tqdm | |
import argparse |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import IntEnum | |
class FieldType(IntEnum): | |
BYTE = 1 | |
ASCII = 2 | |
SHORT = 3 | |
LONG = 4 | |
RATIONAL = 5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
from torchvision.models import resnet, mobilenet, efficientnet | |
from torchvision.models.feature_extraction import create_feature_extractor | |
class _Extractor(nn.Module): | |
def __init__(self, backbone, node_names): | |
super().__init__() | |
self.feat_extractor = create_feature_extractor(backbone, node_names) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch import nn | |
class SeparableConv2d(nn.Sequential): | |
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, norm_layer=None, activation=None): | |
super().__init__() | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
if activation is None: | |
activation = nn.ReLU6 |