Skip to content

Instantly share code, notes, and snippets.

@ehzawad
Created March 12, 2025 09:35
Show Gist options
  • Save ehzawad/407d8e41bfbc5db3266b96aae87f5eb3 to your computer and use it in GitHub Desktop.
Save ehzawad/407d8e41bfbc5db3266b96aae87f5eb3 to your computer and use it in GitHub Desktop.
# SPDX-License-Identifier: Apache-2.0
import json
import random
import string
import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams
# Much smaller models (under 3B parameters)
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Only 1.1B parameters
# Alternative ultra-small options:
# model_name = "microsoft/Phi-1.5" # Only 1.3B parameters
# model_name = "stabilityai/stablelm-2-1_6b" # 1.6B parameters
# Very conservative token allocation
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
# Initialize LLM with maximum memory savings
llm = LLM(model=model_name,
tokenizer_mode="auto",
gpu_memory_utilization=0.7, # Very conservative memory usage
#quantization="bitsandbytes", # Use bitsandbytes instead of int8
max_model_len=1024, # Very limited context length
dtype="half") # Use half precision (fp16) for additional memory savings
def generate_random_id(length=9):
characters = string.ascii_letters + string.digits
random_id = ''.join(random.choice(characters) for _ in range(length))
return random_id
# simulate an API that can be called
def get_current_weather(city: str, state: str, unit: 'str'):
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is "
"partly cloudly, with highs in the 90's.")
tool_funtions = {"get_current_weather": get_current_weather}
tools = [{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'San Francisco'"
},
"state": {
"type": "string",
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'"
},
"unit": {
"type": "string",
"description": "The unit to fetch the temperature in",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city", "state", "unit"]
}
}
}]
# Add a simple system prompt to help the smaller model understand the task
messages = [
{
"role": "system",
"content": "You are a helpful assistant that uses function calling to retrieve information."
},
{
"role": "user",
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?"
}
]
# Aggressively clear cache
torch.cuda.empty_cache()
print("Loading model and running inference...")
try:
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools)
output = outputs[0].outputs[0].text.strip()
print("\nModel output (first pass):")
print(output)
# append the assistant message
messages.append({
"role": "assistant",
"content": output,
})
# Clear GPU cache
torch.cuda.empty_cache()
# Try to parse and execute the model's output
try:
tool_calls = json.loads(output)
print("\nParsed tool calls successfully")
tool_answers = [
tool_funtions[call['name']](**call['arguments']) for call in tool_calls
]
# append the answer as a tool message
messages.append({
"role": "tool",
"content": "\n\n".join(tool_answers),
"tool_call_id": generate_random_id(),
})
# Clear GPU cache again
torch.cuda.empty_cache()
print("\nRunning second inference...")
outputs = llm.chat(messages, sampling_params, tools=tools)
print("\nFinal result:")
print(outputs[0].outputs[0].text.strip())
except (json.JSONDecodeError, KeyError, TypeError) as e:
print(f"\nError processing model output: {str(e)}")
print("Raw output was:")
print(output)
print("\nNote: Very small models may not properly support function calling format.")
# Try a simple prompt to get a response without function calling
simple_messages = [
{
"role": "user",
"content": "What's the weather like in Dallas? Just say it's 85 degrees fahrenheit, partly cloudy with highs in the 90s."
}
]
torch.cuda.empty_cache()
print("\nTrying simple prompt instead...")
simple_outputs = llm.chat(simple_messages, sampling_params)
print("\nSimple response:")
print(simple_outputs[0].outputs[0].text.strip())
except Exception as e:
print(f"Error running model: {str(e)}")
print("Try reducing parameters further or using an even smaller model.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment