Last active
March 29, 2024 23:54
-
-
Save catid/47eda7bd667b4a744697d93e1509089f to your computer and use it in GitHub Desktop.
DBRX on 3x 3090 GPUs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# conda create -n dbrx python=3.10 -y && conda activate dbrx | |
# pip install torch transformers tiktoken flash_attn bitsandbytes | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
tokenizer = AutoTokenizer.from_pretrained("SinclairSchneider/dbrx-instruct-quantization-fixed", trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained("SinclairSchneider/dbrx-instruct-quantization-fixed", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, load_in_4bit=True) | |
input_text = "What does it take to build a great LLM?" | |
messages = [{"role": "user", "content": input_text}] | |
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda") | |
outputs = model.generate(**input_ids, max_new_tokens=200) | |
print(tokenizer.decode(outputs[0])) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment