import os
import caffe
import pdb
import tempfile
import unittest
import numpy.testing as npt

import numpy as np

# fix random seed
np.random.seed(0)

NET_DEF = """
name: 'sin_test_net' force_backward: true
layer {
  type: 'Python'
  name: 'data'
  top: 'data'
  python_param { 
      module: 'caffe.py_blank_data_layer' 
      layer: 'PyBlankDataLayer' 
      param_str: 'shape: [1,3,8,8]'  
      }
}
layer { 
    type: 'Python' 
    name: 'sin' 
    bottom: 'data' 
    top: 'sin'
    python_param { 
        module: 'caffe.py_sin_layer' 
        layer: 'PySinLayer' 
    } 
}
"""

def python_net_file():
    with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
        f.write(NET_DEF)
        return f.name

@unittest.skipIf('Python' not in caffe.layer_type_list(),
    'Caffe built without Python layer support')
class TestPythonSinLayer(unittest.TestCase):
    def setUp(self):
        net_file = python_net_file()
        caffe.Net(net_file, caffe.TRAIN)
        self.net = caffe.Net(net_file, caffe.TRAIN)
        os.remove(net_file)

    def test_forward(self):
        data = np.random.rand(1,3,8,8)
        self.net.blobs['data'].data[...] = data
        self.net.forward()
        res = self.net.blobs['sin'].data
        expected = np.sin(data)
        self.assertTrue(np.isclose(res, expected, atol=1e-6).all())

    def test_backward(self):
        # numerical diff check
        delta = 1e-3 
        data = np.random.rand(1,3,8,8)
        diff = np.random.rand(1,3,8,8)
        self.net.blobs['data'].data[...] = data
        self.net.forward()
        y = self.net.blobs['sin'].data.copy()
        self.net.blobs['sin'].diff[...] = diff
        self.net.backward()
        calc_diff = self.net.blobs['data'].diff

        num_diff = np.zeros_like(data)

        for i,_ in np.ndenumerate(data):
            data_ = data.copy()
            data_[i] = data_[i] + delta
            self.net.blobs['data'].data[...] = data_
            self.net.forward()
            y_ = self.net.blobs['sin'].data
            num_diff[i] = (diff * (y_ - y) / delta).sum()

        self.assertTrue(np.isclose(res, expected, atol=1e-6).all())