Skip to content

Instantly share code, notes, and snippets.

@ZJUGuoShuai
Created January 2, 2025 07:52
Show Gist options
  • Save ZJUGuoShuai/48d21b675d2801927b8f1ad89cb82109 to your computer and use it in GitHub Desktop.
Save ZJUGuoShuai/48d21b675d2801927b8f1ad89cb82109 to your computer and use it in GitHub Desktop.
GEMM_by_Conv2D
import torch
import torch.nn.functional as F
def matrix_mul_conv2d(A, B):
"""
A: [M, K]
B: [K, N]
返回: [M, N]
"""
# 1. 重塑输入矩阵 A
# 我们将 A [M, K] 重塑为 [1, K, M, 1]
# - batch_size = 1
# - C_in = K (输入通道数等于第一个矩阵的列数)
# - H = M (高度等于第一个矩阵的行数)
# - W = 1
x = A.t().reshape(1, K, M, 1)
# 2. 重塑权重矩阵 B
# 我们将 B [K, N] 重塑为 [N, K, 1, 1]
# - C_out = N (输出通道数等于第二个矩阵的列数)
# - C_in = K (输入通道数等于第二个矩阵的行数)
# - kernel_size = (1, 1)
weight = B.t().reshape(N, K, 1, 1)
# 3. 执行卷积运算
# 输出形状将是 [1, N, M, 1]
output = F.conv2d(x, weight)
# 4. 重塑输出为所需的 [M, N] 形状
return output.squeeze().t()
# 测试代码
if __name__ == "__main__":
# 创建示例维度
M, K, N = 3, 4, 2
# 创建随机矩阵
A = torch.randn(M, K) # [3, 4]
B = torch.randn(K, N) # [4, 2]
# 使用常规矩阵乘法
C1 = torch.matmul(A, B) # [3, 2]
# 使用卷积实现的矩阵乘法
C2 = matrix_mul_conv2d(A, B)
# 验证结果
print("最大误差:", torch.max(torch.abs(C1 - C2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment