import torch import numpy as np mha = torch.nn.MultiheadAttention(num_heads=2, embed_dim=4, batch_first=True) mha.in_proj_weight.data = torch.zeros(12, 4) + 0.1 mha.in_proj_bias.data = torch.zeros(12) + 0.11 mha.out_proj.weight.data = torch.zeros(4, 4) + 0.1 mha.out_proj.bias.data = torch.zeros(4) + 0.11 optim = torch.optim.SGD(mha.parameters(), lr=0.01) x = torch.tensor( np.array( [0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], ).reshape(3, 4, 1, order='F').transpose(2, 0, 1), dtype=torch.float32, requires_grad=True, ) out, weights = mha(x, x, x) print('Output:', np.array(np.nditer(out.detach().numpy(), order='F'))) print('Attention Weights:', np.array(np.nditer(weights.detach().numpy(), order='F'))) gradient = torch.tensor( [.1, .1, .1, 3., 3., 3., 2., .1, 2., 3., .1, 3.], requires_grad=True, ).reshape(1, 3, 4) out.backward(gradient=gradient) print('Gradient:', np.array(np.nditer(x.grad.numpy(), order='F'))) optim.step() out, weights = mha(x, x, x) print('Output after one step (SGD):', np.array(np.nditer(out.detach(), order='F')))