Skip to content

Instantly share code, notes, and snippets.

@cshjin
Last active July 11, 2019 17:12
Show Gist options
  • Save cshjin/2589aedba004935940b8f432ee811df8 to your computer and use it in GitHub Desktop.
Save cshjin/2589aedba004935940b8f432ee811df8 to your computer and use it in GitHub Desktop.
KL divergence in Numpy
# -*- coding: utf-8 -*-
#!/bin/env python
import numpy as np
import unittest
def softmax(x):
""" Given a vector, apply the softmax activation function
Parameters
----------
x : 1D numpy.array like
Returns
-------
_x : 1D numpy.array
Notes
-----
Also see the reference: https://en.wikipedia.org/wiki/Softmax_function
"""
_x = np.asarray(x, np.float)
_x = np.exp(_x)/np.sum(np.exp(_x))
return _x
def kl(a, b):
""" Given two vectors, calculate the KL divergence between a and b
Parameters
----------
a : 1D numpy.array like
b : 1D numpy.array like
Returns
-------
_kl : scalar
KL divergence
Notes
-----
Also see the reference: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
"""
_a = np.asarray(softmax(a), np.float)
_b = np.asarray(softmax(b), np.float)
_kl = -np.sum(np.where(_a != 0, _a * np.log(_a / _b), 0))
return _kl
class KLTest(unittest.TestCase):
def test_kl(self):
# test empty list
_a = []
_b = []
self.assertEqual(kl(_a, _b), 0)
_a = [1]
self.assertEqual(kl(_a, _b), 0)
_b = [1]
self.assertEqual(kl(_a, _b), 0)
_a = [1, 2]
self.assertAlmostEqual(kl(_a, _b), 0.6, places=1)
if __name__ == "__main__":
# generate two arrays
np.random.seed(0)
# a = np.random.randn(10)
# b = np.random.randn(10)
a = [1, 2]
b = [1]
# calculate the KL divergence
print(kl(a, b))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment