Created
May 1, 2025 13:55
-
-
Save gante/a080c70c650bec5bfd451348ce6bbc50 to your computer and use it in GitHub Desktop.
generate - Check that there is no randomness associated with launching new processes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Check that there is no randomness associated with launching new processes | |
# Run with `while true; do python this_script.py; done` | |
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B", padding_side="left") | |
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="auto", torch_dtype=torch.bfloat16) | |
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device) | |
set_seed(0) | |
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, return_dict_in_generate=True, output_scores=True) | |
decoded = tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) | |
print(decoded) | |
#sum of the scores for the generated tokens | |
input_length = inputs.input_ids.shape[1] | |
token_scores = [score[0][gen_out.sequences[0][input_length+idx]] for idx, score in enumerate(gen_out.scores)] | |
print(sum(token_scores)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment