Skip to content

Instantly share code, notes, and snippets.

@danieldk
Created May 8, 2025 07:44

Revisions

  1. danieldk created this gist May 8, 2025.
    110 changes: 110 additions & 0 deletions CMakeLists.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,110 @@
    cmake_minimum_required(VERSION 3.26)
    project(activation LANGUAGES CXX)

    set(TARGET_DEVICE "cuda" CACHE STRING "Target device backend for kernel")

    install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)

    include(FetchContent)
    file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
    message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")

    set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0")

    set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101")

    include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

    if(DEFINED Python_EXECUTABLE)
    # Allow passing through the interpreter (e.g. from setup.py).
    find_package(Python COMPONENTS Development Development.SABIModule Interpreter)
    if (NOT Python_FOUND)
    message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
    endif()
    else()
    find_package(Python REQUIRED COMPONENTS Development Development.SABIModule Interpreter)
    endif()

    append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")

    find_package(Torch REQUIRED)

    if (NOT TARGET_DEVICE STREQUAL "cuda" AND
    NOT TARGET_DEVICE STREQUAL "rocm")
    return()
    endif()

    if (NOT HIP_FOUND AND CUDA_FOUND)
    set(GPU_LANG "CUDA")
    elseif(HIP_FOUND)
    set(GPU_LANG "HIP")

    # Importing torch recognizes and sets up some HIP/ROCm configuration but does
    # not let cmake recognize .hip files. In order to get cmake to understand the
    # .hip extension automatically, HIP must be enabled explicitly.
    enable_language(HIP)
    else()
    message(FATAL_ERROR "Can't find CUDA or HIP installation.")
    endif()


    if(GPU_LANG STREQUAL "CUDA")
    clear_cuda_arches(CUDA_ARCH_FLAGS)
    extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
    message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
    # Filter the target architectures by the supported supported archs
    # since for some files we will build for all CUDA_ARCHS.
    cuda_archs_loose_intersection(CUDA_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
    message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}")

    if(NVCC_THREADS AND GPU_LANG STREQUAL "CUDA")
    list(APPEND GPU_FLAGS "--threads=${NVCC_THREADS}")
    endif()
    elseif(GPU_LANG STREQUAL "HIP")
    set(ROCM_ARCHS "${HIP_SUPPORTED_ARCHS}")
    # TODO: remove this once we can set specific archs per source file set.
    override_gpu_arches(GPU_ARCHES
    ${GPU_LANG}
    "${${GPU_LANG}_SUPPORTED_ARCHS}")
    else()
    override_gpu_arches(GPU_ARCHES
    ${GPU_LANG}
    "${${GPU_LANG}_SUPPORTED_ARCHS}")
    endif()

    get_torch_gpu_compiler_flags(TORCH_GPU_FLAGS ${GPU_LANG})
    list(APPEND GPU_FLAGS ${TORCH_GPU_FLAGS})

    set(TORCH_activation_SRC
    torch-ext/torch_binding.cpp torch-ext/torch_binding.h
    )


    list(APPEND SRC "${TORCH_activation_SRC}")

    set(activation_SRC
    "activation/activation_kernels.cu"
    "activation/cuda_compat.h"
    "activation/dispatch_utils.h"
    )


    if(GPU_LANG STREQUAL "CUDA")
    cuda_archs_loose_intersection(activation_ARCHS "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}")
    message(STATUS "Capabilities for kernel activation: ${activation_ARCHS}")
    set_gencode_flags_for_srcs(SRCS "${activation_SRC}" CUDA_ARCHS "${activation_ARCHS}")
    list(APPEND SRC "${activation_SRC}")
    endif()

    define_gpu_extension_target(
    _activation_psnp6q5y4k4wg
    DESTINATION _activation_psnp6q5y4k4wg
    LANGUAGE ${GPU_LANG}
    SOURCES ${SRC}
    COMPILE_FLAGS ${GPU_FLAGS}
    ARCHITECTURES ${GPU_ARCHES}
    #INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
    USE_SABI 3
    WITH_SOABI)

    target_link_options(_activation_psnp6q5y4k4wg PRIVATE -static-libstdc++)