Created
December 14, 2018 17:54
-
-
Save fjsj/9c9f7f36cfd3205343e333d86778433c to your computer and use it in GitHub Desktop.
Python fast sorted list intersection
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bisect | |
def bisect_index(arr, start, end, x): | |
i = bisect.bisect_left(arr, x, lo=start, hi=end) | |
if i != end and arr[i] == x: | |
return i | |
return -1 | |
def exponential_search(arr, start, x): | |
if x == arr[start]: | |
return 0 | |
i = start + 1 | |
while i < len(arr) and arr[i] <= x: | |
i = i * 2 | |
return bisect_index(arr, i // 2, min(i, len(arr)), x) | |
def compute_intersection_list(l1, l2): | |
# find B, the smaller list | |
B = l1 if len(l1) < len(l2) else l2 | |
A = l2 if l1 is B else l1 | |
# run the algorithm described at: | |
# https://stackoverflow.com/a/40538162/145349 | |
i = 0 | |
j = 0 | |
intersection_list = [] | |
for i, x in enumerate(B): | |
j = exponential_search(A, j, x) | |
if j != -1: | |
intersection_list.append(x) | |
else: | |
j += 1 | |
return intersection_list | |
# test | |
l1 = [1, 3, 4, 6, 7, 8, 9, 10] | |
l2 = [0, 2, 3, 6, 7, 9] | |
assert compute_intersection_list(l1, l2) == sorted(set(l1) & set(l2)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice, thanks!
Leaving a reference here, for numpy arrays, especially sorted ones, intersect1d has a really fast implementation: https://numpy.org/doc/stable/reference/generated/numpy.intersect1d.html
It's simply:
https://github.com/numpy/numpy/blob/92ebe1e9a6aeb47a881a1226b08218175776f9ea/numpy/lib/arraysetops.py#L429-L430