Last active
February 23, 2025 22:22
-
-
Save benc-uk/8e576cf72b2361782060f20917f2e280 to your computer and use it in GitHub Desktop.
Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from fastapi import FastAPI, Request, status | |
| from fastapi.responses import PlainTextResponse | |
| import jwt | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| @app.middleware("http") | |
| async def auth_middleware(request: Request, call_next): | |
| auth_header = request.headers.get("Authorization") | |
| # Response for unauthorized requests | |
| resp401 = PlainTextResponse("Unauthorized", status_code=status.HTTP_401_UNAUTHORIZED) | |
| # Check if Authorization header is valid | |
| if auth_header: | |
| # Get the token from the header | |
| token = auth_header.split("Bearer ")[1] | |
| if not token: | |
| return resp401 | |
| try: | |
| decoded_token = validate_token(token) | |
| if decoded_token: | |
| # Here you can do some authorization logic like checking scopes, roles, etc. | |
| # But we don't, we just chain the request to the next middleware | |
| response = await call_next(request) | |
| return response | |
| except Exception as e: | |
| logger.error(f"ERROR: Problem validating token: {e}") | |
| return resp401 | |
| else: | |
| return resp401 | |
| def validate_token(token: str): | |
| jwks_client = jwt.PyJWKClient( | |
| # Magic URL you might want to put in a config file or constant | |
| uri="https://login.microsoftonline.com/common/discovery/keys", | |
| cache_jwk_set=True, | |
| lifespan=600 | |
| ) | |
| signing_key = jwks_client.get_signing_key_from_jwt(token) | |
| return jwt.decode( | |
| token, | |
| signing_key.key, | |
| # This is the algorithm that Azure AD uses and lots of other OIDC providers | |
| algorithms=["RS256"], | |
| # For your API, this will be the Application ID (GUID) of the client you have registered | |
| audience="b79fbf4d-3ef9-4689-8143-76b194e85509", | |
| ) | |
| # Just a simple endpoint to demonstrate the middleware | |
| @app.get("/") | |
| def read_root(): | |
| return {"Hello": "World"} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "import random\n", | |
| "from dataclasses import (\n", | |
| " dataclass,\n", | |
| ")\n", | |
| "from typing import (\n", | |
| " List,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "@dataclass\n", | |
| "class Token:\n", | |
| " value: int\n", | |
| " length: int\n", | |
| "\n", | |
| "\n", | |
| "class LanguageModel:\n", | |
| " def __init__(\n", | |
| " self,\n", | |
| " ):\n", | |
| " self.vocabulary = {\n", | |
| " \"a\": 1,\n", | |
| " \"b\": 2,\n", | |
| " \"c\": 3,\n", | |
| " \"d\": 4,\n", | |
| " \"e\": 5,\n", | |
| " \"f\": 6,\n", | |
| " \"g\": 7,\n", | |
| " \"h\": 8,\n", | |
| " \"i\": 9,\n", | |
| " \"j\": 10,\n", | |
| " \"k\": 11,\n", | |
| " \"l\": 12,\n", | |
| " \"m\": 13,\n", | |
| " \"n\": 14,\n", | |
| " \"o\": 15,\n", | |
| " \"p\": 16,\n", | |
| " \"q\": 17,\n", | |
| " \"r\": 18,\n", | |
| " \"s\": 19,\n", | |
| " \"t\": 20,\n", | |
| " \"u\": 21,\n", | |
| " \"v\": 22,\n", | |
| " \"w\": 23,\n", | |
| " \"x\": 24,\n", | |
| " \"y\": 25,\n", | |
| " \"z\": 26,\n", | |
| " \" \": 27,\n", | |
| " \"jumper\": 28,\n", | |
| " \"kitten\": 29,\n", | |
| " \"harsha\": 30,\n", | |
| " \"umb\": 31,\n", | |
| " \"umbrella\": 32,\n", | |
| " \"lion\": 33,\n", | |
| " \"great\": 34,\n", | |
| " \"violet\": 35,\n", | |
| " \"ban\": 36,\n", | |
| " \"quick\": 37,\n", | |
| " \"open\": 38,\n", | |
| " \"fish\": 39,\n", | |
| " \"fis\": 40,\n", | |
| " \"energy\": 41,\n", | |
| " \"running\": 42,\n", | |
| " \"mango\": 43,\n", | |
| " \"tree\": 44,\n", | |
| " \"skyscraper\": 45,\n", | |
| " \"dog\": 46,\n", | |
| " \"abc\": 47,\n", | |
| " \"cat\": 48,\n", | |
| " \"world\": 49,\n", | |
| " \"intelligence\": 50,\n", | |
| " \"greater\": 51,\n", | |
| " \"band\": 52,\n", | |
| " \"int\": 53,\n", | |
| " \"python\": 54,\n", | |
| " \"door\": 55,\n", | |
| " \"pro\": 56,\n", | |
| " \"jump\": 57,\n", | |
| " \"banan\": 58,\n", | |
| " \"do\": 59,\n", | |
| " \"bandana\": 60,\n", | |
| " \"apple\": 61,\n", | |
| " \"house\": 62,\n", | |
| " \"apply\": 63,\n", | |
| " \"over\": 64,\n", | |
| " \"wolf\": 65,\n", | |
| " \"kite\": 66,\n", | |
| " \"yellow\": 67,\n", | |
| " \"end\": 68,\n", | |
| " \"car\": 69,\n", | |
| " \"unique\": 70,\n", | |
| " \"kit\": 71,\n", | |
| " \"night\": 72,\n", | |
| " \"wat\": 73,\n", | |
| " \"trees\": 74,\n", | |
| " \"vio\": 75,\n", | |
| " \"aha\": 76,\n", | |
| " \"writer\": 77,\n", | |
| " \"cart\": 78,\n", | |
| " \"jum\": 79,\n", | |
| " \"quicker\": 80,\n", | |
| " \"human\": 81,\n", | |
| " \"man\": 82,\n", | |
| " \"writ\": 83,\n", | |
| " \"machine\": 84,\n", | |
| " \"gone\": 85,\n", | |
| " \"yel\": 86,\n", | |
| " \"ab\": 87,\n", | |
| " \"intelligent\": 88,\n", | |
| " \"writing\": 89,\n", | |
| " \"banana\": 90,\n", | |
| " \"ap\": 91,\n", | |
| " \"ni\": 92,\n", | |
| " \"cater\": 93,\n", | |
| " \"program\": 94,\n", | |
| " \"run\": 95,\n", | |
| " \"go\": 96,\n", | |
| " \"tre\": 97,\n", | |
| " \"ma\": 98,\n", | |
| " \"fishing\": 99,\n", | |
| " \"zebra\": 100,\n", | |
| " \"qui\": 101,\n", | |
| " \"water\": 102,\n", | |
| " \"op\": 103,\n", | |
| " \"gre\": 104,\n", | |
| " \"bike\": 500,\n", | |
| " }\n", | |
| "\n", | |
| " self.max_token = max(self.vocabulary.values())\n", | |
| " self.min_token = min(self.vocabulary.values())\n", | |
| "\n", | |
| " def find_longest_token(\n", | |
| " self,\n", | |
| " input,\n", | |
| " ) -> Token:\n", | |
| " while len(input) > 0:\n", | |
| " tok_val = self.vocabulary.get(\n", | |
| " input,\n", | |
| " False,\n", | |
| " )\n", | |
| " if tok_val:\n", | |
| " print(f\" > Found: '{input}' {tok_val} {len(input)}\")\n", | |
| " return Token(\n", | |
| " tok_val,\n", | |
| " len(input),\n", | |
| " )\n", | |
| " input = input[:-1] # Chomp backwards one character\n", | |
| "\n", | |
| " # This should never happen with current vocabulary\n", | |
| " return (\n", | |
| " 1,\n", | |
| " 1,\n", | |
| " )\n", | |
| "\n", | |
| " def encode(\n", | |
| " self,\n", | |
| " prompt: str,\n", | |
| " ) -> List[int]:\n", | |
| " tokens = []\n", | |
| " while len(prompt) > 0:\n", | |
| " print(\n", | |
| " \" > Encode:\",\n", | |
| " prompt,\n", | |
| " )\n", | |
| " token = self.find_longest_token(prompt)\n", | |
| " tokens.append(token.value)\n", | |
| " prompt = prompt[token.length :] # Chomp forward, remove the found token\n", | |
| "\n", | |
| " return tokens\n", | |
| "\n", | |
| " def decode(\n", | |
| " self,\n", | |
| " tokens: List[int],\n", | |
| " ) -> str:\n", | |
| " out = []\n", | |
| " for token_in in tokens:\n", | |
| " matches = [token_str for token_str, token_val in self.vocabulary.items() if token_val == token_in]\n", | |
| " # We should only have one match\n", | |
| " if len(matches) > 0:\n", | |
| " out.append(matches.pop())\n", | |
| "\n", | |
| " return \"\".join(out)\n", | |
| "\n", | |
| " # Draw the rest of the owl (language model) here\n", | |
| " def sample_token(\n", | |
| " self,\n", | |
| " tokens: List[int],\n", | |
| " ) -> int:\n", | |
| " # This is a very simple language model\n", | |
| " # It just samples a random token from the vocabulary\n", | |
| "\n", | |
| " # Don't assume contiguous values, brute force a random token that exists\n", | |
| " while True:\n", | |
| " tok = random.randint(\n", | |
| " self.min_token,\n", | |
| " self.max_token,\n", | |
| " )\n", | |
| " if tok in self.vocabulary.values():\n", | |
| " return tok\n", | |
| "\n", | |
| " def generate(\n", | |
| " self,\n", | |
| " prompt: str,\n", | |
| " max_tokens: int = 10,\n", | |
| " ) -> str:\n", | |
| " encoding = self.encode(prompt)\n", | |
| " print(\n", | |
| " \"DEBUG! Encoded prompt:\",\n", | |
| " encoding,\n", | |
| " )\n", | |
| "\n", | |
| " out_samples = []\n", | |
| " for _ in range(max_tokens):\n", | |
| " sample = self.sample_token(encoding)\n", | |
| " out_samples.append(sample)\n", | |
| " encoding.append(sample) # This is how language models work!\n", | |
| "\n", | |
| " return self.decode(out_samples)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 9, | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "Prompt: house like fish\n", | |
| " > Encode: house like fish\n", | |
| " > Found: 'house' 62 5\n", | |
| " > Encode: like fish\n", | |
| " > Found: ' ' 27 1\n", | |
| " > Encode: like fish\n", | |
| " > Found: 'l' 12 1\n", | |
| " > Encode: ike fish\n", | |
| " > Found: 'i' 9 1\n", | |
| " > Encode: ke fish\n", | |
| " > Found: 'k' 11 1\n", | |
| " > Encode: e fish\n", | |
| " > Found: 'e' 5 1\n", | |
| " > Encode: fish\n", | |
| " > Found: ' ' 27 1\n", | |
| " > Encode: fish\n", | |
| " > Found: 'fish' 39 4\n", | |
| "DEBUG! Encoded prompt: [62, 27, 12, 9, 11, 5, 27, 39]\n", | |
| "Answer: doorrunningnightcaterlgoneintelligentvioletworlddo\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "lm = LanguageModel()\n", | |
| "\n", | |
| "input_prompt = \"house like fish\"\n", | |
| "print(\n", | |
| " \"Prompt:\",\n", | |
| " input_prompt,\n", | |
| ")\n", | |
| "print(\n", | |
| " \"Answer:\",\n", | |
| " lm.generate(input_prompt),\n", | |
| ")" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.12" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 2 | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from fastapi import FastAPI, Request | |
| from dotenv import load_dotenv | |
| import uvicorn | |
| import requests | |
| import os | |
| load_dotenv() | |
| app = FastAPI() | |
| # The SNS_TOPIC_ARN env var must be set | |
| allowed_topics = [ | |
| os.environ['SNS_TOPIC_ARN'] | |
| ] | |
| # Main endpoint to receive SNS messages | |
| @app.post("/sns/subscription") | |
| async def index(request: Request): | |
| if 'x-amz-sns-message-type' not in request.headers: | |
| print('A regular HTTP request, will be ignored!') | |
| return {"message": "Request filtered out", "status": 403} | |
| body = await request.json() | |
| message_type = request.headers.get('x-amz-sns-message-type', '') | |
| topic_arn = body["TopicArn"] | |
| if topic_arn not in allowed_topics: | |
| return {"message": "Topic not allowed", "status": 403} | |
| # More checks can be added here, like checking the signature of the message | |
| # Anyhow, this is just a simple example | |
| if message_type == 'SubscriptionConfirmation': | |
| print('Subscription confirmation in progress...') | |
| subscribe_url = body['SubscribeURL'] | |
| # Confirm the subscription by sending a GET request to the SubscribeURL | |
| response = requests.get(subscribe_url) | |
| if response.status_code == 200: | |
| return {"message": "Successfully subscribed to the topic", "status": 200} | |
| else: | |
| return {"message": "Failed to subscribe to the topic", "status": 500} | |
| sns_message = body.get('Message', '') | |
| print(f'SNS message received, length: {len(sns_message)} bytes') | |
| # ============================================ | |
| # Your message processing logic goes here! | |
| # ============================================ | |
| return {"message": "Message received", "status": 200} | |
| # Health check endpoint for Front Door and other load balancing probes | |
| @app.get("/healthz") | |
| async def index(request: Request): | |
| return {"status": "OK"} | |
| if __name__ == '__main__': | |
| uvicorn.run('main:app', host='0.0.0.0', port=8000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment