class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
r"""
[`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
by default.
In the original [paper](https://huggingface.co/papers/1909.05858), the authors suggest the use of a penalty of around
1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.
Args:
penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
tokens. Between 0.0 and 1.0 rewards previously generated tokens.
prompt_ignore_length (`int`, *optional*):
The original input ids sequence length, which if provided, will not be used in the penalty calculation.
Examples:
```py
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor
>>> # Initializing the model and tokenizer for it
>>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
>>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")
>>> # This shows a normal generate without any specific parameters
>>> summary_ids = model.generate(**inputs)
>>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'm going to be able to do that
>>> # This generates a penalty for repeated tokens
>>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'll just have to go out and play
>>> # We can also exclude the input prompt by creating an instance of this class
>>> # with a `prompt_ignore_length` and passing it as a custom logit processor
>>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
... penalty=1.1,
... prompt_ignore_length=inputs["input_ids"].shape[-1]
... )
>>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
>>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
I'm not going to be able to do that. I'm going to have to go through a lot of things, and
```
"""
supports_continuous_batching = False
def __init__(self, penalty: float, prompt_ignore_length: int | None = None):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
if prompt_ignore_length is not None and (
not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
):
raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")
self.penalty = penalty
self.prompt_ignore_length = prompt_ignore_length
self.logits_indices = None
self.cu_seq_lens_q = None
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if self.prompt_ignore_length:
input_ids = input_ids[:, self.prompt_ignore_length :]
if scores.dim() == 3:
if self.logits_indices is not None and self.cu_seq_lens_q is not None:
last_positions = self.logits_indices
last_scores = scores[0, last_positions, :]
# Prepare token mask
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
cu_seq_lens = self.cu_seq_lens_q
lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
token_mask[seq_indices, input_ids] = True
# Apply penalty
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
else:
batch_size, seq_len, vocab_size = scores.shape
last_scores = scores[:, -1, :]
token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
if input_ids.dim() == 1:
unique_tokens = torch.unique(input_ids)
token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
else:
token_mask.scatter_(1, input_ids, True)
# if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
return scores
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(1)
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
scores_processed = scores.scatter(1, input_ids, score)
return scores_processed