Source code for tensornet.models.base_model

import torch
import torch.nn as nn

from .utils.summary import summary as model_summary
from tensornet.engine.learner import Learner
from typing import Tuple


[docs]class BaseModel(nn.Module): """This is the parent class for all the models that are to be created using ``TensorNet``.""" def __init__(self): """This function instantiates all the model layers.""" super(BaseModel, self).__init__() self.learner = None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """This function defines the forward pass of the model. Args: x (torch.Tensor): Input. Returns: (*torch.Tensor*): Model output. """ raise NotImplementedError
[docs] def summary(self, input_size: Tuple[int]): """Generates model summary. Args: input_size (tuple): Size of input to the model. """ model_summary(self, input_size)
[docs] def create_learner( self, train_loader, optimizer, criterion, device='cpu', epochs=1, l1_factor=0.0, val_loader=None, callbacks=None, metrics=None, activate_loss_logits=False, record_train=True ): """Create Learner object. Args: train_loader (torch.utils.data.DataLoader): Training data loader. optimizer (torch.optim): Optimizer for the model. criterion (torch.nn): Loss Function. device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded. epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1) l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0) val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader. callbacks (:obj:`list`, optional): List of callbacks to be used during training. track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the changes in loss and accuracy for each batch or the entire epoch respectively. (default: *'epoch'*) metrics (:obj:`list`, optional): List of names of the metrics for model evaluation. """ self.learner = Learner( train_loader, optimizer, criterion, device=device, epochs=epochs, val_loader=val_loader, l1_factor=l1_factor, callbacks=callbacks, metrics=metrics, activate_loss_logits=activate_loss_logits, record_train=record_train ) self.learner.set_model(self)
[docs] def set_learner(self, learner: Learner): """Assign a learner object to the model. Args: learner (:obj:`Learner`): Learner object. """ self.learner = learner self.learner.set_model(self)
[docs] def fit( self, train_loader, optimizer, criterion, device='cpu', epochs=1, l1_factor=0.0, val_loader=None, callbacks=None, metrics=None, activate_loss_logits=False, record_train=True, start_epoch=1, verbose=True, ): """Train the model. Args: train_loader (torch.utils.data.DataLoader): Training data loader. optimizer (torch.optim): Optimizer for the model. criterion (torch.nn): Loss Function. device (:obj:`str` or :obj:`torch.device`): Device where the data will be loaded. epochs (:obj:`int`, optional): Numbers of epochs to train the model. (default: 1) l1_factor (:obj:`float`, optional): L1 regularization factor. (default: 0) val_loader (:obj:`torch.utils.data.DataLoader`, optional): Validation data loader. callbacks (:obj:`list`, optional): List of callbacks to be used during training. track (:obj:`str`, optional): Can be set to either `'epoch'` or `'batch'` and will store the changes in loss and accuracy for each batch or the entire epoch respectively. (default: *'epoch'*) metrics (:obj:`list`, optional): List of names of the metrics for model evaluation. record_train (:obj:`bool`, optional): If False, metrics will be calculated only during validation. (default: True) activate_loss_logits (:obj:`bool`, optional): If True, the logits will first pass through the `activate_logits` function before going to the criterion. (default: False) start_epoch (:obj:`int`, optional): Starting epoch number to display during training. (default: 1) verbose (:obj:`bool`, optional): Print loss and metrics. (default: True) """ # Create learner object self.create_learner( train_loader, optimizer, criterion, device=device, epochs=epochs, l1_factor=l1_factor, val_loader=val_loader, callbacks=callbacks, metrics=metrics, activate_loss_logits=activate_loss_logits, record_train=record_train, ) # Train Model self.learner.fit(start_epoch=start_epoch, epochs=epochs, verbose=verbose)
[docs] def rfit(self, start_epoch=1, epochs=None, verbose=True): if self.learner is None: raise ValueError('No learner initialized.') self.learner.fit(start_epoch=start_epoch, epochs=epochs, reset=False, verbose=verbose)
[docs] def evaluate(self, loader, verbose=True, log_message='Evaluation'): """Evaluate the model on a custom data loader. Args: loader (torch.utils.data.DataLoader): Data loader. verbose (:obj:`bool`, optional): Print loss and metrics. (default: True) log_message (str): Prefix for the logs which are printed at the end. Returns: loss and metric values """ if self.learner is None: raise ValueError('Cannot evaluate without a learner. Create and assign a learner object first.') return self.learner.evaluate(loader, verbose=verbose, log_message=log_message)
[docs] def save(self, filepath: str, **kwargs): """Save the model. Args: filepath (str): File in which the model will be saved. **kwargs: Additional parameters to save with the model. """ if self.learner is None: raise ValueError('Cannot save un-trained model.') torch.save({ 'model_state_dict': self.state_dict(), 'optimizer_state_dict': self.learner.optimizer.state_dict(), **kwargs }, filepath)
[docs] def load(self, filepath: str) -> dict: """Load the model and return the additional parameters saved in in the checkpoint file. Args: filepath (str): File in which the model is be saved. Returns: (*dict*): Parameters saved inside the checkpoint file. """ checkpoint = torch.load(filepath) self.load_state_dict(checkpoint['model_state_dict']) return { k: v for k, v in checkpoint.items() if k != 'model_state_dict' }