Skip to content

Instantly share code, notes, and snippets.

@broncotc
Created August 27, 2017 03:31
Show Gist options
  • Save broncotc/1c768054d6a50e031e3d30fc7f3b8ea9 to your computer and use it in GitHub Desktop.
Save broncotc/1c768054d6a50e031e3d30fc7f3b8ea9 to your computer and use it in GitHub Desktop.
A barebone async udp relay / proxy with PSK-based encryption
#!/bin/env/python3
import asyncio
import platform
import time
import sys
import socket
import string
import getopt
from Cryptodome.Cipher import AES
from Cryptodome.Hash import SHA256
from Cryptodome.Random import get_random_bytes, random
if platform.system() == "Linux":
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# predefined constants
serverMode = False
remoteAddress = ("127.0.0.1", 33339)
bindAddress = ("127.0.0.1", random.randint(25519, 32767))
psk = ''.join(random.choice(string.ascii_lowercase) for i in range(32))
bpsk = SHA256.new(psk.encode()).digest()[:16]
socket.setdefaulttimeout(0.0)
connectTable = dict() # inbound remote address --> outbound remote protocol
reverseTable = dict() # inbount remote address --> inbound server protocol
timeTable = dict()
loop = asyncio.get_event_loop()
async def cleanTable():
global connectTable, reverseTable, timeTable
while True:
await asyncio.sleep(100)
poplist = []
for k, v in timeTable.items():
if time.time() - v > 1000:
connectTable.pop(k)
reverseTable.pop(k)
poplist.append(k)
for i in poplist:
timeTable.pop(i)
def encryptPacketData(data):
global bpsk
nonce = get_random_bytes(16)
cipher = AES.new(bpsk, AES.MODE_GCM, nonce)
ciphertext, tag = cipher.encrypt_and_digest(data)
return nonce + tag + ciphertext
def decryptPacketData(data):
global bpsk
nonce = data[:16]
tag = data[16:32]
ciphertext = data[32:]
cipher = AES.new(bpsk, AES.MODE_GCM, nonce)
return cipher.decrypt_and_verify(ciphertext, tag)
class relayServerProtocol(asyncio.DatagramProtocol):
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
global loop, connectTable, reverseTable, timeTable
serverRemote = connectTable.get(addr)
if serverRemote == None:
connectTable[addr] = serverRemote = relayRemoteProtocol(data, addr)
reverseTable[addr] = self
coro = loop.create_datagram_endpoint(serverRemote, remote_addr=remoteAddress)
asyncio.ensure_future(coro, loop=loop)
return
timeTable[addr] = time.time()
serverRemote.transport.sendto(decryptPacketData(data) if serverMode else encryptPacketData(data))
class relayRemoteProtocol(asyncio.DatagramProtocol):
def __init__(self, data, clientAddr):
self.data = data
self.clientAddr = clientAddr
def connection_made(self, transport):
global timeTable
self.transport = transport
timeTable[self.clientAddr] = time.time()
transport.sendto(decryptPacketData(self.data) if serverMode else encryptPacketData(self.data), remoteAddress)
del self.data
def connection_lost(self, exc):
connectTable.pop(self.clientAddr)
reverseTable.pop(self.clientAddr)
def datagram_received(self, data, addr):
global timeTable, reverseTable
timeTable[self.clientAddr] = time.time()
reverseTable[self.clientAddr].transport.sendto(
encryptPacketData(data) if serverMode else decryptPacketData(data), self.clientAddr)
def __call__(self, *args, **kwargs):
return self
def main():
global remoteAddress, bindAddress, psk, serverMode, bpsk
try:
opts, args = getopt.getopt(sys.argv[1:], "sb:r:k:",
["server", "bind-address=", "remote-address=", "pre-shared-key="])
except getopt.GetoptError as err:
print(err) # will print something like "option -a not recognized"
sys.exit(2)
for k, v in opts:
if k in ("-s", "--server"):
serverMode = True
elif k in ("-b", "--bind-address"):
ip, separator, port = v.rpartition(':')
if not separator:
print("Illegal bind address")
sys.exit(2)
bindAddress = (ip, port)
elif k in ("-r", "--remote-address"):
ip, separator, port = v.rpartition(':')
if not separator:
print("Illegal remote address")
sys.exit(2)
remoteAddress = (ip, int(port))
elif k in ("--pre-shared-key", "-k"):
psk = v
bpsk = SHA256.new(psk.encode()).digest()[:16]
print("Relay is working in %s mode" % ("server" if serverMode else "client"))
print("Bind address is %s" % (
("[" if ":" in bindAddress[0] else "") + bindAddress[0] + ("]" if ":" in bindAddress[0] else "") + ":" + str(
bindAddress[1])))
print("Remote address is %s" % (
("[" if ":" in remoteAddress[0] else "") + remoteAddress[0] + (
"]" if ":" in remoteAddress[0] else "") + ":" + str(remoteAddress[1])))
print('Pre-shared Key (PSK) is "%s"' % psk)
# One protocol instance will be created to serve all client requests
listen = loop.create_datagram_endpoint(relayServerProtocol, local_addr=bindAddress)
transport, protocol = loop.run_until_complete(listen)
loop.run_until_complete(cleanTable())
try:
loop.run_forever()
except KeyboardInterrupt:
pass
transport.close()
loop.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment