delve.TorchCovarianceMatrix

class delve.TorchCovarianceMatrix(bias: bool = False, device: str = 'cuda:0', save_data: bool = False)[source]

Computes covariance matrix of features as described in https://arxiv.org/pdf/2006.08679.pdf:

\begin{eqnarray} Q(Z_l, Z_l) = \frac{\sum^{B}_{b=0}A_{l,b}^T A_{l,b}}{n} -(\bar{A}_l \bigotimes \bar{A}_l) \end{eqnarray}

for \(B\) batches of layer output matrix \(A_l\) and \(n\) number of samples.

Note

Method enforces float-64 precision, which may cause numerical instability in some cases.