"""
This file contains alternative file writers
"""
import os
import pathlib
import pickle as pkl
import warnings
from abc import ABC, abstractmethod
from shutil import make_archive
from typing import Callable, Dict, List, Tuple
import matplotlib
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from delve.logger import log
try:
from tensorboardX import SummaryWriter
except ModuleNotFoundError:
pass
STATMAP = {
'idim': 'intrinsic-dimensionality',
'lsat': 'saturation',
'cov': 'covariance-matrix',
'det': 'covariance-determinant',
'trc': 'covariance-trace',
'dtrc': 'diagonal-trace',
'embed': 'embedded-sample'
}
[docs]class AbstractWriter(ABC):
[docs] def _check_savestate_ok(self, savepath: str) -> bool:
"""
Checks if a savestate from a writer is okay; raises a warning if not
:param savepath: the path to the savestate
:return:
"""
if not os.path.exists(savepath):
warnings.warn(
f'{savepath} does not exists, savestate for {self.__class__.__name__} cannot be loaded'
)
return False
else:
return True
[docs] @abstractmethod
def resume_from_saved_state(self, initial_epoch: int):
raise NotImplementedError()
[docs] @abstractmethod
def add_scalar(self, name, value, **kwargs):
raise NotImplementedError()
[docs] @abstractmethod
def add_scalars(self, prefix, value_dict, global_step, **kwargs):
raise NotImplementedError()
[docs] @abstractmethod
def save(self):
pass
[docs] @abstractmethod
def close(self):
pass
[docs]class CompositWriter(AbstractWriter):
[docs] def __init__(self, writers: List[AbstractWriter]):
"""
This writer combines multiple writers.
:param writers: List of writers. Each writer is called when the CompositeWriter is invoked.
"""
super(CompositWriter, self).__init__()
self.writers = writers
[docs] def resume_from_saved_state(self, initial_epoch: int):
for w in self.writers:
try:
w.resume_from_saved_state(initial_epoch)
except NotImplementedError:
warnings.warn(
f'Writer {w.__class__.__name__} raised a NotImplementedError when attempting to resume training'
'This may result in corrupted or overwritten data.')
[docs] def add_scalar(self, name, value, **kwargs):
for w in self.writers:
w.add_scalar(name, value, **kwargs)
[docs] def add_scalars(self, prefix, value_dict, **kwargs):
for w in self.writers:
w.add_scalars(prefix, value_dict, **kwargs)
[docs] def save(self):
for w in self.writers:
w.save()
[docs] def close(self):
for w in self.writers:
w.close()
[docs]class CSVWriter(AbstractWriter):
"""
This writer produces a csv file with all saturation values.
The csv-file is overwritten with
an updated version every time save() is called.
:param savepath: CSV file path
"""
def __init__(self, savepath: str, **kwargs):
super(CSVWriter, self).__init__()
self.value_dict = {}
self.savepath = savepath
[docs] def resume_from_saved_state(self, initial_epoch: int):
self.epoch_counter = initial_epoch
if self._check_savestate_ok(self.savepath + '.csv') is True:
self.value_dict = pd.read_csv(self.savepath + '.csv',
sep=';',
index_col=0).to_dict('list')
[docs] def add_scalar(self, name, value, **kwargs):
if 'covariance-matrix' in name:
return
if name in self.value_dict:
self.value_dict[name].append(value)
else:
self.value_dict[name] = [value]
return
[docs] def add_scalars(self, prefix, value_dict, **kwargs):
for name in value_dict.keys():
self.add_scalar(name, value_dict[name])
[docs] def save(self):
pd.DataFrame.from_dict(self.value_dict).to_csv(self.savepath + '.csv',
sep=';')
[docs]class NPYWriter(AbstractWriter):
[docs] def __init__(self, savepath: str, zip: bool = False, **kwargs):
"""
The NPYWriter creates a folder containing one subfolder for each stat.
Each subfolder contains a npy-file with the saturation value for each epoch.
This writer saves non-scalar values and can thus be used to save
the covariance-matrix.
:param savepath: The root folder to save the folder structure to
:param zip: Whether to zip the output folder after every invocation
"""
super(NPYWriter, self).__init__()
self.savepath = savepath
self.epoch_counter = {}
self.zip = zip
[docs] def resume_from_saved_state(self, initial_epoch: int):
if self._check_savestate_ok(
os.path.join(self.savepath, 'epoch_counter.pkl')):
self.epoch_counter = pkl.load(
open(os.path.join(self.savepath, 'epoch_counter.pkl'), 'rb'))
return
def _update_epoch_counter(self, name: str) -> int:
if name not in self.epoch_counter:
self.epoch_counter[name] = 0
else:
self.epoch_counter[name] += 1
return self.epoch_counter[name]
def _get_and_create_savepath(self, name: str, epoch) -> str:
savepath = os.path.join(self.savepath, name)
if not os.path.exists(savepath):
os.makedirs(savepath)
return os.path.join(savepath, name + f'_epoch{epoch}')
[docs] def add_scalar(self, name, value, **kwargs):
epoch = self._update_epoch_counter(name)
svpth = self._get_and_create_savepath(name, epoch)
np.save(svpth, value, **kwargs)
[docs] def add_scalars(self, prefix, value_dict, **kwargs):
for key in value_dict.keys():
self.add_scalar(prefix + '_' + key, value_dict[key])
[docs] def save(self):
pkl.dump(self.epoch_counter,
open(os.path.join(self.savepath, 'epoch_counter.pkl'), 'wb'))
if self.zip:
make_archive(base_name=os.path.basename(self.savepath),
format='zip',
root_dir=os.path.dirname(self.savepath),
verbose=True)
[docs]class PrintWriter(AbstractWriter):
[docs] def __init__(self, **kwargs):
"""
Prints output to the console
"""
super(PrintWriter, self).__init__()
[docs] def resume_from_saved_state(self, initial_epoch: int):
pass
[docs] def add_scalar(self, name, value, **kwargs):
print(name, ':', value)
[docs] def add_scalars(self, prefix, value_dict, **kwargs):
for key in value_dict.keys():
self.add_scalar(prefix + '_' + key, value_dict[key])
[docs]class TensorBoardWriter(AbstractWriter):
"""
Writes output to tensorflow logs
:param savepath: the path for result logging
"""
def __init__(self, savepath: str, **kwargs):
super(TensorBoardWriter, self).__init__()
self.savepath = savepath
self.writer = SummaryWriter(savepath)
[docs] def resume_from_saved_state(self, initial_epoch: int):
raise NotImplementedError(
'Resuming is not yet implemented for TensorBoardWriter')
[docs] def add_scalar(self, name, value, **kwargs):
if 'covariance-matrix' in name:
return
self.writer.add_scalar(name, value)
[docs] def add_scalars(self, prefix, value_dict, **kwargs):
self.writer.add_scalars(prefix, value_dict)
[docs] def close(self):
self.writer.close()
[docs]class CSVandPlottingWriter(CSVWriter):
"""
This writer produces CSV files and plots.
:param savepath: Path to store plots and CSV files
:param plot_manipulation_func: A function mapping an axis object to an axis object by
using pyplot code.
:param kwargs:
"""
def __init__(self,
savepath: str,
plot_manipulation_func: Callable[[plt.Axes], plt.Axes] = None,
**kwargs):
super(CSVandPlottingWriter, self).__init__(savepath)
self.plot_man_func = plot_manipulation_func if plot_manipulation_func is not None else lambda x: x
self.primary_metric = None if 'primary_metric' not in kwargs else kwargs[
'primary_metric']
self.fontsize = 16 if 'fontsize' not in kwargs else kwargs['fontsize']
self.figsize = None if 'figsize' not in kwargs else kwargs['figsize']
self.epoch_counter: int = 0
self.stats = []
self.sample_stats = list()
self.sample_value_dict = dict()
[docs] def resume_from_saved_state(self, initial_epoch: int):
self.epoch_counter = initial_epoch
if not self._check_savestate_ok(self.savepath + '.csv'):
return
self.value_dict = pd.read_csv(self.savepath + '.csv',
sep=';',
index_col=0).to_dict('list')
def _look_for_stats(self):
if len(self.stats) == 0:
sat = False
idim = False
det = False
trc = False
dtrc = False
embed = True
for key in self.value_dict.keys():
if 'saturation' in key:
sat = True
if 'intrinsic-dimensionality' in key:
idim = True
if 'covariance-determinant' in key:
det = True
if 'covariance-trace' in key:
trc = True
if 'diagonal-trace' in key:
dtrc = True
if 'embed' in key:
embed = True
if sat:
self.stats.append('lsat_train')
self.stats.append('lsat_eval')
if idim:
self.stats.append('idim_train')
self.stats.append('idim_eval')
if det:
self.stats.append('det_train')
self.stats.append('det_eval')
if trc:
self.stats.append('trc_train')
self.stats.append('trc_eval')
if dtrc:
self.stats.append('dtrc_train')
self.stats.append('dtrc_eval')
if embed:
self.sample_stats.append('embed')
[docs] def add_scalar(self, name, value, **kwargs):
if 'covariance-matrix' in name:
return
if name in self.value_dict:
self.value_dict[name].append(value)
else:
self.value_dict[name] = [value]
[docs] def add_sample_scalar(self, name, value, **kwargs):
if name in self.sample_value_dict:
self.sample_value_dict[name].append(value)
else:
self.sample_value_dict[name] = [value]
[docs] def add_scalars(self, prefix, value_dict, sample_value_dict, **kwargs):
for name in value_dict.keys():
self.add_scalar(name, value_dict[name])
if sample_value_dict:
for name in sample_value_dict.keys():
self.add_sample_scalar(name, sample_value_dict[name])
def _find_longest_entry(self) -> int:
return max(*(len(value) for value in self.value_dict.values()))
def _pad_entry(self, entry, max):
if len(entry) == max:
return entry
else:
return [np.nan for _ in range(max - entry)] + entry
def _pad_stat(self):
max_entry = self._find_longest_entry()
return {
k: self._pad_entry(v, max_entry)
for k, v in self.value_dict.items()
}
[docs] def save(self):
self._look_for_stats()
pd.DataFrame.from_dict(self.value_dict).to_csv(self.savepath + '.csv',
sep=';')
for stat in self.stats:
plot_stat_level_from_results(self.savepath + '.csv',
stat=stat,
epoch=-1,
primary_metric=self.primary_metric,
fontsize=self.fontsize,
figsize=self.figsize)
for stat in self.sample_stats:
plot_scatter_from_results(
self.savepath + '.csv', -1, stat,
pd.DataFrame.from_dict(self.sample_value_dict))
self.epoch_counter += 1
WRITERS: Dict[str, AbstractWriter] = {
"csv": CSVWriter,
"npy": NPYWriter,
"print": PrintWriter,
"tb": TensorBoardWriter,
"tensorboard": TensorBoardWriter,
"csvplotting": CSVandPlottingWriter,
"csvplot": CSVandPlottingWriter,
"plot": CSVandPlottingWriter,
"plotcsv": CSVandPlottingWriter
}
[docs]def plot_stat(df,
stat,
pm=-1,
savepath='run.png',
epoch=0,
primary_metric=None,
fontsize=16,
figsize=None,
line=True,
scatter=True,
ylim=(0, 1.0),
alpha_line=.6,
alpha_scatter=1.0,
color_line=None,
color_scatter=None,
primary_metric_loc=(0.7, 0.8),
show_col_label_x=True,
show_col_label_y=True,
show_grid=True,
save=True,
samples=False,
stat_mode="train") -> matplotlib.axes.Axes:
"""Plot statistics
:param df:
:param stat:
:param pm:
:param savepath:
:param epoch:
:param primary_metric:
:param fontsize:
:param figsize:
:param line:
:param scatter:
:param ylim:
:param alpha_line:
:param alpha_scatter:
:param color_line:
:param color_scatter:
:param primary_metric_loc:
:param show_col_label_x:
:param show_col_label_y:
:param show_grid:
:param save:
:return:
"""
plt.clf()
plt.cla()
plt.close()
if epoch == -1:
epoch = df.index.values[-1]
if figsize is not None:
print(figsize)
plt.figure(figsize=figsize)
ax = plt.gca()
col_names = [i for i in df.columns]
try:
if len(df.values[0]) == 0 or (isinstance(df.values[0][0], list) and
isinstance(df.values[0][0][0], tuple)):
pass
elif np.all(np.isnan(df.values[0])):
return ax
except TypeError:
warnings.warn(
"Experienced a TypeError during checking for nan values in, likely caused by non float or int values. "
"Plotting non np.ndarray may lead to crashes or inconsistent results"
)
log.exception(
"Experienced a TypeError during checking for nan values in, likely caused by non float or int values. "
"Plotting non np.ndarray may lead to crashes or inconsistent results"
)
if line:
if samples:
pass
else:
ax.plot(list(range(len(col_names))),
df.values[0],
alpha=alpha_line,
color=color_line)
if scatter:
if samples:
for sample in df.values[0][0]:
x = float(sample[0])
y = float(sample[1])
ax.scatter(x, y, alpha=alpha_scatter, color=color_scatter)
else:
ax.scatter(list(range(len(col_names))),
df.values[0],
alpha=alpha_scatter,
color=color_scatter)
if not samples:
plt.xticks(list(range(len(col_names))),
[col_name.split('_', )[1] for col_name in col_names],
rotation=90)
if ylim is not None:
ax.set_ylim(ylim)
if primary_metric is not None:
ax.text(primary_metric_loc[0], primary_metric_loc[1],
f'{primary_metric}: {pm}')
plt.yticks(fontsize=fontsize)
if show_col_label_x:
plt.xlabel('layers', fontsize=fontsize)
plt.title(pathlib.Path(savepath).name.replace('_', ' ').replace(
'.csv', f' epoch: {epoch}'),
fontsize=fontsize)
if show_col_label_y:
plt.ylabel(stat if stat not in STATMAP else STATMAP[stat],
rotation='vertical',
fontsize=fontsize)
if show_grid:
plt.grid()
plt.tight_layout()
if save:
final_savepath = savepath.replace(
'.csv', f'_{stat}_{stat_mode}_epoch_{epoch}.png')
log.info(final_savepath)
plt.savefig(final_savepath)
return ax
[docs]def plot_stat_level_from_results(savepath,
epoch,
stat,
primary_metric=None,
fontsize=16,
figsize=None,
line=True,
scatter=True,
ylim=(0, 1.0),
alpha_line=.6,
alpha_scatter=1.0,
color_line=None,
color_scatter=None,
primary_metric_loc=(0.7, 0.8),
show_col_label_x=True,
show_col_label_y=True,
show_grid=True,
save=True,
stat_mode="train"):
df = pd.read_csv(savepath, sep=';')
if "_" in stat:
stat, stat_mode = stat.split("_")
if epoch == -1:
epoch = df.index.values[-1]
epoch_df, pm = extract_layer_stat(df,
stat=STATMAP[stat],
epoch=epoch,
primary_metric=primary_metric,
state_mode=stat_mode)
ax = plot_stat(df=epoch_df,
pm=pm,
savepath=savepath,
epoch=epoch,
primary_metric=primary_metric,
fontsize=fontsize,
figsize=figsize,
stat=stat,
ylim=None if stat != 'lsat' else (0, 1.0),
line=line,
scatter=scatter,
alpha_line=alpha_line,
alpha_scatter=alpha_scatter,
color_line=color_line,
color_scatter=color_scatter,
primary_metric_loc=primary_metric_loc,
show_col_label_x=show_col_label_x,
show_col_label_y=show_col_label_y,
show_grid=show_grid,
save=save,
stat_mode=stat_mode)
return ax
[docs]def plot_scatter_from_results(savepath, epoch, stat, df):
if len(df) > 0:
if "_" in stat:
stat = stat.split("_")[0]
ax = plot_stat(df=df,
savepath=savepath,
epoch=epoch,
stat=stat,
line=False,
save=True,
samples=True,
ylim=None)
return ax
else:
return None