Skip to content

Instantly share code, notes, and snippets.

@jackhftang
Last active May 17, 2025 07:24
Show Gist options
  • Save jackhftang/f3f64db84bc8d484c423f779022adfdb to your computer and use it in GitHub Desktop.
Save jackhftang/f3f64db84bc8d484c423f779022adfdb to your computer and use it in GitHub Desktop.
# https://arxiv.org/abs/2503.16306
from collections import Counter, defaultdict
from itertools import product, repeat
from math import prod
a = [1,1,4,4,5,6]
b = [0,1,2,6,6,6]
def p(n):
# prob that P(sum_a < sum_b) > P(sum_a > sum_b) after n times
# (value, prob) for each player
pa = [(x,cnt/6) for (x,cnt) in Counter(a).items()]
pb = [(x,cnt/6) for (x,cnt) in Counter(b).items()]
# O(n^2)
# (sum_of_value, prob) after n times
sa = [(0,1)]
for _ in range(n):
d = defaultdict(float)
for (x1,p1) in sa:
for (x2,p2) in pa:
d[x1+x2] += p1*p2
sa = [*d.items()]
sb = [(0,1)]
for _ in range(n):
d = defaultdict(float)
for (x1,p1) in sb:
for (x2,p2) in pb:
d[x1+x2] += p1*p2
sb = [*d.items()]
sa.sort()
sb.sort()
# O(6*n)
win_a,i,acc=0,0,0
for j in range(len(sa)):
while i < len(sb) and sb[i][0] < sa[j][0]:
acc += sb[i][1]
i += 1
win_a += acc * sa[j][1]
win_b,i,acc=0,0,0
for j in range(len(sb)):
while i < len(sa) and sa[i][0] < sb[j][0]:
acc += sa[i][1]
i += 1
win_b += acc * sb[j][1]
return (win_a, win_b)
for n in range(11):
wa,wb = p(n)
rel = "<" if wa < wb else "=" if wa == wb else ">"
print(f"n={n}: P(a_win)={wa:.8f} {rel} P(b_win)={wb:.8f} ")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment