Skip to content

Instantly share code, notes, and snippets.

@redgeoff
Last active July 16, 2023 15:05

Revisions

  1. redgeoff revised this gist Jul 16, 2023. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion mpt-7b-chat.py
    Original file line number Diff line number Diff line change
    @@ -91,7 +91,7 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa
    end_time = time.time()
    duration = end_time - start_time

    display(Markdown(response))
    print(response)

    print("Function duration:", duration, "seconds")

  2. redgeoff renamed this gist Jul 16, 2023. 1 changed file with 0 additions and 0 deletions.
    File renamed without changes.
  3. redgeoff created this gist Jul 16, 2023.
    101 changes: 101 additions & 0 deletions gistfile1.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,101 @@
    # !pip install -qU transformers accelerate einops langchain xformers

    from torch import cuda, bfloat16
    from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

    device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
    print(f"device={device}")

    # Initialize the tokenizer and the model
    tokenizer = AutoTokenizer.from_pretrained("mosaicml/mpt-7b-chat", trust_remote_code=True)

    # Initialize model with Triton optimization. This is supposed to speed up the
    # model at the cost of using more mem, but I haven't been able to get it to work
    # yet
    # optimize = False
    #
    # if optimize:
    # config = AutoConfig.from_pretrained(
    # 'mosaicml/mpt-7b-chat',
    # trust_remote_code=True
    # )
    # config.attn_config['attn_impl'] = 'triton'
    # # config.update({"init_device": "meta"}) # This causes an issue when calling model.to(device)
    # config.update({"max_seq_len": 100})
    # else:
    # config={"init_device": "meta"}

    config={"init_device": "meta"}

    model = AutoModelForCausalLM.from_pretrained("mosaicml/mpt-7b-chat",
    trust_remote_code=True,
    config=config,
    torch_dtype=bfloat16)

    print('loaded')

    # tokenizer.eval() # fails!
    # tokenizer.to(device)
    # model.eval() # TODO: needed?
    model.to(device)

    import time
    from IPython.display import Markdown

    def ask_question(question, max_length=100):
    start_time = time.time()

    # Encode the question
    input_ids = tokenizer.encode(question, return_tensors='pt')

    input_ids = input_ids.to(device)
    # input_ids = input_ids.to('cuda')

    # mtp-7b is trained to add "<|endoftext|>" at the end of generations
    stop_token_ids = tokenizer.convert_tokens_to_ids(["<|endoftext|>"])

    import torch
    from transformers import StoppingCriteria, StoppingCriteriaList

    # define custom stopping criteria object
    class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
    for stop_id in stop_token_ids:
    if input_ids[0][-1] == stop_id:
    return True
    return False

    stopping_criteria = StoppingCriteriaList([StopOnTokens()])

    # Generate a response
    output = model.generate(
    input_ids,
    max_length=max_length,
    # max_length=1000,
    temperature=0.9,

    # pad_token_id=stop_token_ids[0],
    # num_return_sequences=1,

    stopping_criteria=stopping_criteria,

    # top_p=0.15, # select from top tokens whose probability add up to 15%
    # top_k=0, # select from top 0 tokens (because zero, relies on top_p)
    # max_new_tokens=64, # max number of tokens to generate in the output
    #repetition_penalty=1.1 # without this output begins repeating
    )

    # Decode the response
    response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)

    end_time = time.time()
    duration = end_time - start_time

    display(Markdown(response))

    print("Function duration:", duration, "seconds")

    # Ask a question
    ask_question("What is the capital of France?")
    # ask_question("Explain to me the difference between nuclear fission and fusion.", 200)
    # ask_question("write python code that converts a csv into a pdf", 400)