Created
August 18, 2024 22:38
-
-
Save ZTGeng/39f8f46468d626db1d55f2f3a68c11b2 to your computer and use it in GitHub Desktop.
Verification of Fix for JSON Serialization Error in Meta-Llama-3.1-8B-Instruct Model when using tool_calls
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Use example of Tool Call on Llama 3.1 model\n", | |
"Code copied and modifies from https://huggingface.co/docs/transformers/main/chat_templating#a-complete-tool-use-example" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**Important**\n", | |
"\n", | |
"A documented issue (see [discussion](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct/discussions/104) at HuggingFace Meta-Llama-3.1-8B-Instruct) may trigger a `TypeError: Object of type Undefined is not JSON serializable` when using the tokenizer with messages containing tool_calls. To address this, you need to modify the tokenizer_config.json file associated with the model.\n", | |
"\n", | |
"The configuration file can typically be found in your system at:\n", | |
"\n", | |
"```\n", | |
"~/.cache/huggingface/hub/[model_name]/snapshots/[snapshot_id]\n", | |
"```\n", | |
"\n", | |
"Steps to Fix the Issue:\n", | |
"\n", | |
"1. Open the tokenizer_config.json file.\n", | |
"2. Locate the chat_template entry within the file.\n", | |
"3. In the template string, find all instances of the word \"parameters\".\n", | |
"4. Replace each occurrence of \"parameters\" with \"arguments\". There should be three such replacements." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"c:\\Users\\zhtge\\miniconda3\\envs\\torch\\Lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Loading checkpoint shards: 100%|██████████| 4/4 [00:10<00:00, 2.66s/it]\n" | |
] | |
} | |
], | |
"source": [ | |
"# From https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct\n", | |
"\n", | |
"model_id = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n", | |
"model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map=\"auto\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_current_temperature(location: str, unit: str) -> float:\n", | |
" \"\"\"\n", | |
" Get the current temperature at a location.\n", | |
" \n", | |
" Args:\n", | |
" location: The location to get the temperature for, in the format \"City, Country\"\n", | |
" unit: The unit to return the temperature in. (choices: [\"celsius\", \"fahrenheit\"])\n", | |
" Returns:\n", | |
" The current temperature at the specified location in the specified units, as a float.\n", | |
" \"\"\"\n", | |
" return 22. # A real function should probably actually get the temperature!\n", | |
"\n", | |
"def get_current_wind_speed(location: str) -> float:\n", | |
" \"\"\"\n", | |
" Get the current wind speed in km/h at a given location.\n", | |
" \n", | |
" Args:\n", | |
" location: The location to get the temperature for, in the format \"City, Country\"\n", | |
" Returns:\n", | |
" The current wind speed at the given location in km/h, as a float.\n", | |
" \"\"\"\n", | |
" return 6. # A real function should probably actually get the wind speed!\n", | |
"\n", | |
"tools = [get_current_temperature, get_current_wind_speed]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"messages = [\n", | |
" {\"role\": \"system\", \"content\": \"You are a bot that responds to weather queries. You should reply with the unit used in the queried location.\"},\n", | |
" {\"role\": \"user\", \"content\": \"Hey, what's the temperature in Paris right now?\"}\n", | |
"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n", | |
"c:\\Users\\zhtge\\miniconda3\\envs\\torch\\Lib\\site-packages\\transformers\\models\\llama\\modeling_llama.py:660: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\cb\\pytorch_1000000000000\\work\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n", | |
" attn_output = torch.nn.functional.scaled_dot_product_attention(\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<|python_tag|>{\"name\": \"get_current_temperature\", \"arguments\": {\"location\": \"Paris, France\", \"unit\": \"celsius\"}}<|eom_id|>\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, return_dict=True, return_tensors=\"pt\")\n", | |
"inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", | |
"out = model.generate(**inputs, max_new_tokens=128)\n", | |
"print(tokenizer.decode(out[0][len(inputs[\"input_ids\"][0]):]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'{\"name\": \"get_current_temperature\", \"arguments\": {\"location\": \"Paris, France\", \"unit\": \"celsius\"}}'" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"out_str = tokenizer.decode(out[0][len(inputs[\"input_ids\"][0]):], skip_special_tokens=True)\n", | |
"out_str" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[{'role': 'system',\n", | |
" 'content': 'You are a bot that responds to weather queries. You should reply with the unit used in the queried location.'},\n", | |
" {'role': 'user',\n", | |
" 'content': \"Hey, what's the temperature in Paris right now?\"},\n", | |
" {'role': 'assistant',\n", | |
" 'tool_calls': [{'type': 'function',\n", | |
" 'function': {'name': 'get_current_temperature',\n", | |
" 'arguments': {'location': 'Paris, France', 'unit': 'celsius'}}}]}]" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import json\n", | |
"# tool_call = {\"name\": \"get_current_temperature\", \"arguments\": {\"location\": \"Paris, France\", \"unit\": \"celsius\"}}\n", | |
"tool_call = json.loads(out_str)\n", | |
"messages.append({\"role\": \"assistant\", \"tool_calls\": [{\"type\": \"function\", \"function\": tool_call}]})\n", | |
"messages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[{'role': 'system',\n", | |
" 'content': 'You are a bot that responds to weather queries. You should reply with the unit used in the queried location.'},\n", | |
" {'role': 'user',\n", | |
" 'content': \"Hey, what's the temperature in Paris right now?\"},\n", | |
" {'role': 'assistant',\n", | |
" 'tool_calls': [{'type': 'function',\n", | |
" 'function': {'name': 'get_current_temperature',\n", | |
" 'arguments': {'location': 'Paris, France', 'unit': 'celsius'}}}]},\n", | |
" {'role': 'tool', 'name': 'get_current_temperature', 'content': '22.0'}]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# messages.append({\"role\": \"tool\", \"name\": \"get_current_temperature\", \"content\": \"22.0\"})\n", | |
"tool_call_result = globals()[tool_call[\"name\"]](**tool_call[\"arguments\"])\n", | |
"messages.append({\"role\": \"tool\", \"name\": tool_call[\"name\"], \"content\": str(tool_call_result)})\n", | |
"messages" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"The temperature in Paris, France is 22.0 degrees Celsius.<|eot_id|>\n" | |
] | |
} | |
], | |
"source": [ | |
"inputs = tokenizer.apply_chat_template(messages, tools=tools, add_generation_prompt=True, return_dict=True, return_tensors=\"pt\")\n", | |
"inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", | |
"out = model.generate(**inputs, max_new_tokens=128)\n", | |
"print(tokenizer.decode(out[0][len(inputs[\"input_ids\"][0]):]))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment