BUFFER LWS#

Custom buffer for Learning without Shortcuts (LwS).

Classes#

class utils.buffer_lws.Buffer(buffer_size, device, n_tasks, attributes=['examples', 'labels', 'logits', 'task_labels'], n_bin=8)[source]#

Bases: Dataset

add_data(examples, labels=None, clusters_labels=None, logits=None, clusters_logits=None, task_labels=None, loss_values=None)[source]#
empty()[source]#

Set all the tensors to None.

get_bin_index(loss_value)[source]#

Determines the bin index for a given loss value.

get_data(size, transform=None, return_index=False, to_device=None)[source]#
Return type:

Tuple

get_losses()[source]#
get_task_labels()[source]#
init_tensors(examples, labels, logits, task_labels, clusters_labels=None, clusters_logits=None, loss_values=None)[source]#
is_empty()[source]#

Returns true if the buffer is empty, false otherwise.

Return type:

bool

reservoir_bin_loss(loss_value)[source]#

Modified reservoir sampling algorithm considering loss values and binning.

Return type:

int

reservoir_loss(num_seen_examples, buffer_size, loss_value)[source]#

Modified reservoir sampling algorithm considering loss values

Return type:

int

reset_bins()[source]#
reset_budget()[source]#
to(device)[source]#
update_loss_range(loss_value)[source]#

Updates the min and max loss values seen, for binning purposes.

update_losses(loss_values, indexes)[source]#