Source code for utils.evaluate

import logging
from typing import TYPE_CHECKING, Any, Callable, Tuple
import torch
from tqdm import tqdm

if TYPE_CHECKING:
    from models.utils.continual_model import ContinualModel
    from datasets.utils.continual_dataset import ContinualDataset


[docs] def mask_classes(outputs: torch.Tensor, dataset: 'ContinualDataset', k: int) -> None: """ Given the output tensor, the dataset at hand and the current task, masks the former by setting the responses for the other tasks at -inf. It is used to obtain the results for the task-il setting. Args: outputs: the output tensor dataset: the continual dataset k: the task index """ num_classes = dataset.N_CLASSES start_c, end_c = dataset.get_offsets(k) outputs[:, :start_c] = -float('inf') outputs[:, end_c:num_classes] = -float('inf')
# Type alias for the evaluation function EvalFn = Callable[ ['ContinualModel', 'ContinualDataset', bool, bool], Any ]
[docs] @torch.no_grad() def evaluate(model: 'ContinualModel', dataset: 'ContinualDataset', last=False, return_loss=False) -> Tuple[list, list]: """ Evaluates the single-class top-1 accuracy of the model for each past task. The accuracy is evaluated for all the tasks up to the current one, only for the total number of classes seen so far. Args: model: the model to be evaluated dataset: the continual dataset at hand last: a boolean indicating whether to evaluate only the last task return_loss: a boolean indicating whether to return the loss in addition to the accuracy Returns: a tuple of lists, containing the class-il and task-il accuracy for each task. If return_loss is True, the loss is also returned as a third element. """ status = model.net.training model.net.eval() accs, accs_mask_classes = [], [] n_classes = dataset.get_offsets()[1] loss_fn = dataset.get_loss() avg_loss = 0 tot_seen_samples = 0 total_len = sum(len(x) for x in dataset.test_loaders) if hasattr(dataset.test_loaders[0], '__len__') else None pbar = tqdm(dataset.test_loaders, total=total_len, desc='Evaluating', disable=model.args.non_verbose) for k, test_loader in enumerate(dataset.test_loaders): if last and k < len(dataset.test_loaders) - 1: continue correct, correct_mask_classes, total = 0.0, 0.0, 0.0 test_iter = iter(test_loader) i = 0 while True: try: data = next(test_iter) except StopIteration: break if model.args.debug_mode and i > model.get_debug_iters(): break inputs, labels = data[0], data[1] inputs, labels = inputs.to(model.device), labels.to(model.device) if 'class-il' not in model.COMPATIBILITY and 'general-continual' not in model.COMPATIBILITY: outputs = model(inputs, k) else: if model.args.eval_future and k >= model.current_task: outputs = model.future_forward(inputs) else: outputs = model(inputs) if return_loss: loss = loss_fn(outputs, labels) avg_loss += loss.item() _, pred = torch.max(outputs[:, :n_classes].data, 1) correct += torch.sum(pred == labels).item() total += labels.shape[0] i += 1 pbar.set_postfix({f'acc_task_{k+1}': max(0, correct / total * 100)}, refresh=False) pbar.set_description(f"Evaluating Task {k+1}", refresh=False) pbar.update(1) if dataset.SETTING == 'class-il': mask_classes(outputs, dataset, k) _, pred = torch.max(outputs.data, 1) correct_mask_classes += torch.sum(pred == labels).item() tot_seen_samples += total if correct > correct_mask_classes: logging.warning("Task-IL accuracy is LOWER than Class-IL accuracy. " "This should NEVER happen and probably means there is a bug somewhere. " "Hint: check if the dataloader returns the targets in the correct order.") accs.append(correct / total * 100 if 'class-il' in model.COMPATIBILITY or 'general-continual' in model.COMPATIBILITY else 0) accs_mask_classes.append(correct_mask_classes / total * 100) pbar.close() model.net.train(status) if return_loss: return accs, accs_mask_classes, avg_loss / tot_seen_samples return accs, accs_mask_classes