Source code for models.tak_utils.ties_merging

import torch


## TIES MERGING UTILS
[docs] def topk_values_mask(M, K=0.7, return_mask=False): if K > 1: K /= 100 if K >= 1 and return_mask: return M, torch.ones_like(M).mean(dim=-1), torch.ones_like(M) elif K >= 1: return M, torch.ones_like(M).mean(dim=-1) original_shape = M.shape if M.dim() == 1: M = M.unsqueeze(0) n, d = M.shape k = int(d * K) k = d - k # Keep top k elements instead of bottom k elements # Find the k-th smallest element by magnitude for each row if M.flatten().shape[-1] == 1: kth_values = M.abs() else: kth_values, _ = M.abs().kthvalue(k, dim=1, keepdim=True) # Create a mask tensor with True for the top k elements in each row mask = M.abs() >= kth_values if original_shape == M.squeeze().shape: final_mask = mask.squeeze() M = M.squeeze() else: final_mask = mask if return_mask: return M * final_mask, final_mask.float().mean(dim=-1), final_mask return M * final_mask, final_mask.float().mean(dim=-1)
[docs] def resolve_zero_signs(sign_to_mult, method="majority"): majority_sign = torch.sign(sign_to_mult.sum()) if method == "majority": sign_to_mult[sign_to_mult == 0] = majority_sign elif method == "minority": sign_to_mult[sign_to_mult == 0] = -1 * majority_sign return sign_to_mult
[docs] def chunked_disjoint_mean(vectors, chunk_size=10000): num_chunks = vectors.size(0) // chunk_size + (1 if vectors.size(0) % chunk_size != 0 else 0) total_sum = torch.zeros_like(vectors[0]) non_zero_counts = torch.zeros_like(vectors[0]) for i in range(num_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, vectors.size(0)) chunk = vectors[start_idx:end_idx] # Calculate sum and non-zero counts for the chunk total_sum += torch.sum(chunk, dim=0) non_zero_counts += (chunk != 0).sum(dim=0) # Compute the disjoint mean disjoint_aggs = total_sum / torch.clamp(non_zero_counts.float(), min=1) disjoint_aggs[non_zero_counts == 0] = 0 return disjoint_aggs
[docs] def chunked_sum(tensor, chunk_size=10000): num_chunks = tensor.size(0) // chunk_size + (1 if tensor.size(0) % chunk_size != 0 else 0) total_sum = torch.zeros_like(tensor[0]) for i in range(num_chunks): start_idx = i * chunk_size end_idx = min((i + 1) * chunk_size, tensor.size(0)) chunk = tensor[start_idx:end_idx] # Add the sum of the current chunk to the total sum total_sum += torch.sum(chunk, dim=0) return total_sum
[docs] def disjoint_merge(Tensor, merge_func, reference_sign_to_mult, weights=None): # If sign is provided then we select the corresponding entries and aggregate. if reference_sign_to_mult is not None: rows_to_keep = torch.where( reference_sign_to_mult.unsqueeze(0) > 0, Tensor > 0, Tensor < 0 ) # Else we select all non-zero entries and aggregate. else: rows_to_keep = Tensor != 0 selected_entries = Tensor * rows_to_keep if weights is not None: for selected_entrie in selected_entries: selected_entrie *= weights[0] if merge_func == "mean": non_zero_counts = (selected_entries != 0).sum(dim=0).float() disjoint_aggs = torch.sum(selected_entries, dim=0) / torch.clamp( non_zero_counts, min=1 ) elif merge_func == "sum": disjoint_aggs = chunked_sum(selected_entries) elif merge_func == "max": disjoint_aggs = selected_entries.abs().max(dim=0)[0] disjoint_aggs *= reference_sign_to_mult elif merge_func == 'unmerged': disjoint_aggs = selected_entries else: raise ValueError(f"Merge method {merge_func} is not defined.") return disjoint_aggs, rows_to_keep
[docs] def resolve_sign(Tensor, mode=None): sign_to_mult = torch.sign(Tensor.sum(dim=0)) sign_to_mult = resolve_zero_signs(sign_to_mult, "majority") return sign_to_mult
[docs] def ties_merging(vectors, topK=20, merging_type='mean', weights=None, **kwargs): # Add functionality that allows some layers to not be pruned or lets them be skipped #print(f'TopK is: {topK}') #print(f'weights is: {weights}') original_shape = vectors[0].shape flat_list = [v.reshape(-1) for v in vectors] stacked_vectors = torch.vstack(flat_list).clone() # (n_task, dim) pruned_vectors, _, mask = topk_values_mask( stacked_vectors, K=topK, return_mask=True ) vector_signs = resolve_sign(pruned_vectors) assert vector_signs is not None merged_tv, rows_to_keep = disjoint_merge(pruned_vectors, merging_type, vector_signs, weights) merged_tv = merged_tv.reshape(original_shape) return merged_tv, rows_to_keep, mask