STARPROMPT#

Arguments#

Options

--clip_backbonestr

Help: CLIP backbone architecture

  • Default: ViT-L/14

  • Choices: RN50, RN101, RN50x4, RN50x16, RN50x64, ViT-B/32, ViT-B/16, ViT-L/14, ViT-L/14@336px

Frozen hyperparameters

--virtual_bs_nint

Help: virtual batch size iterations

  • Default: 1

--ortho_split_valint

Help: None

  • Default: 0

--gr_mog_n_iters_second_stageint

Help: Number of EM iterations during fit for GR with MOG on the second stage.

  • Default: 500

--gr_mog_n_iters_first_stageint

Help: Number of EM iterations during fit for GR with MOG on the first stage.

  • Default: 200

--gr_mog_n_componentsint

Help: Number of components for GR with MOG (both first and second stage).

  • Default: 5

--batch_size_grint

Help: Batch size for Generative Replay (both first and second stage).

  • Default: 128

--num_samples_grint

Help: Number of samples for Generative Replay (both first and second stage).

  • Default: 256

--prefix_tuning_prompt_lenint

Help: Prompt length for prefix tuning. Used only if –prompt_mode==concat.

  • Default: 5

Ablations hyperparameters

--gr_modelstr

Help: Type of distribution model for Generative Replay (both first and second stage). - mog: Mixture of Gaussian. - gaussian: Single Gaussian distribution.

  • Default: mog

  • Choices: mog, gaussian

--enable_gr0|1|True|False -> bool

Help: Enable Generative Replay (both first and second stage).

  • Default: 1

--prompt_modestr

Help: Prompt type for the second stage. - residual: STAR-Prompt style prompting. - concat: Prefix-Tuning style prompting.

  • Default: residual

  • Choices: residual, concat

--enable_confidence_modulation0|1|True|False -> bool

Help: Enable confidence modulation with CLIP similarities (Eq. 5 of the main paper)?

  • Default: 1

Tunable hyperparameters

--lambda_ortho_second_stagefloat

Help: orthogonality loss coefficient

  • Default: 10

--num_monte_carlo_gr_second_stageint

Help: how many times to sample from the dataset for alignment

  • Default: 1

--num_epochs_gr_second_stageint

Help: Num. of epochs for GR.

  • Default: 10

--learning_rate_gr_second_stagefloat

Help: Learning rate for GR.

  • Default: 0.001

--num_monte_carlo_gr_first_stageint

Help: how many times to sample from the dataset for alignment

  • Default: 1

--learning_rate_gr_first_stagefloat

Help: Learning rate for Generative Replay.

  • Default: 0.05

--lambda_ortho_first_stagefloat

Help: Orthogonality loss coefficient for coop

  • Default: 30

--num_epochs_gr_first_stageint

Help: Num. of epochs for Generative Replay.

  • Default: 10

First stage optimization hyperparameters

--first_stage_optimstr

Help: First stage optimizer

  • Default: sgd

  • Choices: sgd, adam

--first_stage_lrfloat

Help: First stage learning rate

  • Default: 0.002

--first_stage_momentumfloat

Help: First stage momentum

  • Default: 0

--first_stage_weight_decayfloat

Help: First stage weight decay

  • Default: 0

--first_stage_epochsint

Help: First stage epochs. If not set, it will be the same as n_epochs.

  • Default: None

Classes#

class models.starprompt.STARPrompt(backbone, loss, args, transform, dataset=None)[source]#

Bases: ContinualModel

Second-stage of StarPrompt. Requires the keys saved from the first stage.

COMPATIBILITY: List[str] = ['class-il', 'domain-il', 'task-il', 'general-continual']#
NAME: str = 'starprompt'#
begin_task(dataset)[source]#
end_task(dataset)[source]#
forward(x)[source]#
get_parameters()[source]#
static get_parser(parser)[source]#
Return type:

ArgumentParser

get_scheduler()[source]#
net: STARPromptModel#
observe(inputs, labels, not_aug_inputs, epoch=None)[source]#