Source code for denspp.offline.dnn.training.classifier_train

from dataclasses import dataclass
from datetime import datetime
from logging import Logger, getLogger
from pathlib import Path
from shutil import copy

import numpy as np
from torch import (
    Tensor,
    add,
    cat,
    concatenate,
    cuda,
    div,
    inference_mode,
    load,
    randn,
    save,
    tensor,
    zeros,
)

from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.data_config import DatasetFromFile, SettingsDataset
from denspp.offline.dnn.training.classifier_dataset import DatasetClassifier
from denspp.offline.dnn.training.ptq_help import quantize_model_fxp
from denspp.offline.metric.data_torch import (
    calculate_fbeta,
    calculate_number_true_predictions,
    calculate_precision,
    calculate_recall,
)

from .common_train import DataValidation, PyTorchHandler, SettingsPytorch


[docs] @dataclass class SettingsClassifier(SettingsPytorch): """Class for handling the PyTorch training/inference pipeline Attributes: model_name: String with the model name patience: Integer value with number of epochs before early stopping optimizer: String with PyTorch optimizer name loss: String with method name for the loss function deterministic_do: Boolean if deterministic training should be done deterministic_seed: Integer with the seed for deterministic training num_kfold: Integer value with applying k-fold cross validation num_epochs: Integer value with number of epochs batch_size: Integer value with batch size data_split_ratio: Float value for splitting the input dataset between training and validation data_do_shuffle: Boolean if data should be shuffled before training custom_metrics: List with string of custom metrics to calculate during training """ pass
DefaultSettingsTrainingCE = SettingsClassifier( model_name="", patience=20, optimizer="Adam", loss="Cross Entropy", num_kfold=1, num_epochs=10, batch_size=256, data_do_shuffle=True, data_split_ratio=0.2, deterministic_do=False, deterministic_seed=42, custom_metrics=[], )
[docs] class TrainClassifier(PyTorchHandler): _logger: Logger _settings_train: SettingsClassifier _train_loader: list _valid_loader: list def __init__( self, config_train: SettingsClassifier, config_data: SettingsDataset, do_train: bool = True, ) -> None: """Class for Handling Training of Classifiers :param config_data: Settings for handling and loading the dataset (just for saving) :param config_train: Settings for handling the PyTorch Trainings Routine of a Classifier :param do_train: Do training of model otherwise only inference :return: None """ PyTorchHandler.__init__(self, config_train, config_data, do_train) self._logger = getLogger(__name__) self.__metric_buffer = dict() self.__metric_result = dict() self._metric_methods = { "accuracy": self.__determine_accuracy_per_class, "precision": self.__determine_buffering_metric_calculation, "recall": self.__determine_buffering_metric_calculation, "fbeta": self.__determine_buffering_metric_calculation, "ptq_acc": self.__determine_ptq_acc, }
[docs] def load_dataset(self, dataset: DatasetFromFile) -> None: """Loading the loaded dataset and transform it into right dataloader :param dataset: Dataclass with dataset loaded from extern :return: None """ dataset0 = DatasetClassifier( dataset=dataset, ) self._prepare_dataset_for_training(data_set=dataset0, num_workers=0)
def __do_training_epoch(self) -> tuple[float, float]: """Do training during epoch of training Return: Floating value of training loss and accuracy of used epoch """ train_loss = 0.0 total_batches = 0 total_correct = 0 total_samples = 0 self._model.train(True) for tdata in self._train_loader[self._run_kfold]: self._optimizer.zero_grad() tdata_out = tdata["out"].to(self._used_hw_dev) pred_cl, dec_cl = self._model(tdata["in"].to(self._used_hw_dev)) loss = self._loss_fn(pred_cl, tdata_out) loss.backward() self._optimizer.step() train_loss += loss.item() total_batches += 1 total_correct += calculate_number_true_predictions(dec_cl, tdata_out) total_samples += len(tdata["in"]) train_acc = float(int(total_correct) / total_samples) train_loss = float(train_loss / total_batches) return train_loss, train_acc def __do_valid_epoch(self, epoch_custom_metrics: list) -> tuple[float, float]: """Do validation during epoch of training Args: epoch_custom_metrics: List with entries of custom-made metric calculations Return: Floating value of validation loss and validation accuracy of used epoch """ valid_loss = 0.0 total_batches = 0 total_correct = 0 total_samples = 0 self._model.eval() with inference_mode(): for vdata in self._valid_loader[self._run_kfold]: # --- Validation phase of model pred_cl, dec_cl = self._model(vdata["in"].to(self._used_hw_dev)) true_cl = vdata["out"].to(self._used_hw_dev) valid_loss += self._loss_fn(pred_cl, true_cl).item() total_batches += 1 total_correct += calculate_number_true_predictions(dec_cl, true_cl) total_samples += len(vdata["in"]) # --- Calculating custom made metrics for metric_used in epoch_custom_metrics: self._determine_epoch_metrics(metric_used)( dec_cl, true_cl, metric=metric_used, frame=vdata["in"].to(self._used_hw_dev), ) valid_acc = float(int(total_correct) / total_samples) valid_loss = float(valid_loss / total_batches) return valid_loss, valid_acc def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None: """Function for preparing the custom-made metric calculation Args: init_phase: Boolean decision if processing part is in init (True) or in post-training phase (False) custom_made_metrics:List with custom metrics for calculation during validation phase Return: None """ assert check_keylist_elements_any( keylist=custom_made_metrics, elements=self.get_epoch_metric_custom_methods ), ( f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods} - Please adapt in settings!" ) # --- Init phase for generating empty data structure if init_phase: for key0 in custom_made_metrics: self.__metric_result.update({key0: list()}) match key0: case "accuracy": self.__metric_buffer.update( { key0: [ zeros((len(self._cell_classes),)), zeros((len(self._cell_classes),)), ] } ) case "precision": self.__metric_buffer.update({key0: [[], []]}) case "recall": self.__metric_buffer.update({key0: [[], []]}) case "fbeta": self.__metric_buffer.update({key0: [[], []]}) case "ptq_acc": self.__metric_buffer.update( { key0: [ zeros( 1, ), zeros( 1, ), ] } ) # --- Processing results else: for key0 in self.__metric_buffer.keys(): match key0: case "accuracy": self.__metric_result[key0].append( div( self.__metric_buffer[key0][0], self.__metric_buffer[key0][1], ) ) self.__metric_buffer.update( { key0: [ zeros((len(self._cell_classes),)), zeros((len(self._cell_classes),)), ] } ) case "precision": out = self._separate_classes_from_label( self.__metric_buffer[key0][0], self.__metric_buffer[key0][1], key0, calculate_precision, ) self.__metric_result[key0].append(out[0]) self.__metric_buffer.update({key0: [[], []]}) case "recall": out = self._separate_classes_from_label( self.__metric_buffer[key0][0], self.__metric_buffer[key0][1], key0, calculate_recall, ) self.__metric_result[key0].append(out[0]) self.__metric_buffer.update({key0: [[], []]}) case "fbeta": out = self._separate_classes_from_label( self.__metric_buffer[key0][0], self.__metric_buffer[key0][1], key0, calculate_fbeta, ) self.__metric_result[key0].append(out[0]) self.__metric_buffer.update({key0: [[], []]}) case "ptq_acc": self.__metric_result[key0].append( div( self.__metric_buffer[key0][0], self.__metric_buffer[key0][1], ) ) self.__metric_buffer.update( { key0: [ zeros( 1, ), zeros( 1, ), ] } ) def __determine_accuracy_per_class(self, pred: Tensor, true: Tensor, **kwargs) -> None: out = self._separate_classes_from_label( pred, true, kwargs["metric"], calculate_number_true_predictions ) self.__metric_buffer[kwargs["metric"]][0] = add(self.__metric_buffer[kwargs["metric"]][0], out[0]) self.__metric_buffer[kwargs["metric"]][1] = add(self.__metric_buffer[kwargs["metric"]][1], out[1]) def __determine_buffering_metric_calculation(self, pred: Tensor, true: Tensor, **kwargs) -> None: if len(self.__metric_buffer[kwargs["metric"]][0]) == 0: self.__metric_buffer[kwargs["metric"]][0] = true self.__metric_buffer[kwargs["metric"]][1] = pred else: self.__metric_buffer[kwargs["metric"]][0] = concatenate( (self.__metric_buffer[kwargs["metric"]][0], true), dim=0 ) self.__metric_buffer[kwargs["metric"]][1] = concatenate( (self.__metric_buffer[kwargs["metric"]][1], pred), dim=0 ) def __determine_ptq_acc(self, pred: Tensor, true: Tensor, **kwargs) -> None: model_ptq = quantize_model_fxp( model=self._model, total_bits=self._ptq_level[0], frac_bits=self._ptq_level[1], ) model_ptq.eval() pred_cl, dec_cl = model_ptq(kwargs["frame"]) num_true = calculate_number_true_predictions(dec_cl, true) a = tensor([num_true]) b = tensor([kwargs["frame"].shape[0]]) if self.__metric_buffer[kwargs["metric"]][0].size == 1: self.__metric_buffer[kwargs["metric"]][0] = a self.__metric_buffer[kwargs["metric"]][1] = b else: self.__metric_buffer[kwargs["metric"]][0] = concatenate( (self.__metric_buffer[kwargs["metric"]][0], a), dim=0 ) self.__metric_buffer[kwargs["metric"]][1] = concatenate( (self.__metric_buffer[kwargs["metric"]][1], b), dim=0 )
[docs] def do_training(self, path2save=Path(".")) -> dict: """Start model training incl. validation and custom-own metric calculation Args: path2save: Path for saving the results [Default: '' --> generate new folder] Returns: Dictionary with metrics from training (loss_train, loss_valid, own_metrics) """ metrics = self._settings_train.custom_metrics self._init_train(path2save=path2save, addon="_cl") if self._kfold_do: self._logger.info( f"Starting Kfold cross validation training in {self._settings_train.num_kfold} steps" ) path2model = str() path2model_init = self._path2save / "model_cl_reset.pt" save(self._model.state_dict(), path2model_init) timestamp_start = datetime.now() timestamp_string = timestamp_start.strftime("%H:%M:%S") self._logger.info(f"Training starts on {timestamp_string}") self._logger.info( "=====================================================================================" ) metric_out = dict() self.__process_epoch_metrics_calculation(True, metrics) for fold in np.arange(self._settings_train.num_kfold): # --- Init fold best_loss = [1e6, 1e6] best_acc = [0.0, 0.0] patience_counter = self._settings_train.patience epoch_train_acc = list() epoch_valid_acc = list() epoch_train_loss = list() epoch_valid_loss = list() self._model.load_state_dict(load(path2model_init, weights_only=False)) self._run_kfold = fold if self._kfold_do: self._logger.info(f"Starting with Fold #{fold}") for epoch in range(0, self._settings_train.num_epochs): if self._settings_train.deterministic_do: self._deterministic_generator.manual_seed( self._settings_train.deterministic_seed + epoch ) train_loss, train_acc = self.__do_training_epoch() valid_loss, valid_acc = self.__do_valid_epoch(metrics) self._logger.info( f"... results of epoch {epoch + 1}/{self._settings_train.num_epochs} " f"[{(epoch + 1) / self._settings_train.num_epochs * 100:.2f} %]: " f"train_loss = {train_loss:.5f}, delta_loss = {train_loss - valid_loss:.5f}, " f"train_acc = {100 * train_acc:.4f} %, delta_acc = {100 * (train_acc - valid_acc):.4f} %" ) # Saving metrics after each epoch epoch_train_acc.append(train_acc) epoch_train_loss.append(train_loss) epoch_valid_acc.append(valid_acc) epoch_valid_loss.append(valid_loss) self.__process_epoch_metrics_calculation(False, metrics) # Tracking the best performance and saving the model if valid_loss < best_loss[1]: best_loss = [train_loss, valid_loss] best_acc = [train_acc, valid_acc] path2model = self._path2temp / f"model_cl_fold{fold:03d}_epoch{epoch:04d}.pt" save(self._model, path2model) patience_counter = self._settings_train.patience else: patience_counter -= 1 # Early Stopping if patience_counter <= 0: self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!") break copy(path2model, self._path2save) self._save_train_results(best_loss[0], best_loss[1], "Loss") self._save_train_results(best_acc[0], best_acc[1], "Acc.") # --- Saving metrics after each fold metric_fold = { "acc_train": epoch_train_acc, "acc_valid": epoch_valid_acc, "loss_train": epoch_train_loss, "loss_valid": epoch_valid_loss, } metric_fold.update(self.__metric_result) metric_out.update({f"fold_{fold:03d}": metric_fold}) # --- Ending of all trainings phases self._end_training_routine(timestamp_start) return self._converting_tensor_to_numpy(metric_out)
[docs] def do_post_training_validation(self, do_ptq: bool = False) -> DataValidation: """Performing the post-training validation with the best model :param do_ptq: Boolean for activating post training quantization during post-training validation :return: Dataclass with results from validation phase """ if cuda.is_available(): cuda.empty_cache() # --- Do the Inference with Best Model overview_models = self.get_best_model("cl") if len(overview_models) == 0: raise RuntimeError(f"No models found on {self._path2save} - Please start training!") path2model = overview_models[0] if do_ptq: model_test = quantize_model_fxp( model=load(path2model, weights_only=False), total_bits=self._ptq_level[0], frac_bits=self._ptq_level[1], ) else: model_test = load(path2model, weights_only=False) self._logger.info("=================================================================") self._logger.info(f"Do Validation with best model: {path2model}") clus_pred_list = randn(32, 1) clus_orig_list = randn(32, 1) data_orig_list = randn(32, 1) first_cycle = True model_test.eval() for vdata in self._valid_loader[-1]: _, clus_pred = model_test(vdata["in"].to(self._used_hw_dev)) if first_cycle: clus_pred_list = clus_pred.detach().cpu() clus_orig_list = vdata["out"] data_orig_list = vdata["in"] else: clus_pred_list = cat((clus_pred_list, clus_pred.detach().cpu()), dim=0) clus_orig_list = cat((clus_orig_list, vdata["out"]), dim=0) data_orig_list = cat((data_orig_list, vdata["in"]), dim=0) first_cycle = False # --- Preparing output data_out = self._getting_data_for_plotting( valid_input=data_orig_list.numpy(), valid_label=clus_orig_list.numpy(), addon="cl", ) data_out.output = clus_pred_list.numpy() return data_out