# this perturber model was based on the code by AWP (https://github.com/csdongxian/AWP) and of course, modified to our needs
import torch
import torch.nn.functional as F
from collections import OrderedDict
from backbone import get_backbone
[docs]
def add_perturb_args(parser):
parser.add_argument('--p-steps', type=int, default=1)
parser.add_argument('--p-lam', type=float, default=0.01)
parser.add_argument('--p-gamma', type=float, default=0.05, help='how far we can go from original weights')
EPS = 1E-20
[docs]
def diff_in_weights(model, proxy):
with torch.no_grad():
diff_dict = OrderedDict()
model_state_dict = model.state_dict()
proxy_state_dict = proxy.state_dict()
for (old_k, old_w), (new_k, new_w) in zip(model_state_dict.items(), proxy_state_dict.items()):
if len(old_w.size()) <= 1:
continue
if 'weight' in old_k:
diff_w = new_w - old_w
diff_dict[old_k] = diff_w # old_w.norm() / (diff_w.norm() + EPS) *
return diff_dict
[docs]
def add_into_weights(model, diff, coeff=1.0):
names_in_diff = diff.keys()
with torch.no_grad():
for name, param in model.named_parameters():
if name in names_in_diff:
param.add_(coeff * diff[name])
[docs]
def normalize(perturbations, weights):
perturbations.mul_(weights.norm() / (perturbations.norm() + EPS))
[docs]
def normalize_grad(weights, ref_weights):
with torch.no_grad():
for w, ref_w in zip(weights, ref_weights):
if w.dim() <= 1:
w.grad.data.fill_(0) # ignore perturbations with 1 dimension (e.g. BN, bias)
else:
normalize(w.grad.data, ref_w)
[docs]
class Perturber():
EPS = 1E-20
def __init__(self, continual_model):
self.continual_model = continual_model
self.device = continual_model.device
self.args = continual_model.args
self.net = continual_model.net
self.proxy = get_backbone(self.args).to(self.device)
self.steps = self.args.p_steps
self.lam = self.args.p_lam
self.gamma = self.args.p_gamma
self.diff = None
[docs]
def init_rand(self, model):
with torch.no_grad():
for w in model.parameters():
if w.dim() <= 1:
continue
else:
# z = torch.randn_like(w) # uncomment for random perturbations
# z = z/torch.linalg.norm(z) #
w.add_(torch.randn_like(w) * torch.norm(w) * EPS) # z *torch.norm(w) * self.gamma) # This is changed for ablation
[docs]
def perturb_model(self, X, y):
out_o = F.softmax(self.net(X), dim=-1).detach()
self.proxy.load_state_dict(self.net.state_dict())
# initialize small random noise (delta = 0 is global minimizer)
self.init_rand(self.proxy)
self.proxy.train()
pertopt = torch.optim.SGD(self.proxy.parameters(), lr=self.gamma / self.steps)
# perturb the model
mask = torch.where(out_o.max(1)[1] == y, 1., 0.).detach() # This is changed for ablation
if mask.sum() < 2:
return None, mask
for idx in range(self.steps): # to have multiple steps (set to 1 step by default)
pertopt.zero_grad()
loss = -(F.kl_div(F.log_softmax(self.proxy(X), dim=1), out_o, reduction='none').sum(dim=1) * mask).sum() / mask.sum()
loss.backward()
normalize_grad(self.proxy.parameters(), self.net.parameters())
pertopt.step()
# calculate the weight perturbation and add onto original network
self.diff = diff_in_weights(self.net, self.proxy)
add_into_weights(self.net, self.diff, coeff=1.0)
return out_o, mask
[docs]
def get_loss(self, X, y):
outs, mask = self.perturb_model(X, y)
out_n = self.net(X)
if outs is not None:
loss_kl = self.lam * (F.kl_div(F.log_softmax(out_n, dim=1), outs, reduction='none').sum(dim=1) * mask).sum() / mask.sum()
return loss_kl
else:
return None
[docs]
def restore_model(self):
add_into_weights(self.net, self.diff, coeff=-1.0)
def __call__(self, X, y):
X = X.to(self.device)
y = y.to(self.device)
loss_kl = self.get_loss(X, y)
if loss_kl is not None:
loss_kl.backward()
self.restore_model()