Created
September 7, 2018 14:29
-
-
Save pcyin/b027ffec9b1bc1b87ba02286b55c2484 to your computer and use it in GitHub Desktop.
Pytorch masked `log_sum_exp`
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def log_sum_exp(inputs, keepdim=False, mask=None): | |
"""Numerically stable logsumexp on the last dim of `inputs`. | |
reference: https://github.com/pytorch/pytorch/issues/2591 | |
Args: | |
inputs: A Variable with any shape. | |
keepdim: A boolean. | |
mask: A mask variable of type float. It has the same shape as `inputs`. | |
**ATTENTION** invalid entries are masked to **ONE**, not ZERO | |
Returns: | |
Equivalent of log(sum(exp(inputs), keepdim=keepdim)). | |
""" | |
if mask is not None: | |
mask = 1. - mask | |
max_offset = -1e7 * mask | |
else: | |
max_offset = 0. | |
s, _ = torch.max(inputs + max_offset, dim=-1, keepdim=True) | |
inputs_offset = inputs - s | |
if mask is not None: | |
inputs_offset.masked_fill_(mask.byte(), -float('inf')) | |
outputs = s + inputs_offset.exp().sum(dim=-1, keepdim=True).log() | |
if not keepdim: | |
outputs = outputs.squeeze(-1) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment