Skip to content

Instantly share code, notes, and snippets.

@knwng
Created January 13, 2025 22:08
Show Gist options
  • Save knwng/d3c37535a8afbdba4332bb9e76b9c275 to your computer and use it in GitHub Desktop.
Save knwng/d3c37535a8afbdba4332bb9e76b9c275 to your computer and use it in GitHub Desktop.
import triton
import triton.language as tl
import triton.compiler as tc
from triton.backends.compiler import GPUTarget
@triton.jit
def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr):
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(src + idxs)
# y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding)
y = x.to(dst.dtype.element_ty)
tl.store(dst + idxs, y)
src = tc.ASTSource(
fn=type_convert_triton,
constexprs={"BLOCK_SIZE": 1024, "rounding": 'rtne'},
# signature="*fp16,*fp8e4nv,constexpr,constexpr",
signature="*fp8e4nv,*fp16,constexpr,constexpr",
)
ret = triton.compile(src, target=GPUTarget("hip", 'gfx942', 64))
# for ir in ('ttir', 'ttgir', 'llir', 'amdgcn', 'hsaco'):
for ir in ('ttir', 'ttgir', 'llir', 'amdgcn'):
if ir not in ret.asm:
continue
print(f'AMD IR: {ir}')
print(ret.asm[ir])
print('\n')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment