Skip to content

Instantly share code, notes, and snippets.

@marthinwurer
Created September 10, 2023 21:40
Show Gist options
  • Save marthinwurer/9de97803378537d4b4907ebfe81f187f to your computer and use it in GitHub Desktop.
Save marthinwurer/9de97803378537d4b4907ebfe81f187f to your computer and use it in GitHub Desktop.
my solve for Katherine Johnson for chctf
#
# This might not actually run as is because it's copy-pasted from a jupyter notebook
#
import socket
import string
import random
from itertools import *
import time
import math
from collections import defaultdict
hostname = "0.cloud.chals.io"
port = 21838
with open("initials.txt", "r") as f:
id = f.readlines()
initials = [i.strip() for i in id if len(i)]
def encrypt(key, message):
return "".join(key.get(x, '') for x in message)
def norm_dict(d, s):
nk = len(d.keys())
t = nk + s
for k in d.keys():
d[k] /= t
def gram_dict(n):
d = {}
for i in product(list(string.ascii_uppercase), repeat=n):
d["".join(i)] = 1
return d
def build_dist(initials):
l = len(initials)
trigrams = gram_dict(3)
# print(list(trigrams.keys())[0])
# ab bc ac
bigrams = [gram_dict(2), gram_dict(2), gram_dict(2)]
monograms = [gram_dict(1), gram_dict(1), gram_dict(1)]
mul = 1
frac = 1 * mul
for i in initials:
trigrams[i] += frac
a, b, c = i
bigrams[0][a + b] += frac
bigrams[1][b + c] += frac
bigrams[2][a + c] += frac
monograms[0][a] += frac
monograms[1][b] += frac
monograms[2][c] += frac
norm_dict(trigrams, l * mul)
for i in range(3):
norm_dict(bigrams[i], l * mul)
norm_dict(monograms[i], l * mul)
return trigrams, bigrams, monograms
trigrams, bigrams, monograms = build_dist(initials)
def kl_dicts(target, actual):
total = 0
for k in actual.keys():
# bayesian prior
# https://mathoverflow.net/questions/72668/how-to-compute-kl-divergence-when-pmf-contains-0s
a = actual[k]
t = target[k]
# if a == 0 or t == 0:
# continue
inner = target[k] / actual[k]
e = target[k] * math.log(inner)
# e = abs(target[k] - actual[k])
total += e
return total
def count_existing(ce_messages):
# print(decrypted[0][:3])
initials = [m[:3] for m in ce_messages]
t, b, m = build_dist(initials)
t_sum = kl_dicts(t, trigrams)
# t_sum = 0
b_sum = sum([kl_dicts(b[i], bigrams[i]) for i in range(3)])
m_sum = sum([kl_dicts(m[i], monograms[i]) for i in range(3)])
# return b_sum
return t_sum + b_sum + m_sum
def invert_key(key):
output = {}
for k, v in key.items():
output[v] = k
return output
def decrypt(inv_key, message):
return "".join(inv_key.get(x, '') for x in message)
def iteration(key, messages, score_func):
# run the key on all the messages
inv_key = invert_key(key)
decrypted = [decrypt(inv_key, m) for m in messages]
# score the messages
return score_func(decrypted)
def simple_hill_climbing(initial, messages):
current = initial
os = 99999
key = list(string.ascii_uppercase)
curr_vals = [initial[k] for k in key]
for i in range(1000):
curr_vals = [current[k] for k in key]
s = random.randint(0,25)
d = random.randint(0,25)
copy = curr_vals[:]
copy[s], copy[d] = copy[d], copy[s]
child = dict(zip(key, copy))
score = iteration(child, messages, count_existing)
# cipher_distance = sum([child[c] == kg_lookup[c] for c in string.ascii_uppercase])
# cipher_distance = sum([child[c] == kg_lookup[c] for c in target_chars])
if score < os:
# print("iteration:", i, "score:", score)#, cipher_distance)
# print("new best")
current = child
os = score
return current, os
key = list(string.ascii_uppercase)
def get_try(lines):
best = None
bv = 9999999
start = time.time()
while time.time()-start < 80:
value = list(key)
random.shuffle(value)
test = dict(zip(key, value))
res, s = simple_hill_climbing(test, lines)
print(res, s, time.time()-start)
if s < bv:
print("new best", s)
best = res
bv = s
target = "CABLAND"
return encrypt(best, target)
def do_connection():
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((hostname, port))
lines = []
while True:
line = s.recv(8).decode("utf-8").strip()
# print(line)
if "!" in line:
break
lines.append(line)
print("Got lines")
# do sleep and other handling here
data = get_try(lines)
s.send(data.encode("utf-8"))
time.sleep(0.5)
s.setblocking(0)
new_data = s.recv(1024)
print(new_data)
return new_data
while True:
result = do_connection()
if b"Uh oh.." not in result:
print(result)
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment