Created
April 23, 2020 11:27
-
-
Save valkum/449bc1790a08063ba03ec47384b7670e to your computer and use it in GitHub Desktop.
check_signature update for v2 keys and python3
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import json | |
import logging | |
import sys | |
import urllib | |
import dns.resolver | |
from signedjson.key import ( | |
decode_verify_key_bytes, | |
encode_verify_key_base64, | |
is_signing_algorithm_supported, | |
write_signing_keys | |
) | |
from signedjson.sign import ( | |
SignatureVerifyException, | |
encode_canonical_json, | |
signature_ids, | |
verify_signed_json, | |
) | |
from unpaddedbase64 import decode_base64 | |
from synapse.storage.keys import FetchKeyResult | |
def get_targets(server_name): | |
if ":" in server_name: | |
target, port = server_name.split(":") | |
yield (target, int(port)) | |
return | |
try: | |
answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV") | |
for srv in answers: | |
yield (srv.target, srv.port) | |
except dns.resolver.NXDOMAIN: | |
yield (server_name, 8448) | |
except dns.resolver.NoAnswer: | |
yield (server_name, 8448) | |
def get_server_keys(server_name, target, port): | |
url = "https://%s:%i/_matrix/key/v2/server/" % (target, port) | |
print("Target {}:{}".format(target, port)) | |
response_json = json.load(urllib.request.urlopen(url)) | |
ts_valid_until_ms = response_json["valid_until_ts"] | |
verify_keys = {} | |
for key_id, key_data in response_json["verify_keys"].items(): | |
if is_signing_algorithm_supported(key_id): | |
key_base64 = key_data["key"] | |
key_bytes = decode_base64(key_base64) | |
verify_key = decode_verify_key_bytes(key_id, key_bytes) | |
verify_keys[key_id] = FetchKeyResult( | |
verify_key=verify_key, valid_until_ts=ts_valid_until_ms | |
) | |
server_name = response_json["server_name"] | |
verified = False | |
for key_id in response_json["signatures"].get(server_name, {}): | |
key = verify_keys.get(key_id) | |
if not key: | |
# the key may not be present in verify_keys if: | |
# * we got the key from the notary server, and: | |
# * the key belongs to the notary server, and: | |
# * the notary server is using a different key to sign notary | |
# responses. | |
continue | |
verify_signed_json(response_json, server_name, key.verify_key) | |
verified = True | |
break | |
if not verified: | |
print('what') | |
raise KeyLookupError( | |
"Key response for %s is not signed by the origin server" | |
% (server_name,) | |
) | |
return verify_keys | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("signature_name") | |
parser.add_argument( | |
"input_json", nargs="?", type=argparse.FileType("r"), default=sys.stdin | |
) | |
args = parser.parse_args() | |
logging.basicConfig() | |
server_name = args.signature_name | |
key_id = 'ed25519:key1' | |
result_keys = {} | |
for target, port in get_targets(server_name): | |
try: | |
result_keys = get_server_keys(server_name, target, port) | |
print("Using keys from https://%s:%s/_matrix/key/v2/server" % (target, port)) | |
fetch_key_result = result_keys.get(key_id) | |
if not fetch_key_result: | |
# we didn't get a result for this key | |
print('No key for this ID') | |
import time | |
current_milli_time = lambda: int(round(time.time() * 1000)) | |
if ( | |
fetch_key_result.valid_until_ts | |
< current_milli_time() | |
): | |
# key was not valid at this point | |
print('key was not valid at this point') | |
write_signing_keys(sys.stdout, map(lambda x: x.verify_key, result_keys.values())) | |
break | |
except Exception: | |
logging.exception("Error talking to %s:%s", target, port) | |
print("Checking JSON:") | |
json_to_check = json.load(args.input_json) | |
for key_id in json_to_check["signatures"][args.signature_name]: | |
try: | |
key = result_keys.get(key_id) | |
verify_signed_json(json_to_check, args.signature_name, key) | |
print("PASS %s" % (key_id,)) | |
except Exception: | |
logging.exception("Check for key %s failed" % (key_id,)) | |
print("FAIL %s" % (key_id,)) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment