Created
June 8, 2020 14:13
-
-
Save cbalint13/8829a4c5f1eb36fe8508e557eb2b4fc8 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
--- a/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp.orig 2020-06-08 17:01:47.788349629 +0300 | |
+++ b/modules/dnn/src/cuda4dnn/csl/cudnn/convolution.hpp 2020-06-08 16:50:51.579297388 +0300 | |
@@ -260,10 +260,10 @@ | |
const TensorDescriptor<T>& output) | |
{ | |
CUDA4DNN_CHECK_CUDNN( | |
- cudnnGetConvolutionForwardAlgorithm( | |
+ cudnnGetConvolutionForwardAlgorithm_v7( | |
handle.get(), | |
input.get(), filter.get(), conv.get(), output.get(), | |
- CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, | |
+ 0, | |
0, /* no memory limit */ | |
&algo | |
) | |
@@ -273,7 +273,7 @@ | |
cudnnGetConvolutionForwardWorkspaceSize( | |
handle.get(), | |
input.get(), filter.get(), conv.get(), output.get(), | |
- algo, &workspace_size | |
+ algo.algo, &workspace_size | |
) | |
); | |
} | |
@@ -281,13 +281,13 @@ | |
ConvolutionAlgorithm& operator=(const ConvolutionAlgorithm&) = default; | |
ConvolutionAlgorithm& operator=(ConvolutionAlgorithm&& other) = default; | |
- cudnnConvolutionFwdAlgo_t get() const noexcept { return algo; } | |
+ cudnnConvolutionFwdAlgo_t get() const noexcept { return algo.algo; } | |
/** number of bytes of workspace memory required by the algorithm */ | |
std::size_t get_workspace_size() const noexcept { return workspace_size; } | |
private: | |
- cudnnConvolutionFwdAlgo_t algo; | |
+ cudnnConvolutionFwdAlgoPerf_t algo; | |
std::size_t workspace_size; | |
}; | |
--- a/modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp.orig 2020-06-08 17:01:35.797383567 +0300 | |
+++ b/modules/dnn/src/cuda4dnn/csl/cudnn/transpose_convolution.hpp 2020-06-08 16:52:44.001957698 +0300 | |
@@ -36,10 +36,10 @@ | |
const TensorDescriptor<T>& output) | |
{ | |
CUDA4DNN_CHECK_CUDNN( | |
- cudnnGetConvolutionBackwardDataAlgorithm( | |
+ cudnnGetConvolutionBackwardDataAlgorithm_v7( | |
handle.get(), | |
filter.get(), input.get(), conv.get(), output.get(), | |
- CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, | |
+ 0, | |
0, /* no memory limit */ | |
&dalgo | |
) | |
@@ -49,7 +49,7 @@ | |
cudnnGetConvolutionBackwardDataWorkspaceSize( | |
handle.get(), | |
filter.get(), input.get(), conv.get(), output.get(), | |
- dalgo, &workspace_size | |
+ dalgo.algo, &workspace_size | |
) | |
); | |
} | |
@@ -57,12 +57,12 @@ | |
TransposeConvolutionAlgorithm& operator=(const TransposeConvolutionAlgorithm&) = default; | |
TransposeConvolutionAlgorithm& operator=(TransposeConvolutionAlgorithm&& other) = default; | |
- cudnnConvolutionBwdDataAlgo_t get() const noexcept { return dalgo; } | |
+ cudnnConvolutionBwdDataAlgo_t get() const noexcept { return dalgo.algo; } | |
std::size_t get_workspace_size() const noexcept { return workspace_size; } | |
private: | |
- cudnnConvolutionBwdDataAlgo_t dalgo; | |
+ cudnnConvolutionBwdDataAlgoPerf_t dalgo; | |
std::size_t workspace_size; | |
}; | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment