Instantly share code, notes, and snippets.
Created
October 13, 2017 20:47
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save shoaibkamil/a23a6c9804157467552cfac32c8aa087 to your computer and use it in GitHub Desktop.
Metal generated code
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <metal_stdlib> | |
using namespace metal; | |
namespace { | |
constexpr float float_from_bits(unsigned int x) {return as_type<float>(x);} | |
constexpr float nan_f32() { return as_type<float>(0x7fc00000); } | |
constexpr float neg_inf_f32() { return float_from_bits(0xff800000); } | |
constexpr float inf_f32() { return float_from_bits(0x7f800000); } | |
float fast_inverse_f32(float x) { return 1.0f / x; } | |
#define sqrt_f32 sqrt | |
#define sin_f32 sin | |
#define cos_f32 cos | |
#define exp_f32 exp | |
#define log_f32 log | |
#define abs_f32 fabs | |
#define floor_f32 floor | |
#define ceil_f32 ceil | |
#define round_f32 round | |
#define trunc_f32 trunc | |
#define pow_f32 pow | |
#define asin_f32 asin | |
#define acos_f32 acos | |
#define tan_f32 tan | |
#define atan_f32 atan | |
#define atan2_f32 atan2 | |
#define sinh_f32 sinh | |
#define asinh_f32 asinh | |
#define cosh_f32 cosh | |
#define acosh_f32 acosh | |
#define tanh_f32 tanh | |
#define atanh_f32 atanh | |
#define fast_inverse_sqrt_f32 rsqrt | |
#define halide_gpu_thread_barrier() \ | |
(threadgroup_barrier(mem_flags::mem_threadgroup), 0) | |
} | |
#define __address_space___shared threadgroup | |
// Address spaces for kernel_output_s0_y_y___block_id_y | |
#define __address_space__dst device | |
#define __address_space__mask device | |
#define __address_space__output device | |
#define __address_space__src device | |
struct kernel_output_s0_y_y___block_id_y_args { | |
int _dst_min_0; | |
int _dst_min_1; | |
int _dst_stride_1; | |
int _mask_min_0; | |
int _mask_min_1; | |
int _mask_stride_1; | |
int _output_extent_0; | |
int _output_extent_1; | |
int _output_min_0; | |
int _output_min_1; | |
int _output_stride_1; | |
int _src_min_0; | |
int _src_min_1; | |
int _src_stride_1; | |
}; | |
kernel void kernel_output_s0_y_y___block_id_y( | |
uint3 tgroup_index [[ threadgroup_position_in_grid ]], | |
uint3 tid_in_tgroup [[ thread_position_in_threadgroup ]], | |
const device kernel_output_s0_y_y___block_id_y_args *_scalar_args [[ buffer(0) ]], | |
__address_space__dst const uchar *_dst [[ buffer(1) ]], | |
__address_space__mask const uchar *_mask [[ buffer(2) ]], | |
__address_space__output uchar *_output [[ buffer(3) ]], | |
__address_space__src const uchar *_src [[ buffer(4) ]], | |
threadgroup int16_t* __shared [[ threadgroup(0) ]]) | |
{ | |
int _dst_min_0 = _scalar_args->_dst_min_0; | |
int _dst_min_1 = _scalar_args->_dst_min_1; | |
int _dst_stride_1 = _scalar_args->_dst_stride_1; | |
int _mask_min_0 = _scalar_args->_mask_min_0; | |
int _mask_min_1 = _scalar_args->_mask_min_1; | |
int _mask_stride_1 = _scalar_args->_mask_stride_1; | |
int _output_extent_0 = _scalar_args->_output_extent_0; | |
int _output_extent_1 = _scalar_args->_output_extent_1; | |
int _output_min_0 = _scalar_args->_output_min_0; | |
int _output_min_1 = _scalar_args->_output_min_1; | |
int _output_stride_1 = _scalar_args->_output_stride_1; | |
int _src_min_0 = _scalar_args->_src_min_0; | |
int _src_min_1 = _scalar_args->_src_min_1; | |
int _src_stride_1 = _scalar_args->_src_stride_1; | |
int _output_s0_y_y___block_id_y = (int)tgroup_index.y; | |
int _output_s0_x_x___block_id_x = (int)tgroup_index.x; | |
int ___thread_id_y = (int)tid_in_tgroup.y; | |
int ___thread_id_x = (int)tid_in_tgroup.x; | |
int _0 = _output_s0_y_y___block_id_y * 8; | |
int _1 = _0 + _output_min_1; | |
int _2 = _output_min_1 + _output_extent_1; | |
int _3 = _2 + -8; | |
int _4 = min(_1, _3); | |
int _5 = _output_extent_0 >> 3; | |
int _6 = max(_5, 0); | |
bool _7 = _output_s0_x_x___block_id_x < _6; | |
if (_7) | |
{ | |
int _8 = _output_s0_x_x___block_id_x * 8; | |
int _9 = _8 + _output_min_0; | |
int _10 = _9 + ___thread_id_x; | |
int _11 = _10 * 4; | |
int _12 = _4 + ___thread_id_y; | |
int _13 = _12 * _src_stride_1; | |
int _14 = _11 + _13; | |
int _15 = _src_min_0 * 4; | |
int _16 = _src_min_1 * _src_stride_1; | |
int _17 = _15 + _16; | |
int _18 = _14 - _17; | |
uchar _19 = _src[_18]; | |
int _20 = _12 * _dst_stride_1; | |
int _21 = _11 + _20; | |
int _22 = _dst_min_0 * 4; | |
int _23 = _dst_min_1 * _dst_stride_1; | |
int _24 = _22 + _23; | |
int _25 = _21 - _24; | |
uchar _26 = _dst[_25]; | |
int _27 = _12 * _mask_stride_1; | |
int _28 = _10 + _27; | |
int _29 = _mask_min_1 * _mask_stride_1; | |
int _30 = _mask_min_0 + _29; | |
int _31 = _28 - _30; | |
uchar _32 = _mask[_31]; | |
uchar _33 = (uchar)(128); | |
bool _34 = _33 < _19; | |
uchar _35 = _19 - _33; | |
short _36 = short(_35); | |
short _37 = _36 >> 6; | |
short _38 = _36 + _37; | |
uchar _39 = uchar(_38); | |
bool _40 = _26 < _39; | |
bool _41 = _34 && _40; | |
uchar _42 = (uchar)select(_26, _32, _41); | |
int _43 = _12 * _output_stride_1; | |
int _44 = _11 + _43; | |
int _45 = _output_min_0 * 4; | |
int _46 = _output_min_1 * _output_stride_1; | |
int _47 = _45 + _46; | |
int _48 = _44 - _47; | |
_output[_48] = _42; | |
int _49 = _output_s0_x_x___block_id_x * 8; | |
int _50 = _49 + _output_min_0; | |
int _51 = _50 + ___thread_id_x; | |
int _52 = _51 * 4; | |
int _53 = _4 + ___thread_id_y; | |
int _54 = _53 * _src_stride_1; | |
int _55 = _52 + _54; | |
int _56 = _src_min_0 * 4; | |
int _57 = _src_min_1 * _src_stride_1; | |
int _58 = _56 + _57; | |
int _59 = _55 - _58; | |
int _60 = _59 + 1; | |
uchar _61 = _src[_60]; | |
int _62 = _53 * _dst_stride_1; | |
int _63 = _52 + _62; | |
int _64 = _dst_min_0 * 4; | |
int _65 = _dst_min_1 * _dst_stride_1; | |
int _66 = _64 + _65; | |
int _67 = _63 - _66; | |
int _68 = _67 + 1; | |
uchar _69 = _dst[_68]; | |
int _70 = _53 * _mask_stride_1; | |
int _71 = _51 + _70; | |
int _72 = _mask_min_1 * _mask_stride_1; | |
int _73 = _mask_min_0 + _72; | |
int _74 = _71 - _73; | |
uchar _75 = _mask[_74]; | |
uchar _76 = (uchar)(128); | |
bool _77 = _76 < _61; | |
uchar _78 = _61 - _76; | |
short _79 = short(_78); | |
short _80 = _79 >> 6; | |
short _81 = _79 + _80; | |
uchar _82 = uchar(_81); | |
bool _83 = _69 < _82; | |
bool _84 = _77 && _83; | |
uchar _85 = (uchar)select(_69, _75, _84); | |
int _86 = _53 * _output_stride_1; | |
int _87 = _52 + _86; | |
int _88 = _output_min_0 * 4; | |
int _89 = _output_min_1 * _output_stride_1; | |
int _90 = _88 + _89; | |
int _91 = _87 - _90; | |
int _92 = _91 + 1; | |
_output[_92] = _85; | |
int _93 = _output_s0_x_x___block_id_x * 8; | |
int _94 = _93 + _output_min_0; | |
int _95 = _94 + ___thread_id_x; | |
int _96 = _95 * 4; | |
int _97 = _4 + ___thread_id_y; | |
int _98 = _97 * _src_stride_1; | |
int _99 = _96 + _98; | |
int _100 = _src_min_0 * 4; | |
int _101 = _src_min_1 * _src_stride_1; | |
int _102 = _100 + _101; | |
int _103 = _99 - _102; | |
int _104 = _103 + 2; | |
uchar _105 = _src[_104]; | |
int _106 = _97 * _dst_stride_1; | |
int _107 = _96 + _106; | |
int _108 = _dst_min_0 * 4; | |
int _109 = _dst_min_1 * _dst_stride_1; | |
int _110 = _108 + _109; | |
int _111 = _107 - _110; | |
int _112 = _111 + 2; | |
uchar _113 = _dst[_112]; | |
int _114 = _97 * _mask_stride_1; | |
int _115 = _95 + _114; | |
int _116 = _mask_min_1 * _mask_stride_1; | |
int _117 = _mask_min_0 + _116; | |
int _118 = _115 - _117; | |
uchar _119 = _mask[_118]; | |
uchar _120 = (uchar)(128); | |
bool _121 = _120 < _105; | |
uchar _122 = _105 - _120; | |
short _123 = short(_122); | |
short _124 = _123 >> 6; | |
short _125 = _123 + _124; | |
uchar _126 = uchar(_125); | |
bool _127 = _113 < _126; | |
bool _128 = _121 && _127; | |
uchar _129 = (uchar)select(_113, _119, _128); | |
int _130 = _97 * _output_stride_1; | |
int _131 = _96 + _130; | |
int _132 = _output_min_0 * 4; | |
int _133 = _output_min_1 * _output_stride_1; | |
int _134 = _132 + _133; | |
int _135 = _131 - _134; | |
int _136 = _135 + 2; | |
_output[_136] = _129; | |
} // if _7 | |
else | |
{ | |
int _137 = _output_min_0 + _output_extent_0; | |
int _138 = _137 + ___thread_id_x; | |
int _139 = _138 * 4; | |
int _140 = _4 + ___thread_id_y; | |
int _141 = _140 * _src_stride_1; | |
int _142 = _139 + _141; | |
int _143 = _src_min_0 * 4; | |
int _144 = _src_min_1 * _src_stride_1; | |
int _145 = _143 + _144; | |
int _146 = _142 - _145; | |
int _147 = _146 + -32; | |
uchar _148 = _src[_147]; | |
int _149 = _140 * _dst_stride_1; | |
int _150 = _139 + _149; | |
int _151 = _dst_min_0 * 4; | |
int _152 = _dst_min_1 * _dst_stride_1; | |
int _153 = _151 + _152; | |
int _154 = _150 - _153; | |
int _155 = _154 + -32; | |
uchar _156 = _dst[_155]; | |
int _157 = _140 * _mask_stride_1; | |
int _158 = _138 + _157; | |
int _159 = _mask_min_1 * _mask_stride_1; | |
int _160 = _mask_min_0 + _159; | |
int _161 = _158 - _160; | |
int _162 = _161 + -8; | |
uchar _163 = _mask[_162]; | |
uchar _164 = (uchar)(128); | |
bool _165 = _164 < _148; | |
uchar _166 = _148 - _164; | |
short _167 = short(_166); | |
short _168 = _167 >> 6; | |
short _169 = _167 + _168; | |
uchar _170 = uchar(_169); | |
bool _171 = _156 < _170; | |
bool _172 = _165 && _171; | |
uchar _173 = (uchar)select(_156, _163, _172); | |
int _174 = _140 * _output_stride_1; | |
int _175 = _139 + _174; | |
int _176 = _output_min_0 * 4; | |
int _177 = _output_min_1 * _output_stride_1; | |
int _178 = _176 + _177; | |
int _179 = _175 - _178; | |
int _180 = _179 + -32; | |
_output[_180] = _173; | |
int _181 = _output_min_0 + _output_extent_0; | |
int _182 = _181 + ___thread_id_x; | |
int _183 = _182 * 4; | |
int _184 = _4 + ___thread_id_y; | |
int _185 = _184 * _src_stride_1; | |
int _186 = _183 + _185; | |
int _187 = _src_min_0 * 4; | |
int _188 = _src_min_1 * _src_stride_1; | |
int _189 = _187 + _188; | |
int _190 = _186 - _189; | |
int _191 = _190 + -31; | |
uchar _192 = _src[_191]; | |
int _193 = _184 * _dst_stride_1; | |
int _194 = _183 + _193; | |
int _195 = _dst_min_0 * 4; | |
int _196 = _dst_min_1 * _dst_stride_1; | |
int _197 = _195 + _196; | |
int _198 = _194 - _197; | |
int _199 = _198 + -31; | |
uchar _200 = _dst[_199]; | |
int _201 = _184 * _mask_stride_1; | |
int _202 = _182 + _201; | |
int _203 = _mask_min_1 * _mask_stride_1; | |
int _204 = _mask_min_0 + _203; | |
int _205 = _202 - _204; | |
int _206 = _205 + -8; | |
uchar _207 = _mask[_206]; | |
uchar _208 = (uchar)(128); | |
bool _209 = _208 < _192; | |
uchar _210 = _192 - _208; | |
short _211 = short(_210); | |
short _212 = _211 >> 6; | |
short _213 = _211 + _212; | |
uchar _214 = uchar(_213); | |
bool _215 = _200 < _214; | |
bool _216 = _209 && _215; | |
uchar _217 = (uchar)select(_200, _207, _216); | |
int _218 = _184 * _output_stride_1; | |
int _219 = _183 + _218; | |
int _220 = _output_min_0 * 4; | |
int _221 = _output_min_1 * _output_stride_1; | |
int _222 = _220 + _221; | |
int _223 = _219 - _222; | |
int _224 = _223 + -31; | |
_output[_224] = _217; | |
int _225 = _output_min_0 + _output_extent_0; | |
int _226 = _225 + ___thread_id_x; | |
int _227 = _226 * 4; | |
int _228 = _4 + ___thread_id_y; | |
int _229 = _228 * _src_stride_1; | |
int _230 = _227 + _229; | |
int _231 = _src_min_0 * 4; | |
int _232 = _src_min_1 * _src_stride_1; | |
int _233 = _231 + _232; | |
int _234 = _230 - _233; | |
int _235 = _234 + -30; | |
uchar _236 = _src[_235]; | |
int _237 = _228 * _dst_stride_1; | |
int _238 = _227 + _237; | |
int _239 = _dst_min_0 * 4; | |
int _240 = _dst_min_1 * _dst_stride_1; | |
int _241 = _239 + _240; | |
int _242 = _238 - _241; | |
int _243 = _242 + -30; | |
uchar _244 = _dst[_243]; | |
int _245 = _228 * _mask_stride_1; | |
int _246 = _226 + _245; | |
int _247 = _mask_min_1 * _mask_stride_1; | |
int _248 = _mask_min_0 + _247; | |
int _249 = _246 - _248; | |
int _250 = _249 + -8; | |
uchar _251 = _mask[_250]; | |
uchar _252 = (uchar)(128); | |
bool _253 = _252 < _236; | |
uchar _254 = _236 - _252; | |
short _255 = short(_254); | |
short _256 = _255 >> 6; | |
short _257 = _255 + _256; | |
uchar _258 = uchar(_257); | |
bool _259 = _244 < _258; | |
bool _260 = _253 && _259; | |
uchar _261 = (uchar)select(_244, _251, _260); | |
int _262 = _228 * _output_stride_1; | |
int _263 = _227 + _262; | |
int _264 = _output_min_0 * 4; | |
int _265 = _output_min_1 * _output_stride_1; | |
int _266 = _264 + _265; | |
int _267 = _263 - _266; | |
int _268 = _267 + -30; | |
_output[_268] = _261; | |
} // if _7 else | |
} // kernel kernel_output_s0_y_y___block_id_y | |
#undef __address_space__dst | |
#undef __address_space__mask | |
#undef __address_space__output | |
#undef __address_space__src | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment