Engine

Classes and methods used to train and test models.

class tensornet.engine.Learner(train_loader, optimizer, criterion, device='cpu', epochs=1, l1_factor=0.0, val_loader=None, callbacks=None, metrics=None, activate_loss_logits=False, record_train=True)[source]

Model Trainer and Validator.

Parameters
  • train_loader (torch.utils.data.DataLoader) – Training data loader.

  • optimizer (torch.optim) – Optimizer for the model.

  • criterion (torch.nn) – Loss Function.

  • device (str or torch.device, optional) – Device where the data will be loaded. (default=’cpu’)

  • epochs (int, optional) – Numbers of epochs/iterations to train the model for. (default: 1)

  • l1_factor (float, optional) – L1 regularization factor. (default: 0)

  • val_loader (torch.utils.data.DataLoader, optional) – Validation data loader.

  • callbacks (list, optional) – List of callbacks to be used during training.

  • metrics (list, optional) –

    List of names of the metrics for model evaluation.

    Note: If the model has multiple outputs, then this will be a nested list where each individual sub-list will specify the metrics which are to be used for evaluating each output respectively. In such cases, the model checkpoint will consider only the metric of the first output for saving checkpoints.

  • activate_loss_logits (bool, optional) – If True, the logits will first pass through the activate_logits function before going to the criterion. (default: False)

  • record_train (bool, optional) – If False, metrics will be calculated only during validation. (default: True)

set_model(model)[source]

Assign model to learner.

Parameters

model (torch.nn.Module) – Model Instance.

update_training_history(loss)[source]

Update the training history.

Parameters

loss (float) – Loss value.

reset_history()[source]

Reset the training history

activate_logits(logits)[source]

Apply activation function to the logits if needed. After this the logits will be sent for calculation of loss or evaluation metrics.

Parameters

logits (torch.Tensor) – Model output

Returns

activated logits

Return type

(torch.Tensor)

calculate_criterion(logits, targets, train=True)[source]

Calculate loss.

Parameters
  • logits (torch.Tensor) – Prediction.

  • targets (torch.Tensor) – Ground truth.

  • train (bool, optional) – If True, loss is sent to the L1 regularization function. (default: True)

Returns

loss value

Return type

(torch.Tensor)

fetch_data(data)[source]

Fetch data from loader and load it to GPU.

Parameters

data (tuple or list) – List containing inputs and targets.

Returns

inputs and targets loaded to GPU.

train_batch(data)[source]

Train the model on a batch of data.

Parameters

data (tuple or list) – Input and target data for the model.

Returns

Batch loss.

Return type

(float)

train_epoch(verbose=True)[source]

Run an epoch of model training.

Parameters

verbose (bool, optional) – Print logs. (default: True)

train_iterations(verbose=True)[source]

Train model for the ‘self.epochs’ number of batches.

evaluate(loader, verbose=True, log_message='Evaluation')[source]

Evaluate the model on a custom data loader.

Parameters
  • loader (torch.utils.data.DataLoader) – Data loader.

  • verbose (bool, optional) – Print loss and metrics. (default: True)

  • log_message (str) – Prefix for the logs which are printed at the end.

Returns

loss and metric values

validate(verbose=True)[source]

Validate an epoch of model training.

Parameters

verbose (bool, optional) – Print validation loss and metrics. (default: True)

save_checkpoint(epoch=None)[source]

Save model checkpoint.

Parameters

epoch (int, optional) – Current epoch number.

write_summary(epoch, train)[source]

Write training summary in tensorboard.

Parameters
  • epoch (int) – Current epoch number.

  • train (bool) – If True, summary will be written for model training else it will be writtern for model validation.

fit(start_epoch=1, epochs=None, reset=True, verbose=True)[source]

Perform model training.

Parameters
  • start_epoch (int, optional) – Start epoch for training. (default: 1)

  • epochs (int, optional) – Numbers of epochs/iterations to train the model for. If no value is given, the original value given during initialization of learner will be used.

  • reset (bool, optional) – Flag to indicate that training is starting from scratch. (default: True)

  • verbose (bool, optional) – Print logs. (default: True)

class tensornet.engine.LRFinder(model, optimizer, criterion, metric='loss', device=None, memory_cache=True, cache_dir=None)[source]

Learning rate range test. The learning rate range test increases the learning rate in a pre-training run between two boundaries in a linear or exponential manner. It provides valuable information on how well the network can be trained over a range of learning rates and what is the optimal learning rate.

Parameters
  • model (torch.nn.Module) – Model Instance.

  • optimizer (torch.optim) – Optimizer where the defined learning is assumed to be the lower boundary of the range test.

  • criterion (torch.nn) – Loss function.

  • metric (str, optional) – Metric to use for finding the best learning rate. Can be either ‘loss’ or ‘accuracy’. (default: ‘loss’)

  • device (str or torch.device, optional) – Device where the computation will take place. If None, uses the same device as model. (default: none)

  • memory_cache (bool, optional) – If this flag is set to True, state_dict of model and optimizer will be cached in memory. Otherwise, they will be saved to files under the cache_dir. (default: True)

  • cache_dir (str, optional) – Path for storing temporary files. If no path is specified, system-wide temporary directory is used. Notice that this parameter will be ignored if memory_cache is True. (default: None)

reset()[source]

Restores the model and optimizer to their initial states.

range_test(train_loader, iterations, mode='iteration', learner=None, val_loader=None, start_lr=None, end_lr=10, step_mode='exp', smooth_f=0.0, diverge_th=5)[source]

Performs the learning rate range test.

Parameters
  • train_loader (torch.utils.data.DataLoader) – The training set data loader.

  • iterations (int) – The number of iterations/epochs over which the test occurs. If ‘mode’ is set to ‘iteration’ then it will correspond to the number of iterations else if mode is set to ‘epoch’ then it will correspond to the number of epochs.

  • mode (str, optional) – After which mode to update the learning rate. Can be either ‘iteration’ or ‘epoch’. (default: ‘iteration’)

  • learner (Learner, optional) – Learner object for the model. (default: None)

  • val_loader (torch.utils.data.DataLoader, optional) – If None, the range test will only use the training metric. When given a data loader, the model is evaluated after each iteration on that dataset and the evaluation metric is used. Note that in this mode the test takes significantly longer but generally produces more precise results. (default: None)

  • start_lr (float, optional) – The starting learning rate for the range test. If None, uses the learning rate from the optimizer. (default: None)

  • end_lr (float, optional) – The maximum learning rate to test. (default: 10)

  • step_mode (str, optional) – One of the available learning rate policies, linear or exponential (‘linear’, ‘exp’). (default: ‘exp’)

  • smooth_f (float, optional) – The metric smoothing factor within the [0, 1] interval. Disabled if set to 0, otherwise the metric is smoothed using exponential smoothing. (default: 0.0)

  • diverge_th (int, optional) – The test is stopped when the metric surpasses the threshold: diverge_th * best_metric. To disable, set it to 0. (default: 5)

plot(log_lr=True, show_lr=None)[source]

Plots the learning rate range test.

Parameters
  • skip_start (int, optional) – Number of batches to trim from the start. (default: 10)

  • skip_end (int, optional) – Number of batches to trim from the end. (default: 5)

  • log_lr (bool, optional) – True to plot the learning rate in a logarithmic scale; otherwise, plotted in a linear scale. (default: True)

  • show_lr (float, optional) – Is set, will add vertical line to visualize specified learning rate. (default: None)