Skip to content

Instantly share code, notes, and snippets.

@benc-uk
Last active February 23, 2025 22:22
Show Gist options
  • Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python

Python Things

from fastapi import FastAPI, Request, status
from fastapi.responses import PlainTextResponse
import jwt
import logging
logger = logging.getLogger(__name__)
app = FastAPI()
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization")
# Response for unauthorized requests
resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED)
# Check if Authorization header is valid
if auth_header:
# Get the token from the header
token = auth_header.split("Bearer ")[1]
if not token:
return resp401
try:
decoded_token = validate_token(token)
if decoded_token:
# Here you can do some authorization logic like checking scopes, roles, etc.
# But we don't, we just chain the request to the next middleware
response = await call_next(request)
return response
except Exception as e:
logger.error(f"ERROR: Problem validating token: {e}")
return resp401
else:
return resp401
def validate_token(token: str):
jwks_client = jwt.PyJWKClient(
# Magic URL you might want to put in a config file or constant
uri="https://login.microsoftonline.com/common/discovery/keys",
cache_jwk_set=True,
lifespan=600
)
signing_key = jwks_client.get_signing_key_from_jwt(token)
return jwt.decode(
token,
signing_key.key,
# This is the algorithm that Azure AD uses and lots of other OIDC providers
algorithms=["RS256"],
# For your API, this will be the Application ID (GUID) of the client you have registered
audience="b79fbf4d-3ef9-4689-8143-76b194e85509",
)
# Just a simple endpoint to demonstrate the middleware
@app.get("/")
def read_root():
return {"Hello": "World"}
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"from dataclasses import (\n",
" dataclass,\n",
")\n",
"from typing import (\n",
" List,\n",
")\n",
"\n",
"\n",
"@dataclass\n",
"class Token:\n",
" value: int\n",
" length: int\n",
"\n",
"\n",
"class LanguageModel:\n",
" def __init__(\n",
" self,\n",
" ):\n",
" self.vocabulary = {\n",
" \"a\": 1,\n",
" \"b\": 2,\n",
" \"c\": 3,\n",
" \"d\": 4,\n",
" \"e\": 5,\n",
" \"f\": 6,\n",
" \"g\": 7,\n",
" \"h\": 8,\n",
" \"i\": 9,\n",
" \"j\": 10,\n",
" \"k\": 11,\n",
" \"l\": 12,\n",
" \"m\": 13,\n",
" \"n\": 14,\n",
" \"o\": 15,\n",
" \"p\": 16,\n",
" \"q\": 17,\n",
" \"r\": 18,\n",
" \"s\": 19,\n",
" \"t\": 20,\n",
" \"u\": 21,\n",
" \"v\": 22,\n",
" \"w\": 23,\n",
" \"x\": 24,\n",
" \"y\": 25,\n",
" \"z\": 26,\n",
" \" \": 27,\n",
" \"jumper\": 28,\n",
" \"kitten\": 29,\n",
" \"harsha\": 30,\n",
" \"umb\": 31,\n",
" \"umbrella\": 32,\n",
" \"lion\": 33,\n",
" \"great\": 34,\n",
" \"violet\": 35,\n",
" \"ban\": 36,\n",
" \"quick\": 37,\n",
" \"open\": 38,\n",
" \"fish\": 39,\n",
" \"fis\": 40,\n",
" \"energy\": 41,\n",
" \"running\": 42,\n",
" \"mango\": 43,\n",
" \"tree\": 44,\n",
" \"skyscraper\": 45,\n",
" \"dog\": 46,\n",
" \"abc\": 47,\n",
" \"cat\": 48,\n",
" \"world\": 49,\n",
" \"intelligence\": 50,\n",
" \"greater\": 51,\n",
" \"band\": 52,\n",
" \"int\": 53,\n",
" \"python\": 54,\n",
" \"door\": 55,\n",
" \"pro\": 56,\n",
" \"jump\": 57,\n",
" \"banan\": 58,\n",
" \"do\": 59,\n",
" \"bandana\": 60,\n",
" \"apple\": 61,\n",
" \"house\": 62,\n",
" \"apply\": 63,\n",
" \"over\": 64,\n",
" \"wolf\": 65,\n",
" \"kite\": 66,\n",
" \"yellow\": 67,\n",
" \"end\": 68,\n",
" \"car\": 69,\n",
" \"unique\": 70,\n",
" \"kit\": 71,\n",
" \"night\": 72,\n",
" \"wat\": 73,\n",
" \"trees\": 74,\n",
" \"vio\": 75,\n",
" \"aha\": 76,\n",
" \"writer\": 77,\n",
" \"cart\": 78,\n",
" \"jum\": 79,\n",
" \"quicker\": 80,\n",
" \"human\": 81,\n",
" \"man\": 82,\n",
" \"writ\": 83,\n",
" \"machine\": 84,\n",
" \"gone\": 85,\n",
" \"yel\": 86,\n",
" \"ab\": 87,\n",
" \"intelligent\": 88,\n",
" \"writing\": 89,\n",
" \"banana\": 90,\n",
" \"ap\": 91,\n",
" \"ni\": 92,\n",
" \"cater\": 93,\n",
" \"program\": 94,\n",
" \"run\": 95,\n",
" \"go\": 96,\n",
" \"tre\": 97,\n",
" \"ma\": 98,\n",
" \"fishing\": 99,\n",
" \"zebra\": 100,\n",
" \"qui\": 101,\n",
" \"water\": 102,\n",
" \"op\": 103,\n",
" \"gre\": 104,\n",
" \"bike\": 500,\n",
" }\n",
"\n",
" self.max_token = max(self.vocabulary.values())\n",
" self.min_token = min(self.vocabulary.values())\n",
"\n",
" def find_longest_token(\n",
" self,\n",
" input,\n",
" ) -> Token:\n",
" while len(input) > 0:\n",
" tok_val = self.vocabulary.get(\n",
" input,\n",
" False,\n",
" )\n",
" if tok_val:\n",
" print(f\" > Found: '{input}' {tok_val} {len(input)}\")\n",
" return Token(\n",
" tok_val,\n",
" len(input),\n",
" )\n",
" input = input[:-1] # Chomp backwards one character\n",
"\n",
" # This should never happen with current vocabulary\n",
" return (\n",
" 1,\n",
" 1,\n",
" )\n",
"\n",
" def encode(\n",
" self,\n",
" prompt: str,\n",
" ) -> List[int]:\n",
" tokens = []\n",
" while len(prompt) > 0:\n",
" print(\n",
" \" > Encode:\",\n",
" prompt,\n",
" )\n",
" token = self.find_longest_token(prompt)\n",
" tokens.append(token.value)\n",
" prompt = prompt[token.length :] # Chomp forward, remove the found token\n",
"\n",
" return tokens\n",
"\n",
" def decode(\n",
" self,\n",
" tokens: List[int],\n",
" ) -> str:\n",
" out = []\n",
" for token_in in tokens:\n",
" matches = [token_str for token_str, token_val in self.vocabulary.items() if token_val == token_in]\n",
" # We should only have one match\n",
" if len(matches) > 0:\n",
" out.append(matches.pop())\n",
"\n",
" return \"\".join(out)\n",
"\n",
" # Draw the rest of the owl (language model) here\n",
" def sample_token(\n",
" self,\n",
" tokens: List[int],\n",
" ) -> int:\n",
" # This is a very simple language model\n",
" # It just samples a random token from the vocabulary\n",
"\n",
" # Don't assume contiguous values, brute force a random token that exists\n",
" while True:\n",
" tok = random.randint(\n",
" self.min_token,\n",
" self.max_token,\n",
" )\n",
" if tok in self.vocabulary.values():\n",
" return tok\n",
"\n",
" def generate(\n",
" self,\n",
" prompt: str,\n",
" max_tokens: int = 10,\n",
" ) -> str:\n",
" encoding = self.encode(prompt)\n",
" print(\n",
" \"DEBUG! Encoded prompt:\",\n",
" encoding,\n",
" )\n",
"\n",
" out_samples = []\n",
" for _ in range(max_tokens):\n",
" sample = self.sample_token(encoding)\n",
" out_samples.append(sample)\n",
" encoding.append(sample) # This is how language models work!\n",
"\n",
" return self.decode(out_samples)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prompt: house like fish\n",
" > Encode: house like fish\n",
" > Found: 'house' 62 5\n",
" > Encode: like fish\n",
" > Found: ' ' 27 1\n",
" > Encode: like fish\n",
" > Found: 'l' 12 1\n",
" > Encode: ike fish\n",
" > Found: 'i' 9 1\n",
" > Encode: ke fish\n",
" > Found: 'k' 11 1\n",
" > Encode: e fish\n",
" > Found: 'e' 5 1\n",
" > Encode: fish\n",
" > Found: ' ' 27 1\n",
" > Encode: fish\n",
" > Found: 'fish' 39 4\n",
"DEBUG! Encoded prompt: [62, 27, 12, 9, 11, 5, 27, 39]\n",
"Answer: doorrunningnightcaterlgoneintelligentvioletworlddo\n"
]
}
],
"source": [
"lm = LanguageModel()\n",
"\n",
"input_prompt = \"house like fish\"\n",
"print(\n",
" \"Prompt:\",\n",
" input_prompt,\n",
")\n",
"print(\n",
" \"Answer:\",\n",
" lm.generate(input_prompt),\n",
")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from fastapi import FastAPI, Request
from dotenv import load_dotenv
import uvicorn
import requests
import os
load_dotenv()
app = FastAPI()
# The SNS_TOPIC_ARN env var must be set
allowed_topics = [
os.environ['SNS_TOPIC_ARN']
]
# Main endpoint to receive SNS messages
@app.post("/sns/subscription")
async def index(request: Request):
if 'x-amz-sns-message-type' not in request.headers:
print('A regular HTTP request, will be ignored!')
return {"message": "Request filtered out", "status": 403}
body = await request.json()
message_type = request.headers.get('x-amz-sns-message-type', '')
topic_arn = body["TopicArn"]
if topic_arn not in allowed_topics:
return {"message": "Topic not allowed", "status": 403}
# More checks can be added here, like checking the signature of the message
# Anyhow, this is just a simple example
if message_type == 'SubscriptionConfirmation':
print('Subscription confirmation in progress...')
subscribe_url = body['SubscribeURL']
# Confirm the subscription by sending a GET request to the SubscribeURL
response = requests.get(subscribe_url)
if response.status_code == 200:
return {"message": "Successfully subscribed to the topic", "status": 200}
else:
return {"message": "Failed to subscribe to the topic", "status": 500}
sns_message = body.get('Message', '')
print(f'SNS message received, length: {len(sns_message)} bytes')
# ============================================
# Your message processing logic goes here!
# ============================================
return {"message": "Message received", "status": 200}
# Health check endpoint for Front Door and other load balancing probes
@app.get("/healthz")
async def index(request: Request):
return {"status": "OK"}
if __name__ == '__main__':
uvicorn.run('main:app', host='0.0.0.0', port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment