Created
March 12, 2025 09:35
-
-
Save ehzawad/407d8e41bfbc5db3266b96aae87f5eb3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SPDX-License-Identifier: Apache-2.0 | |
import json | |
import random | |
import string | |
import torch | |
from vllm import LLM | |
from vllm.sampling_params import SamplingParams | |
# Much smaller models (under 3B parameters) | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Only 1.1B parameters | |
# Alternative ultra-small options: | |
# model_name = "microsoft/Phi-1.5" # Only 1.3B parameters | |
# model_name = "stabilityai/stablelm-2-1_6b" # 1.6B parameters | |
# Very conservative token allocation | |
sampling_params = SamplingParams(max_tokens=512, temperature=0.0) | |
# Initialize LLM with maximum memory savings | |
llm = LLM(model=model_name, | |
tokenizer_mode="auto", | |
gpu_memory_utilization=0.7, # Very conservative memory usage | |
#quantization="bitsandbytes", # Use bitsandbytes instead of int8 | |
max_model_len=1024, # Very limited context length | |
dtype="half") # Use half precision (fp16) for additional memory savings | |
def generate_random_id(length=9): | |
characters = string.ascii_letters + string.digits | |
random_id = ''.join(random.choice(characters) for _ in range(length)) | |
return random_id | |
# simulate an API that can be called | |
def get_current_weather(city: str, state: str, unit: 'str'): | |
return (f"The weather in {city}, {state} is 85 degrees {unit}. It is " | |
"partly cloudly, with highs in the 90's.") | |
tool_funtions = {"get_current_weather": get_current_weather} | |
tools = [{ | |
"type": "function", | |
"function": { | |
"name": "get_current_weather", | |
"description": "Get the current weather in a given location", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"city": { | |
"type": "string", | |
"description": "The city to find the weather for, e.g. 'San Francisco'" | |
}, | |
"state": { | |
"type": "string", | |
"description": "the two-letter abbreviation for the state that the city is in, e.g. 'CA' which would mean 'California'" | |
}, | |
"unit": { | |
"type": "string", | |
"description": "The unit to fetch the temperature in", | |
"enum": ["celsius", "fahrenheit"] | |
} | |
}, | |
"required": ["city", "state", "unit"] | |
} | |
} | |
}] | |
# Add a simple system prompt to help the smaller model understand the task | |
messages = [ | |
{ | |
"role": "system", | |
"content": "You are a helpful assistant that uses function calling to retrieve information." | |
}, | |
{ | |
"role": "user", | |
"content": "Can you tell me what the temperate will be in Dallas, in fahrenheit?" | |
} | |
] | |
# Aggressively clear cache | |
torch.cuda.empty_cache() | |
print("Loading model and running inference...") | |
try: | |
outputs = llm.chat(messages, sampling_params=sampling_params, tools=tools) | |
output = outputs[0].outputs[0].text.strip() | |
print("\nModel output (first pass):") | |
print(output) | |
# append the assistant message | |
messages.append({ | |
"role": "assistant", | |
"content": output, | |
}) | |
# Clear GPU cache | |
torch.cuda.empty_cache() | |
# Try to parse and execute the model's output | |
try: | |
tool_calls = json.loads(output) | |
print("\nParsed tool calls successfully") | |
tool_answers = [ | |
tool_funtions[call['name']](**call['arguments']) for call in tool_calls | |
] | |
# append the answer as a tool message | |
messages.append({ | |
"role": "tool", | |
"content": "\n\n".join(tool_answers), | |
"tool_call_id": generate_random_id(), | |
}) | |
# Clear GPU cache again | |
torch.cuda.empty_cache() | |
print("\nRunning second inference...") | |
outputs = llm.chat(messages, sampling_params, tools=tools) | |
print("\nFinal result:") | |
print(outputs[0].outputs[0].text.strip()) | |
except (json.JSONDecodeError, KeyError, TypeError) as e: | |
print(f"\nError processing model output: {str(e)}") | |
print("Raw output was:") | |
print(output) | |
print("\nNote: Very small models may not properly support function calling format.") | |
# Try a simple prompt to get a response without function calling | |
simple_messages = [ | |
{ | |
"role": "user", | |
"content": "What's the weather like in Dallas? Just say it's 85 degrees fahrenheit, partly cloudy with highs in the 90s." | |
} | |
] | |
torch.cuda.empty_cache() | |
print("\nTrying simple prompt instead...") | |
simple_outputs = llm.chat(simple_messages, sampling_params) | |
print("\nSimple response:") | |
print(simple_outputs[0].outputs[0].text.strip()) | |
except Exception as e: | |
print(f"Error running model: {str(e)}") | |
print("Try reducing parameters further or using an even smaller model.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment