Source code for models.utils.future_model

"""
This is the base class for all models that support future prediction, i.e., zero-shot prediction.

It extends the ContinualModel class and adds the future_forward method, which should be implemented by all models that inherit from this class.
Such a method should take an input tensor and return a tensor representing the future prediction. This method is used by the future prediction evaluation protocol.

The change_transform method is used to update the transformation applied to the input data. This is useful when the model is trained on a dataset and then evaluated on a different dataset. In this case, the transformation should be updated to match the new dataset.
"""
import torch

from datasets.utils.continual_dataset import ContinualDataset
from .continual_model import ContinualModel


[docs] class FutureModel(ContinualModel):
[docs] def future_forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Function that implements the forward pass of the model for future prediction. This method should be implemented by all models that inherit from this class. Args: x (torch.Tensor): The input tensor. Returns: torch.Tensor: The output tensor representing the future prediction. """ raise NotImplementedError
[docs] def change_transform(self, dataset: ContinualDataset): """ Change the transformation applied to the input data. In Zero-shot learning, the model is trained on a dataset and then evaluated on a different one. In this case, the transformation should be updated to match the new dataset. Args: dataset (ContinualDataset): An instance of the dataset on which the model will be evaluated on new classes. """ pass