Source code for denspp.offline.dnn.training.autoencoder_train

from dataclasses import dataclass
from datetime import datetime
from logging import Logger, getLogger
from pathlib import Path
from shutil import copy

import numpy as np
from torch import (
    Tensor,
    cat,
    concatenate,
    cuda,
    flatten,
    from_numpy,
    inference_mode,
    load,
    randn,
    save,
)

from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.data_config import DatasetFromFile, SettingsDataset
from denspp.offline.dnn.training.autoencoder_dataset import DatasetAutoencoder
from denspp.offline.dnn.training.ptq_help import quantize_model_fxp
from denspp.offline.metric.snr import calculate_dsnr_tensor, calculate_snr_tensor

from .common_train import DataValidation, PyTorchHandler, SettingsPytorch


[docs] @dataclass class SettingsAutoencoder(SettingsPytorch): """Class for handling the PyTorch training/inference pipeline Attributes: model_name: String with the model name patience: Integer value with number of epochs before early stopping optimizer: String with PyTorch optimizer name loss: String with method name for the loss function deterministic_do: Boolean if deterministic training should be done deterministic_seed: Integer with the seed for deterministic training num_kfold: Integer value with applying k-fold cross validation num_epochs: Integer value with number of epochs batch_size: Integer value with batch size data_split_ratio: Float value for splitting the input dataset between training and validation data_do_shuffle: Boolean if data should be shuffled before training custom_metrics: List with string of custom metrics to calculate during training trainings_mode: Integer to define trainings mode of the autoencoder [0: Autoencoder, 1: Denoising Autoencoder (mean), 2: Denoising Autoencoder (add random noise), 3: Denoising Autoencoder (add gaussian noise)] feat_size: Integer with defining the feature size of the encoder output / decoder input noise_std: Float value for adding noise standard deviation on input data """ trainings_mode: int feat_size: int noise_std: float
DefaultSettingsTrainingMSE = SettingsAutoencoder( model_name="", patience=20, optimizer="Adam", loss="MSE", deterministic_do=False, deterministic_seed=42, num_kfold=1, num_epochs=10, batch_size=256, data_do_shuffle=True, data_split_ratio=0.2, custom_metrics=[], trainings_mode=0, feat_size=4, noise_std=0.1, )
[docs] class TrainAutoencoder(PyTorchHandler): _logger: Logger _settings_train: SettingsAutoencoder _train_loader: list _valid_loader: list _mean_data: np.ndarray def __init__( self, config_train: SettingsAutoencoder, config_data: SettingsDataset, do_train: bool = True, ) -> None: """Class for Handling Training of Autoencoders :param config_data: Settings for handling and loading the dataset (just for saving) :param config_train: Settings for handling the PyTorch Trainings Routine of an Autoencoder :param do_train: Do training of model otherwise only inference :return: None """ PyTorchHandler.__init__(self, config_train, config_data, do_train) self._settings_train = config_train self._logger = getLogger(__name__) self.__metric_buffer = dict() self.__metric_result = dict() self._metric_methods = { "snr_in": self.__determine_snr_input, "snr_out": self.__determine_snr_output, "dsnr_all": self.__determine_dsnr_all, "ptq_loss": self.__determine_ptq_loss, }
[docs] def extract_feature_space(self, path2model: Path, rawdata: DatasetFromFile) -> DatasetFromFile: """Function for extracting the feature space from the autoencoder to use it for classification. :param path2model: Path to the autoencoder model :param rawdata: Dataset autoencoder rawdata :return: Dataset with autoencoder feature space as input """ overview_model = [path for path in path2model.glob("*ae*.pt")] if len(overview_model) == 0: raise FileNotFoundError("No .pt files found in {}".format(path2model)) self._logger.debug(f"Use pt-file {overview_model[0]} for extracting feature space") model_ae = load(overview_model[0], weights_only=False).to("cpu") feat = model_ae(from_numpy(np.array(rawdata.data, dtype=np.float32)))[0] return DatasetFromFile( data=feat.detach().numpy(), label=rawdata.label, dict=rawdata.dict, mean=rawdata.mean, )
[docs] def load_dataset(self, dataset: DatasetFromFile) -> None: """Loading the loaded dataset and transform it into right dataloader :param dataset: Dataclass with dataset loaded from extern :return: None """ dataset0 = DatasetAutoencoder( dataset=dataset, noise_std=self._settings_train.noise_std, mode_train=self._settings_train.trainings_mode, ) self._mean_data = dataset0.get_mean_waveforms self._prepare_dataset_for_training(data_set=dataset0, num_workers=0)
def __do_training_epoch(self) -> float: """Do training during epoch of training Return: Floating value with training loss value """ train_loss = 0.0 total_batches = 0 self._model.train(True) for tdata in self._train_loader[self._run_kfold]: data_x = tdata["in"].to(self._used_hw_dev) data_y = tdata["out"].to(self._used_hw_dev) data_p = self._model(data_x)[1] self._optimizer.zero_grad() if len(data_y) > 2: loss = self._loss_fn(flatten(data_p, 1), flatten(data_y, 1)) else: loss = self._loss_fn(data_p, data_y) loss.backward() self._optimizer.step() train_loss += loss.item() total_batches += 1 return float(train_loss / total_batches) def __do_valid_epoch(self, epoch_custom_metrics: list) -> float: """Do validation during epoch of training Args: epoch_custom_metrics: List with entries of custom-made metric calculations Return: Floating value with validation loss value """ self._total_batches_valid = 0 valid_loss = 0.0 self._model.eval() with inference_mode(): for vdata in self._valid_loader[self._run_kfold]: data_x = vdata["in"].to(self._used_hw_dev) data_y = vdata["out"].to(self._used_hw_dev) data_m = vdata["mean"].to(self._used_hw_dev) data_id = vdata["class"].to(self._used_hw_dev) data_p = self._model(data_x)[1] self._total_batches_valid += 1 if len(data_y) > 2: valid_loss += self._loss_fn(flatten(data_p, 1), flatten(data_y, 1)).item() else: valid_loss += self._loss_fn(data_p, data_y).item() # --- Calculating custom made metrics for metric_used in epoch_custom_metrics: self._determine_epoch_metrics(metric_used)( data_x, data_p, data_m, metric=metric_used, id=data_id ) return float(valid_loss / self._total_batches_valid) def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None: """Function for preparing the custom-made metric calculation Args: init_phase: Boolean decision if processing part is in init (True) or in training phase (False) custom_made_metrics:List with custom metrics for calculation during validation phase Return: None """ assert check_keylist_elements_any( keylist=custom_made_metrics, elements=self.get_epoch_metric_custom_methods ), ( f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods} - Please adapt in settings!" ) # --- Init phase for generating empty data structure if init_phase: for key0 in custom_made_metrics: self.__metric_result.update({key0: list()}) self.__metric_buffer.update({key0: list()}) # --- Processing results else: for key0 in self.__metric_buffer.keys(): self.__metric_result[key0].append(self.__metric_buffer[key0]) self.__metric_buffer.update({key0: list()}) def __determine_snr_input( self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs, ) -> None: out = calculate_snr_tensor(input_waveform, mean_waveform) if isinstance(self.__metric_buffer[kwargs["metric"]], list): self.__metric_buffer[kwargs["metric"]] = out else: self.__metric_buffer[kwargs["metric"]] = concatenate( (self.__metric_buffer[kwargs["metric"]], out), dim=0 ) def __determine_snr_output( self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs, ) -> None: out = calculate_snr_tensor(pred_waveform, mean_waveform) if isinstance(self.__metric_buffer[kwargs["metric"]], list): self.__metric_buffer[kwargs["metric"]] = out else: self.__metric_buffer[kwargs["metric"]] = concatenate( (self.__metric_buffer[kwargs["metric"]], out), dim=0 ) def __determine_dsnr_all( self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs, ) -> None: out = calculate_dsnr_tensor(input_waveform, pred_waveform, mean_waveform) if isinstance(self.__metric_buffer[kwargs["metric"]], list): self.__metric_buffer[kwargs["metric"]] = out else: self.__metric_buffer[kwargs["metric"]] = concatenate( (self.__metric_buffer[kwargs["metric"]], out), dim=0 ) def __determine_ptq_loss( self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs, ) -> None: """if not hasattr(self.model, 'bit_config'): raise NotImplementedError('PTQ Test is only available with elasticAI.creator Models or ' 'model includes variable \"bit_config\" = [total_bitwidth, frac_bitwidth]') else:""" # --- Load model and make inference model_ptq = quantize_model_fxp(self._model, self._ptq_level[0], self._ptq_level[1]) model_ptq.eval() pred_waveform_ptq = model_ptq(input_waveform)[1] # --- Calculate loss if len(input_waveform) > 2: loss = ( self._loss_fn(flatten(pred_waveform_ptq, 1), flatten(input_waveform, 1)).item() / self._total_batches_valid ) else: loss = self._loss_fn(pred_waveform_ptq, input_waveform).item() / self._total_batches_valid # --- Saving results if len(self.__metric_buffer[kwargs["metric"]]): self.__metric_buffer[kwargs["metric"]][0] = self.__metric_buffer[kwargs["metric"]][0] + loss else: self.__metric_buffer[kwargs["metric"]].append(loss)
[docs] def do_training(self, path2save=Path(".")) -> dict: """Start model training incl. validation and custom-own metric calculation Args: path2save: Path for saving the results [Default: '' --> generate new folder] Returns: Dictionary with metrics from training (loss_train, loss_valid, own_metrics) """ metrics = self._settings_train.custom_metrics self._init_train(path2save=path2save, addon="_ae") if self._kfold_do: self._logger.info( f"Starting Kfold cross validation training in {self._settings_train.num_kfold} steps" ) metric_out = dict() path2model = str() path2model_init = self._path2save / "model_ae_reset.pt" save(self._model.state_dict(), path2model_init) timestamp_start = datetime.now() timestamp_string = timestamp_start.strftime("%H:%M:%S") self._logger.info(f"\nTraining starts on {timestamp_string}") self._logger.info( "=====================================================================================" ) self.__process_epoch_metrics_calculation(True, metrics) for fold in np.arange(self._settings_train.num_kfold): # --- Init fold best_loss = np.array((1_000_000.0, 1_000_000.0), dtype=float) patience_counter = self._settings_train.patience epoch_loss_train = list() epoch_loss_valid = list() self._model.load_state_dict(load(path2model_init, weights_only=False)) self._run_kfold = fold if self._kfold_do: self._logger.info(f"\nStarting with Fold #{fold}") for epoch in range(0, self._settings_train.num_epochs): if self._settings_train.deterministic_do: self._deterministic_generator.manual_seed( self._settings_train.deterministic_seed + epoch ) loss_train = self.__do_training_epoch() loss_valid = self.__do_valid_epoch(metrics) epoch_loss_train.append(loss_train) epoch_loss_valid.append(loss_valid) self.__process_epoch_metrics_calculation(False, metrics) self._logger.info( f"... results of epoch {epoch + 1}/{self._settings_train.num_epochs} " f"[{(epoch + 1) / self._settings_train.num_epochs * 100:.2f} %]: " f"train_loss = {loss_train:.5f}," f"\tvalid_loss = {loss_valid:.5f}," f"\tdelta_loss = {loss_train - loss_valid:.6f}" ) # Tracking the best performance and saving the model if loss_valid < best_loss[1]: best_loss = [loss_train, loss_valid] path2model = self._path2temp / f"model_ae_fold{fold:03d}_epoch{epoch:04d}.pt" save(self._model, path2model) patience_counter = self._settings_train.patience else: patience_counter -= 1 # Early Stopping if patience_counter <= 0: self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!") break copy(path2model, self._path2save) self._save_train_results(best_loss[0], best_loss[1], "Loss") # --- Saving results metric_fold = { "loss_train": epoch_loss_train, "loss_valid": epoch_loss_valid, } metric_fold.update(self.__metric_result) metric_out.update({f"fold_{fold:03d}": metric_fold}) # --- Ending of all trainings phases self._end_training_routine(timestamp_start) return self._converting_tensor_to_numpy(metric_out)
[docs] def do_post_training_validation(self, do_ptq: bool = False) -> DataValidation: """Performing the post-training validation with the best model :param do_ptq: Boolean for activating post training quantization during post-training validation :return: Dataclass with results from validation phase """ if cuda.is_available(): cuda.empty_cache() # --- Do the Inference with Best Model overview_models = self.get_best_model("ae") if len(overview_models) == 0: raise RuntimeError(f"No models found on {str(self._path2save)} - Please start training!") path2model = overview_models[0] if do_ptq: model_test = quantize_model_fxp( model=load(path2model, weights_only=False), total_bits=self._ptq_level[0], frac_bits=self._ptq_level[1], ) else: model_test = load(path2model, weights_only=False) self._logger.info("=================================================================") self._logger.info(f"Do Validation with best model: {path2model}") pred_model = randn(32, 1) feat_model = randn(32, 1) clus_orig_list = randn(32, 1) data_orig_list = randn(32, 1) first_cycle = True model_test.eval() for vdata in self._valid_loader[-1]: feat, pred = model_test(vdata["in"].to(self._used_hw_dev)) if first_cycle: feat_model = feat.detach().cpu() pred_model = pred.detach().cpu() clus_orig_list = vdata["class"] data_orig_list = vdata["in"] else: feat_model = cat((feat_model, feat.detach().cpu()), dim=0) pred_model = cat((pred_model, pred.detach().cpu()), dim=0) clus_orig_list = cat((clus_orig_list, vdata["class"]), dim=0) data_orig_list = cat((data_orig_list, vdata["in"]), dim=0) first_cycle = False # --- Preparing output data_out = self._getting_data_for_plotting( valid_input=data_orig_list.numpy(), valid_label=clus_orig_list.numpy(), addon="ae", ) data_out.output = pred_model.numpy() data_out.feat = feat_model.numpy() data_out.mean = self._mean_data return data_out