Created
January 13, 2025 22:08
-
-
Save knwng/d3c37535a8afbdba4332bb9e76b9c275 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import triton | |
import triton.language as tl | |
import triton.compiler as tc | |
from triton.backends.compiler import GPUTarget | |
@triton.jit | |
def type_convert_triton(src, dst, rounding : tl.constexpr, BLOCK_SIZE : tl.constexpr): | |
idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
x = tl.load(src + idxs) | |
# y = x.to(dst.dtype.element_ty, fp_downcast_rounding=rounding) | |
y = x.to(dst.dtype.element_ty) | |
tl.store(dst + idxs, y) | |
src = tc.ASTSource( | |
fn=type_convert_triton, | |
constexprs={"BLOCK_SIZE": 1024, "rounding": 'rtne'}, | |
# signature="*fp16,*fp8e4nv,constexpr,constexpr", | |
signature="*fp8e4nv,*fp16,constexpr,constexpr", | |
) | |
ret = triton.compile(src, target=GPUTarget("hip", 'gfx942', 64)) | |
# for ir in ('ttir', 'ttgir', 'llir', 'amdgcn', 'hsaco'): | |
for ir in ('ttir', 'ttgir', 'llir', 'amdgcn'): | |
if ir not in ret.asm: | |
continue | |
print(f'AMD IR: {ir}') | |
print(ret.asm[ir]) | |
print('\n') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment