Last active
June 19, 2024 07:07
-
-
Save amachang/a973d7d9af616c629c928b0cedc1497f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MomentumOptimizer(keras.optimizers.Optimizer): | |
# super class support [tf.Variable, LearningRateSchedule], but we only support tf.Variable for simplicity | |
_learning_rate: tf.Variable | |
momentum: float | |
_built: bool | |
velocities: List[tf.Variable] | |
def __init__(self, learning_rate: float, momentum: float, **kwargs): | |
assert 0.0 <= momentum or momentum <= 1.0 | |
super().__init__(name="MomentumOptimizer", **kwargs) | |
# _build_learning_rate mainly used for difference between float and LearningRateSchedule, but currently we only support float. However for future compatibility, we use _build_learning_rate | |
self._learning_rate = cast(tf.Variable, self._build_learning_rate(learning_rate)) | |
self.momentum = momentum | |
self._built = False | |
self.velocities = [] | |
def decay_learning_rate(self, decay_rate: float) -> float: | |
self._learning_rate.assign(tf.multiply(self._learning_rate, decay_rate)) | |
for v in self.velocities: | |
v.assign(tf.multiply(v, decay_rate)) | |
applied_lr = self._learning_rate.read_value() | |
return applied_lr | |
def current_velocities(self) -> List[tf.Tensor]: | |
return [v.read_value() for v in self.velocities] | |
def update_velocity(self, velocities: List[tf.Tensor]): | |
assert len(velocities) == len(self.velocities) | |
for v, new_v in zip(self.velocities, velocities): | |
v.assign(new_v) | |
# override method | |
# var_list is model's variables | |
# if one of var_list is named as 'a', we will create a variable named 'velocity/a' for momentum | |
def build(self, var_list: List[tf.Variable]): | |
super().build(var_list) | |
if self._built: | |
return | |
for var in var_list: | |
self.velocities.append(self.add_variable_from_reference(model_variable=var, variable_name="velocity")) | |
self._built = True | |
def built(self) -> bool: | |
return self._built | |
# override method | |
def update_step(self, gradient: Union[tf.Tensor, tf.IndexedSlices], variable: tf.Variable): | |
assert self._built | |
lr = tf.cast(self.learning_rate, variable.dtype) | |
assert isinstance(lr, tf.Variable) | |
momentum = tf.cast(self.momentum, variable.dtype) | |
assert isinstance(momentum, tf.Tensor) | |
# in my understanding the m must not be None, because build method will create all known variables | |
v = self.velocities[self._index_dict[self._var_key(variable)]] | |
assert isinstance(v, tf.Variable) | |
add_value = calc(gradient, lambda g: tf.negative(g) * lr) | |
v.assign(tf.multiply(v, momentum)) | |
assign_add(v, add_value) | |
variable.assign_add(v) | |
def get_config(self): | |
config = super().get_config() | |
config.update( | |
{ | |
"learning_rate": self._serialize_hyperparameter(self._learning_rate), | |
"momentum": self.momentum, | |
"nesterov": self.nesterov, | |
} | |
) | |
return config | |
class RestoreBestWeightsAndVelocitiesCallback(keras.callbacks.Callback): | |
# instance individual variables | |
patience: int | |
baseline: float | |
# for assertion | |
started: bool | |
# model set before on_train_begin | |
model: Optional[keras_module.Model] | |
# training state | |
wait: int | |
stopped_epoch: int | |
best: float | |
best_epoch: int | |
best_weights: Optional[List[tf.Tensor]] | |
best_velocities: Optional[List[tf.Tensor]] | |
prev_epoch: int | |
prev_weights: Optional[List[tf.Tensor]] | |
prev_velocities: Optional[List[tf.Tensor]] | |
prev_best_weights: Optional[List[tf.Tensor]] | |
prev_best_velocities: Optional[List[tf.Tensor]] | |
def __init__(self, patience: int, baseline: float): | |
super().__init__() | |
self.patience = patience | |
self.baseline = baseline | |
self.started = False | |
self.model = None | |
self.wait = 0 | |
self.best = float("inf") | |
self.best_epoch = -1 | |
self.best_weights = None | |
self.best_velocities = None | |
self.prev_epoch = -1 | |
self.prev_weights = None | |
self.prev_velocities = None | |
self.prev_best_weights = None | |
self.prev_best_velocities = None | |
def on_train_begin(self, logs: Optional[Dict[str, float]] = None): | |
assert not self.started | |
assert logs is None or isinstance(logs, dict) # suppress unused warning | |
assert isinstance(self.model, keras_module.Model) | |
assert isinstance(self.model.optimizer, MomentumOptimizer) | |
self.started = True | |
self.wait = 0 | |
self.best = self.baseline | |
self.best_epoch = -1 | |
if self.model.optimizer.built(): | |
self.best_weights = self.model.get_weights() | |
self.best_velocities = self.model.optimizer.current_velocities() | |
else: | |
self.best_weights = None | |
self.best_velocities = None | |
self.prev_epoch = -1 | |
self.prev_weights = None | |
self.prev_velocities = None | |
self.prev_best_weights = None | |
self.prev_best_velocities = None | |
def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None): | |
assert self.started | |
assert logs is not None | |
assert self.prev_epoch == epoch - 1 | |
assert isinstance(self.model, keras_module.Model) | |
assert isinstance(self.model.optimizer, MomentumOptimizer) | |
current = logs.get("val_loss") | |
assert isinstance(current, float) | |
if current < self.best or self.best_weights is None: | |
self.best = current | |
self.best_epoch = epoch | |
self.best_weights = self.model.get_weights() | |
self.best_velocities = self.model.optimizer.current_velocities() | |
self.prev_best_weights = self.prev_weights | |
self.prev_best_velocities = self.prev_velocities | |
self.wait = 0 | |
else: | |
self.wait += 1 | |
if self.patience <= self.wait: | |
self.model.stop_training = True | |
self.prev_epoch = epoch | |
self.prev_weights = self.model.get_weights() | |
self.prev_velocities = self.model.optimizer.current_velocities() | |
def on_train_end(self, logs: Optional[Dict[str, float]] = None): | |
assert self.started | |
assert logs is None or isinstance(logs, dict) # suppress unused warning | |
assert isinstance(self.model, keras_module.Model) | |
assert isinstance(self.model.optimizer, MomentumOptimizer) | |
assert self.best_weights is not None | |
assert self.best_velocities is not None | |
self.started = False | |
if self.prev_best_weights is not None: | |
self.model.set_weights(self.prev_best_weights) | |
else: | |
self.model.set_weights(self.best_weights) | |
if self.prev_best_velocities is not None: | |
self.model.optimizer.update_velocity(self.prev_best_velocities) | |
else: | |
self.model.optimizer.update_velocity(self.best_velocities) | |
@tf.function | |
def assign_add(var: tf.Variable, value: Union[tf.Tensor, tf.IndexedSlices]): | |
if isinstance(value, tf.IndexedSlices): | |
var.scatter_add(value) | |
else: | |
var.assign_add(value) | |
@tf.function | |
def calc(value: Union[tf.Tensor, tf.IndexedSlices], fn: Callable[[tf.Tensor], tf.Tensor]) -> Union[tf.Tensor, tf.IndexedSlices]: | |
if isinstance(value, tf.IndexedSlices): | |
return tf.IndexedSlices(fn(value.values), value.indices) | |
else: | |
return fn(value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment