Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
gau-nernst / int4mm_cpu.py
Created April 8, 2025 02:32
PyTorch int4mm_cpu
import torch
print(torch.__version__)
group_size = 32
w = torch.randn(512, 1024)
w_groups = w.unflatten(1, (-1, group_size))
min_val = w_groups.amin(2, keepdim=True)
max_val = w_groups.amax(2, keepdim=True)
scale = (max_val - min_val) / 15 # scale (max-min) to 15
@gau-nernst
gau-nernst / offload.py
Created December 12, 2024 12:44
Full CPU offload for single-GPU training
import torch
from torch import Tensor, nn
from tqdm import tqdm
class PerLayerOffloadWithBackwardGradient:
"This version also offloads gradients. To ensure proper synchronization, it will take control over the optimizer."
def __init__(
self,
@gau-nernst
gau-nernst / flux_infer.py
Last active January 24, 2025 06:33
FLUX CPU offload
import torch
from diffusers import FluxPipeline
from torch import nn
class ModelOffloaderV2:
def __init__(self, model: nn.Module, record_stream: bool = False):
# move model to pinned memory. keep a model copy in CPU pinned memory.
for p in model.parameters():
p.data = p.data.cpu().pin_memory()
@gau-nernst
gau-nernst / fp8_linear.py
Created August 24, 2024 04:51
FP8 linear triton with row-wise scaling
import torch
import triton
import triton.language as tl
from torch import Tensor
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
configs = [
(128, 256, 64, 3, 8),
(64, 256, 32, 4, 4),
@gau-nernst
gau-nernst / ffmpeg.py
Created September 17, 2023 03:39
Read audio with ffmpeg for PyTorch
import subprocess
import torch
def load_audio(path: str, sample_rate: int) -> torch.Tensor:
cmd = f"{FFMPEG_PATH} -i {path} -ar {sample_rate} -ac 1 -f s32le -"
proc = subprocess.run(shlex.split(cmd), capture_output=True)
if proc.returncode:
raise RuntimeError(proc.stderr.decode())
@gau-nernst
gau-nernst / pytorch_serialized_list.py
Created August 22, 2023 08:12
PyTorch serialized list
import torch
# Modified from https://github.com/ppwwyyxx/RAM-multiprocess-dataloader
class PyTorchStrList:
def __init__(self, items: list[str]):
data = [torch.frombuffer(x.encode(), dtype=torch.uint8) for x in items]
lengths = [0] + [x.shape[0] for x in data]
self.data = torch.cat(data, 0)
self.index = torch.tensor(lengths).cumsum_(0)
@gau-nernst
gau-nernst / ddg_scrape.py
Created April 19, 2023 04:08
Scrape images from DuckDuckGo
from typing import List
from playwright.sync_api import sync_playwright
import requests
import re
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import argparse
@gau-nernst
gau-nernst / tiff_encoder.py
Created March 21, 2023 13:47
Simple TIFF encoder
from enum import IntEnum
class FieldType(IntEnum):
BYTE = 1
ASCII = 2
SHORT = 3
LONG = 4
RATIONAL = 5
@gau-nernst
gau-nernst / torchvision_extractor.py
Last active December 22, 2021 13:52
Torchvision feature extractor
from torch import nn
from torchvision.models import resnet, mobilenet, efficientnet
from torchvision.models.feature_extraction import create_feature_extractor
class _Extractor(nn.Module):
def __init__(self, backbone, node_names):
super().__init__()
self.feat_extractor = create_feature_extractor(backbone, node_names)
@gau-nernst
gau-nernst / separable_conv.py
Last active December 21, 2021 06:11
Separable Convolution Block in PyTorch
from torch import nn
class SeparableConv2d(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, norm_layer=None, activation=None):
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation is None:
activation = nn.ReLU6