Source code for core.nas.momentum_eval
"""Calculates and incentivizes the stability of convergence.
"""
from tensorflow.keras.callbacks import Callback
[docs]class MomentumAugmentation(Callback):
'''Calculates the momentum's moving average of the parent model
Attributes:
monitor (str): the optimizer metric type to monitor and calculate momentums on
'''
def __init__(self, monitor='val_sparse_categorical_accuracy'):
'''Initialize MA
Args:
monitor (str, optional): the optimizer metric type to monitor and calculate momentums on
'''
super(MomentumAugmentation, self).__init__()
self.monitor = monitor
[docs] def get_momentum(self, epoch, acc):
'''Calculates the momentums based on the given accuracies and epochs
.. math:: μm(ε) = \\frac{αm(ε) − αm(ε − 1)}{αm(ε − 1) − αm(ε − 2)} \\quad \\forall \; ε \\ge 2
Args:
epoch (int): current epoch
acc (float): current epoch's accuracy
Returns:
(float, float): a tuple consisting of the (current accuracy, current momentum)
'''
if epoch < 2:
# momentum = acc at ε < 3
return (acc, acc)
delta_1 = acc - self.model.momentum[epoch - 1][0]
delta_2 = self.model.momentum[epoch - 1][0] - self.model.momentum[epoch - 2][0]
if delta_2 == 0.0:
# avoid division by 0
# if previous 2 accuracies are somehow exactly the same (very unlikely) => 0 momentum
return (acc, 0.0)
current_momentum = delta_1 / delta_2
return (acc, current_momentum)
[docs] def on_epoch_end(self, epoch, logs=None):
'''Called by Keras backend after each epoch during :code:`.fit()` & :code:`.evaluate()`
Args:
epoch (int): current epoch
logs (dict, optional): contains all the monitors (or metrics) used by the optimizer in the training and evaluation contexts
'''
logs = logs or {}
if self.model is None:
return
if not hasattr(self.model, 'momentum'):
self.model.momentum = {}
if self.monitor in logs:
val_acc = logs[self.monitor]
self.model.momentum[epoch] = self.get_momentum(epoch, val_acc)