This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.distributed.tensor.placement_types import Replicate, Shard | |
import torch.distributed as dist | |
from torch.distributed.device_mesh import init_device_mesh | |
from torch.distributed.tensor import DTensor | |
from torch.distributed.tensor.parallel import parallelize_module | |
def dist_print(*args, **kwargs): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.cpp_extension import _get_cuda_arch_flags | |
def test_fix(): | |
print("Testing CUDA arch flags fix...") | |
user_arch_flags = ['-gencode=arch=compute_86,code=sm_86'] | |
result = _get_cuda_arch_flags(user_arch_flags) | |
print(f"User provided: {user_arch_flags}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Stop all GPU monitoring services that block ncu | |
sudo systemctl stop nvidia-dcgm.service dynologd.service | |
# Verify they're stopped | |
sudo systemctl list-units --state=active | grep -E "(nvidia|dynolog)" | |
# Check GPU is clear | |
sudo lsof /dev/nvidia7 | grep -v python | |
# Now run ncu |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Limitations | |
1. Cannot do heavy templating, cannot use thrust for reductions | |
2. Cannot import any host includes | |
Thank you @malfet! | |
""" | |
import ctypes | |
import torch |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(pt) ➜ examples git:(msaroufim/noheader) ✗ python tensor_base_example.py | |
Clearing existing build directory: /home/marksaroufim/pytorch/examples/custom_extension_build | |
Created build directory: /home/marksaroufim/pytorch/examples/custom_extension_build | |
Compiling TensorBase CUDA extension with no_header=True... | |
Using /home/marksaroufim/.cache/torch_extensions/py310_cu124 as PyTorch extensions root... | |
Detected CUDA files, patching ldflags | |
Emitting ninja build file /home/marksaroufim/.cache/torch_extensions/py310_cu124/tensor_base_example/build.ninja... | |
Building extension module tensor_base_example... | |
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N) | |
[1/2] /home/marksaroufim/.conda/envs/pt/bin/nvcc --generate-dependencies-with-compile --dependency-output cuda.cuda.o.d -ccbin /home/marksaroufim/.conda/envs/pt/bin/x86_64-conda-linux-gnu-cc -DTORCH_EXTENSION_NAME=tensor_base_example -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYB |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Minimal example that: | |
- Only includes <ATen/core/Tensor.h> (for at::Tensor) | |
and <ATen/Functions.h> (for at::empty). | |
- Avoids <torch/extension.h> or <torch/types.h>. | |
- Uses <torch/csrc/utils/pybind.h> so PyBind can cast torch.Tensor <-> at::Tensor. | |
- Demonstrates a custom CUDA kernel that adds x + y + 1. | |
- Uses no_implicit_headers=True to reduce compile overhead. | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Example showing how to use the no_implicit_headers mode with a TensorBase CUDA extension | |
This example creates a CUDA extension that directly includes ATen/core/TensorBase.h | |
instead of torch/extension.h or types.h, resulting in faster compilation | |
""" | |
from datetime import datetime | |
import torch | |
import torch.utils.cpp_extension | |
import shutil |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils.cpp_extension import load_inline | |
cpp_code = """ | |
torch::Tensor to_gray(torch::Tensor input); | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("to_gray", &to_gray, "Convert RGB to Grayscale (CUDA)"); | |
} | |
""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Version 1: Using if-condition | |
__global__ void kernel_if(float* out, const float* in, int n) { | |
int idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (idx < n) { | |
if (in[idx] > 0.0f) { | |
out[idx] = in[idx] * 2.0f; | |
} else { | |
out[idx] = in[idx] / 2.0f; | |
} | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from pathlib import Path | |
from typing import Set, Dict | |
from collections import defaultdict | |
def analyze_imports(file_path: str) -> Dict[str, Set[str]]: | |
""" | |
Analyze Python file imports and return a dictionary of package dependencies. | |
Args: |
NewerOlder