Source code for models.coda_prompt_utils

"""
This package contains utility functions for the CoDA Prompt model. Implements a custom version of ViT to add prompt parameters.
"""

import copy

import torch


[docs] def gram_schmidt(vv, start_c, end_c, return_in_parameter=True): """ Code for this function is modified from: https://github.com/legendongary/pytorch-gram-schmidt/blob/master/gram_schmidt.py Perform Gram-Schmidt orthogonalization on the input matrix vv. """ def projection(u, v): denominator = (u * u).sum() if denominator < 1e-8: return None else: return (v * u).sum() / denominator * u # check if the tensor is 3D and flatten the last two dimensions if necessary is_3d = len(vv.shape) == 3 if is_3d: shape_2d = copy.deepcopy(vv.shape) vv = vv.view(vv.shape[0], -1) # swap rows and columns vv = vv.T # process matrix size uu = torch.zeros_like(vv, device=vv.device) if start_c > 0: uu[:, 0:start_c] = vv[:, 0:start_c].clone() for k in range(start_c, end_c): redo = True while redo: redo = False vk = torch.randn_like(vv[:, k]).to(vv.device) uk = 0 for j in range(0, k): if not redo: uj = uu[:, j].clone() proj = projection(uj, vk) if proj is None: redo = True print('restarting!!!') else: uk = uk + proj if not redo: uu[:, k] = vk - uk for k in range(start_c, end_c): uk = uu[:, k].clone() uu[:, k] = uk / (uk.norm()) # undo swapping of rows and columns uu = uu.T # return from 2D if is_3d: uu = uu.view(shape_2d) if return_in_parameter: return torch.nn.Parameter(uu) return uu