Source code for tensornet.engine.lr_finder

# Some snippets for the code in this file are referenced from
# https://github.com/davidtvs/pytorch-lr-finder


import os
import copy
import torch
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import _LRScheduler

from tensornet.engine.learner import Learner
from tensornet.utils.progress_bar import ProgressBar
from tensornet.data.processing import InfiniteDataLoader


[docs]class LRFinder: """Learning rate range test. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what is the optimal learning rate. Args: model (torch.nn.Module): Model Instance. optimizer (torch.optim): Optimizer where the defined learning is assumed to be the lower boundary of the range test. criterion (torch.nn): Loss function. metric (:obj:`str`, optional): Metric to use for finding the best learning rate. Can be either 'loss' or 'accuracy'. (default: 'loss') device (:obj:`str` or :obj:`torch.device`, optional): Device where the computation will take place. If None, uses the same device as `model`. (default: none) memory_cache (:obj:`bool`, optional): If this flag is set to True, state_dict of model and optimizer will be cached in memory. Otherwise, they will be saved to files under the `cache_dir`. (default: True) cache_dir (:obj:`str`, optional): Path for storing temporary files. If no path is specified, system-wide temporary directory is used. Notice that this parameter will be ignored if `memory_cache` is True. (default: None) """ def __init__( self, model, optimizer, criterion, metric='loss', device=None, memory_cache=True, cache_dir=None, ): # Parameter validation # Check if correct 'metric' has been given if not metric in ['loss', 'accuracy']: raise ValueError( f'For "metric" expected one of (loss, accuracy), got {metric}') # Check if the optimizer is already attached to a scheduler self.optimizer = optimizer self._check_for_scheduler() self.model = model self.criterion = criterion self.metric = metric self.history = {'lr': [], 'metric': []} self.best_metric = None self.best_lr = None self.memory_cache = memory_cache self.cache_dir = cache_dir self.learner = None # Save the original state of the model and optimizer so they can be restored if # needed self.model_device = next(self.model.parameters()).device self.state_cacher = StateCacher(memory_cache, cache_dir=cache_dir) self.state_cacher.store('model', self.model.state_dict()) self.state_cacher.store('optimizer', self.optimizer.state_dict()) # If device is None, use the same as the model self.device = self.model_device if not device else device
[docs] def reset(self): """Restores the model and optimizer to their initial states.""" self.model.load_state_dict(self.state_cacher.retrieve('model')) self.optimizer.load_state_dict(self.state_cacher.retrieve('optimizer')) self.model.to(self.model_device) if not self.learner is None: self.learner.reset_history()
def _check_for_scheduler(self): """Check if the optimizer has and existing scheduler attached to it.""" for param_group in self.optimizer.param_groups: if 'initial_lr' in param_group: raise RuntimeError( 'Optimizer already has a scheduler attached to it') def _set_learning_rate(self, new_lrs): """Set the given learning rates in the optimizer.""" if not isinstance(new_lrs, list): new_lrs = [new_lrs] * len(self.optimizer.param_groups) if len(new_lrs) != len(self.optimizer.param_groups): raise ValueError( 'Length of new_lrs is not equal to the number of parameter groups in the given optimizer' ) # Set the learning rates to the parameter groups for param_group, new_lr in zip(self.optimizer.param_groups, new_lrs): param_group['lr'] = new_lr
[docs] def range_test( self, train_loader, iterations, mode='iteration', learner=None, val_loader=None, start_lr=None, end_lr=10, step_mode='exp', smooth_f=0.0, diverge_th=5, ): """Performs the learning rate range test. Args: train_loader (torch.utils.data.DataLoader): The training set data loader. iterations (int): The number of iterations/epochs over which the test occurs. If 'mode' is set to 'iteration' then it will correspond to the number of iterations else if mode is set to 'epoch' then it will correspond to the number of epochs. mode (:obj:`str`, optional): After which mode to update the learning rate. Can be either 'iteration' or 'epoch'. (default: 'iteration') learner (:obj:`Learner`, optional): Learner object for the model. (default: None) val_loader (:obj:`torch.utils.data.DataLoader`, optional): If None, the range test will only use the training metric. When given a data loader, the model is evaluated after each iteration on that dataset and the evaluation metric is used. Note that in this mode the test takes significantly longer but generally produces more precise results. (default: None) start_lr (:obj:`float`, optional): The starting learning rate for the range test. If None, uses the learning rate from the optimizer. (default: None) end_lr (:obj:`float`, optional): The maximum learning rate to test. (default: 10) step_mode (:obj:`str`, optional): One of the available learning rate policies, linear or exponential ('linear', 'exp'). (default: 'exp') smooth_f (:obj:`float`, optional): The metric smoothing factor within the [0, 1] interval. Disabled if set to 0, otherwise the metric is smoothed using exponential smoothing. (default: 0.0) diverge_th (:obj:`int`, optional): The test is stopped when the metric surpasses the threshold: diverge_th * best_metric. To disable, set it to 0. (default: 5) """ # Check if correct 'mode' mode has been given if not mode in ['iteration', 'epoch']: raise ValueError( f'For "mode" expected one of (iteration, epoch), got {mode}') # Reset test results self.history = {'lr': [], 'metric': []} self.best_metric = None self.best_lr = None # Check if the optimizer is already attached to a scheduler self._check_for_scheduler() # Set the starting learning rate if start_lr: self._set_learning_rate(start_lr) # Initialize the proper learning rate policy if step_mode.lower() == 'exp': lr_schedule = ExponentialLR(self.optimizer, end_lr, iterations) elif step_mode.lower() == 'linear': lr_schedule = LinearLR(self.optimizer, end_lr, iterations) else: raise ValueError(f'Expected one of (exp, linear), got {step_mode}') if smooth_f < 0 or smooth_f >= 1: raise ValueError('smooth_f is outside the range [0, 1]') # Set accuracy metric if needed metrics = None if self.metric == 'accuracy': metrics = ['accuracy'] # Get the learner object if not learner is None: self.learner = learner( train_loader, self.optimizer, self.criterion, device=self.device, val_loader=val_loader, metrics=metrics ) else: self.learner = Learner( train_loader, self.optimizer, self.criterion, device=self.device, val_loader=val_loader, metrics=metrics ) self.learner.set_model(self.model) train_iterator = InfiniteDataLoader(train_loader) pbar = ProgressBar(target=iterations, width=8) if mode == 'iteration': print(mode.title() + 's') for iteration in range(iterations): # Train model if mode == 'epoch': print(f'{mode.title()} {iteration + 1}:') self._train_model(mode, train_iterator) if val_loader: self.learner.validate(verbose=False) # Get metric value metric_value = self._get_metric(val_loader) # Update the learning rate lr_schedule.step() self.history['lr'].append(lr_schedule.get_lr()[0]) # Track the best metric and smooth it if smooth_f is specified if iteration == 0: self.best_metric = metric_value self.best_lr = self.history['lr'][-1] else: if smooth_f > 0: metric_value = smooth_f * metric_value + \ (1 - smooth_f) * self.history['metric'][-1] if ( (self.metric == 'loss' and metric_value < self.best_metric) or (self.metric == 'accuracy' and metric_value > self.best_metric) ): self.best_metric = metric_value self.best_lr = self.history['lr'][-1] # Check if the metric has diverged; if it has, stop the test self.history['metric'].append(metric_value) metric_value = self._display_metric_value(metric_value) if ( diverge_th > 0 and ((self.metric == 'loss' and metric_value > self.best_metric * diverge_th) or (self.metric == 'accuracy' and metric_value < self.best_metric / diverge_th)) ): if mode == 'iteration': pbar.update(iterations - 1, values=[ ('lr', self.history['lr'][-1]), (self.metric.title(), metric_value) ]) print('\nStopping early, the loss has diverged.') break else: if mode == 'epoch': lr = self.history['lr'][-1] print( f'Learning Rate: {lr:.4f}, {self.metric.title()}: {metric_value:.2f}\n') elif mode == 'iteration': pbar.update(iteration, values=[ ('lr', self.history['lr'][-1]), (self.metric.title(), metric_value) ]) metric = self._display_metric_value(self.best_metric) if mode == 'epoch': print( f'Learning Rate: {self.best_lr:.4f}, {self.metric.title()}: {metric:.2f}\n') elif mode == 'iteration': pbar.add(1, values=[ ('lr', self.best_lr), (self.metric.title(), metric) ]) print('Learning rate search finished.')
def _train_model(self, mode, train_iterator): if mode == 'iteration': self.learner.model.train() data, targets = train_iterator.get_batch() loss = self.learner.train_batch((data, targets)) self.learner.update_training_history(loss) elif mode == 'epoch': self.learner.train_epoch() def _get_metric(self, validation=None): if self.metric == 'loss': if validation: return self.learner.val_losses[-1] return self.learner.train_losses[-1] elif self.metric == 'accuracy': if validation: return self.learner.val_metrics[0][self.metric][-1] / 100 return self.learner.train_metrics[0][self.metric][-1] / 100 def _display_metric_value(self, value): if self.metric == 'accuracy': return value * 100 return value
[docs] def plot(self, log_lr=True, show_lr=None): """Plots the learning rate range test. Args: skip_start (:obj:`int`, optional): Number of batches to trim from the start. (default: 10) skip_end (:obj:`int`, optional): Number of batches to trim from the end. (default: 5) log_lr (:obj:`bool`, optional): True to plot the learning rate in a logarithmic scale; otherwise, plotted in a linear scale. (default: True) show_lr (:obj:`float`, optional): Is set, will add vertical line to visualize specified learning rate. (default: None) """ if show_lr is not None and not isinstance(show_lr, float): raise ValueError("show_lr must be float") # Get the data to plot from the history dictionary. lrs = self.history['lr'] metrics = self.history['metric'] # Plot metric_value as a function of the learning rate plt.plot(lrs, metrics) if log_lr: plt.xscale('log') plt.xlabel('Learning rate') plt.ylabel(self.metric.title()) if show_lr is not None: plt.axvline(x=show_lr, color='red') plt.show()
class LinearLR(_LRScheduler): """Linearly increases the learning rate between two boundaries over a number of iterations. Args: optimizer (torch.optim.Optimizer): Optimizer. end_lr (float): The final learning rate. iterations (int): The number of iterations over which the test occurs. last_epoch (:obj:`int`, optional): The index of last epoch. (default: -1) """ def __init__(self, optimizer, end_lr, iterations, last_epoch=-1): self.end_lr = end_lr self.iterations = iterations super(LinearLR, self).__init__(optimizer, last_epoch) def get_lr(self): curr_iter = self.last_epoch + 1 r = curr_iter / self.iterations return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] class ExponentialLR(_LRScheduler): """Exponentially increases the learning rate between two boundaries over a number of iterations. Args: optimizer (torch.optim.Optimizer): Optimizer. end_lr (float): The final learning rate. iterations (int): The number of iterations/epochs over which the test occurs. last_epoch (:obj:`int`, optional): The index of last epoch. (default: -1) """ def __init__(self, optimizer, end_lr, iterations, last_epoch=-1): self.end_lr = end_lr self.iterations = iterations super(ExponentialLR, self).__init__(optimizer, last_epoch) def get_lr(self): curr_iter = self.last_epoch + 1 r = curr_iter / self.iterations return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] class StateCacher(object): def __init__(self, in_memory, cache_dir=None): self.in_memory = in_memory self.cache_dir = cache_dir if self.cache_dir is None: import tempfile self.cache_dir = tempfile.gettempdir() else: if not os.path.isdir(self.cache_dir): raise ValueError('Given cache_dir is not a valid directory.') self.cached = {} def store(self, key, state_dict): if self.in_memory: self.cached.update({key: copy.deepcopy(state_dict)}) else: fn = os.path.join(self.cache_dir, f'state_{key}_{id(self)}.pt') self.cached.update({key: fn}) torch.save(state_dict, fn) def retrieve(self, key): if key not in self.cached: raise KeyError(f'Target {key} was not cached.') if self.in_memory: return self.cached.get(key) else: fn = self.cached.get(key) if not os.path.exists(fn): raise RuntimeError( f"Failed to load state in {fn}. File doesn't exist anymore." ) state_dict = torch.load( fn, map_location=lambda storage, location: storage) return state_dict def __del__(self): """Check whether there are unused cached files existing in cache_dir before this instance being destroyed. """ if self.in_memory: return for k in self.cached: if os.path.exists(self.cached[k]): os.remove(self.cached[k])