Skip to content

Instantly share code, notes, and snippets.

@tail-call
Created August 19, 2025 08:13
Show Gist options
  • Select an option

  • Save tail-call/308d5e34ecca2d7455a9dd77deb5594c to your computer and use it in GitHub Desktop.

Select an option

Save tail-call/308d5e34ecca2d7455a9dd77deb5594c to your computer and use it in GitHub Desktop.
LSW.py and Friends
import json
import warnings
from typing import Any
from sammo.runners.base import BaseRunner
from sammo.base import LLMResult, Costs
from sammo.schemas import JsonSchema
from sammo.utils import serialize_json
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, BaseMessage
from langchain_core.prompt_values import ChatPromptValue
class LangchainSammoWrapper(BaseRunner):
"""
A SAMMO runner that wraps a LangChain BaseLanguageModel instance.
This class is designed to be a drop-in replacement for `sammo.runners.OpenAIChat`,
allowing you to use any LangChain-compatible model within the SAMMO framework.
Usage:
>>> from langchain_openai import ChatOpenAI
>>> from sammo import Output, Template
>>>
>>> # 1. Initialize your LangChain model
>>> lc_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
>>>
>>> # 2. Wrap it with LangchainSammoWrapper
>>> sammo_runner = LangchainSammoWrapper(langchain_llm=lc_model)
>>>
>>> # 3. Use it in your SAMMO workflow
>>> output = Output(Template("Translate to French: {{english_text}}",
>>> runner=sammo_runner))
>>> result = await output.arun(english_text="Hello, world!")
"""
RETRY_ERRORS = () # Start with no specific retry errors; can be customized
def __init__(self, langchain_llm: BaseLanguageModel, model_id: str | None = None, **kwargs: Any):
"""
Initializes the wrapper.
:param langchain_llm: An instantiated LangChain BaseLanguageModel.
:param model_id: An optional model identifier for SAMMO's caching. If None,
it will be inferred from the LangChain model.
:param kwargs: Additional arguments to pass to the `sammo.runners.base.BaseRunner`,
such as `cache`, `rate_limit`, `max_retries`, etc.
"""
if not isinstance(langchain_llm, BaseLanguageModel):
raise TypeError("langchain_llm must be an instance of LangChain's BaseLanguageModel")
self.langchain_llm = langchain_llm
# BaseRunner requires api_config, but we don't need it. Provide a dummy dict.
kwargs.setdefault('api_config', {})
# Infer model_id if not provided, for SAMMO's equivalence class logic.
if model_id is None:
if hasattr(langchain_llm, 'model_name'):
model_id = str(getattr(langchain_llm, 'model_name'))
elif hasattr(langchain_llm, 'model'):
model_id = str(getattr(langchain_llm, 'model'))
else:
model_id = langchain_llm.__class__.__name__
super().__init__(model_id=model_id, **kwargs)
async def generate_text(
self,
prompt: str,
max_tokens: int | None = None,
randomness: float | None = 0,
seed: int = 0,
priority: int = 0,
system_prompt: str | None = None,
history: list[dict] | None = None,
json_mode: bool | JsonSchema = False,
) -> LLMResult:
"""
Generates text using the wrapped LangChain model, compatible with SAMMO's interface.
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history:
messages.extend(history)
messages.append({"role": "user", "content": prompt})
request = {
"messages": messages,
"randomness": randomness,
"max_tokens": max_tokens,
"json_mode": json_mode,
}
# The fingerprint is used by SAMMO for caching. It must contain all parameters
# that influence the generation result.
fingerprint = serialize_json({"seed": seed, "generative_model_id": self._equivalence_class, **request})
return await self._execute_request(request, fingerprint, priority)
async def _call_backend(self, request: dict) -> dict:
"""
Calls the LangChain model. This method is invoked by `_execute_request`.
It translates the SAMMO-style request into LangChain inputs and formats the
LangChain output into a standardized dictionary for `_to_llm_result`.
"""
lc_messages = self._to_langchain_messages(request["messages"])
model_kwargs = {"temperature": request["randomness"]}
if request["max_tokens"] is not None:
model_kwargs["max_tokens"] = request["max_tokens"]
if request.get("json_mode"):
warnings.warn(
"json_mode is not fully supported in LangchainSammoWrapper. "
"Attempting to pass 'response_format' kwarg to the model, which may not be supported."
)
model_kwargs["response_format"] = {"type": "json_object"}
prompt_value = ChatPromptValue(messages=lc_messages)
langchain_result = await self.langchain_llm.agenerate_prompt(
prompts=[prompt_value],
**model_kwargs
)
generation = langchain_result.generations[0][0]
llm_output = langchain_result.llm_output or {}
token_usage = llm_output.get("token_usage", {})
# Robustly calculate token counts if not provided by the model's output
prompt_tokens = token_usage.get("prompt_tokens")
if prompt_tokens is None:
prompt_tokens = self.langchain_llm.get_num_tokens_from_messages(lc_messages)
completion_tokens = token_usage.get("completion_tokens")
if completion_tokens is None:
completion_tokens = self.langchain_llm.get_num_tokens(generation.text)
# Mimic the structure that _to_llm_result and _extract_costs expect
return {
"choices": [{"message": {"role": "assistant", "content": generation.text}}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
},
}
def _to_llm_result(self, request: dict, json_data: dict, fingerprint: str | bytes) -> LLMResult:
"""
Converts the dictionary from `_call_backend` into a SAMMO `LLMResult` object.
"""
llm_response_text = json_data["choices"][0]["message"]["content"]
# Attempt to parse JSON if json_mode was requested
if request.get("json_mode"):
try:
llm_response_obj = json.loads(llm_response_text)
except json.JSONDecodeError:
# If model fails to produce valid JSON, return the raw string
llm_response_obj = llm_response_text
else:
llm_response_obj = llm_response_text
return LLMResult(
llm_response_obj,
history=request["messages"] + [json_data["choices"][0]["message"]],
costs=self._extract_costs(json_data),
request_text=request["messages"][-1]["content"],
)
@staticmethod
def _extract_costs(json_data: dict) -> Costs:
"""Extracts token counts into a SAMMO Costs object."""
return Costs(
input_costs=json_data["usage"].get("prompt_tokens", 0),
output_costs=json_data["usage"].get("completion_tokens", 0),
reasoning_costs=0, # Not applicable to most LangChain models this way
)
@staticmethod
def _to_langchain_messages(messages: list[dict]) -> list[BaseMessage]:
"""Converts SAMMO message history to LangChain's message format."""
lc_messages = []
for msg in messages:
role = msg.get("role")
content = str(msg.get("content", ''))
if role == "user":
lc_messages.append(HumanMessage(content=content))
elif role == "system":
lc_messages.append(SystemMessage(content=content))
elif role == "assistant":
lc_messages.append(AIMessage(content=content))
return lc_messages
@classmethod
def _get_equivalence_class(cls, model_id: str) -> str:
"""
Provides a basic equivalence class for caching, similar to OpenAIBaseRunner.
Customize this for more specific model families if needed.
"""
model_id = model_id.lower()
if "gpt-4" in model_id:
return "gpt-4"
if "gpt-3.5" in model_id or "gpt-3" in model_id:
return "gpt-3.5"
if "claude-3" in model_id:
return "claude-3"
if "gemini" in model_id:
return "gemini"
return model_id
@tail-call
Copy link
Author

Warning: LLM slop, never tested it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment