Source code for delve.torchcallback

import os
import warnings
from collections import OrderedDict
from itertools import product
from typing import Any, Dict, List, Optional, Union

import torch
from torch.nn.functional import interpolate
from torch.nn.modules import LSTM, Module
from torch.nn.modules.conv import Conv2d
from torch.nn.modules.linear import Linear

import delve
from delve.logger import log
from delve.metrics import *
from delve.torch_utils import TorchCovarianceMatrix
from delve.writers import STATMAP, WRITERS, CompositWriter, NPYWriter

from typing import Callable


[docs]class SaturationTracker(object): """Takes PyTorch module and records layer saturation, intrinsic dimensionality and other scalars. Args: savefile (str) : destination for summaries save_to (str, List[Union[str, delve.writers.AbstractWriter]]: Specify one or multiple save strategies. You can use preimplemented save strategies or inherit from the AbstractWriter in order to implement your own preferred saving strategy. pre-existing saving strategies are: csv : stores all stats in a csv-file with one row for each epoch. plot : produces plots from intrinsic dimensionality and / or layer saturation tensorboard : saves all stats to tensorboard print : print all metrics on console as soon as they are logged npy : creates a folder-structure with npy-files containing the logged values. This is the only save strategy that can save the full covariance matrix. This strategy is useful if you want to reproduce intrinsic dimensionality and saturation values with other thresholds without re-evaluating model checkpoints. modules (torch modules or list of modules) : layer-containing object. Per default, only Conv2D, Linear and LSTM-Cells are recorded layer_filter (func): A filter function that is used to avoid layers from being tracked. This is function receiving a dictionary as input and returning it with undesired entries removed. Default: Identity function. The dictionary contains string keys mapping to torch.nn.Module objects. writers_args (dict) : contains additional arguments passed over to the writers. This is only used, when a writer is initialized through a string-key. log_interval (int) : distances between two batches used for updating the covariance matrix. Default value is 1, which means that all data is used for computing intrinsic dimensionality and saturation. Increasing the log interval is usefull on very large datasets to reduce numeric instability. max_samples (int) : (optional) the covariance matrix in each layer will halt updating itself when max_samples are reached. Usecase is similar to log-interval, when datasets are very large. stats (list of str): list of stats to compute supported stats are: idim : intrinsic dimensionality lsat : layer saturation (intrinsic dimensionality divided by feature space dimensionality) cov : the covariance-matrix (only saveable using the 'npy' save strategy) det : the determinant of the covariance matrix (also known as generalized variance) trc : the trace of the covariance matrix, generally a more useful metric than det for determining the total variance of the data than the determinant. However note that this does not take the correlation between features into account. On the other hand, in most cases the determinent will be zero, since there will be very strongly correlated features, so trace might be the better option. dtrc : the trace of the diagonalmatrix, another way of measuring the dispersion of the data. lsat : layer saturation (intrinsic dimensionality divided by feature space dimensionality) embed : samples embedded in the eigenspace of dimension 2 layerwise_sat (bool): whether or not to include layerwise saturation when saving reset_covariance (bool): True by default, resets the covariance every time the stats are computed. Disabling this option will strongly bias covariance since the gradient will influence the model. We recommend computing saturation at the end of training and testing. include_conv : setting to False includes only linear layers conv_method (str) : how to subsample convolutional layers. Default is channelwise, which means that the each position of the filter tensor is considered a datapoint, effectivly yielding a data matrix of shape (height*width*batch_size, num_filters) supported methods are: channelwise : treats every depth vector of the tensor as a datapoint, effectivly reshaping the data tensor from shape (batch_size, height, width, channel) into (batch_size*height*width, channel). mean : applies global average pooling on each feature map max : applies global max pooling on each feature map median : applies global median pooling on each feature map flatten : flattenes the entire feature map to a vector, reshaping the data tensor into a data matrix of shape (batch_size, height*width*channel). This strategy for dealing with convolutions is extremly memory intensive and will likely cause memory and performance problems for any non toy-problem timeseries_method (str) : how to subsample timeseries methods. Default is last_timestep. supported methods are: timestepwise : stacks each sample timestep-by-timestep last_timestep : selects the last timestep's output nosave (bool) : If True, disables saving artifacts (images), default is False verbose (bool) : print saturation for every layer during training sat_threshold (float): threshold used to determine the number of eigendirections belonging to the latent space. In effect, this is the threshold determining the the intrinsic dimensionality. Default value is 0.99 (99% of the explained variance), which is a compromise between a good and interpretable approximation. From experience the threshold should be between 0.97 and 0.9995 for meaningfull results. verbose (bool) : Change verbosity level (default is 0) device (str) : Device to do the computations on. Default is cuda:0. Generally it is recommended to do the computations on the gpu in order to get maximum performance. Using the cpu is generally slower but it lets delve use regular RAM instead of the generally more limited VRAM of the GPU. Not having delve run on the same device as the network causes slight performance decrease due to copying memory between devices during each forward pass. Delve can handle models distributed on multiple GPUs, however delve itself will always run on a single device. initial_epoch (int) : The initial epoch to start with. Default is 0, which corresponds to a new run. If initial_epoch != 0 the writers will look for save states that they can resume. If set to zero, all existing states will be overwritten. If set to a lower epoch than actually recorded the behavior of the writers is undefined and may result in crashes, loss of data or corrupted data. interpolation_strategy (str) : Default is None (disabled). If set to a string key accepted by the model-argument of torch.nn.functional.interpolate, the feature map will be resized to match the interpolated size. This is useful if you work with large resolutions and want to save up on computation time. is done if the resolution is smaller. interpolation_downsampling (int): Default is 32. The target resolution if downsampling is enabled. """ def __init__(self, savefile: str, save_to: Union[str, delve.writers.AbstractWriter], modules: Module, layer_filter: Callable[[Dict[str, Module]], Dict[str, Module]] = lambda x: x, writer_args: Optional[Dict[str, Any]] = None, log_interval=1, max_samples=None, stats: list = ['lsat'], layerwise_sat: bool = True, reset_covariance: bool = True, average_sat: bool = False, ignore_layer_names: List[str] = [], include_conv: bool = True, conv_method: str = 'channelwise', timeseries_method: str = 'last_timestep', sat_threshold: str = .99, nosave=False, verbose: bool = False, device='cuda:0', initial_epoch: int = 0, interpolation_strategy: Optional[str] = None, interpolation_downsampling: int = 32): self.nosave = nosave self.verbose = verbose # self.disable_compute: bool = False self.include_conv = include_conv self.conv_method = conv_method self.timeseries_method = timeseries_method self.threshold = sat_threshold self.layers = layer_filter(self.get_layers_recursive(modules)) self.max_samples = max_samples self.log_interval = log_interval self.reset_covariance = reset_covariance self.initial_epoch = initial_epoch self.interpolation_strategy = interpolation_strategy self.interpolation_downsampling = interpolation_downsampling writer_args = writer_args or {} writer_args['savepath'] = savefile os.makedirs(savefile, exist_ok=True) self.writer = self._get_writer(save_to, writer_args) self.interval = log_interval self._warn_if_covariance_not_saveable(stats) self.logs, self.stats = self._check_stats(stats) self.layerwise_sat = layerwise_sat self.average_sat = average_sat self.ignore_layer_names = ignore_layer_names self.seen_samples = {'train': {}, 'eval': {}} self.global_steps = 0 self.global_hooks_registered = False self.is_notebook = None self.device = device self.record = True for name, layer in self.layers.items(): if isinstance(layer, Conv2d) or isinstance(layer, Linear) \ or isinstance(layer, LSTM): self._register_hooks(layer=layer, layer_name=name, interval=log_interval) if self.initial_epoch != 0: self.writer.resume_from_saved_state(self.initial_epoch) def _warn_if_covariance_not_saveable(self, stats: List[str]): warn = False if 'cov' in stats: if isinstance(self.writer, CompositWriter): for writer in self.writer.writers: if isinstance(writer, NPYWriter): return warn = True elif not isinstance(self.writer, NPYWriter): warn = True if warn: warnings.warn("'cov' was selected as stat, but 'npy' (NPYWriter)" "is not used as a save strategy, which is the only" "writer able to save the covariance matrix. The" "training and logging will run normally, but the" "covariance matrix will not be saved. Note that you" "can add multiple writers by passing a list.") def __getattr__(self, name): if name.startswith('add_') and name != 'add_saturations': if not self.nosave: return getattr(self.writer, name) else: def noop(*args, **kwargs): log.info( f'Logging disabled, not logging: {args}, {kwargs}') pass return noop else: try: # Redirect to writer object return self.writer.__getattribute__(name) except Exception: # Default behaviour return self.__getattribute__(name) def __repr__(self): return self.layers.keys().__repr__()
[docs] def is_recording(self) -> bool: return self.record
[docs] def stop(self): self.record = False
[docs] def resume(self): self.record = True
[docs] def close(self): """User endpoint to close writer and progress bars.""" return self.writer.close()
def _format_saturation(self, saturation_status): raise NotImplementedError def _check_stats(self, stats: list): if not isinstance(stats, list): stats = list(stats) supported_stats = [ 'lsat', 'idim', 'cov', 'det', 'trc', 'dtrc', 'embed', ] compatible = [ stat in supported_stats if "_" not in stat else stat.split("_")[0] in stats for stat in stats ] incompatible = [i for i, x in enumerate(compatible) if not x] assert all(compatible), "Stat {} is not supported".format( stats[incompatible[0]]) name_mapper = STATMAP logs = { f'{mode}-{name_mapper[stat]}': OrderedDict() for mode, stat in product(['train', 'eval'], ['cov']) } return logs, stats def _add_conv_layer(self, layer: torch.nn.Module): layer.out_features = layer.out_channels layer.conv_method = self.conv_method def _add_lstm_layer(self, layer: torch.nn.Module): layer.out_features = layer.hidden_size layer.timeseries_method = self.timeseries_method
[docs] def get_layer_from_submodule(self, submodule: torch.nn.Module, layers: dict, name_prefix: str = ''): if len(submodule._modules) > 0: for idx, (name, subsubmodule) in \ enumerate(submodule._modules.items()): new_prefix = name if name_prefix == '' else name_prefix + \ '-' + name self.get_layer_from_submodule(subsubmodule, layers, new_prefix) return layers else: layer_name = name_prefix layer_type = layer_name if not self._check_is_supported_layer(submodule): log.info(f"Skipping {layer_type}") return layers if isinstance(submodule, Conv2d) and self.include_conv: self._add_conv_layer(submodule) layers[layer_name] = submodule log.info('added layer {}'.format(layer_name)) return layers
def _check_is_supported_layer(self, layer: torch.nn.Module) -> bool: return isinstance(layer, Conv2d) or isinstance( layer, Linear) or isinstance(layer, LSTM)
[docs] def get_layers_recursive(self, modules: Union[list, torch.nn.Module]): layers = {} if not isinstance(modules, list) and not hasattr( modules, 'out_features'): # submodules = modules._modules # OrderedDict layers = self.get_layer_from_submodule(modules, layers, '') elif self._check_is_supported_layer(modules): for module in modules: layers = self.get_layer_from_submodule(module, layers, type(module)) else: for i, module in enumerate(modules): layers = self.get_layer_from_submodule( module, layers, '' if not self._check_is_supported_layer(module) else f'Module-{i}-{type(module).__name__}') return layers
[docs] def _get_writer(self, save_to, writers_args) -> \ delve.writers.AbstractWriter: """Create a writer to log history to `writer_dir`.""" if issubclass(type(save_to), delve.writers.AbstractWriter): return save_to if isinstance(save_to, list): all_writers = [] for saver in save_to: all_writers.append( self._get_writer(save_to=saver, writers_args=writers_args)) return CompositWriter(all_writers) if save_to in WRITERS: writer = WRITERS[save_to](**writers_args) else: raise ValueError( 'Illegal argument for save_to "{}"'.format(save_to)) return writer
def _register_hooks(self, layer: torch.nn.Module, layer_name: str, interval): layer.eval_layer_history = getattr(layer, 'eval_layer_history', list()) layer.train_layer_history = getattr(layer, 'train_layer_history', list()) layer.layer_svd = getattr(layer, 'layer_svd', None) layer.forward_iter = getattr(layer, 'forward_iter', 0) layer.interval = getattr(layer, 'interval', interval) layer.writer = getattr(layer, 'writer', self.writer) layer.name = getattr(layer, 'name', layer_name) self.register_forward_hooks(layer, self.stats) return self def _record_stat(self, activations_batch: torch.Tensor, lstm_ae: bool, layer: torch.nn.Module, training_state: str, stat: str): if activations_batch.dim() == 4: # conv layer (B x C x H x W) if self.interpolation_strategy is not None and ( activations_batch.shape[3] > self.interpolation_downsampling or activations_batch.shape[2] > self.interpolation_downsampling): activations_batch = interpolate( activations_batch, size=self.interpolation_downsampling, mode=self.interpolation_strategy) if self.conv_method == 'median': shape = activations_batch.shape reshaped_batch = activations_batch.reshape( shape[0], shape[1], shape[2] * shape[3]) activations_batch, _ = torch.median(reshaped_batch, dim=2) # channel median elif self.conv_method == 'max': shape = activations_batch.shape reshaped_batch = activations_batch.reshape( shape[0], shape[1], shape[2] * shape[3]) activations_batch, _ = torch.max(reshaped_batch, dim=2) # channel median elif self.conv_method == 'mean': activations_batch = torch.mean(activations_batch, dim=(2, 3)) elif self.conv_method == 'flatten': activations_batch = activations_batch.view( activations_batch.size(0), -1) elif self.conv_method == 'channelwise': reshaped_batch: torch.Tensor = activations_batch.permute( [1, 0, 2, 3]) shape = reshaped_batch.shape reshaped_batch: torch.Tensor = reshaped_batch.flatten(1) reshaped_batch: torch.Tensor = reshaped_batch.permute([1, 0]) activations_batch = reshaped_batch elif activations_batch.dim() == 3: # LSTM layer (B x T x U) if self.timeseries_method == 'timestepwise': activations_batch = activations_batch.flatten(1) elif self.timeseries_method == 'last_timestep': activations_batch = activations_batch[:, -1, :] if layer.name not in self.logs[f'{training_state}-{stat}'] or ( not isinstance(self.logs[f'{training_state}-{stat}'], TorchCovarianceMatrix) and self.record): save_data = 'embed' in self.stats self.logs[f'{training_state}-{stat}'][ layer.name] = TorchCovarianceMatrix(device=self.device, save_data=save_data) self.logs[f'{training_state}-{stat}'][layer.name].update( activations_batch, lstm_ae)
[docs] def register_forward_hooks(self, layer: torch.nn.Module, stats: list): """Register hook to show `stats` in `layer`.""" def record_layer_saturation(layer: torch.nn.Module, input, output): """Hook to register in `layer` module.""" if not self.record: if layer.name not in self.logs[ f'{"train" if layer.training else "eval"}-{"covariance-matrix"}']: # save_data = 'embed' in self.stats self.logs[ f'{"train" if layer.training else "eval"}-{"covariance-matrix"}'][ layer.name] = np.nan return # Increment step counter layer.forward_iter += 1 # VAE output is a tuple; Hence output.data throw exception lstm_ae = False if layer.name in [ 'encoder_lstm', 'encoder_output', 'decoder_lstm', 'decoder_output' ]: output = output[1][0] lstm_ae = True elif isinstance(layer, torch.nn.LSTM): output = output[0] training_state = 'train' if layer.training else 'eval' if layer.name not in self.seen_samples[training_state]: self.seen_samples[training_state][layer.name] = 0 if (self.max_samples is None or self.seen_samples[training_state][layer.name] < self.max_samples ) and layer.forward_iter % self.log_interval == 0: num_samples = min( output.data.shape[0], self.max_samples - self.seen_samples[training_state][layer.name] ) if self.max_samples is not None else output.data.shape[0] activations_batch = output.data[:num_samples] self.seen_samples[training_state][layer.name] += num_samples self._record_stat(activations_batch, lstm_ae, layer, training_state, 'covariance-matrix') layer.register_forward_hook(record_layer_saturation)
[docs] def add_saturations(self, save=True): """ Computes saturation and saves all stats :return: """ for key in self.logs: train_sats = [] val_sats = [] for i, layer_name in enumerate(self.logs[key]): if layer_name in self.ignore_layer_names: continue if self.record and self.logs[key][layer_name]._cov_mtx is None: raise ValueError("Attempting to compute intrinsic" "dimensionality when covariance" "is not initialized") if self.record: cov_mat = self.logs[key][layer_name].fix() log_values = {} sample_log_values = {} for stat in self.stats: if stat == 'lsat': log_values[key.replace(STATMAP['cov'], STATMAP['lsat']) + '_' + layer_name] = compute_saturation( cov_mat, thresh=self.threshold ) if self.record else np.nan elif stat == 'idim': log_values[ key.replace(STATMAP['cov'], STATMAP['idim']) + '_' + layer_name] = compute_intrinsic_dimensionality( cov_mat, thresh=self.threshold ) if self.record else np.nan elif stat == 'cov': log_values[key + '_' + layer_name] = cov_mat.cpu().numpy() elif stat == 'det': log_values[key.replace(STATMAP['cov'], STATMAP['det']) + '_' + layer_name] = compute_cov_determinant( cov_mat) if self.record else np.nan elif stat == 'trc': log_values[key.replace(STATMAP['cov'], STATMAP['trc']) + '_' + layer_name] = compute_cov_trace(cov_mat) elif stat == 'dtrc': log_values[key.replace(STATMAP['cov'], STATMAP['dtrc']) + '_' + layer_name] = compute_diag_trace(cov_mat) elif stat == 'embed': transformation_matrix = torch.mm( cov_mat[0:2].transpose(0, 1), cov_mat[0:2]) saved_samples = self.logs[key][ layer_name].saved_samples sample_log_values['embed'] = list() for (index, sample) in enumerate(saved_samples): coord = torch.matmul(transformation_matrix, sample) sample_log_values['embed'].append( (coord[0], coord[1])) self.seen_samples[key.split('-')[0]][layer_name] = 0 if self.reset_covariance and self.record: self.logs[key][layer_name]._cov_mtx = None if self.layerwise_sat: self.writer.add_scalars( prefix='', value_dict=log_values, sample_value_dict=sample_log_values) if self.average_sat: self.writer.add_scalar('average-train-sat', np.mean(train_sats)) self.writer.add_scalar('average-eval-sat', np.mean(val_sats)) if save: self.save()
[docs] def save(self): self.writer.save()
[docs]class CheckLayerSat(SaturationTracker): def __init_subclass__(self): warnings.warn("Class has been renamed NewClassName", DeprecationWarning, 2)