Created
January 2, 2025 07:52
-
-
Save ZJUGuoShuai/48d21b675d2801927b8f1ad89cb82109 to your computer and use it in GitHub Desktop.
GEMM_by_Conv2D
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def matrix_mul_conv2d(A, B): | |
""" | |
A: [M, K] | |
B: [K, N] | |
返回: [M, N] | |
""" | |
# 1. 重塑输入矩阵 A | |
# 我们将 A [M, K] 重塑为 [1, K, M, 1] | |
# - batch_size = 1 | |
# - C_in = K (输入通道数等于第一个矩阵的列数) | |
# - H = M (高度等于第一个矩阵的行数) | |
# - W = 1 | |
x = A.t().reshape(1, K, M, 1) | |
# 2. 重塑权重矩阵 B | |
# 我们将 B [K, N] 重塑为 [N, K, 1, 1] | |
# - C_out = N (输出通道数等于第二个矩阵的列数) | |
# - C_in = K (输入通道数等于第二个矩阵的行数) | |
# - kernel_size = (1, 1) | |
weight = B.t().reshape(N, K, 1, 1) | |
# 3. 执行卷积运算 | |
# 输出形状将是 [1, N, M, 1] | |
output = F.conv2d(x, weight) | |
# 4. 重塑输出为所需的 [M, N] 形状 | |
return output.squeeze().t() | |
# 测试代码 | |
if __name__ == "__main__": | |
# 创建示例维度 | |
M, K, N = 3, 4, 2 | |
# 创建随机矩阵 | |
A = torch.randn(M, K) # [3, 4] | |
B = torch.randn(K, N) # [4, 2] | |
# 使用常规矩阵乘法 | |
C1 = torch.matmul(A, B) # [3, 2] | |
# 使用卷积实现的矩阵乘法 | |
C2 = matrix_mul_conv2d(A, B) | |
# 验证结果 | |
print("最大误差:", torch.max(torch.abs(C1 - C2))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment