Skip to content

Instantly share code, notes, and snippets.

@m3hrdadfi
Created May 10, 2022 09:41
Show Gist options
  • Save m3hrdadfi/3f8e60bd9b8fff899ea12dbfd8efc3f5 to your computer and use it in GitHub Desktop.
Save m3hrdadfi/3f8e60bd9b8fff899ea12dbfd8efc3f5 to your computer and use it in GitHub Desktop.
DialogGPT Bot
import torch
i = 0
maxlen = 1024
while True:
user_input = input('>> User: ').strip()
if user_input.lower() == "q":
break
input_ids = tokenizer(user_input + tokenizer.eos_token, return_tensors='pt').input_ids
input_ids = torch.cat([response, input_ids], dim=-1) if i > 0 else input_ids
response = model.generate(input_ids,
max_length=maxlen,
pad_token_id=tokenizer.eos_token_id)
# As in the GPT-2 example above, the generated text includes the "prompt", so
# we remove it. Then we decode as above.
input_len = input_ids.shape[-1]
generated = tokenizer.decode(response[:, input_len:][0], skip_special_tokens=True)
print("DialoGPT: {}".format(generated))
i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment