import math
import os
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
from utils.conditional_bn import ConditionalBatchNorm1d
from utils.conditional_bn import ConditionalBatchNorm2d
from utils.conf import warn_once
[docs]
def get_rnd_weight(num_tasks, fin, fout=None, nonlinearity='relu'):
results = []
if fout is None:
fin = fout
for i in range(num_tasks):
mat = torch.zeros((fout, fin))
nn.init.kaiming_normal_(mat, mode='fan_out',
nonlinearity=nonlinearity)
results.append(mat.view(-1))
return torch.stack(results)
[docs]
class ConditionalLinear(nn.Module):
def __init__(self, fin: int, fout: int, n_tasks: int,
use_bn: bool = False, act_init: str = 'relu'):
super(ConditionalLinear, self).__init__()
self.fin, self.fout = fin, fout
self.n_tasks = n_tasks
self.weight = nn.Embedding(self.n_tasks, self.fin * self.fout) # C
self.condbn = None
if use_bn:
self.condbn = ConditionalBatchNorm1d(self.fout, self.n_tasks)
self.init_parameters(act_init)
[docs]
def init_parameters(self, act_init: str):
self.weight.weight.data.copy_(
get_rnd_weight(self.n_tasks, fin=self.fin,
fout=self.fout, nonlinearity=act_init))
[docs]
def forward(self, x, task_id):
weight = self.weight(task_id).view(-1, self.fout, self.fin)
x = x.unsqueeze(2) # B, fin, 1
x = torch.bmm(weight, x)
x = x.squeeze(2) # B, C
if self.condbn is not None:
x = self.condbn(x, task_id)
return x
[docs]
class DiverseLoss(nn.Module):
def __init__(self, lambda_loss: float, temp: float = 2.0):
super(DiverseLoss, self).__init__()
self.lambda_loss = lambda_loss
self.temp = temp
[docs]
def forward(self, logits: torch.Tensor):
c = logits.shape[1]
if len(logits.shape) > 2:
logits = F.adaptive_avg_pool2d(logits, 1).view(-1, c)
mean = torch.mean(logits, dim=1, keepdim=True)
std = torch.std(logits, dim=1, keepdim=True)
normalized_logits = (logits - mean) / std
dotlogits = torch.matmul(logits, logits.t()) / self.temp
batch_size = normalized_logits.shape[0]
loss = torch.logsumexp(dotlogits, dim=1).mean(0)
loss -= 1 / self.temp
loss -= math.log(batch_size)
return self.lambda_loss * loss
[docs]
class SoftAttentionSoftmax(nn.Module):
def __init__(self, fin: int, fout: int, n_tasks: int):
super(SoftAttentionSoftmax, self).__init__()
self.fin, self.fout = fin, fout
self.n_tasks = n_tasks
self.l = ConditionalLinear(fin, fout, n_tasks)
self.init_parameters()
[docs]
def forward(self, x, task_id):
logits = self.l(x, task_id)
rho = torch.softmax(logits, dim=-1)
return rho, logits
[docs]
class BinaryGumbelSoftmax(nn.Module):
def __init__(self, tau: float = (2. / 3.)):
super(BinaryGumbelSoftmax, self).__init__()
self.tau = tau
[docs]
def forward(self, logits):
if self.training:
if str(logits.device) == 'cpu':
warn_once('GumbelSoftmax may be unstable in CPU (see https://github.com/pytorch/pytorch/issues/101620)')
h = nn.functional.gumbel_softmax(logits, tau=self.tau, hard=True)
h = h[..., 0]
return h
h = torch.softmax(logits, -1)
h = 1. - torch.argmax(h, -1)
return h
[docs]
class HardAttentionSoftmax(nn.Module):
def __init__(self, fin: int, fout: int, n_tasks: int,
tau: float = (2. / 3.)):
super(HardAttentionSoftmax, self).__init__()
self.fin, self.fout = fin, fout
self.n_tasks = n_tasks
self.gumbel = BinaryGumbelSoftmax(tau)
self.l = ConditionalLinear(self.fin, 2 * self.fout, n_tasks)
[docs]
def forward(self, x, task_id, flag_stop_grad=None):
assert len(task_id) == len(x)
logits = self.l(x, task_id).view(-1, self.fout, 2)
h = self.gumbel(logits)
return h, logits
[docs]
class SpatialAttn(nn.Module):
def __init__(self, c: int, n_tasks: int, reduction_rate: int = 4):
super(SpatialAttn, self).__init__()
self.c_in = c
self.c_out = self.c_in // reduction_rate
self.n_tasks = n_tasks
self.eps = 1e-6
self.act = nn.ReLU()
self.conv1 = nn.Conv2d(self.c_in, self.c_out, kernel_size=1, stride=1)
self.condbn_1 = ConditionalBatchNorm2d(self.c_out, self.n_tasks)
self.conv2 = nn.Conv2d(self.c_out, self.c_out, kernel_size=3, stride=1,
dilation=2, padding=2)
self.condbn_2 = ConditionalBatchNorm2d(self.c_out, self.n_tasks)
self.conv3 = nn.Conv2d(self.c_out, self.c_out, kernel_size=3, stride=1,
dilation=2, padding=2)
self.condbn_3 = ConditionalBatchNorm2d(self.c_out, self.n_tasks)
self.conv4 = nn.Conv2d(self.c_out, 1, kernel_size=1, stride=1)
self.condbn_4 = ConditionalBatchNorm2d(1, self.n_tasks)
[docs]
def forward(self, fm_t: torch.Tensor, tasks_id: torch.Tensor):
x = fm_t
x = self.conv1(x)
x = self.condbn_1(x, tasks_id)
x = self.act(x)
x = self.conv2(x)
x = self.condbn_2(x, tasks_id)
x = self.act(x)
x = self.conv3(x)
x = self.condbn_3(x, tasks_id)
x = self.act(x)
x = self.conv4(x)
x = self.condbn_4(x, tasks_id)
return x
[docs]
class ChannelAttn(nn.Module):
def __init__(self, c: int, n_tasks: int, reduction_rate: int = 1,
activated_with_softmax: bool = False):
super(ChannelAttn, self).__init__()
self.c_in = c
self.c_out = self.c_in // reduction_rate
self.n_tasks = n_tasks
self.eps = 1e-6
self.activated_with_softmax = activated_with_softmax
self.l1 = ConditionalLinear(self.c_in, self.c_out, n_tasks,
use_bn=True, act_init='tanh')
self.l2 = ConditionalLinear(self.c_in, self.c_out, n_tasks,
use_bn=True, act_init='sigmoid')
self.lres = ConditionalLinear(self.c_in, self.c_out, n_tasks) # C
self.attn_act = None
if activated_with_softmax:
self.attn_act = HardAttentionSoftmax(self.c_out, self.c_in, n_tasks)
[docs]
def upsample(self, x, desired_shape):
return x
[docs]
def downsample(self, x, *args, **kwargs):
return x
[docs]
def compute_distance(self, fm_s, fm_t, rho,
use_overhaul_fd):
dist = (fm_s - fm_t) ** 2
if use_overhaul_fd:
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
dist = dist * mask
dist = dist.mean(dim=(2, 3))
dist = rho * dist
dist = dist.sum(1).mean(0)
return dist
[docs]
def forward(self, fm_t: torch.Tensor, tasks_id: torch.Tensor):
c = fm_t.shape[1] # b, c, h, w
x = F.adaptive_avg_pool2d(fm_t, 1).view(-1, c)
rho_a = self.l1(x, tasks_id)
rho_a = torch.tanh(rho_a)
rho_b = self.l2(x, tasks_id)
rho_b = torch.sigmoid(rho_b)
res = self.lres(x, tasks_id)
rho = rho_a * rho_b + res
if self.activated_with_softmax:
rho, logits = self.attn_act(rho, tasks_id)
return rho, logits
return rho
[docs]
class DoubleAttn(nn.Module):
def __init__(self, c: int, n_tasks: int, reduction_rate: int = 4):
super(DoubleAttn, self).__init__()
self.c = c
self.n_tasks = n_tasks
self.channel_attn = ChannelAttn(c, n_tasks, reduction_rate=1,
activated_with_softmax=False)
self.spatial_attn = SpatialAttn(c, n_tasks, reduction_rate=reduction_rate)
self.weight = nn.Embedding(self.n_tasks, self.c * (self.c * 2))
self.gumbel = BinaryGumbelSoftmax()
[docs]
def init_parameters(self):
self.weight.weight.data.copy_(
get_rnd_weight(self.n_tasks, self.c, self.c * 2))
[docs]
def compute_distance(self, fm_s, fm_t, rho,
use_overhaul_fd):
dist = (fm_s - fm_t) ** 2
if use_overhaul_fd:
mask = 1.0 - ((fm_s <= fm_t) & (fm_t <= 0.0)).float()
dist = dist * mask
dist = rho * dist
dist = dist.mean(dim=(2, 3))
dist = dist.sum(1).mean(0)
return dist
[docs]
def upsample(self, x, desired_shape):
_, c, h, w = x.shape
cd, hd, wd = desired_shape
assert cd == c and h <= hd and w <= wd
if h == hd and w == wd:
return x
return F.interpolate(x, (hd, wd))
[docs]
def downsample(self, x, min_resize_threshold=16):
_, c, h, w = x.shape
if h < min_resize_threshold:
return x
return F.interpolate(x, (h // 2, w // 2))
[docs]
def forward(self, fm_t: torch.Tensor, tasks_id: torch.Tensor):
ch_attn = self.channel_attn(fm_t, tasks_id)
sp_attn = self.spatial_attn(fm_t, tasks_id)
if 'ablation_type' in os.environ:
if os.environ['ablation_type'] == 'chan_only':
sp_attn = torch.ones_like(sp_attn)
elif os.environ['ablation_type'] == 'space_only':
ch_attn = torch.ones_like(ch_attn)
ch_attn = ch_attn.unsqueeze(2).unsqueeze(3)
x = ch_attn + sp_attn
weight = self.weight(tasks_id).view(-1, self.c * 2, self.c)
logits = torch.einsum('bji,bixy->bjxy', weight, x)
_, _, h, w = x.shape
x = logits.permute((0, 2, 3, 1))
x = x.view(-1, w, h, self.c, 2)
rho = self.gumbel(x)
logits = logits.view(-1, self.c, 2, w, h)
rho = rho.permute((0, 3, 1, 2))
return rho, logits
[docs]
class Normalize(nn.Module):
def __init__(self, eps: float = 1e-6):
super(Normalize, self).__init__()
self.eps = eps
[docs]
def forward(self, x):
norm = torch.norm(x, dim=(2, 3), keepdim=True)
return torch.div(x, norm + self.eps)
[docs]
class TeacherForcingLoss(nn.Module):
def __init__(self, teacher_forcing_or: bool, lambda_forcing_loss: float):
super(TeacherForcingLoss, self).__init__()
self.teacher_forcing_or = teacher_forcing_or
self.lambda_forcing_loss = lambda_forcing_loss
self.register_buffer('index', torch.LongTensor([0]))
[docs]
def forward(self, logits, pred, target, teacher_forcing):
logits, pred, target = logits[teacher_forcing], pred[teacher_forcing], target[teacher_forcing]
logits = torch.index_select(logits, dim=2, index=self.index).squeeze(2)
teacher_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
if self.teacher_forcing_or:
mask = (1. - pred) * target
teacher_loss = mask * teacher_loss
teacher_loss = teacher_loss.mean() if len(teacher_loss) > 0 else .0
teacher_loss *= self.lambda_forcing_loss
return teacher_loss
[docs]
class MultiTaskAFDAlternative(nn.Module):
def __init__(self, chw: Tuple[int], n_tasks: int,
cpt: int, clear_grad: bool = False,
use_overhaul_fd: bool = False,
lambda_diverse_loss: float = 0.0,
use_hard_softmax: bool = True,
teacher_forcing_or: bool = False,
lambda_forcing_loss: float = 0.0,
attn_mode: str = 'ch',
resize_maps: bool = False,
min_resize_threshold: int = 16):
super(MultiTaskAFDAlternative, self).__init__()
assert use_hard_softmax, 'use_hard_softmax must be True'
assert attn_mode in ['ch', 'chsp'], 'wrong value of attn_mode'
self.c, self.h, self.w = chw
self.n_tasks = n_tasks
self.cpt = cpt
self.clear_grad = clear_grad
self.use_overhaul_fd = use_overhaul_fd
self.teacher_forcing_or = teacher_forcing_or
self.lambda_forcing_loss = lambda_forcing_loss
self.resize_maps = resize_maps
self.min_resize_threshold = min_resize_threshold
self.attn_fn = None
if attn_mode == 'ch':
self.attn_fn = ChannelAttn(self.c, n_tasks, activated_with_softmax=True)
elif attn_mode == 'chsp':
self.attn_fn = DoubleAttn(self.c, self.n_tasks)
else:
raise ValueError
self.teacher_forcing_loss = TeacherForcingLoss(self.teacher_forcing_or, self.lambda_forcing_loss)
self.teacher_transform = TeacherTransform()
self.norm = Normalize()
self.diverse_loss = DiverseLoss(lambda_diverse_loss)
[docs]
def get_tasks_id(self, targets):
if 'ablation_type' in os.environ and os.environ['ablation_type'] == 'non_cond':
return torch.zeros_like(targets)
return torch.div(targets, self.cpt, rounding_mode='floor')
[docs]
def extend_like(self, teacher_forcing, y):
dest_shape = (-1,) + (1,) * (len(y.shape) - 1)
return teacher_forcing.view(dest_shape).expand(y.shape)
[docs]
def forward(self, fm_s, fm_t, targets, teacher_forcing, attention_map):
assert len(targets) == len(fm_s) == len(fm_t) == len(teacher_forcing) == len(attention_map)
output_rho, logits = self.attn_fn(fm_t, self.get_tasks_id(targets))
rho = output_rho
loss = .0
if not self.lambda_forcing_loss > 0.0:
if teacher_forcing.any():
if self.resize_maps:
attention_map = self.attn_fn.upsample(attention_map, fm_t.shape[1:])
p1 = torch.max(attention_map, output_rho) if self.teacher_forcing_or else attention_map
rho = torch.where(self.extend_like(teacher_forcing, attention_map), p1, output_rho)
else:
rho = output_rho
elif teacher_forcing.any():
if 'ablation_type' not in os.environ or os.environ['ablation_type'] != 'no_mask_replay':
if self.resize_maps:
attention_map = self.attn_fn.upsample(attention_map, fm_t.shape[1:])
loss += self.teacher_forcing_loss(logits, output_rho,
attention_map, teacher_forcing)
if self.use_overhaul_fd:
fm_t = self.teacher_transform(fm_t, targets)
fm_t, fm_s = self.norm(fm_t), self.norm(fm_s)
loss += self.attn_fn.compute_distance(fm_s, fm_t, rho, self.use_overhaul_fd)
if 'ablation_type' not in os.environ or os.environ['ablation_type'] != 'no_diverse':
loss += self.diverse_loss(rho[~teacher_forcing])
if self.resize_maps:
output_rho = self.attn_fn.downsample(output_rho, min_resize_threshold=self.min_resize_threshold)
return loss, output_rho