Skip to content

Instantly share code, notes, and snippets.

@pranjal-joshi
Created April 24, 2017 10:42
Show Gist options
  • Save pranjal-joshi/4795c2736e893a17b8de3ad49fc2a998 to your computer and use it in GitHub Desktop.
Save pranjal-joshi/4795c2736e893a17b8de3ad49fc2a998 to your computer and use it in GitHub Desktop.
def eval_loss_and_grads(x):
x = x.reshape((1,3,img_nrows,img_ncols))
outs = f_outputs([x])
loss_value = outs[0]
if(len(outs[1:])==1):
grad_values = outs[1].flatten().astype('float64')
else:
grad_values = np.array(outs[1:]).flatten().astype('float64')
return loss_value,grad_values
class Evaluator(object):
def __init__(self):
self.loss_value = None
self.grad_values = None
def loss(self,x):
assert self.loss_value is None
loss_value, grad_values = eval_loss_and_grads(x)
self.loss_value = loss_value
self.grad_values = grad_values
return self.loss_value
def grads(self,x):
assert self.loss_value is not None
grad_values = np.copy(self.grad_values)
self.loss_value = None
self.grad_values = None
return grad_values
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment