|
import os |
|
import subprocess |
|
from functools import partial |
|
|
|
import pyMSVC |
|
|
|
environment = pyMSVC.Environment() |
|
print(environment) |
|
|
|
os.environ.update(environment) # should add msbuild's cl.exe to PATH |
|
|
|
# check that visual studio compiler is in the path |
|
subprocess.check_call('cl', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
|
|
# add cuda binaries to path |
|
program_files = os.environ.get("ProgramFiles", r"C:\Program Files") |
|
os.environ["PATH"] += os.pathsep + os.path.join(program_files, r"NVIDIA GPU Computing Toolkit\CUDA\v12.6\bin") |
|
subprocess.check_call('ptxas --version', shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
|
|
# next is taken and adapted from https://github.com/pytorch-labs/attention-gym/tree/54755782172e89045bc908365f18ab75ae685708 |
|
|
|
import torch |
|
from torch.nn.attention.flex_attention import ( |
|
create_block_mask, |
|
flex_attention, |
|
) |
|
|
|
flex_attention = torch.compile(flex_attention, dynamic=False) |
|
|
|
|
|
# Tanh Soft-Capping |
|
@torch.library.custom_op("approx::tanh", mutates_args=()) |
|
def tanh_approx(inp: torch.Tensor) -> torch.Tensor: |
|
return torch.tanh(inp) |
|
|
|
|
|
@tanh_approx.register_fake |
|
def _(inp: torch.Tensor) -> torch.Tensor: |
|
return torch.tanh(inp) |
|
|
|
|
|
from torch._inductor.lowering import make_pointwise, register_lowering |
|
|
|
# Some internal torch.compile details |
|
from torch._inductor.virtualized import ops |
|
|
|
|
|
def tanh_approx_lowering(inp): |
|
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32;") |
|
return make_pointwise(fn)(inp) |
|
|
|
|
|
register_lowering(torch.ops.approx.tanh)(tanh_approx_lowering) |
|
|
|
|
|
class TanhApprox(torch.autograd.Function): |
|
generate_vmap_rule = True |
|
|
|
@staticmethod |
|
def forward(x): |
|
return torch.ops.approx.tanh(x) |
|
|
|
@staticmethod |
|
def setup_context(ctx, inputs, output): |
|
(x,) = inputs |
|
result = output |
|
ctx.save_for_backward(result) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
(result,) = ctx.saved_tensors |
|
return grad_output * (1 - result * result) |
|
|
|
|
|
tanh_approx = TanhApprox.apply |
|
|
|
|
|
def tanh_soft_cap(score, b, h, q_idx, kv_idx): |
|
score = score / 2 |
|
score = tanh_approx(score) |
|
return score * 2 |
|
|
|
|
|
def causal_mask_but_look_4_back(b, h, q_idx, kv_idx): |
|
return (q_idx >= kv_idx) | (kv_idx - q_idx < 4) |
|
|
|
|
|
NUM_HEADS = 2 |
|
MAX_SEQ_LEN = 32 |
|
|
|
|
|
def alibi_plus_tanh_score(score, b, h, q_idx, kv_idx): |
|
bias = (q_idx - kv_idx) |
|
scale = torch.exp2(-((h + 1) * 8.0 / NUM_HEADS)) # static scale |
|
return torch.where(torch.isfinite(score), tanh_approx(score + bias * scale) * 2, -float("inf")) |
|
|
|
|
|
block_mask = create_block_mask(causal_mask_but_look_4_back, 1, NUM_HEADS, MAX_SEQ_LEN, MAX_SEQ_LEN, device="cuda") |
|
query = torch.randn(1, NUM_HEADS, MAX_SEQ_LEN, 64, device="cuda", dtype=torch.float16) |
|
query[:, :, -8:, :] = -float("inf") |
|
key = torch.randn(1, NUM_HEADS, MAX_SEQ_LEN, 64, device="cuda", dtype=torch.float16) |
|
value = key # self attention |
|
output, lse = flex_attention(query, key, value, block_mask=block_mask, score_mod=torch.compile(alibi_plus_tanh_score), |
|
return_lse=True) |
|
|
|
print(lse) |
|
print(output) |
|
print(output.size()) |