from copy import deepcopy
from dataclasses import dataclass
from logging import Logger, getLogger
from pathlib import Path
from typing import Any
import numpy as np
import denspp.offline.dnn.plots as dnn_plot
from denspp.offline import check_keylist_elements_all, get_path_to_project
from denspp.offline.data_format import JsonHandler
from denspp.offline.dnn import DatasetFromFile, DefaultSettingsDataset, SettingsDataset
from denspp.offline.dnn.model_library import DatasetLoaderLibrary
from denspp.offline.dnn.training import (
DataValidation,
DefaultSettingsTrainingCE,
DefaultSettingsTrainingMSE,
SettingsAutoencoder,
SettingsClassifier,
TrainAutoencoder,
TrainClassifier,
)
from denspp.offline.logger import define_logger_runtime
[docs]
@dataclass
class TrainingResults:
"""Dataclass with returning results from training routine of a deep learning model
Attributes:
metrics: Dictionary with metrics from training
data: Dataclass with results from validation phase
settings: Dictionary with settings from dataset, model and training
path: Path to training results
metrics_custom: List with string names of custom labels used during training
"""
metrics: dict
data: DataValidation
settings: dict
path: Path
metrics_custom: list[str]
[docs]
@dataclass
class SettingsTraining:
"""Configuration class for handling the training phase of deep neural networks
Attributes:
mode_train: Integer of selected training routine regarding the training handler [0: Classifier (CL), 1: Autoencoder (AE), 2: Autoencoder-based Classifier (AE+CL), 3: Long-Short Term-Memory (LSTM)]
do_block: Boolean value to block the generated plots after training
do_ptq: Apply Post Training Quantization Scheme during Training
ptq_total_bitwidth: Integer for total bitwidth in PTQ
ptq_frac_bitwidth: Integer for fractional bitwidth in PTQ
"""
mode_train: int
do_block: bool
do_ptq: bool
ptq_total_bitwidth: int
ptq_frac_bitwidth: int
DefaultSettingsTraining = SettingsTraining(
mode_train=0,
do_block=False,
do_ptq=False,
ptq_total_bitwidth=8,
ptq_frac_bitwidth=4,
)
[docs]
class PyTorchPlot:
_logger: Logger
def __init__(self):
"""Class for handling all suitable plot options for the PyTorch Training Handler"""
self._logger: Logger = getLogger(__name__)
[docs]
def loss(
self,
data: TrainingResults,
loss_type: str,
fold_num: int = 0,
epoch_zoom=None,
show_plot: bool = False,
) -> None:
"""Plotting the loss values of each epoch during training
:param data: Dataclass TrainingResults with results from Training
:param loss_type: String with name of the used loss function
:param fold_num: Integer with fold number to analyse
:param epoch_zoom: Optional list with ranges for zooming loss data
:param show_plot: Boolean value to show the plot
:return: None
"""
fold_overview = [key for key in data.metrics.keys()]
if len(fold_overview) == 0:
raise AttributeError("No fold available in dataset")
else:
self._logger.info(f"... plotting metric: {loss_type}")
used_fold = fold_overview[fold_num]
dnn_plot.plot_loss(
loss_train=data.metrics[used_fold]["loss_train"],
loss_valid=data.metrics[used_fold]["loss_valid"],
loss_type=loss_type,
path2save=str(data.path),
epoch_zoom=epoch_zoom,
do_logy=False,
show_plot=show_plot,
)
[docs]
def custom_loss(
self,
data: TrainingResults,
fold_num: int = 0,
epoch_zoom=None,
show_plot: bool = False,
) -> None:
"""Plotting the custom metrics of each epoch during training
:param data: Dataclass TrainingResults with results from Training
:param fold_num: Integer with fold number to analyse
:param epoch_zoom: Optional list with ranges for zooming loss data
:param show_plot: Boolean value to show the plot
:return: None
"""
fold_overview = [key for key in data.metrics.keys()]
if len(fold_overview) == 0:
raise AttributeError("No fold available in dataset")
used_fold = fold_overview[fold_num]
for used_metric in data.metrics_custom:
available_metric = list(data.metrics[used_fold].keys())
if used_metric in available_metric:
last_ite = used_metric == available_metric[-1]
self._logger.info(f"... plotting custom metric: {used_metric}")
if type(data.settings["model"]) == SettingsClassifier or used_metric == "ptq_loss":
dnn_plot.plot_custom_loss_classifier(
data=data.metrics[used_fold][used_metric],
loss_name=used_metric,
fold_num=fold_num,
do_logy=False,
epoch_zoom=epoch_zoom,
path2save=str(data.path),
show_plot=show_plot and last_ite,
)
else:
dnn_plot.plot_custom_loss_autoencoder(
data=data.metrics[used_fold][used_metric],
loss_name=used_metric,
do_boxplot=False,
do_logy=False,
epoch_zoom=epoch_zoom,
path2save=str(data.path),
show_plot=show_plot and last_ite,
)
[docs]
def statistics(self, data: TrainingResults, show_plot: bool = False) -> None:
"""Plotting the statistics of the used dataset for training
:param data: Dataclass TrainingResults with results from Training
:param show_plot: Boolean value to show the plot
:return: None
"""
self._logger.info("... plotting statistics from dataset")
dnn_plot.plot_statistic(
train_cl=data.data.train_label,
valid_cl=data.data.valid_label,
path2save=str(data.path),
cl_dict=data.data.label_names,
show_plot=show_plot,
)
[docs]
class PyTorchTrainer:
_logger: Logger
_plotter: PyTorchPlot
_settings_ml: SettingsTraining
_settings_data: SettingsDataset
_settings_model: SettingsClassifier | SettingsAutoencoder
_path2config: Path
_dataloader: Any
__default_model: str
__use_case: str
def __init__(
self,
use_case: str,
settings: SettingsTraining = DefaultSettingsTraining,
default_model: str = "",
path2config: str = "config",
) -> None:
"""Class for handling and wrapping all PyTorch Training Routines incl. Report Generation and Plotting
:param use_case: String with name of use-case
:param settings: Dataclass for defining trainer properties
:param default_model: String with name of default model
:param path2config: Path to folder with configuration files
:return: None
"""
define_logger_runtime(save_file=False)
self._logger: Logger = getLogger(__name__)
self._plotter = PyTorchPlot()
self._path2config = Path(get_path_to_project(path2config))
self._path2config.mkdir(parents=True, exist_ok=True)
self._do_init = self.config_available
self.__default_model = default_model
self.__use_case = use_case
self.__prepare_training()
if settings == DefaultSettingsTraining:
self._settings_ml = self._get_config_ml(
use_case=use_case, default_training_mode=settings.mode_train
)
else:
self._settings_ml = settings
@property
def config_available(self) -> bool:
"""Checking if configs are in the folder available or must be initialized"""
return self.path2config.exists() and len(list(self.path2config.glob("Config*_*.json"))) > 0
@property
def path2config(self) -> Path:
"""Returning the absolute path to config folder"""
return self._path2config.absolute()
[docs]
def get_type_metric_calculation(self, use_case: int) -> list[str]:
"""Returning an overview of custom metric calculation methods during PyTorch Training
:param use_case: Number with use case of the model type (0=Classifier, 1=Autoencoder, 2=Autoencoder-based Classifier)
:return: List of custom metric calculation methods
"""
match use_case:
case 0:
method = TrainClassifier
default = DefaultSettingsTrainingCE
case 1:
method = TrainAutoencoder
default = DefaultSettingsTrainingMSE
case _:
raise NotImplementedError("Training routine not implemented")
return method(
config_train=default, config_data=DefaultSettingsDataset, do_train=False
).get_epoch_metric_custom_methods
@property
def get_custom_metric_calculation(self) -> list[str]:
"""Returning an overview of custom metric calculation methods during PyTorch Training
:return: List of custom metric calculation methods
"""
return self.get_type_metric_calculation(self._settings_ml.mode_train)
@property
def get_model_overview(self) -> list[str]:
"""Returning an overview of model training methods during PyTorch Training"""
if not self._settings_model:
raise ValueError("Available ")
return self._settings_model.get_model_overview()
[docs]
def get_model(self, *args, **kwargs):
"""Returning the deep learning model for training loaded from ModelLibrary"""
return self._settings_model.get_model(*args, **kwargs)
def _get_config_ml(self, use_case: str, default_training_mode: int = 0) -> SettingsTraining:
default_set = deepcopy(DefaultSettingsTraining)
default_set.mode_train = default_training_mode
return JsonHandler(
template=default_set,
path=str(self.path2config),
file_name=f"ConfigTraining_{use_case}",
).get_class(SettingsTraining)
def _get_config_dataset(self, default_dataset_name: str, use_case: str) -> SettingsDataset:
default_set: SettingsTraining = deepcopy(DefaultSettingsDataset)
default_set.data_type = default_dataset_name
self._settings_model = JsonHandler(
template=default_set,
path=str(self.path2config),
file_name=f"ConfigDataset_{use_case}",
).get_class(SettingsDataset)
return self._settings_model
@staticmethod
def _get_dataset_loader() -> Any:
datalib = DatasetLoaderLibrary().get_registry()
matches = [item for item in datalib.get_library_overview() if "DatasetLoader" == item]
if len(matches) == 0:
raise AttributeError("No DatasetLoader available")
return datalib.build_object(matches[0])
[docs]
def get_dataset(self) -> DatasetFromFile:
"""Getting the dataset with rawdata, label and label names for training a deep learning model
:return: Dataclass DatasetFromFile with loaded dataset
"""
return self._dataloader(self._settings_data).load_dataset()
def _get_config_classifier(self, default_model_name: str, use_case: str) -> SettingsClassifier:
default_set: SettingsClassifier = deepcopy(DefaultSettingsTrainingCE)
default_set.model_name = default_model_name
default_set.custom_metrics = self.get_type_metric_calculation(0)
self._settings_model = JsonHandler(
template=default_set,
path=str(self.path2config),
file_name=f"ConfigClassifier_{use_case}",
).get_class(SettingsClassifier)
return self._settings_model
def _prepare_training_classifier(self, used_dataset: DatasetFromFile) -> TrainClassifier:
"""PyTorch Training Routing for Classifiers
:return: Training Handler
"""
self._get_config_classifier(default_model_name=self.__default_model, use_case=self.__use_case)
# --- Processing Step #0: Get dataset and build model
model_signature = self._settings_model.get_signature()
if len(model_signature) and check_keylist_elements_all(
keylist=model_signature, elements=["input_size", "output_size"]
):
sets = dict(
input_size=int(np.prod(used_dataset.data.shape[1:])),
output_size=np.unique(used_dataset.label).size,
)
else:
sets = dict()
used_model = self.get_model(**sets)
# ---Processing Step #1: Prepare Training Handler
train_handler = TrainClassifier(
config_train=self._settings_model,
config_data=self._settings_data,
do_train=True,
)
train_handler.load_model(model=used_model)
train_handler.load_dataset(dataset=used_dataset)
return train_handler
def _get_config_autoencoder(self, default_model_name: str, use_case: str) -> SettingsAutoencoder:
default_set: SettingsAutoencoder = deepcopy(DefaultSettingsTrainingMSE)
default_set.model_name = default_model_name
default_set.custom_metrics = self.get_type_metric_calculation(1)
return JsonHandler(
template=default_set,
path=str(self.path2config),
file_name=f"ConfigAutoencoder_{use_case}",
).get_class(SettingsAutoencoder)
def _prepare_training_autoencoder(self, used_dataset: DatasetFromFile) -> TrainAutoencoder:
"""PyTorch Training Routing for Autoencoders
:return: Training handler
"""
self._settings_model = self._get_config_autoencoder(
default_model_name=self.__default_model, use_case=self.__use_case
)
# --- Processing Step #0: Get dataset and build model
model_signature = self._settings_model.get_signature()
if len(model_signature) and check_keylist_elements_all(
keylist=model_signature, elements=["input_size", "output_size"]
):
if self._settings_model.feat_size:
sets = dict(
input_size=int(np.prod(used_dataset.data.shape[1:])),
output_size=self._settings_model.feat_size,
)
else:
sets = dict(
input_size=int(np.prod(used_dataset.data.shape[1:])),
output_size=int(np.prod(used_dataset.data.shape[1:])),
)
else:
sets = dict()
used_model = self.get_model(**sets)
# ---Processing Step #1: Prepare Trainings Handler
train_handler = TrainAutoencoder(
config_train=self._settings_model,
config_data=self._settings_data,
do_train=True,
)
train_handler.load_model(model=used_model)
train_handler.load_dataset(dataset=used_dataset)
return train_handler
def _save_training_results(
self,
addon: str,
metrics: dict,
data_result: DataValidation,
custom_metrics: list,
path2save: Path,
) -> TrainingResults:
results = TrainingResults(
settings={
"train": self._settings_ml,
"model": self._settings_model,
"data": self._settings_data,
},
metrics=metrics,
data=data_result,
path=path2save,
metrics_custom=custom_metrics,
)
data2save = path2save / f"results_{addon}.npy"
self._logger.debug(f"... saving results: {data2save}")
np.save(data2save, results, allow_pickle=True)
return results
def _run_training_single(
self, mode: int, used_dataset: DatasetFromFile, path2save=Path(".")
) -> list[TrainingResults]:
# --- Processing Step #1: Prepare Trainings handler with dataset and model
match mode:
case 0:
dut: TrainClassifier = self._prepare_training_classifier(used_dataset=used_dataset)
addon = "cl"
case 1:
dut: TrainAutoencoder = self._prepare_training_autoencoder(used_dataset=used_dataset)
addon = "ae"
case 2:
raise ValueError(
"It enables the Autoencoder-based Classifier, but you should select either 'cl' or 'ae'"
)
case _:
raise NotImplementedError("Training Routine is not implemented")
# --- Processing Step #2: Do Training and Validation
dut.define_ptq_level(
total_bitwidth=self._settings_ml.ptq_total_bitwidth,
frac_bitwidth=self._settings_ml.ptq_frac_bitwidth,
)
metrics = dut.do_training(path2save=path2save)
data_result = dut.do_post_training_validation(do_ptq=self._settings_ml.do_ptq)
# --- Processing Step #3: Saving results
return [
self._save_training_results(
addon=addon,
metrics=metrics,
data_result=data_result,
custom_metrics=dut.get_epoch_metric_custom_methods,
path2save=dut.get_saving_path(),
)
]
def _run_training_sequence(
self, used_dataset: DatasetFromFile, path2save=Path(".")
) -> list[TrainingResults]:
# --- Processing Step #1: Do Training and Validation of Autoencoder
self.__default_model = (
self.__default_model
if "ae" in self.__default_model
else self.__default_model.replace("cl", "ae")
)
results_ae = self._run_training_single(mode=1, used_dataset=used_dataset, path2save=path2save)[0]
# --- Processing Step #3: Extract Feature Space and build new dataset for Classifier
dut0: TrainAutoencoder = self._prepare_training_autoencoder(used_dataset=used_dataset)
dut0.define_ptq_level(
total_bitwidth=self._settings_ml.ptq_total_bitwidth,
frac_bitwidth=self._settings_ml.ptq_frac_bitwidth,
)
feat_dataset = dut0.extract_feature_space(path2model=results_ae.path, rawdata=self.get_dataset())
# --- Processing Step #3: Do Training and Validation of Classifier
self.__default_model = (
self.__default_model
if "cl" in self.__default_model
else self.__default_model.replace("ae", "cl")
)
results_cl = self._run_training_single(
mode=0, used_dataset=feat_dataset, path2save=results_ae.path
)[0]
# --- Processing Step #4: Returning Results
return [results_ae, results_cl]
def __prepare_training(self) -> None:
self._dataloader = self._get_dataset_loader()
self._settings_data = self._get_config_dataset(
default_dataset_name=self.__use_case, use_case=self.__use_case
)
[docs]
def do_training(self, path2save=Path(".")) -> list[TrainingResults]:
"""Running PyTorch Training for specified configuration
:param path2save: Path to save the results and models after training [default runs/<YYYYMMDD>_<model>]
:return: Dataclass TrainingResults with internal metrics, data and path to run folder
"""
if not self._do_init:
raise AttributeError("Configs are generated - Please adapt and restart!")
used_dataset = self.get_dataset()
if self._settings_ml.mode_train == 2:
return self._run_training_sequence(used_dataset=used_dataset, path2save=path2save)
else:
return self._run_training_single(
mode=self._settings_ml.mode_train,
used_dataset=used_dataset,
path2save=path2save,
)
[docs]
def do_plot_dataset(self, path2save: Path = Path(".")) -> None:
"""Function for plotting the dataset content
:param path2save: Path to save the dataset plot
:return: None
"""
dataset = self.get_dataset()
match self._settings_data.data_type.lower():
case "mnist":
dnn_plot.plot_mnist_dataset(
data=dataset.data,
label=dataset.label,
title="",
path2save=str(path2save.absolute()),
show_plot=False,
)
case "waveforms":
dnn_plot.plot_waveforms_dataset(
dataset=dataset,
num_samples_class=10,
path2save=str(path2save.absolute()),
show_plot=False,
)
case _:
dnn_plot.plot_frames_dataset(
data=dataset,
take_samples=100,
do_norm=True,
add_subtitle=False,
path2save=str(path2save.absolute()),
show_plot=False,
)
[docs]
def do_plot_results(self, results: TrainingResults, epoch_zoom=None, do_plot: bool = True) -> None:
"""Function for plotting the results from training [metric, performance, data statistics]
:param results: Dataclass TrainingResults with internal metrics, data and path to run folder
:param epoch_zoom: Optional list with ranges for zooming loss data
:param do_plot: Optional boolean to plot results
:return: None
"""
self._plotter.loss(
data=results,
loss_type=results.settings["model"].loss,
epoch_zoom=epoch_zoom,
show_plot=False,
)
self._plotter.custom_loss(data=results, epoch_zoom=epoch_zoom, show_plot=False)
self._plotter.statistics(data=results, show_plot=False)
match self._settings_ml.mode_train:
case 0:
self._plotter.performance_classifier(
data=results,
show_plot=results.settings["train"].do_block and do_plot,
)
case 1:
if results.settings["data"].data_type.lower() == "mnist":
self._plotter.performance_autoencoder_mnist(
data=results,
show_plot=results.settings["train"].do_block and do_plot,
)
else:
self._plotter.performance_autoencoder(
data=results,
mean_value=results.data.mean,
show_plot=results.settings["train"].do_block and do_plot,
)
case _:
raise NotImplementedError
@staticmethod
def _load_results(path2file: Path) -> TrainingResults:
if not path2file.is_file():
raise AttributeError(f"{path2file} is not a file")
if not path2file.exists():
raise AttributeError(f"{path2file} does not exists")
data = np.load(path2file, allow_pickle=True).flatten()[0]
return data
[docs]
def read_file_and_plot(
self, path2file: Path, epoch_zoom=None, do_plot: bool = True
) -> TrainingResults:
"""Loading results file from training and plot the results
:param path2file: Path to file with results from PyTorch Training
:param epoch_zoom: Optional list or tuple with zoom on specific epoch range
:param do_plot: Boolean for plotting results
:return: Dataclass with metrics and results from training and validation phase
"""
data = self._load_results(path2file)
self.do_plot_results(results=data, epoch_zoom=epoch_zoom, do_plot=do_plot)
return data