Last active
March 6, 2025 16:41
-
-
Save iamkahvi/d504e24806e15aee18ba8835855e0ad9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import llm | |
import urllib.request | |
import urllib.error | |
import json | |
import os | |
from typing import List, Dict | |
@llm.hookimpl | |
def register_models(register): | |
register(OpenAIProxy()) | |
class OpenAIProxy(llm.Model): | |
model_id = "openai_proxy" | |
can_stream = True | |
def __init__(self): | |
# Define default option values | |
self.default_model = "anthropic:claude-3-7-sonnet" | |
self.default_timeout = 30.0 | |
def validate_options(self, options): | |
"""Validate the provided options.""" | |
errors = [] | |
# Validate temperature if provided | |
if hasattr(options, 'temperature') and options.temperature is not None: | |
if not 0 <= options.temperature <= 2: | |
errors.append("temperature must be between 0 and 2") | |
# Validate max_tokens if provided | |
if hasattr(options, 'max_tokens') and options.max_tokens is not None: | |
if options.max_tokens < 1: | |
errors.append("max_tokens must be at least 1") | |
# Validate timeout if provided | |
if hasattr(options, 'timeout') and options.timeout is not None: | |
if options.timeout <= 0: | |
errors.append("timeout must be greater than 0") | |
return errors | |
def execute(self, prompt, stream, response, conversation): | |
"""Execute the OpenAI proxy model with the given prompt.""" | |
api_base_url = os.environ.get("OPENAI_API_BASE") | |
if not api_base_url: | |
error_message = "Error: OPENAI_API_BASE environment variable is not set" | |
response.response_json = {"error": error_message} | |
return [error_message] | |
url = f"{api_base_url}/chat/completions" | |
# Check for API key | |
api_key = os.environ.get("OPENAI_API_KEY") | |
if not api_key: | |
error_message = "Error: OPENAI_API_KEY environment variable is not set" | |
response.response_json = {"error": error_message} | |
return [error_message] | |
# Validate options | |
option_errors = self.validate_options(prompt.options) | |
if option_errors: | |
error_message = f"Option validation errors: {', '.join(option_errors)}" | |
response.response_json = {"error": error_message} | |
return [error_message] | |
# Get model from options or use default | |
model = getattr(prompt.options, 'model', self.default_model) | |
timeout = getattr(prompt.options, 'timeout', self.default_timeout) | |
# Prepare messages from conversation or prompt | |
messages = self._prepare_messages(prompt, conversation) | |
# Prepare request payload | |
data = { | |
"model": model, | |
"messages": messages | |
} | |
# Add optional parameters if provided | |
if hasattr(prompt.options, 'temperature') and prompt.options.temperature is not None: | |
data["temperature"] = prompt.options.temperature | |
if hasattr(prompt.options, 'max_tokens') and prompt.options.max_tokens is not None: | |
data["max_tokens"] = prompt.options.max_tokens | |
# Set streaming mode if requested | |
if stream: | |
data["stream"] = True | |
# Store request details in response_json for logging | |
response.response_json = { | |
"request": { | |
"model": data["model"], | |
"messages_count": len(messages) | |
} | |
} | |
try: | |
# Convert data to JSON | |
data_json = json.dumps(data).encode('utf-8') | |
# Create request | |
req = urllib.request.Request(url) | |
req.add_header('Content-Type', 'application/json') | |
req.add_header('Authorization', f'Bearer {api_key}') | |
# Handle streaming vs non-streaming responses | |
if stream: | |
return self._handle_streaming_response(req, data_json, timeout, response) | |
else: | |
return self._handle_non_streaming_response(req, data_json, timeout, response) | |
except Exception as e: | |
error_message = f"Error: {str(e)}" | |
response.response_json["error"] = error_message | |
return [error_message] | |
def _prepare_messages(self, prompt, conversation) -> List[Dict[str, str]]: | |
"""Convert conversation or prompt to OpenAI message format.""" | |
messages = [] | |
# If we have a conversation with messages, use those | |
if conversation and hasattr(conversation, 'responses') and conversation.responses: | |
# Add system prompt if available | |
if prompt.system: | |
messages.append({"role": "system", "content": prompt.system}) | |
# Add previous messages from conversation | |
for prev_response in conversation.responses: | |
# Add user message | |
messages.append({"role": "user", "content": prev_response.prompt.prompt}) | |
# Add assistant response | |
messages.append({"role": "assistant", "content": prev_response.text()}) | |
# Add the current prompt as the final user message | |
messages.append({"role": "user", "content": prompt.prompt}) | |
else: | |
# If no conversation, add system prompt if available | |
if prompt.system: | |
messages.append({"role": "system", "content": prompt.system}) | |
# Add the current prompt as a user message | |
messages.append({"role": "user", "content": prompt.prompt}) | |
return messages | |
def _handle_non_streaming_response(self, req, data_json, timeout, response): | |
"""Handle a non-streaming response from the API.""" | |
try: | |
with urllib.request.urlopen(req, data_json, timeout=timeout) as response_obj: | |
response_data = response_obj.read().decode('utf-8') | |
result = json.loads(response_data) | |
# Store API response details in response_json for logging | |
response.response_json["api_response"] = { | |
"id": result.get("id"), | |
"model": result.get("model"), | |
"usage": result.get("usage") | |
} | |
# Extract content | |
if 'choices' in result and result['choices'] and 'message' in result['choices'][0]: | |
content = result['choices'][0]['message'].get('content', '') | |
return [content] | |
else: | |
error = f"Unexpected response format: {json.dumps(result)}" | |
response.response_json["error"] = error | |
return [error] | |
except urllib.error.URLError as e: | |
error = f"API request failed: {str(e)}" | |
response.response_json["error"] = error | |
return [error] | |
except json.JSONDecodeError as e: | |
error = f"Failed to parse API response: {str(e)}" | |
response.response_json["error"] = error | |
return [error] | |
def _handle_streaming_response(self, req, data_json, timeout, response): | |
"""Handle a streaming response from the API.""" | |
try: | |
with urllib.request.urlopen(req, data_json, timeout=timeout) as response_obj: | |
buffer = "" | |
for line in response_obj: | |
line = line.decode('utf-8').strip() | |
# Skip empty lines | |
if not line: | |
continue | |
# Skip "data: [DONE]" message | |
if line == "data: [DONE]": | |
continue | |
# Process data line | |
if line.startswith("data: "): | |
try: | |
data = json.loads(line[6:]) # Remove "data: " prefix | |
# Extract delta content if available | |
if ('choices' in data and data['choices'] and | |
'delta' in data['choices'][0] and | |
'content' in data['choices'][0]['delta']): | |
content = data['choices'][0]['delta']['content'] | |
yield content | |
buffer += content | |
except json.JSONDecodeError: | |
# Skip malformed JSON | |
continue | |
# Store final content in response_json for logging | |
response.response_json["content_length"] = len(buffer) | |
except Exception as e: | |
error = f"Streaming error: {str(e)}" | |
response.response_json["error"] = error | |
yield error |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "openai-proxy" | |
version = "0.1" | |
[project.entry-points.llm] | |
markov = "llm_openai_proxy" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment