Skip to content

Instantly share code, notes, and snippets.

@kumanna
Last active April 21, 2025 13:14
Show Gist options
  • Save kumanna/dd1dd59de394bfa868f0d8a75e90b12d to your computer and use it in GitHub Desktop.
Save kumanna/dd1dd59de394bfa868f0d8a75e90b12d to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def get_dominant_paths(Y_DD):
threshold = 0.1 * np.max(np.abs(Y_DD))
M, N = Y_DD.shape
K_list = []
for l in range(M):
k = int(np.argmax(Y_DD[l, :]))
K_list.append((int(l), k))
K_list = list(set(K_list))
K_list_onlyks = {i[1] for i in K_list}
L_set = set()
for k in K_list_onlyks:
L_set = L_set.union(set(np.argwhere(np.abs(Y_DD[:,k]) > threshold).flatten()))
#L_list = list(set(L_list))
LK_set = set()
for l in L_set:
k = int(np.argmax(Y_DD[l, :]))
LK_set = LK_set.union({(l, k)})
LK_set = {(i[0], i[1]) for i in LK_set if np.abs(Y_DD[i[0], i[1]]) > threshold}
return LK_set
M = 4
N = 4
# Fixed negative sign
gen_fourier_matrix = lambda N : np.exp(-1j * 2 * np.pi / N * np.outer(np.arange(N), np.arange(N))) / np.sqrt(N)
#symbols = (np.sign(np.random.randn(M, N)) + 1j * np.sign(np.random.randn(M, N))) / np.sqrt(2)
symbols = np.zeros((M, N), dtype='complex')
symbols[0,0] = 1
FM = gen_fourier_matrix(M)
FN = gen_fourier_matrix(N)
x_DD = symbols.reshape((-1, 1)).flatten()
X_TF = FM @ symbols @ FN.T.conj()
# Without channel
Y_DD = FM.T.conj() @ X_TF @ FN
assert np.max(np.abs(symbols - Y_DD)) < 1e-8
GAMMA = lambda li : np.diag(np.exp(-1j * 2 * np.pi * li / M * np.arange(M)))
DELTA = lambda ki : np.diag(np.exp(1j * 2 * np.pi * ki / N * np.arange(N)))
channel_coeffs = [(1+0j, -1, -1), (0.4j, -2, -2)]
#channel_coeffs = [(1+0j, 0, 0)]
H_DD = np.zeros((M*N, M*N), dtype='complex')
for i in channel_coeffs:
hi, li, ki = i
H_DD = H_DD + hi * np.exp(-1j * li / M * ki / N) * np.kron(FM.T.conj() @ GAMMA(li) @ FM, FN @ DELTA(ki) @ FN.T.conj())
y_DD = H_DD @ x_DD
x_DD_hat = np.linalg.solve(H_DD, y_DD)
X_DD_hat = x_DD_hat.reshape(M, N)
assert np.max(np.abs(symbols - X_DD_hat)) < 1e-8
Y_DD = y_DD.reshape(M, N)
LK_set = get_dominant_paths(Y_DD)
print("Y_DD\n")
print("LK_set:\n", LK_set)
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")
#threshold = 0.6
#Y_masked = np.where(np.round(Y_DD, 3) > threshold, np.round(Y_DD, 3), 0)
plt.figure(figsize=(8, 6))
ax = sns.heatmap(np.round(np.abs(Y_DD),3), cmap="viridis", cbar=True)
for (l, k) in LK_set:
rect = plt.Rectangle((k, l), 1, 1, fill=False, edgecolor='red', linewidth=2)
ax.add_patch(rect)
plt.title("Y_DD Matrix (Heatmap of Values)", fontsize=14)
plt.xlabel("Doppler Index")
plt.ylabel("Delay Index")
plt.tight_layout()
plt.show()
#plt.savefig("Y_DD_dominant_paths_all_vals.png", dpi=300)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment