Skip to content

Instantly share code, notes, and snippets.

@shauray8
Created April 16, 2025 12:10
Show Gist options
  • Save shauray8/212d8cbf0db5f0c83b0469ddcf716681 to your computer and use it in GitHub Desktop.
Save shauray8/212d8cbf0db5f0c83b0469ddcf716681 to your computer and use it in GitHub Desktop.
# AOT ID: ['0_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (4096, 4096), (4096, 1))
assert_size_stride(arg1_1, (4096, 4096), (4096, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((4096, 4096), (4096, 1), torch.float16)
# Topologically Sorted Source Nodes: [mm], Original ATen: [aten.mm]
extern_kernels.mm(arg1_1, arg0_1, out=buf0)
del arg0_1
del arg1_1
return (buf0, )
def benchmark_compiled_module(times=10, repeat=10):
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
arg0_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.float16)
arg1_1 = rand_strided((4096, 4096), (4096, 1), device='cuda:0', dtype=torch.float16)
fn = lambda: call([arg0_1, arg1_1])
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
compiled_module_main('None', benchmark_compiled_module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment