AFD#

Classes#

class models.twf_utils.afd.BinaryGumbelSoftmax(tau=0.6666666666666666)[source]#

Bases: Module

forward(logits)[source]#
class models.twf_utils.afd.ChannelAttn(c, n_tasks, reduction_rate=1, activated_with_softmax=False)[source]#

Bases: Module

compute_distance(fm_s, fm_t, rho, use_overhaul_fd)[source]#
downsample(x, *args, **kwargs)[source]#
forward(fm_t, tasks_id)[source]#
upsample(x, desired_shape)[source]#
class models.twf_utils.afd.ConditionalLinear(fin, fout, n_tasks, use_bn=False, act_init='relu')[source]#

Bases: Module

forward(x, task_id)[source]#
init_parameters(act_init)[source]#
class models.twf_utils.afd.DiverseLoss(lambda_loss, temp=2.0)[source]#

Bases: Module

forward(logits)[source]#
class models.twf_utils.afd.DoubleAttn(c, n_tasks, reduction_rate=4)[source]#

Bases: Module

compute_distance(fm_s, fm_t, rho, use_overhaul_fd)[source]#
downsample(x, min_resize_threshold=16)[source]#
forward(fm_t, tasks_id)[source]#
init_parameters()[source]#
upsample(x, desired_shape)[source]#
class models.twf_utils.afd.HardAttentionSoftmax(fin, fout, n_tasks, tau=0.6666666666666666)[source]#

Bases: Module

forward(x, task_id, flag_stop_grad=None)[source]#
class models.twf_utils.afd.MultiTaskAFDAlternative(chw, n_tasks, cpt, clear_grad=False, use_overhaul_fd=False, lambda_diverse_loss=0.0, use_hard_softmax=True, teacher_forcing_or=False, lambda_forcing_loss=0.0, attn_mode='ch', resize_maps=False, min_resize_threshold=16)[source]#

Bases: Module

extend_like(teacher_forcing, y)[source]#
forward(fm_s, fm_t, targets, teacher_forcing, attention_map)[source]#
get_tasks_id(targets)[source]#
class models.twf_utils.afd.Normalize(eps=1e-06)[source]#

Bases: Module

forward(x)[source]#
class models.twf_utils.afd.SoftAttentionSoftmax(fin, fout, n_tasks)[source]#

Bases: Module

forward(x, task_id)[source]#
class models.twf_utils.afd.SpatialAttn(c, n_tasks, reduction_rate=4)[source]#

Bases: Module

forward(fm_t, tasks_id)[source]#
class models.twf_utils.afd.StudentTransform(chw, n_tasks, cpt)[source]#

Bases: Module

forward(fm_s, tasks_id)[source]#
init_parameters()[source]#
class models.twf_utils.afd.TeacherForcingLoss(teacher_forcing_or, lambda_forcing_loss)[source]#

Bases: Module

forward(logits, pred, target, teacher_forcing)[source]#
class models.twf_utils.afd.TeacherTransform[source]#

Bases: Module

forward(fm_t, targets)[source]#
get_margin(fm, eps=1e-06)[source]#

Functions#

models.twf_utils.afd.get_rnd_weight(num_tasks, fin, fout=None, nonlinearity='relu')[source]#