Created
October 14, 2022 16:31
-
-
Save jeromekelleher/653f333d65d8fcd88ffc8a108b54f55d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import tskit | |
import numpy as np | |
import numba | |
@numba.njit | |
def _normalise(B): | |
K = np.zeros_like(B) | |
N = K.shape[0] | |
B_mean = np.mean(B) | |
# Numba doesn't support np.mean(a, axis=0) | |
Bi_mean = np.zeros(N) | |
for i in range(N): | |
for j in range(N): | |
Bi_mean[i] += B[i, j] | |
Bi_mean /= N | |
for i in range(N): | |
for j in range(N): | |
K[i, j] = B[i, j] - Bi_mean[i] - Bi_mean[j] + B_mean | |
return K | |
@numba.njit | |
def _update_B_matrix(B, area, samples, nodes_individual): | |
n = samples.shape[0] | |
for j in range(n): | |
v = samples[j] | |
V = nodes_individual[v] | |
B[V, V] += area | |
for k in range(j + 1, n): | |
w = samples[k] | |
W = nodes_individual[w] | |
B[V, W] += area | |
B[W, V] += area | |
def B_matrix_incremental(ts): | |
N = ts.num_individuals | |
B = np.zeros((N, N)) | |
last_update = np.zeros(ts.num_nodes) | |
time = ts.nodes_time | |
parent = np.zeros(ts.num_nodes, dtype=np.int32) - 1 | |
# The Descendent nodes of individuals at each node in the tree. | |
# Note: we could probably use the SampleLists in the C code to | |
# do this, but we'd have to check the sample <-> Individual mapping. | |
D = [set() for _ in range(ts.num_nodes)] | |
for ind in ts.individuals(): | |
for u in ind.nodes: | |
D[u].add(u) | |
def update_matrix(u, distance): | |
if parent[u] != -1: | |
branch_length = time[parent[u]] - time[u] | |
area = branch_length * distance | |
samples = np.array(list(D[u])) | |
_update_B_matrix(B, area, samples, ts.nodes_individual) | |
for (left, right), edges_out, edges_in in ts.edge_diffs(include_terminal=True): | |
for edge in edges_out: | |
u = edge.child | |
update_matrix(u, left - last_update[u]) | |
last_update[u] = left | |
parent[edge.child] = -1 | |
u = edge.parent | |
while u != -1: | |
update_matrix(u, left - last_update[u]) | |
last_update[u] = left | |
D[u] -= D[edge.child] | |
u = parent[u] | |
for edge in edges_in: | |
parent[edge.child] = edge.parent | |
u = edge.child | |
update_matrix(u, left - last_update[u]) | |
last_update[u] = left | |
u = edge.parent | |
while u != -1: | |
D[u] |= D[edge.child] | |
update_matrix(u, left - last_update[u]) | |
last_update[u] = left | |
u = parent[u] | |
return B | |
def branch_genetic_relatedness_matrix(ts): | |
B = B_matrix_incremental(ts) | |
return _normalise(B) | |
if __name__ == "__main__": | |
if len(sys.argv) != 3: | |
print(f"usage: {sys.argv[0]} file.trees relatedness.txt") | |
sys.exit(1) | |
ts = tskit.load(sys.argv[1]) | |
K = branch_genetic_relatedness_matrix(ts) | |
# K2 = genetic_relatedness_matrix(ts) | |
# assert np.allclose(K, K2) | |
np.savetxt(sys.argv[2], K) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment