Skip to content

Instantly share code, notes, and snippets.

View vgoklani's full-sized avatar

Vishal Goklani vgoklani

View GitHub Profile
@willccbb
willccbb / grpo_demo.py
Last active May 2, 2025 08:43
GRPO Llama-1B
# train_grpo.py
#
# See https://github.com/willccbb/verifiers for ongoing developments
#
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
@drisspg
drisspg / scaled_mm_api.md
Last active February 8, 2025 16:03
Scaled MM API

Summary

This doc servers as a quick reference for the _scaled_mm API and how it has changed overtime for each major version of PyTorch.


NOTE The leading underscore is intended here and we make no current FC/BC guarantees on this API. That being said it is currently the only OP that has native support for FP8 matmuls within the PyTorch Libary. We are planning to make an official Public api for this. Until then this is subject to change but you can use this doc as a reference.


@andrewnc
andrewnc / lr_scheduler.py
Created December 27, 2024 02:24
God's Chosen Schedule
import math
import torch
from torch.optim.lr_scheduler import _LRScheduler
from dataclasses import dataclass
from typing import List
@dataclass
class SchedulePhase:
"""Defines a phase in the learning rate schedule"""
percent: float # Percentage of total steps this phase covers
@awni
awni / resnet_mlx.py
Created September 7, 2024 20:02
MLX ResNet18 Inference Benchmark
from huggingface_hub import snapshot_download
import mlx.core as mx
import mlx.nn as nn
import time
class Block(nn.Module):
def __init__(self, in_dims, dims, stride=1):
super().__init__()
@kalomaze
kalomaze / modeling_mixtral.py
Created May 5, 2024 03:38
Fixed Mixtral training code for HF Transformers
# coding=utf-8
# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@vgoklani
vgoklani / torch_ddp_verify.py
Created April 17, 2024 22:21 — forked from jxmorris12/torch_ddp_verify.py
verify parameter weights & gradients in pytorch
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
if hasattr(model, "module"):
model = model.module
world_size = get_world_size()
for name, param in model.named_parameters():
gathered_param = gather(param).reshape((world_size, -1))
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
rank_params_eq = (absolute_diffs < atol).all()
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
@jxmorris12
jxmorris12 / torch_ddp_verify.py
Last active April 19, 2024 15:54
verify parameter weights & gradients in pytorch
def verify_ddp_weights_equal(model: torch.nn.Module, atol: float = 1e-5) -> None:
if hasattr(model, "module"):
model = model.module
world_size = get_world_size()
for name, param in model.named_parameters():
gathered_param = gather(param).reshape((world_size, -1))
absolute_diffs = (gathered_param[None, 0, :] - gathered_param).abs()
rank_params_eq = (absolute_diffs < atol).all()
assert rank_params_eq, f"❌ param [{name}] not equal - got max_absolute_diff={absolute_diffs.max()}"
@Birch-san
Birch-san / llama_flash.py
Last active January 22, 2024 06:05
Loading llama with Flash Attention
from transformers import (
AutoConfig,
AutoTokenizer,
BitsAndBytesConfig,
GenerationConfig,
AutoModelForCausalLM,
LlamaTokenizerFast,
PreTrainedModel,
TextIteratorStreamer,
StoppingCriteria,
@mara004
mara004 / pypdfjs.py
Last active April 24, 2025 20:57
PDF rendering with pdf.js, from Python
# Four lines intentionally left blank
# SPDX-FileCopyrightText: 2025 geisserml <[email protected]>
# SPDX-License-Identifier: Apache-2.0 OR MPL-2.0
# See also https://github.com/extremeheat/JSPyBridge/blob/master/examples/python/pdfjs.py
@codekansas
codekansas / benchmark_self_attention.py
Last active March 11, 2023 18:34
Benchmarking script for attention
import argparse
import contextlib
import logging
import math
import random
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable