Skip to content

Instantly share code, notes, and snippets.

@wassname
Created April 23, 2026 00:47
Show Gist options
  • Select an option

  • Save wassname/b4b7c33d2f2edfbea60c79df5941158b to your computer and use it in GitHub Desktop.

Select an option

Save wassname/b4b7c33d2f2edfbea60c79df5941158b to your computer and use it in GitHub Desktop.
hf transformers repetition penalty
    repetition_penalty (`float`, *optional*):
        The parameter for repetition penalty. 1.0 means no penalty. See [this
        paper](https://huggingface.co/papers/1909.05858) for more details.

https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L416

class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
    r"""
    [`LogitsProcessor`] that prevents the repetition of previous tokens through a penalty. This penalty is applied at
    most once per token. Note that, for decoder-only models like most LLMs, the considered tokens include the prompt
    by default.

    In the original [paper](https://huggingface.co/papers/1909.05858), the authors suggest the use of a penalty of around
    1.2 to achieve a good balance between truthful generation and lack of repetition. To penalize and reduce
    repetition, use `penalty` values above 1.0, where a higher value penalizes more strongly. To reward and encourage
    repetition, use `penalty` values between 0.0 and 1.0, where a lower value rewards more strongly.

    Args:
        penalty (`float`):
            The parameter for repetition penalty. 1.0 means no penalty. Above 1.0 penalizes previously generated
            tokens. Between 0.0 and 1.0 rewards previously generated tokens.
        prompt_ignore_length (`int`, *optional*):
            The original input ids sequence length, which if provided, will not be used in the penalty calculation.

    Examples:

    ```py
    >>> from transformers import AutoTokenizer, AutoModelForCausalLM, RepetitionPenaltyLogitsProcessor

    >>> # Initializing the model and tokenizer for it
    >>> model = AutoModelForCausalLM.from_pretrained("distilbert/distilgpt2")
    >>> tokenizer = AutoTokenizer.from_pretrained("distilbert/distilgpt2")
    >>> inputs = tokenizer(["I'm not going to"], return_tensors="pt")

    >>> # This shows a normal generate without any specific parameters
    >>> summary_ids = model.generate(**inputs)
    >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0])
    I'm not going to be able to do that. I'm going to be able to do that

    >>> # This generates a penalty for repeated tokens
    >>> penalized_ids = model.generate(**inputs, repetition_penalty=1.1)
    >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
    I'm not going to be able to do that. I'll just have to go out and play

    >>> # We can also exclude the input prompt by creating an instance of this class
    >>> # with a `prompt_ignore_length` and passing it as a custom logit processor
    >>> rep_pen_processor = RepetitionPenaltyLogitsProcessor(
    ...     penalty=1.1,
    ...     prompt_ignore_length=inputs["input_ids"].shape[-1]
    ... )
    >>> penalized_ids = model.generate(**inputs, logits_processor=[rep_pen_processor])
    >>> print(tokenizer.batch_decode(penalized_ids, skip_special_tokens=True)[0])
    I'm not going to be able to do that. I'm going to have to go through a lot of things, and
    ```
    """

    supports_continuous_batching = False

    def __init__(self, penalty: float, prompt_ignore_length: int | None = None):
        if not isinstance(penalty, float) or not (penalty > 0):
            raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

        if prompt_ignore_length is not None and (
            not isinstance(prompt_ignore_length, int) or prompt_ignore_length < 0
        ):
            raise ValueError(f"`prompt_ignore_length` has to be a positive integer, but is {prompt_ignore_length}")

        self.penalty = penalty
        self.prompt_ignore_length = prompt_ignore_length
        self.logits_indices = None
        self.cu_seq_lens_q = None

    @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if self.prompt_ignore_length:
            input_ids = input_ids[:, self.prompt_ignore_length :]

        if scores.dim() == 3:
            if self.logits_indices is not None and self.cu_seq_lens_q is not None:
                last_positions = self.logits_indices
                last_scores = scores[0, last_positions, :]

                # Prepare token mask
                token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
                cu_seq_lens = self.cu_seq_lens_q
                lengths = cu_seq_lens[1:] - cu_seq_lens[:-1]
                seq_indices = torch.repeat_interleave(torch.arange(len(lengths), device=input_ids.device), lengths)
                token_mask[seq_indices, input_ids] = True

                # Apply penalty
                penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
                scores[0, last_positions, :] = torch.where(token_mask, penalty_scores, last_scores)
            else:
                batch_size, seq_len, vocab_size = scores.shape
                last_scores = scores[:, -1, :]
                token_mask = torch.zeros_like(last_scores, dtype=torch.bool)
                if input_ids.dim() == 1:
                    unique_tokens = torch.unique(input_ids)
                    token_mask.scatter_(1, unique_tokens.unsqueeze(0), True)
                else:
                    token_mask.scatter_(1, input_ids, True)
                # if last_scores < 0 then repetition penalty has to be multiplied to reduce the token probabilities
                penalty_scores = torch.where(last_scores < 0, last_scores * self.penalty, last_scores / self.penalty)
                scores[:, -1, :] = torch.where(token_mask, penalty_scores, last_scores)
            return scores

        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(1)

        score = torch.gather(scores, 1, input_ids)
        # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
        score = torch.where(score < 0, score * self.penalty, score / self.penalty)
        scores_processed = scores.scatter(1, input_ids, score)
        return scores_processed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment