Created
January 19, 2024 17:18
-
-
Save chengscott/a2ad964e836aeaa14ff688852207bea5 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
static size_t GLOBAL_WORKSPACE_SIZE_DeviceConvFwdInstance_0 = 0; | |
#include <cstdio> | |
#include <stdexcept> | |
#include "cutlass/cutlass.h" | |
#include "cutlass/conv/kernel/default_conv2d_fprop.h" | |
#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" | |
#include "cutlass/conv/device/implicit_gemm_convolution.h" | |
#include "cutlass/util/host_tensor.h" | |
#include "cutlass/util/reference/host/tensor_fill.h" | |
#include <cutlass/epilogue/thread/linear_combination_bias_relu.h> | |
#include <cutlass/epilogue/thread/linear_combination_hardswish.h> | |
#define CUTLASS_CHECK(status) \ | |
{ \ | |
cutlass::Status error = status; \ | |
if (error != cutlass::Status::kSuccess) { \ | |
static char msg[2048]; \ | |
snprintf(msg, sizeof(msg), "[%s] Got cutlass error: %s at: %s", \ | |
__FILE__, cutlassGetStatusString(error), __LINE__); \ | |
fprintf(stderr, msg); \ | |
throw std::runtime_error(msg); \ | |
} \ | |
} | |
// Conv2dFprop Fixed_channels kernel instance "cutlass_tensorop_h884fprop_fixed_channels_256x128_32x3_nhwc_align4" | |
using cutlass_tensorop_h884fprop_fixed_channels_256x128_32x3_nhwc_align4_base = | |
typename cutlass::conv::kernel::DefaultConv2dFprop< | |
cutlass::half_t, | |
cutlass::layout::TensorNHWC, | |
cutlass::half_t, | |
cutlass::layout::TensorNHWC, | |
cutlass::half_t, | |
cutlass::layout::TensorNHWC, | |
cutlass::half_t, | |
cutlass::arch::OpClassTensorOp, | |
cutlass::arch::Sm70, | |
cutlass::gemm::GemmShape<256, 128, 32>, | |
cutlass::gemm::GemmShape<64, 64, 32 >, | |
cutlass::gemm::GemmShape<8, 8, 4>, | |
cutlass::epilogue::thread::LinearCombinationRelu< | |
cutlass::half_t, | |
8, | |
cutlass::half_t, | |
cutlass::half_t | |
>, | |
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, | |
3, | |
cutlass::arch::OpMultiplyAdd, | |
cutlass::conv::IteratorAlgorithm::kFixedChannels, | |
cutlass::conv::StrideSupport::kStrided, | |
4, | |
4 | |
>::Kernel; | |
using DeviceConvFwdInstance_0 = cutlass::conv::device::ImplicitGemmConvolution<cutlass_tensorop_h884fprop_fixed_channels_256x128_32x3_nhwc_align4_base>; | |
void conv2d_bias_relu_few_channels_cutlass_h884fprop_fixed_channels_256x128_32x3_nhwc_align_4_8 ( | |
void* in_ptr, | |
void* weight_ptr, | |
void* out_ptr, | |
void* bias_ptr, | |
uint8_t* workspace, | |
int64_t* batch, | |
int64_t* out_ch, | |
int64_t* in_ch, | |
int64_t* kernel_h, | |
int64_t* kernel_w, | |
int64_t* in_h, | |
int64_t* in_w, | |
int64_t* out_batch, | |
int64_t* out_h, | |
int64_t* out_w, | |
int strideh, | |
int dilationh, | |
int padh, | |
int stridew, | |
int dilationw, | |
int padw, | |
cudaStream_t stream | |
) { | |
int i32_batch = *batch; | |
int i32_in_h = *in_h; | |
int i32_in_w = *in_w; | |
int i32_in_ch = *in_ch; | |
int i32_out_ch = *out_ch; | |
int i32_kernel_h = *kernel_h; | |
int i32_kernel_w = *kernel_w; | |
int i32_out_batch = *out_batch; | |
int i32_out_h = *out_h; | |
int i32_out_w = *out_w; | |
using cutlass::layout::TensorNHWC; | |
TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(i32_batch, i32_in_h, i32_in_w, i32_in_ch))); | |
TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch))); | |
TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(i32_out_batch, i32_out_h, i32_out_w, i32_out_ch))); | |
cutlass::conv::Conv2dProblemSize problem_size( | |
{i32_batch, i32_in_h, i32_in_w, i32_in_ch}, // cutlass::Tensor4DCoord input_size | |
{i32_out_ch, i32_kernel_h, i32_kernel_w, i32_in_ch}, // cutlass::Tensor4DCoord filter_size | |
{padh, padh, padw, padw}, // cutlass::Tensor4DCoord padding | |
{strideh, stridew}, // cutlass::MatrixCoord stride | |
{dilationh, dilationw}, // cutlass::MatrixCoord dilation | |
{i32_out_batch, i32_out_h, i32_out_w, i32_out_ch}, // cutlass::Tensor4DCoord output_size | |
cutlass::conv::Mode::kCrossCorrelation, // cutlass::conv::Mode mode | |
1 // int split_k_slices | |
); | |
using ElementComputeEpilogue = typename DeviceConvFwdInstance_0::ElementCompute; | |
// TODO: cast to right dtype | |
typename DeviceConvFwdInstance_0::Arguments arguments{ | |
problem_size, // ConvProblemSize const & problem_size | |
{static_cast<cutlass::half_t*>(in_ptr), layout_A}, // TensorRefA const & ref_A | |
{static_cast<cutlass::half_t*>(weight_ptr), layout_B}, // TensorRefA const & ref_B | |
{static_cast<cutlass::half_t*>(bias_ptr), cutlass::layout::TensorNHWC::Stride(0)}, // TensorRefC const & ref_C | |
{static_cast<cutlass::half_t*>(out_ptr), layout_C}, // TensorRefC const & ref_D | |
{ElementComputeEpilogue(1), ElementComputeEpilogue(1)}, // typename EpilogueOutputOp::Params const & output_op | |
}; | |
DeviceConvFwdInstance_0 conv_op; | |
size_t workspace_size = conv_op.get_workspace_size(arguments); | |
cutlass::device_memory::allocation<uint8_t> local_workspace(workspace_size); | |
workspace = local_workspace.get(); | |
GLOBAL_WORKSPACE_SIZE_DeviceConvFwdInstance_0 = workspace_size; | |
auto status = conv_op.can_implement(arguments); | |
CUTLASS_CHECK(status); | |
status = conv_op.initialize(arguments, workspace); | |
CUTLASS_CHECK(status); | |
status = conv_op(stream); | |
CUTLASS_CHECK(status); | |
return; | |
throw std::runtime_error( | |
"Unsupported workload for this conv2d specialization." | |
); | |
} | |
int benchmark_conv2d_bias_relu_few_channels_cutlass_h884fprop_fixed_channels_256x128_32x3_nhwc_align_4_8 ( | |
float* runtime, | |
size_t* workspace_size, | |
int64_t NI, | |
int64_t HI, | |
int64_t WI, | |
int64_t CI, | |
int64_t CO, | |
int64_t KH, | |
int64_t KW, | |
int64_t NO, | |
int64_t HO, | |
int64_t WO, | |
int strideh, | |
int dilationh, | |
int padh, | |
int stridew, | |
int dilationw, | |
int padw, | |
uint8_t* global_workspace_, | |
cudaStream_t stream | |
) { | |
using ElementInputA = typename DeviceConvFwdInstance_0::ElementA; | |
using ElementInputB = typename DeviceConvFwdInstance_0::ElementB; | |
using ElementOutput = typename DeviceConvFwdInstance_0::ElementC; | |
cutlass::HostTensor<ElementInputA, typename DeviceConvFwdInstance_0::LayoutA> x({NI, HI, WI, CI}); | |
cutlass::HostTensor<ElementInputB, typename DeviceConvFwdInstance_0::LayoutB> w({CO, KH, KW, CI}); | |
cutlass::HostTensor<ElementInputB, typename DeviceConvFwdInstance_0::LayoutB> b({(int)CO, 1, 1, 1}); | |
cutlass::HostTensor<ElementOutput, typename DeviceConvFwdInstance_0::LayoutC> y({NO, HO, WO, CO}); | |
// warmup | |
conv2d_bias_relu_few_channels_cutlass_h884fprop_fixed_channels_256x128_32x3_nhwc_align_4_8( | |
x.device_data(), | |
w.device_data(), | |
y.device_data(), | |
b.device_data(), | |
global_workspace_, | |
&NI, | |
&CO, | |
&CI, | |
&KH, | |
&KW, | |
&HI, | |
&WI, | |
&NO, | |
&HO, | |
&WO, | |
strideh, | |
dilationh, | |
padh, | |
stridew, | |
dilationw, | |
padw, | |
stream | |
); | |
cudaEvent_t events[2]; | |
for (auto & event : events) { | |
cudaEventCreate(&event); | |
} | |
cudaEventRecord(events[0], stream); | |
for (int i = 0; i < 5; ++i) { | |
conv2d_bias_relu_few_channels_cutlass_h884fprop_fixed_channels_256x128_32x3_nhwc_align_4_8( | |
x.device_data(), | |
w.device_data(), | |
y.device_data(), | |
b.device_data(), | |
global_workspace_, | |
&NI, | |
&CO, | |
&CI, | |
&KH, | |
&KW, | |
&HI, | |
&WI, | |
&NO, | |
&HO, | |
&WO, | |
strideh, | |
dilationh, | |
padh, | |
stridew, | |
dilationw, | |
padw, | |
stream | |
); | |
} | |
cudaEventRecord(events[1], stream); | |
cudaEventSynchronize(events[1]); | |
float runtime_ms = 0; | |
cudaEventElapsedTime(&runtime_ms, events[0], events[1]); | |
for (auto event : events) { | |
(void)cudaEventDestroy(event); | |
} | |
// TODO: output workspace | |
if (runtime_ms < 0.00001) { | |
throw std::runtime_error( | |
"OOB in cutlass." | |
); | |
} | |
*runtime = runtime_ms; | |
*workspace_size = GLOBAL_WORKSPACE_SIZE_DeviceConvFwdInstance_0; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment