Skip to content

Instantly share code, notes, and snippets.

@iamkahvi
Last active March 6, 2025 16:41
Show Gist options
  • Save iamkahvi/d504e24806e15aee18ba8835855e0ad9 to your computer and use it in GitHub Desktop.
Save iamkahvi/d504e24806e15aee18ba8835855e0ad9 to your computer and use it in GitHub Desktop.
import llm
import urllib.request
import urllib.error
import json
import os
from typing import List, Dict
@llm.hookimpl
def register_models(register):
register(OpenAIProxy())
class OpenAIProxy(llm.Model):
model_id = "openai_proxy"
can_stream = True
def __init__(self):
# Define default option values
self.default_model = "anthropic:claude-3-7-sonnet"
self.default_timeout = 30.0
def validate_options(self, options):
"""Validate the provided options."""
errors = []
# Validate temperature if provided
if hasattr(options, 'temperature') and options.temperature is not None:
if not 0 <= options.temperature <= 2:
errors.append("temperature must be between 0 and 2")
# Validate max_tokens if provided
if hasattr(options, 'max_tokens') and options.max_tokens is not None:
if options.max_tokens < 1:
errors.append("max_tokens must be at least 1")
# Validate timeout if provided
if hasattr(options, 'timeout') and options.timeout is not None:
if options.timeout <= 0:
errors.append("timeout must be greater than 0")
return errors
def execute(self, prompt, stream, response, conversation):
"""Execute the OpenAI proxy model with the given prompt."""
api_base_url = os.environ.get("OPENAI_API_BASE")
if not api_base_url:
error_message = "Error: OPENAI_API_BASE environment variable is not set"
response.response_json = {"error": error_message}
return [error_message]
url = f"{api_base_url}/chat/completions"
# Check for API key
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
error_message = "Error: OPENAI_API_KEY environment variable is not set"
response.response_json = {"error": error_message}
return [error_message]
# Validate options
option_errors = self.validate_options(prompt.options)
if option_errors:
error_message = f"Option validation errors: {', '.join(option_errors)}"
response.response_json = {"error": error_message}
return [error_message]
# Get model from options or use default
model = getattr(prompt.options, 'model', self.default_model)
timeout = getattr(prompt.options, 'timeout', self.default_timeout)
# Prepare messages from conversation or prompt
messages = self._prepare_messages(prompt, conversation)
# Prepare request payload
data = {
"model": model,
"messages": messages
}
# Add optional parameters if provided
if hasattr(prompt.options, 'temperature') and prompt.options.temperature is not None:
data["temperature"] = prompt.options.temperature
if hasattr(prompt.options, 'max_tokens') and prompt.options.max_tokens is not None:
data["max_tokens"] = prompt.options.max_tokens
# Set streaming mode if requested
if stream:
data["stream"] = True
# Store request details in response_json for logging
response.response_json = {
"request": {
"model": data["model"],
"messages_count": len(messages)
}
}
try:
# Convert data to JSON
data_json = json.dumps(data).encode('utf-8')
# Create request
req = urllib.request.Request(url)
req.add_header('Content-Type', 'application/json')
req.add_header('Authorization', f'Bearer {api_key}')
# Handle streaming vs non-streaming responses
if stream:
return self._handle_streaming_response(req, data_json, timeout, response)
else:
return self._handle_non_streaming_response(req, data_json, timeout, response)
except Exception as e:
error_message = f"Error: {str(e)}"
response.response_json["error"] = error_message
return [error_message]
def _prepare_messages(self, prompt, conversation) -> List[Dict[str, str]]:
"""Convert conversation or prompt to OpenAI message format."""
messages = []
# If we have a conversation with messages, use those
if conversation and hasattr(conversation, 'responses') and conversation.responses:
# Add system prompt if available
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
# Add previous messages from conversation
for prev_response in conversation.responses:
# Add user message
messages.append({"role": "user", "content": prev_response.prompt.prompt})
# Add assistant response
messages.append({"role": "assistant", "content": prev_response.text()})
# Add the current prompt as the final user message
messages.append({"role": "user", "content": prompt.prompt})
else:
# If no conversation, add system prompt if available
if prompt.system:
messages.append({"role": "system", "content": prompt.system})
# Add the current prompt as a user message
messages.append({"role": "user", "content": prompt.prompt})
return messages
def _handle_non_streaming_response(self, req, data_json, timeout, response):
"""Handle a non-streaming response from the API."""
try:
with urllib.request.urlopen(req, data_json, timeout=timeout) as response_obj:
response_data = response_obj.read().decode('utf-8')
result = json.loads(response_data)
# Store API response details in response_json for logging
response.response_json["api_response"] = {
"id": result.get("id"),
"model": result.get("model"),
"usage": result.get("usage")
}
# Extract content
if 'choices' in result and result['choices'] and 'message' in result['choices'][0]:
content = result['choices'][0]['message'].get('content', '')
return [content]
else:
error = f"Unexpected response format: {json.dumps(result)}"
response.response_json["error"] = error
return [error]
except urllib.error.URLError as e:
error = f"API request failed: {str(e)}"
response.response_json["error"] = error
return [error]
except json.JSONDecodeError as e:
error = f"Failed to parse API response: {str(e)}"
response.response_json["error"] = error
return [error]
def _handle_streaming_response(self, req, data_json, timeout, response):
"""Handle a streaming response from the API."""
try:
with urllib.request.urlopen(req, data_json, timeout=timeout) as response_obj:
buffer = ""
for line in response_obj:
line = line.decode('utf-8').strip()
# Skip empty lines
if not line:
continue
# Skip "data: [DONE]" message
if line == "data: [DONE]":
continue
# Process data line
if line.startswith("data: "):
try:
data = json.loads(line[6:]) # Remove "data: " prefix
# Extract delta content if available
if ('choices' in data and data['choices'] and
'delta' in data['choices'][0] and
'content' in data['choices'][0]['delta']):
content = data['choices'][0]['delta']['content']
yield content
buffer += content
except json.JSONDecodeError:
# Skip malformed JSON
continue
# Store final content in response_json for logging
response.response_json["content_length"] = len(buffer)
except Exception as e:
error = f"Streaming error: {str(e)}"
response.response_json["error"] = error
yield error
[project]
name = "openai-proxy"
version = "0.1"
[project.entry-points.llm]
markov = "llm_openai_proxy"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment