Source code for tensornet.engine.learner

import math
import time
import torch
from copy import deepcopy

from tensornet.engine.ops.regularizer import l1
from tensornet.engine.ops.checkpoint import ModelCheckpoint
from tensornet.engine.ops.tensorboard import TensorBoard
from tensornet.data.processing import InfiniteDataLoader
from tensornet.utils.progress_bar import ProgressBar


[docs]class Learner: """Model Trainer and Validator. Args: train_loader (torch.utils.data.DataLoader): Training data loader. optimizer (torch.optim): Optimizer for the model. criterion (torch.nn): Loss Function. device (:obj:`str` or :obj:`torch.device`, optional): Device where the data will be loaded. (default='cpu') epochs (:obj:`int`, optional): Numbers of epochs/iterations to train the model for. (default: 1) l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0) val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader. callbacks (:obj:`list`, optional): List of callbacks to be used during training. metrics (:obj:`list`, optional): List of names of the metrics for model evaluation. *Note*: If the model has multiple outputs, then this will be a nested list where each individual sub-list will specify the metrics which are to be used for evaluating each output respectively. In such cases, the model checkpoint will consider only the metric of the first output for saving checkpoints. activate_loss_logits (:obj:`bool`, optional): If True, the logits will first pass through the `activate_logits` function before going to the criterion. (default: False) record_train (:obj:`bool`, optional): If False, metrics will be calculated only during validation. (default: True) """ def __init__( self, train_loader, optimizer, criterion, device='cpu', epochs=1, l1_factor=0.0, val_loader=None, callbacks=None, metrics=None, activate_loss_logits=False, record_train=True ): self.model = None self.optimizer = optimizer self.criterion = criterion self.train_loader = train_loader self.device = device self.epochs = epochs self.val_loader = val_loader self.l1_factor = l1_factor self.activate_loss_logits = activate_loss_logits self.record_train = record_train self.lr_schedulers = { 'step_lr': None, 'lr_plateau': None, 'one_cycle_policy': None, 'cyclic_lr': None, } self.checkpoint = None self.summary_writer = None if callbacks is not None: self._setup_callbacks(callbacks) # Training self.train_losses = [] # Change in loss self.train_metrics = [] # Change in evaluation metric self.val_losses = [] # Change in loss self.val_metrics = [] # Change in evaluation metric # Set evaluation metrics self.metrics = [] if metrics: self._setup_metrics(metrics) def _setup_callbacks(self, callbacks): """Extract callbacks passed to the class. Args: callbacks (list): List of callbacks. """ for callback in callbacks: if isinstance(callback, torch.optim.lr_scheduler.StepLR): self.lr_schedulers['step_lr'] = callback elif isinstance(callback, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_schedulers['lr_plateau'] = callback elif isinstance(callback, torch.optim.lr_scheduler.OneCycleLR): self.lr_schedulers['one_cycle_policy'] = callback elif isinstance(callback, ModelCheckpoint): if callback.monitor.startswith('train_'): if self.record_train: self.checkpoint = callback else: raise ValueError( 'Cannot use checkpoint for a training metric if record_train is set to False' ) else: self.checkpoint = callback elif isinstance(callback, TensorBoard): self.summary_writer = callback elif isinstance(callback, torch.optim.lr_scheduler.CyclicLR): self.lr_schedulers['cyclic_lr'] = callback
[docs] def set_model(self, model): """Assign model to learner. Args: model (torch.nn.Module): Model Instance. """ self.model = model if self.summary_writer is not None: self.summary_writer.write_model(self.model)
def _accuracy(self, label, prediction, idx=0): """Calculate accuracy. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. """ self.metrics[idx]['accuracy']['sum'] += prediction.eq( label.view_as(prediction) ).sum().item() self.metrics[idx]['accuracy']['num_steps'] += len(label) self.metrics[idx]['accuracy']['value'] = round( 100 * self.metrics[idx]['accuracy']['sum'] / self.metrics[idx]['accuracy']['num_steps'], 2 ) def _iou(self, label, prediction, idx=0): """Calculate Intersection over Union. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. """ # Remove 1 channel dimension label = label.squeeze(1) prediction = prediction.squeeze(1) intersection = (prediction * label).sum(2).sum(1) union = (prediction + label).sum(2).sum(1) - intersection # epsilon is added to avoid 0/0 epsilon = 1e-6 iou = (intersection + epsilon) / (union + epsilon) self.metrics[idx]['iou']['sum'] += iou.sum().item() self.metrics[idx]['iou']['num_steps'] += label.size(0) self.metrics[idx]['iou']['value'] = round( self.metrics[idx]['iou']['sum'] / self.metrics[idx]['iou']['num_steps'], 3 ) def _pred_label_diff(self, label, prediction, rel=False): """Calculate the difference between label and prediction. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. rel (:obj:`bool`, optional): If True, return the relative difference. (default: False) Returns: Difference between label and prediction """ # For numerical stability valid_labels = label > 0.0001 _label = label[valid_labels] _prediction = prediction[valid_labels] valid_element_count = _label.size(0) if valid_element_count > 0: diff = torch.abs(_label - _prediction) if rel: diff = torch.div(diff, _label) return diff, valid_element_count def _rmse(self, label, prediction, idx=0): """Calculate Root Mean Square Error. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. """ diff = self._pred_label_diff(label, prediction) rmse = 0 if diff is not None: rmse = math.sqrt(torch.sum(torch.pow(diff[0], 2)) / diff[1]) self.metrics[idx]['rmse']['num_steps'] += label.size(0) self.metrics[idx]['rmse']['sum'] += rmse * label.size(0) self.metrics[idx]['rmse']['value'] = round( self.metrics[idx]['rmse']['sum'] / self.metrics[idx]['rmse']['num_steps'], 3 ) def _mae(self, label, prediction, idx=0): """Calculate Mean Average Error. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. """ diff = self._pred_label_diff(label, prediction) mae = 0 if diff is not None: mae = torch.sum(diff[0]).item() / diff[1] self.metrics[idx]['mae']['num_steps'] += label.size(0) self.metrics[idx]['mae']['sum'] += mae * label.size(0) self.metrics[idx]['mae']['value'] = round( self.metrics[idx]['mae']['sum'] / self.metrics[idx]['mae']['num_steps'], 3 ) def _abs_rel(self, label, prediction, idx=0): """Calculate Absolute Relative Error. Args: label (torch.Tensor): Ground truth. prediction (torch.Tensor): Prediction. """ diff = self._pred_label_diff(label, prediction, rel=True) abs_rel = 0 if diff is not None: abs_rel = torch.sum(diff[0]).item() / diff[1] self.metrics[idx]['abs_rel']['num_steps'] += label.size(0) self.metrics[idx]['abs_rel']['sum'] += abs_rel * label.size(0) self.metrics[idx]['abs_rel']['value'] = round( self.metrics[idx]['abs_rel']['sum'] / self.metrics[idx]['abs_rel']['num_steps'], 3 ) def _setup_metrics(self, metrics): """Validate the evaluation metrics passed to the class. Args: metrics (:obj:`list` or :obj:`dict`): Metrics. """ if not isinstance(metrics[0], (list, tuple)): metrics = [metrics] for idx, metric_list in enumerate(metrics): metric_dict = {} for metric in metric_list: metric_info = {'value': 0, 'sum': 0, 'num_steps': 0} if metric == 'accuracy': metric_info['func'] = self._accuracy elif metric == 'rmse': metric_info['func'] = self._rmse elif metric == 'mae': metric_info['func'] = self._mae elif metric == 'abs_rel': metric_info['func'] = self._abs_rel elif metric == 'iou': metric_info['func'] = self._iou if 'func' in metric_info: metric_dict[metric] = metric_info if metric_dict: self.metrics.append(metric_dict) self.train_metrics.append({ x: [] for x in metric_dict.keys() }) self.val_metrics.append({ x: [] for x in metric_dict.keys() }) def _calculate_metrics(self, labels, predictions): """Update evaluation metric values. Args: label (:obj:`torch.Tensor` or :obj:`dict`): Ground truth. prediction (:obj:`torch.Tensor` or :obj:`dict`): Prediction. """ predictions = self.activate_logits(predictions) if not isinstance(labels, (list, tuple)): labels = [labels] predictions = [predictions] for idx, (label, prediction) in enumerate(zip(labels, predictions)): # If predictions are one-hot encoded if label.size() != prediction.size(): prediction = prediction.argmax(dim=1, keepdim=True) * 1.0 if idx < len(self.metrics): for metric in self.metrics[idx]: self.metrics[idx][metric]['func']( label, prediction, idx=idx ) def _reset_metrics(self): """Reset metric params.""" for idx in range(len(self.metrics)): for metric in self.metrics[idx]: self.metrics[idx][metric]['value'] = 0 self.metrics[idx][metric]['sum'] = 0 self.metrics[idx][metric]['num_steps'] = 0 def _get_pbar_values(self, loss): """Create progress bar description. Args: loss (float): Loss value. """ pbar_values = [('loss', round(loss, 2))] if self.metrics and self.record_train: for idx in range(len(self.metrics)): for metric, info in self.metrics[idx].items(): metric_name = metric if len(self.metrics) > 1: metric_name = f'{idx} - {metric}' pbar_values.append((metric_name, info['value'])) return pbar_values
[docs] def update_training_history(self, loss): """Update the training history. Args: loss (float): Loss value. """ self.train_losses.append(loss) if self.record_train: for idx in range(len(self.metrics)): for metric in self.metrics[idx]: self.train_metrics[idx][metric].append( self.metrics[idx][metric]['value'] )
[docs] def reset_history(self): """Reset the training history""" self.train_losses = [] self.val_losses = [] for idx in range(len(self.metrics)): for metric in self.metrics[idx]: self.train_metrics[idx][metric] = [] self.val_metrics[idx][metric] = [] self._reset_metrics()
[docs] def activate_logits(self, logits): """Apply activation function to the logits if needed. After this the logits will be sent for calculation of loss or evaluation metrics. Args: logits (torch.Tensor): Model output Returns: (*torch.Tensor*): activated logits """ return logits
[docs] def calculate_criterion(self, logits, targets, train=True): """Calculate loss. Args: logits (torch.Tensor): Prediction. targets (torch.Tensor): Ground truth. train (:obj:`bool`, optional): If True, loss is sent to the L1 regularization function. (default: True) Returns: (*torch.Tensor*): loss value """ if self.activate_loss_logits: logits = self.activate_logits(logits) if train: return l1(self.model, self.criterion(logits, targets), self.l1_factor) return self.criterion(logits, targets)
[docs] def fetch_data(self, data): """Fetch data from loader and load it to GPU. Args: data (:obj:`tuple` or :obj:`list`): List containing inputs and targets. Returns: inputs and targets loaded to GPU. """ return data[0].to(self.device), data[1].to(self.device)
[docs] def train_batch(self, data): """Train the model on a batch of data. Args: data (:obj:`tuple` or :obj:`list`): Input and target data for the model. Returns: (*float*): Batch loss. """ inputs, targets = self.fetch_data(data) self.optimizer.zero_grad() # Set gradients to zero before starting backpropagation y_pred = self.model(inputs) # Predict output loss = self.calculate_criterion(y_pred, targets, train=True) # Calculate loss # Perform backpropagation loss.backward() self.optimizer.step() if self.record_train: self._calculate_metrics(targets, y_pred) # One Cycle Policy for learning rate if self.lr_schedulers['one_cycle_policy'] is not None: self.lr_schedulers['one_cycle_policy'].step() # Cyclic LR policy if self.lr_schedulers['cyclic_lr'] is not None: self.lr_schedulers['cyclic_lr'].step() return loss.item()
[docs] def train_epoch(self, verbose=True): """Run an epoch of model training. Args: verbose (:obj:`bool`, optional): Print logs. (default: True) """ self.model.train() if verbose: pbar = ProgressBar(target=len(self.train_loader), width=8) for batch_idx, data in enumerate(self.train_loader, 0): # Train a batch loss = self.train_batch(data) # Update Progress Bar if verbose: pbar_values = self._get_pbar_values(loss) pbar.update(batch_idx, values=pbar_values) # Update training history self.update_training_history(loss) if verbose: pbar_values = self._get_pbar_values(loss) pbar.add(1, values=pbar_values) self._reset_metrics()
[docs] def train_iterations(self, verbose=True): """Train model for the 'self.epochs' number of batches.""" self.model.train() if verbose: pbar = ProgressBar(target=self.epochs, width=8) iterator = InfiniteDataLoader(self.train_loader) for iteration in range(self.epochs): # Train a batch loss = self.train_batch(iterator.get_batch()) # Update Progress Bar if verbose: pbar_values = self._get_pbar_values(loss) pbar.update(iteration, values=pbar_values) # Update training history self.update_training_history(loss) if verbose: pbar.add(1, values=pbar_values)
[docs] def evaluate(self, loader, verbose=True, log_message='Evaluation'): """Evaluate the model on a custom data loader. Args: loader (torch.utils.data.DataLoader): Data loader. verbose (:obj:`bool`, optional): Print loss and metrics. (default: True) log_message (str): Prefix for the logs which are printed at the end. Returns: loss and metric values """ start_time = time.time() self.model.eval() eval_loss = 0 with torch.no_grad(): for data in loader: inputs, targets = self.fetch_data(data) output = self.model(inputs) # Get trained model output eval_loss += self.calculate_criterion( output, targets, train=False ).item() # Sum up batch loss self._calculate_metrics(targets, output) # Calculate evaluation metrics eval_loss /= len(loader.dataset) eval_metrics = deepcopy(self.metrics) end_time = time.time() # Time spent during validation duration = int(end_time - start_time) minutes = duration // 60 seconds = duration % 60 if verbose: log = f'{log_message} (took {minutes} minutes, {seconds} seconds): Average loss: {eval_loss:.4f}' for idx in range(len(self.metrics)): for metric in self.metrics[idx]: log += f', {metric}: {self.metrics[idx][metric]["value"]}' log += '\n' print(log) self._reset_metrics() return eval_loss, eval_metrics
[docs] def validate(self, verbose=True): """Validate an epoch of model training. Args: verbose (:obj:`bool`, optional): Print validation loss and metrics. (default: True) """ eval_loss, eval_metrics = self.evaluate( self.val_loader, verbose=verbose, log_message='Validation set' ) # Update validation logs self.val_losses.append(eval_loss) for idx in range(len(eval_metrics)): for metric in eval_metrics[idx]: self.val_metrics[idx][metric].append( eval_metrics[idx][metric]['value'] )
[docs] def save_checkpoint(self, epoch=None): """Save model checkpoint. Args: epoch (:obj:`int`, optional): Current epoch number. """ if self.checkpoint is not None: metric = None if self.checkpoint.monitor == 'train_loss': metric = self.train_losses[-1] elif self.checkpoint.monitor == 'val_loss': metric = self.val_losses[-1] elif self.metrics: if self.checkpoint.monitor.startswith('train_'): if self.record_train: metric = self.train_metrics[0][ self.checkpoint.monitor.split('train_')[-1] ][-1] else: metric = self.val_metrics[0][ self.checkpoint.monitor.split('val_')[-1] ][-1] else: print('Invalid metric function, can\'t save checkpoint.') return self.checkpoint(self.model, metric, epoch)
[docs] def write_summary(self, epoch, train): """Write training summary in tensorboard. Args: epoch (int): Current epoch number. train (bool): If True, summary will be written for model training else it will be writtern for model validation. """ if self.summary_writer is not None: if train: mode = 'train' # Write Images self.summary_writer.write_images( self.model, self.activate_logits, f'prediction_epoch_{epoch}' ) loss = self.train_losses[-1] else: mode = 'val' loss = self.val_losses[-1] # Write Loss self.summary_writer.write_scalar( f'Loss/{mode}', loss, epoch ) if not train or self.record_train: for idx in range(len(self.metrics)): for metric, info in self.metrics[idx].items(): self.summary_writer.write_scalar( f'{idx}/{metric.title()}/{mode}', info['value'], epoch )
[docs] def fit(self, start_epoch=1, epochs=None, reset=True, verbose=True): """Perform model training. Args: start_epoch (:obj:`int`, optional): Start epoch for training. (default: 1) epochs (:obj:`int`, optional): Numbers of epochs/iterations to train the model for. If no value is given, the original value given during initialization of learner will be used. reset (:obj:`bool`, optional): Flag to indicate that training is starting from scratch. (default: True) verbose (:obj:`bool`, optional): Print logs. (default: True) """ if reset: self.reset_history() if epochs is not None: self.epochs = epochs for epoch in range(start_epoch, start_epoch + self.epochs): if verbose: print(f'Epoch {epoch}:') # Train an epoch self.train_epoch(verbose=verbose) self.write_summary(epoch, True) # Validate the model if self.val_loader is not None: self.validate(verbose=verbose) self.write_summary(epoch, False) # Save model checkpoint self.save_checkpoint(epoch) # Call Step LR if not self.lr_schedulers['step_lr'] is None: self.lr_schedulers['step_lr'].step() # Call Reduce LR on Plateau if not self.lr_schedulers['lr_plateau'] is None: self.lr_schedulers['lr_plateau'].step(self.val_losses[-1])