Source code for tensornet.utils.display

import os
import matplotlib.pyplot as plt
from typing import Union, List, Dict, Tuple, Optional

[docs]def plot_metric( data: Union[List[float], Dict[str, List[float]]], metric: str, title: str = None, size: Tuple[int] = (7, 5), legend_font: int = 15, legend_loc: str = 'lower right' ): """Plot accuracy graph or loss graph. Args: data (:obj:`list` or :obj:`dict`): If only single plot then this is a list, else for multiple plots this is a dict with keys containing the plot name and values being a list of points to plot. metric (str): Metric name which is to be plotted. Can be either loss or accuracy. title (:obj:`str`, optional): Title of the plot, if no title given then it is determined from the x and y label. size (:obj:`tuple`, optional): Size of the plot. (default: **'(7, 5)'**) legend_loc (:obj:`str`, optional): Location of the legend box in the plot. No legend will be plotted if there is only a single plot. (default: *'lower right'*) legend_font (:obj:`int`, optional): Font size of the legend (default: *'15'*) """ single_plot = True if type(data) == dict: single_plot = False # Initialize a figure fig = plt.figure(figsize=size) # Plot data if single_plot: plt.plot(data) else: plots = [] for value in data.values(): plots.append(plt.plot(value)[0]) # Set plot title if title is None: title = f'{metric} Change' plt.title(title) # Label axes plt.xlabel('Epoch') plt.ylabel(metric) if not single_plot: # Set legend plt.legend( tuple(plots), tuple(data.keys()), loc=legend_loc, shadow=True, prop={'size': legend_font} ) # Save plot fig.savefig(f'{"_".join(title.split()).lower()}.png')
[docs]def plot_predictions( data: List[dict], classes: Union[List[str], Tuple[str]], plot_title: str, plot_path: str ): """Display data. Args: data (list): List of images, model predictions and ground truths. Images should be numpy arrays. classes (:obj:`list` or :obj:`tuple`): List of classes in the dataset. plot_title (str): Title for the plot. plot_path (str): Complete path for saving the plot. """ # Initialize plot row_count = -1 fig, axs = plt.subplots(5, 5, figsize=(10, 10)) fig.suptitle(plot_title) for idx, result in enumerate(data): # If 25 samples have been stored, break out of loop if idx > 24: break label = result['label'].item() prediction = result['prediction'].item() # Plot image if idx % 5 == 0: row_count += 1 axs[row_count][idx % 5].axis('off') axs[row_count][idx % 5].set_title(f'Label: {classes[label]}\nPrediction: {classes[prediction]}') axs[row_count][idx % 5].imshow(result['image']) # Set spacing fig.tight_layout() fig.subplots_adjust(top=0.88) # Save image fig.savefig(f'{plot_path}', bbox_inches='tight')
[docs]def save_and_show_result( classes: Union[List[str], Tuple[str]], correct_pred: Optional[List[dict]] = None, incorrect_pred: Optional[List[dict]] = None, path: Optional[str] = None ): """Display network predictions. Args: classes (:obj:`list` or :obj:`tuple`): List of classes in the dataset. correct_pred (:obj:`list`, optional): Contains correct model predictions and labels. incorrect_pred (:obj:`list`, optional): Contains incorrect model predictions and labels. path (:obj:`str`, optional): Path where the results will be saved. """ # Create directories for saving predictions if path is None: path = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'predictions' ) if not os.path.exists(path): os.makedirs(path) if not correct_pred is None: # Plot correct predicitons plot_predictions( correct_pred, classes, 'Correct Predictions', f'{path}/correct_predictions.png' ) if not incorrect_pred is None: # Plot incorrect predicitons plot_predictions( incorrect_pred, classes, '\nIncorrect Predictions', f'{path}/incorrect_predictions.png' )