from dataclasses import dataclass
from datetime import datetime
from logging import Logger, getLogger
from pathlib import Path
from shutil import copy
import numpy as np
from torch import (
Tensor,
cat,
concatenate,
cuda,
flatten,
from_numpy,
inference_mode,
load,
randn,
save,
)
from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.data_config import DatasetFromFile, SettingsDataset
from denspp.offline.dnn.training.autoencoder_dataset import DatasetAutoencoder
from denspp.offline.dnn.training.ptq_help import quantize_model_fxp
from denspp.offline.metric.snr import calculate_dsnr_tensor, calculate_snr_tensor
from .common_train import DataValidation, PyTorchHandler, SettingsPytorch
[docs]
@dataclass
class SettingsAutoencoder(SettingsPytorch):
"""Class for handling the PyTorch training/inference pipeline
Attributes:
model_name: String with the model name
patience: Integer value with number of epochs before early stopping
optimizer: String with PyTorch optimizer name
loss: String with method name for the loss function
deterministic_do: Boolean if deterministic training should be done
deterministic_seed: Integer with the seed for deterministic training
num_kfold: Integer value with applying k-fold cross validation
num_epochs: Integer value with number of epochs
batch_size: Integer value with batch size
data_split_ratio: Float value for splitting the input dataset between training and validation
data_do_shuffle: Boolean if data should be shuffled before training
custom_metrics: List with string of custom metrics to calculate during training
trainings_mode: Integer to define trainings mode of the autoencoder
[0: Autoencoder,
1: Denoising Autoencoder (mean),
2: Denoising Autoencoder (add random noise),
3: Denoising Autoencoder (add gaussian noise)]
feat_size: Integer with defining the feature size of the encoder output / decoder input
noise_std: Float value for adding noise standard deviation on input data
"""
trainings_mode: int
feat_size: int
noise_std: float
DefaultSettingsTrainingMSE = SettingsAutoencoder(
model_name="",
patience=20,
optimizer="Adam",
loss="MSE",
deterministic_do=False,
deterministic_seed=42,
num_kfold=1,
num_epochs=10,
batch_size=256,
data_do_shuffle=True,
data_split_ratio=0.2,
custom_metrics=[],
trainings_mode=0,
feat_size=4,
noise_std=0.1,
)
[docs]
class TrainAutoencoder(PyTorchHandler):
_logger: Logger
_settings_train: SettingsAutoencoder
_train_loader: list
_valid_loader: list
_mean_data: np.ndarray
def __init__(
self,
config_train: SettingsAutoencoder,
config_data: SettingsDataset,
do_train: bool = True,
) -> None:
"""Class for Handling Training of Autoencoders
:param config_data: Settings for handling and loading the dataset (just for saving)
:param config_train: Settings for handling the PyTorch Trainings Routine of an Autoencoder
:param do_train: Do training of model otherwise only inference
:return: None
"""
PyTorchHandler.__init__(self, config_train, config_data, do_train)
self._settings_train = config_train
self._logger = getLogger(__name__)
self.__metric_buffer = dict()
self.__metric_result = dict()
self._metric_methods = {
"snr_in": self.__determine_snr_input,
"snr_out": self.__determine_snr_output,
"dsnr_all": self.__determine_dsnr_all,
"ptq_loss": self.__determine_ptq_loss,
}
[docs]
def load_dataset(self, dataset: DatasetFromFile) -> None:
"""Loading the loaded dataset and transform it into right dataloader
:param dataset: Dataclass with dataset loaded from extern
:return: None
"""
dataset0 = DatasetAutoencoder(
dataset=dataset,
noise_std=self._settings_train.noise_std,
mode_train=self._settings_train.trainings_mode,
)
self._mean_data = dataset0.get_mean_waveforms
self._prepare_dataset_for_training(data_set=dataset0, num_workers=0)
def __do_training_epoch(self) -> float:
"""Do training during epoch of training
Return:
Floating value with training loss value
"""
train_loss = 0.0
total_batches = 0
self._model.train(True)
for tdata in self._train_loader[self._run_kfold]:
data_x = tdata["in"].to(self._used_hw_dev)
data_y = tdata["out"].to(self._used_hw_dev)
data_p = self._model(data_x)[1]
self._optimizer.zero_grad()
if len(data_y) > 2:
loss = self._loss_fn(flatten(data_p, 1), flatten(data_y, 1))
else:
loss = self._loss_fn(data_p, data_y)
loss.backward()
self._optimizer.step()
train_loss += loss.item()
total_batches += 1
return float(train_loss / total_batches)
def __do_valid_epoch(self, epoch_custom_metrics: list) -> float:
"""Do validation during epoch of training
Args:
epoch_custom_metrics: List with entries of custom-made metric calculations
Return:
Floating value with validation loss value
"""
self._total_batches_valid = 0
valid_loss = 0.0
self._model.eval()
with inference_mode():
for vdata in self._valid_loader[self._run_kfold]:
data_x = vdata["in"].to(self._used_hw_dev)
data_y = vdata["out"].to(self._used_hw_dev)
data_m = vdata["mean"].to(self._used_hw_dev)
data_id = vdata["class"].to(self._used_hw_dev)
data_p = self._model(data_x)[1]
self._total_batches_valid += 1
if len(data_y) > 2:
valid_loss += self._loss_fn(flatten(data_p, 1), flatten(data_y, 1)).item()
else:
valid_loss += self._loss_fn(data_p, data_y).item()
# --- Calculating custom made metrics
for metric_used in epoch_custom_metrics:
self._determine_epoch_metrics(metric_used)(
data_x, data_p, data_m, metric=metric_used, id=data_id
)
return float(valid_loss / self._total_batches_valid)
def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None:
"""Function for preparing the custom-made metric calculation
Args:
init_phase: Boolean decision if processing part is in init (True) or in training phase (False)
custom_made_metrics:List with custom metrics for calculation during validation phase
Return:
None
"""
assert check_keylist_elements_any(
keylist=custom_made_metrics, elements=self.get_epoch_metric_custom_methods
), (
f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods} - Please adapt in settings!"
)
# --- Init phase for generating empty data structure
if init_phase:
for key0 in custom_made_metrics:
self.__metric_result.update({key0: list()})
self.__metric_buffer.update({key0: list()})
# --- Processing results
else:
for key0 in self.__metric_buffer.keys():
self.__metric_result[key0].append(self.__metric_buffer[key0])
self.__metric_buffer.update({key0: list()})
def __determine_snr_input(
self,
input_waveform: Tensor,
pred_waveform: Tensor,
mean_waveform: Tensor,
**kwargs,
) -> None:
out = calculate_snr_tensor(input_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs["metric"]], list):
self.__metric_buffer[kwargs["metric"]] = out
else:
self.__metric_buffer[kwargs["metric"]] = concatenate(
(self.__metric_buffer[kwargs["metric"]], out), dim=0
)
def __determine_snr_output(
self,
input_waveform: Tensor,
pred_waveform: Tensor,
mean_waveform: Tensor,
**kwargs,
) -> None:
out = calculate_snr_tensor(pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs["metric"]], list):
self.__metric_buffer[kwargs["metric"]] = out
else:
self.__metric_buffer[kwargs["metric"]] = concatenate(
(self.__metric_buffer[kwargs["metric"]], out), dim=0
)
def __determine_dsnr_all(
self,
input_waveform: Tensor,
pred_waveform: Tensor,
mean_waveform: Tensor,
**kwargs,
) -> None:
out = calculate_dsnr_tensor(input_waveform, pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs["metric"]], list):
self.__metric_buffer[kwargs["metric"]] = out
else:
self.__metric_buffer[kwargs["metric"]] = concatenate(
(self.__metric_buffer[kwargs["metric"]], out), dim=0
)
def __determine_ptq_loss(
self,
input_waveform: Tensor,
pred_waveform: Tensor,
mean_waveform: Tensor,
**kwargs,
) -> None:
"""if not hasattr(self.model, 'bit_config'):
raise NotImplementedError('PTQ Test is only available with elasticAI.creator Models or '
'model includes variable \"bit_config\" = [total_bitwidth, frac_bitwidth]')
else:"""
# --- Load model and make inference
model_ptq = quantize_model_fxp(self._model, self._ptq_level[0], self._ptq_level[1])
model_ptq.eval()
pred_waveform_ptq = model_ptq(input_waveform)[1]
# --- Calculate loss
if len(input_waveform) > 2:
loss = (
self._loss_fn(flatten(pred_waveform_ptq, 1), flatten(input_waveform, 1)).item()
/ self._total_batches_valid
)
else:
loss = self._loss_fn(pred_waveform_ptq, input_waveform).item() / self._total_batches_valid
# --- Saving results
if len(self.__metric_buffer[kwargs["metric"]]):
self.__metric_buffer[kwargs["metric"]][0] = self.__metric_buffer[kwargs["metric"]][0] + loss
else:
self.__metric_buffer[kwargs["metric"]].append(loss)
[docs]
def do_training(self, path2save=Path(".")) -> dict:
"""Start model training incl. validation and custom-own metric calculation
Args:
path2save: Path for saving the results [Default: '' --> generate new folder]
Returns:
Dictionary with metrics from training (loss_train, loss_valid, own_metrics)
"""
metrics = self._settings_train.custom_metrics
self._init_train(path2save=path2save, addon="_ae")
if self._kfold_do:
self._logger.info(
f"Starting Kfold cross validation training in {self._settings_train.num_kfold} steps"
)
metric_out = dict()
path2model = str()
path2model_init = self._path2save / "model_ae_reset.pt"
save(self._model.state_dict(), path2model_init)
timestamp_start = datetime.now()
timestamp_string = timestamp_start.strftime("%H:%M:%S")
self._logger.info(f"\nTraining starts on {timestamp_string}")
self._logger.info(
"====================================================================================="
)
self.__process_epoch_metrics_calculation(True, metrics)
for fold in np.arange(self._settings_train.num_kfold):
# --- Init fold
best_loss = np.array((1_000_000.0, 1_000_000.0), dtype=float)
patience_counter = self._settings_train.patience
epoch_loss_train = list()
epoch_loss_valid = list()
self._model.load_state_dict(load(path2model_init, weights_only=False))
self._run_kfold = fold
if self._kfold_do:
self._logger.info(f"\nStarting with Fold #{fold}")
for epoch in range(0, self._settings_train.num_epochs):
if self._settings_train.deterministic_do:
self._deterministic_generator.manual_seed(
self._settings_train.deterministic_seed + epoch
)
loss_train = self.__do_training_epoch()
loss_valid = self.__do_valid_epoch(metrics)
epoch_loss_train.append(loss_train)
epoch_loss_valid.append(loss_valid)
self.__process_epoch_metrics_calculation(False, metrics)
self._logger.info(
f"... results of epoch {epoch + 1}/{self._settings_train.num_epochs} "
f"[{(epoch + 1) / self._settings_train.num_epochs * 100:.2f} %]: "
f"train_loss = {loss_train:.5f},"
f"\tvalid_loss = {loss_valid:.5f},"
f"\tdelta_loss = {loss_train - loss_valid:.6f}"
)
# Tracking the best performance and saving the model
if loss_valid < best_loss[1]:
best_loss = [loss_train, loss_valid]
path2model = self._path2temp / f"model_ae_fold{fold:03d}_epoch{epoch:04d}.pt"
save(self._model, path2model)
patience_counter = self._settings_train.patience
else:
patience_counter -= 1
# Early Stopping
if patience_counter <= 0:
self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!")
break
copy(path2model, self._path2save)
self._save_train_results(best_loss[0], best_loss[1], "Loss")
# --- Saving results
metric_fold = {
"loss_train": epoch_loss_train,
"loss_valid": epoch_loss_valid,
}
metric_fold.update(self.__metric_result)
metric_out.update({f"fold_{fold:03d}": metric_fold})
# --- Ending of all trainings phases
self._end_training_routine(timestamp_start)
return self._converting_tensor_to_numpy(metric_out)
[docs]
def do_post_training_validation(self, do_ptq: bool = False) -> DataValidation:
"""Performing the post-training validation with the best model
:param do_ptq: Boolean for activating post training quantization during post-training validation
:return: Dataclass with results from validation phase
"""
if cuda.is_available():
cuda.empty_cache()
# --- Do the Inference with Best Model
overview_models = self.get_best_model("ae")
if len(overview_models) == 0:
raise RuntimeError(f"No models found on {str(self._path2save)} - Please start training!")
path2model = overview_models[0]
if do_ptq:
model_test = quantize_model_fxp(
model=load(path2model, weights_only=False),
total_bits=self._ptq_level[0],
frac_bits=self._ptq_level[1],
)
else:
model_test = load(path2model, weights_only=False)
self._logger.info("=================================================================")
self._logger.info(f"Do Validation with best model: {path2model}")
pred_model = randn(32, 1)
feat_model = randn(32, 1)
clus_orig_list = randn(32, 1)
data_orig_list = randn(32, 1)
first_cycle = True
model_test.eval()
for vdata in self._valid_loader[-1]:
feat, pred = model_test(vdata["in"].to(self._used_hw_dev))
if first_cycle:
feat_model = feat.detach().cpu()
pred_model = pred.detach().cpu()
clus_orig_list = vdata["class"]
data_orig_list = vdata["in"]
else:
feat_model = cat((feat_model, feat.detach().cpu()), dim=0)
pred_model = cat((pred_model, pred.detach().cpu()), dim=0)
clus_orig_list = cat((clus_orig_list, vdata["class"]), dim=0)
data_orig_list = cat((data_orig_list, vdata["in"]), dim=0)
first_cycle = False
# --- Preparing output
data_out = self._getting_data_for_plotting(
valid_input=data_orig_list.numpy(),
valid_label=clus_orig_list.numpy(),
addon="ae",
)
data_out.output = pred_model.numpy()
data_out.feat = feat_model.numpy()
data_out.mean = self._mean_data
return data_out