This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import math | |
| def warmup_stable_decay(*, W: int, S: int, D: int, min_lr_scale_factor: float = 0.1): | |
| """ | |
| Returns a lambda function for PyTorch's LambdaLR scheduler implementing the | |
| WSD learning rate schedule. | |
| Parameters: | |
| - W: The last step of the warmup phase. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Callable, Protocol | |
| import torch | |
| from torch import Tensor | |
| from torch.nn import Linear, Module | |
| from torch.nn.functional import silu | |
| def compute_frequencies( | |
| *, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| from tqdm import tqdm | |
| from torch.nn import Module | |
| from torch.nn.functional import cross_entropy | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModelForCausalLM, | |
| ) | |
| BATCH = 16 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| class GradientAccumulationSchedule: | |
| """ | |
| A schedule that linearly increases the number of gradient accumulation | |
| steps throughout training to converge faster. | |
| """ | |
| def __init__(self, *, min: int, max: int, steps: int, factor: int | None = None): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "architectures": ["Qwen3ForCausalLM"], | |
| "attention_bias": false, | |
| "attention_dropout": 0.0, | |
| "bos_token_id": 151643, | |
| "eos_token_id": 151645, | |
| "head_dim": 128, | |
| "hidden_act": "silu", | |
| "hidden_size": 1024, | |
| "initializer_range": 0.02, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from typing import Any, Dict, List | |
| import pyarrow.parquet as pq | |
| class ParquetReader: | |
| def __init__(self, file: str, batch_size: int = 256): | |
| self.fp = pq.ParquetFile(file) | |
| self.num_rows = self.fp.metadata.num_rows | |
| self.num_rows_read = 0 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import random | |
| YEAR_SHIFT = 32 | |
| def encode(uid, year): | |
| return (year << YEAR_SHIFT) | uid | |
| def decode(docid): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from random import choices, randint | |
| from typing import Any, Callable, Dict, Generic, List, TypeVar | |
| import torch | |
| from torch import Tensor | |
| T = TypeVar("T") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class AttentionMask: | |
| """ | |
| A (Batch, 1, Queries, Keys & Values) attention mask for attention between queries and keys/values. | |
| The mask is "additive" or "inversed" meaning it is a tensor of floating point values | |
| that can be added to the attention scores before the softmax operation. | |
| >>> 0 = Unmasked | |
| >>> dtype.min = Masked | |
| """ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from dataclasses import dataclass | |
| from typing import List, Iterator | |
| @dataclass | |
| class Sequence: | |
| """Contains a single token sequence""" | |
| x: List[int] | |
| y: List[int] |
NewerOlder