Last active
July 11, 2019 17:12
-
-
Save cshjin/2589aedba004935940b8f432ee811df8 to your computer and use it in GitHub Desktop.
KL divergence in Numpy
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
#!/bin/env python | |
import numpy as np | |
import unittest | |
def softmax(x): | |
""" Given a vector, apply the softmax activation function | |
Parameters | |
---------- | |
x : 1D numpy.array like | |
Returns | |
------- | |
_x : 1D numpy.array | |
Notes | |
----- | |
Also see the reference: https://en.wikipedia.org/wiki/Softmax_function | |
""" | |
_x = np.asarray(x, np.float) | |
_x = np.exp(_x)/np.sum(np.exp(_x)) | |
return _x | |
def kl(a, b): | |
""" Given two vectors, calculate the KL divergence between a and b | |
Parameters | |
---------- | |
a : 1D numpy.array like | |
b : 1D numpy.array like | |
Returns | |
------- | |
_kl : scalar | |
KL divergence | |
Notes | |
----- | |
Also see the reference: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence | |
""" | |
_a = np.asarray(softmax(a), np.float) | |
_b = np.asarray(softmax(b), np.float) | |
_kl = -np.sum(np.where(_a != 0, _a * np.log(_a / _b), 0)) | |
return _kl | |
class KLTest(unittest.TestCase): | |
def test_kl(self): | |
# test empty list | |
_a = [] | |
_b = [] | |
self.assertEqual(kl(_a, _b), 0) | |
_a = [1] | |
self.assertEqual(kl(_a, _b), 0) | |
_b = [1] | |
self.assertEqual(kl(_a, _b), 0) | |
_a = [1, 2] | |
self.assertAlmostEqual(kl(_a, _b), 0.6, places=1) | |
if __name__ == "__main__": | |
# generate two arrays | |
np.random.seed(0) | |
# a = np.random.randn(10) | |
# b = np.random.randn(10) | |
a = [1, 2] | |
b = [1] | |
# calculate the KL divergence | |
print(kl(a, b)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment