Created
June 2, 2025 20:30
-
-
Save unixpickle/350f821297a2aa4cb1b04768fc590db1 to your computer and use it in GitHub Desktop.
Linear programming matching
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from scipy.optimize import linear_sum_assignment, linprog | |
def main(): | |
n_verts = 20 | |
weights = np.random.uniform(low=0.01, high=1.0, size=(n_verts, n_verts)) | |
print("Using linear programming solution...") | |
soln = solve_matching_linprog(weights) | |
assert len(soln) == len(set(soln)) | |
print("Total cost:", weights[range(n_verts), soln].sum()) | |
print("Using linear_sum_assignment()...") | |
row_ind, col_ind = linear_sum_assignment(weights) | |
assert list(row_ind) == list(range(n_verts)) | |
print("Total cost:", weights[row_ind, col_ind].sum()) | |
def solve_matching_linprog(weights: np.ndarray) -> np.ndarray: | |
n_verts = weights.shape[0] | |
assert weights.shape == (n_verts,) * 2 | |
# Each row of this matrix corresponds to a vertex, and each column | |
# corresponds to an edge. Each edge is 1 for every vertex it touches. | |
constraint_mat = np.zeros((n_verts * 2, n_verts**2)) | |
constraint_bound = np.ones(n_verts * 2) | |
# Outgoing edges from each left vertex | |
for v_idx in range(n_verts): | |
constraint_mat[v_idx, v_idx * n_verts : (1 + v_idx) * n_verts] = 1.0 | |
# Incoming edges from each right vertex | |
for v_idx in range(n_verts): | |
constraint_mat[n_verts + v_idx, v_idx::n_verts] = 1.0 | |
# We want to minimize coefficients*x, where A*x <= 1, so that we get a | |
# valid matching. The solution is always full of 0s and 1s, according to | |
# https://nvlpubs.nist.gov/nistpubs/jres/69B/jresv69Bn1-2p125_A1b.pdf | |
coefficients = weights.flatten() | |
solution = linprog( | |
# Make sure all coefficients are negative so that the all zero solution is not useful | |
coefficients - 1, | |
A_ub=constraint_mat, | |
b_ub=constraint_bound, | |
).x | |
return solution.reshape([n_verts, n_verts]).argmax(1) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment