CGIL#

Arguments#

Options

--clip_backbonestr

Help: Clip backbone

  • Default: ViT-L/14

--learning_rate_alignmentfloat

Help: Learning rate for GR.

  • Default: 0.05

--optim_alignmentstr

Help: Optimizer for GR.

  • Default: adamw

  • Choices: sgd, adam, adamw

--optim_alignment_wdfloat

Help: Weight decay for GR.

  • Default: 0

--lambda_ortho_first_stagefloat

Help: Orthogonality loss coefficient for coop

  • Default: 1

--num_epochs_alignmentint

Help: Num. of epochs for GR.

  • Default: 30

--batch_size_alignmentint

Help: Batch size for alignment.

  • Default: 128

--gr_mog_n_componentsint

Help: Number of components for GR with MOG.

  • Default: 5

--gr_mog_n_itersint

Help: Number of EM iterations during fit for GR with MOG.

  • Default: 500

--gr_vae_hidden_dimint

Help: Hidden dimension for GR with VAE.

  • Default: 512

--gr_vae_latent_dimint

Help: Latent dimension for GR with VAE.

  • Default: 256

--gr_vae_n_itersint

Help: Number of iterations for GR with VAE.

  • Default: 500

--train_only_current_promptsint

Help: Train only current prompts.

  • Default: 0

  • Choices: 0, 1

--align_with_ortholossint

Help: Align with orthogonality loss.

  • Default: 0

  • Choices: 0, 1

--lr_vaefloat

Help: Learning rate for VAE.

  • Default: 0.0002

--general_contextint

Help: Use general context (number of contexts created).

  • Default: 0

--generated_contextint

Help: Use generated context.

  • Default: 0

--cocoopint

Help: Use image embedding to generate context.

  • Default: 0

--combo_contextint

Help: Use both generated and prompt context.

  • Default: 1

--n_contextint

Help: Use both generated and prompt context.

  • Default: 1

--g_modelsstr

Help: Generative model to use for alignment

  • Default: vae

  • Choices: vae, mog, gauss, diffusion

Classes#

class models.cgil.CGIL(backbone, loss, args, transform, dataset=None)[source]#

Bases: FutureModel

COMPATIBILITY: List[str] = ['class-il', 'domain-il', 'task-il', 'general-continual']#
NAME: str = 'cgil'#
begin_task(dataset)[source]#
change_transform(dataset)[source]#
end_task(dataset)[source]#
forward(x)[source]#
Return type:

Tensor

future_forward(x)[source]#
Return type:

Tensor

static get_parser(parser)[source]#
Return type:

ArgumentParser

observe(*args, **kwargs)[source]#