Created
September 19, 2022 01:48
-
-
Save jcrist/d7271415011cdc528ba82d2e5f328808 to your computer and use it in GitHub Desktop.
An example TCP Key-Value store written using msgspec and asyncio
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""An example key-value store server and client implementation using msgspec | |
and asyncio. | |
Requests are serialized using the MessagePack protocol, as implemented by | |
msgspec. Additionally, all messages are length-prefix framed using a 32 bit | |
big-endian integer. | |
Note that this encoding isn't tied to either asyncio or msgspec - this could | |
just as easily be implemented using sockets and a different serialization | |
protocol. Length-prefix framing is useful in that respect - it separates the IO | |
handling code from the actual encoding, making it easy to swap out transports | |
or serialization protocols without having to rewrite everything. | |
Also note that this example is written for clarity over efficiency, there are | |
more efficient ways to do this with msgspec than shown here. The protocol | |
defined in this example is very similar to the original reason I wrote msgspec; | |
it's kind of what it was designed to be best at. | |
For more information please see the msgspec docs: | |
https://jcristharif.com/msgspec/index.html. | |
In particular, the following sections are relevant: | |
- Structs: https://jcristharif.com/msgspec/structs.html | |
- Length-Prefix framing: https://jcristharif.com/msgspec/perf-tips.html#length-prefix-framing | |
""" | |
from __future__ import annotations | |
import asyncio | |
import msgspec | |
from typing import Union | |
async def prefixed_send(stream: asyncio.StreamWriter, buffer: bytes) -> None: | |
"""Write a length-prefixed buffer to the stream""" | |
# Encode the message length as a 4 byte big-endian integer. | |
n = len(buffer) | |
prefix = n.to_bytes(4, "big") | |
# Write the prefix and buffer to the stream. We await `drain` to propogate | |
# backpressure on the writing side. | |
stream.write(prefix) | |
stream.write(buffer) | |
await stream.drain() | |
async def prefixed_recv(stream: asyncio.StreamReader) -> bytes: | |
"""Read a length-prefixed buffer from the stream""" | |
# Read the next 4 byte prefix | |
prefix = await stream.readexactly(4) | |
# Convert the prefix back into an integer for the next message length | |
n = int.from_bytes(prefix, "big") | |
# Read in the full message buffer | |
return await stream.readexactly(n) | |
# Define some request types | |
class Get(msgspec.Struct, tag=True): | |
key: str | |
class Put(msgspec.Struct, tag=True): | |
key: str | |
val: str | |
class Del(msgspec.Struct, tag=True): | |
key: str | |
class Server: | |
"""An example TCP key-value server using asyncio and msgspec""" | |
def __init__(self, host="127.0.0.1", port=8888): | |
self.host = host | |
self.port = port | |
self.kv = {} | |
# A msgpack encoder for encoding responses | |
self.encoder = msgspec.msgpack.Encoder() | |
# A *typed* msgpack decoder for decoding requests. If a request doesn't | |
# match the specified types, a nice error will be raised. | |
self.decoder = msgspec.msgpack.Decoder(Union[Get, Put, Del]) | |
async def handle_connection(self, reader, writer): | |
"""Handle the full lifetime of a single connection""" | |
print("Connection opened") | |
while True: | |
try: | |
# Receive and decode a request | |
buffer = await prefixed_recv(reader) | |
req = self.decoder.decode(buffer) | |
# Process the request | |
resp = await self.handle_request(req) | |
# Encode and write the response | |
buffer = self.encoder.encode(resp) | |
await prefixed_send(writer, buffer) | |
except EOFError: | |
print("Connection closed") | |
return | |
async def handle_request(self, req: Get | Put | Del) -> str | None: | |
"""Handle a single request""" | |
# You don't have to use pattern matching here, but it works and is new | |
# and shiny. | |
match req: | |
case Get(key): | |
return self.kv.get(key) | |
case Put(key, val): | |
self.kv[key] = val | |
return None | |
case Del(key): | |
self.kv.pop(key, None) | |
return None | |
async def serve(self): | |
server = await asyncio.start_server( | |
self.handle_connection, self.host, self.port | |
) | |
print(f"Serving on tcp://{self.host}:{self.port}...") | |
async with server: | |
await server.serve_forever() | |
def run(self): | |
"""Run the server until ctrl-C""" | |
asyncio.run(self.serve()) | |
class Client: | |
"""An example TCP key-value client using asyncio and msgspec.""" | |
def __init__(self, host="127.0.0.1", port=8888): | |
self.host = host | |
self.port = port | |
self.reader = self.writer = None | |
async def close(self): | |
"""Close the client.""" | |
self.writer.close() | |
await self.writer.wait_closed() | |
self.reader = None | |
self.writer = None | |
async def __aenter__(self): | |
if self.reader is None: | |
reader, writer = await asyncio.open_connection(self.host, self.port) | |
self.reader = reader | |
self.writer = writer | |
return self | |
async def __aexit__(self, *args): | |
await self.close() | |
def __await__(self): | |
return self.__aenter__().__await__() | |
async def request(self, req): | |
"""Send a request and await the response""" | |
# Encode and send the request | |
buffer = msgspec.msgpack.encode(req) | |
await prefixed_send(self.writer, buffer) | |
# Receive and decode the response | |
buffer = await prefixed_recv(self.reader) | |
return msgspec.msgpack.decode(buffer) | |
async def get(self, key: str) -> str | None: | |
"""Get a key from the KV store, returning None if not present""" | |
return await self.request(Get(key)) | |
async def put(self, key: str, val: str) -> None: | |
"""Put a key-val pair in the KV store""" | |
return await self.request(Put(key, val)) | |
async def delete(self, key: str) -> None: | |
"""Delete a key from the KV store. No-op if not present""" | |
return await self.request(Del(key)) | |
if __name__ == "__main__": | |
Server().run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
An example usage session:
Server
Client