import numpy as np
from dataclasses import dataclass
from logging import getLogger, Logger
from pathlib import Path
from shutil import copy
from datetime import datetime
from torch import Tensor, load, save, inference_mode, flatten, cuda, cat, concatenate, randn
from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.data_config import DatasetFromFile, SettingsDataset
from denspp.offline.dnn.training.ptq_help import quantize_model_fxp
from denspp.offline.dnn.training.autoencoder_dataset import DatasetAutoencoder
from denspp.offline.metric.snr import calculate_snr_tensor, calculate_dsnr_tensor
from .common_train import PyTorchHandler, SettingsPytorch, DataValidation
[docs]
@dataclass
class SettingsAutoencoder(SettingsPytorch):
"""Class for handling the PyTorch training/inference pipeline
Attributes:
model_name: String with the model name
patience: Integer value with number of epochs before early stopping
optimizer: String with PyTorch optimizer name
loss: String with method name for the loss function
deterministic_do: Boolean if deterministic training should be done
deterministic_seed: Integer with the seed for deterministic training
num_kfold: Integer value with applying k-fold cross validation
num_epochs: Integer value with number of epochs
batch_size: Integer value with batch size
data_split_ratio: Float value for splitting the input dataset between training and validation
data_do_shuffle: Boolean if data should be shuffled before training
custom_metrics: List with string of custom metrics to calculate during training
trainings_mode: Integer to define trainings mode of the autoencoder
[0: Autoencoder,
1: Denoising Autoencoder (mean),
2: Denoising Autoencoder (add random noise),
3: Denoising Autoencoder (add gaussian noise)]
feat_size: Integer with defining the feature size of the encoder output / decoder input
noise_std: Float value for adding noise standard deviation on input data
"""
trainings_mode: int
feat_size: int
noise_std: float
DefaultSettingsTrainingMSE = SettingsAutoencoder(
model_name='',
patience=20,
optimizer='Adam',
loss='MSE',
deterministic_do=False,
deterministic_seed=42,
num_kfold=1,
num_epochs=10,
batch_size=256,
data_do_shuffle=True,
data_split_ratio=0.2,
custom_metrics=[],
trainings_mode=0,
feat_size=4,
noise_std=0.1
)
[docs]
class TrainAutoencoder(PyTorchHandler):
_logger: Logger
_settings_train: SettingsAutoencoder
_train_loader: list
_valid_loader: list
_mean_data: np.ndarray
def __init__(self, config_train: SettingsAutoencoder, config_data: SettingsDataset, do_train: bool=True) -> None:
"""Class for Handling Training of Autoencoders
:param config_data: Settings for handling and loading the dataset (just for saving)
:param config_train: Settings for handling the PyTorch Trainings Routine of an Autoencoder
:param do_train: Do training of model otherwise only inference
:return: None
"""
PyTorchHandler.__init__(self, config_train, config_data, do_train)
self._settings_train = config_train
self._logger = getLogger(__name__)
self.__metric_buffer = dict()
self.__metric_result = dict()
self._metric_methods = {
'snr_in': self.__determine_snr_input,
'snr_out': self.__determine_snr_output,
'dsnr_all': self.__determine_dsnr_all,
'ptq_loss': self.__determine_ptq_loss
}
[docs]
def load_dataset(self, dataset: DatasetFromFile) -> None:
"""Loading the loaded dataset and transform it into right dataloader
:param dataset: Dataclass with dataset loaded from extern
:return: None
"""
dataset0 = DatasetAutoencoder(
dataset=dataset,
noise_std=self._settings_train.noise_std,
mode_train=self._settings_train.trainings_mode
)
self._mean_data = dataset0.get_mean_waveforms
self._prepare_dataset_for_training(
data_set=dataset0,
num_workers=0
)
def __do_training_epoch(self) -> float:
"""Do training during epoch of training
Return:
Floating value with training loss value
"""
train_loss = 0.0
total_batches = 0
self._model.train(True)
for tdata in self._train_loader[self._run_kfold]:
data_x = tdata['in'].to(self._used_hw_dev)
data_y = tdata['out'].to(self._used_hw_dev)
data_p = self._model(data_x)[1]
self._optimizer.zero_grad()
if len(data_y) > 2:
loss = self._loss_fn(flatten(data_p, 1), flatten(data_y, 1))
else:
loss = self._loss_fn(data_p, data_y)
loss.backward()
self._optimizer.step()
train_loss += loss.item()
total_batches += 1
return float(train_loss / total_batches)
def __do_valid_epoch(self, epoch_custom_metrics: list) -> float:
"""Do validation during epoch of training
Args:
epoch_custom_metrics: List with entries of custom-made metric calculations
Return:
Floating value with validation loss value
"""
self._total_batches_valid = 0
valid_loss = 0.0
self._model.eval()
with inference_mode():
for vdata in self._valid_loader[self._run_kfold]:
data_x = vdata['in'].to(self._used_hw_dev)
data_y = vdata['out'].to(self._used_hw_dev)
data_m = vdata['mean'].to(self._used_hw_dev)
data_id = vdata['class'].to(self._used_hw_dev)
data_p = self._model(data_x)[1]
self._total_batches_valid += 1
if len(data_y) > 2:
valid_loss += self._loss_fn(flatten(data_p, 1), flatten(data_y, 1)).item()
else:
valid_loss += self._loss_fn(data_p, data_y).item()
# --- Calculating custom made metrics
for metric_used in epoch_custom_metrics:
self._determine_epoch_metrics(metric_used)(data_x, data_p, data_m, metric=metric_used, id=data_id)
return float(valid_loss / self._total_batches_valid)
def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None:
"""Function for preparing the custom-made metric calculation
Args:
init_phase: Boolean decision if processing part is in init (True) or in training phase (False)
custom_made_metrics:List with custom metrics for calculation during validation phase
Return:
None
"""
assert check_keylist_elements_any(
keylist=custom_made_metrics,
elements=self.get_epoch_metric_custom_methods
), f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods} - Please adapt in settings!"
# --- Init phase for generating empty data structure
if init_phase:
for key0 in custom_made_metrics:
self.__metric_result.update({key0: list()})
self.__metric_buffer.update({key0: list()})
# --- Processing results
else:
for key0 in self.__metric_buffer.keys():
self.__metric_result[key0].append(self.__metric_buffer[key0])
self.__metric_buffer.update({key0: list()})
def __determine_snr_input(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_snr_tensor(input_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_snr_output(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_snr_tensor(pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_dsnr_all(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_dsnr_tensor(input_waveform, pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_ptq_loss(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
"""if not hasattr(self.model, 'bit_config'):
raise NotImplementedError('PTQ Test is only available with elasticAI.creator Models or '
'model includes variable \"bit_config\" = [total_bitwidth, frac_bitwidth]')
else:"""
# --- Load model and make inference
model_ptq = quantize_model_fxp(self._model, self._ptq_level[0], self._ptq_level[1])
model_ptq.eval()
pred_waveform_ptq = model_ptq(input_waveform)[1]
# --- Calculate loss
if len(input_waveform) > 2:
loss = self._loss_fn(flatten(pred_waveform_ptq, 1), flatten(input_waveform, 1)).item() / self._total_batches_valid
else:
loss = self._loss_fn(pred_waveform_ptq, input_waveform).item() / self._total_batches_valid
# --- Saving results
if len(self.__metric_buffer[kwargs['metric']]):
self.__metric_buffer[kwargs['metric']][0] = self.__metric_buffer[kwargs['metric']][0] + loss
else:
self.__metric_buffer[kwargs['metric']].append(loss)
[docs]
def do_training(self, path2save=Path(".")) -> dict:
"""Start model training incl. validation and custom-own metric calculation
Args:
path2save: Path for saving the results [Default: '' --> generate new folder]
Returns:
Dictionary with metrics from training (loss_train, loss_valid, own_metrics)
"""
metrics = self._settings_train.custom_metrics
self._init_train(path2save=path2save, addon='_ae')
if self._kfold_do:
self._logger.info(f"Starting Kfold cross validation training in {self._settings_train.num_kfold} steps")
metric_out = dict()
path2model = str()
path2model_init = self._path2save / f'model_ae_reset.pt'
save(self._model.state_dict(), path2model_init)
timestamp_start = datetime.now()
timestamp_string = timestamp_start.strftime('%H:%M:%S')
self._logger.info(f'\nTraining starts on {timestamp_string}')
self._logger.info("=====================================================================================")
self.__process_epoch_metrics_calculation(True, metrics)
for fold in np.arange(self._settings_train.num_kfold):
# --- Init fold
best_loss = np.array((1_000_000., 1_000_000.), dtype=float)
patience_counter = self._settings_train.patience
epoch_loss_train = list()
epoch_loss_valid = list()
self._model.load_state_dict(load(path2model_init, weights_only=False))
self._run_kfold = fold
if self._kfold_do:
self._logger.info(f'\nStarting with Fold #{fold}')
for epoch in range(0, self._settings_train.num_epochs):
if self._settings_train.deterministic_do:
self._deterministic_generator.manual_seed(self._settings_train.deterministic_seed + epoch)
loss_train = self.__do_training_epoch()
loss_valid = self.__do_valid_epoch(metrics)
epoch_loss_train.append(loss_train)
epoch_loss_valid.append(loss_valid)
self.__process_epoch_metrics_calculation(False, metrics)
self._logger.info(f'... results of epoch {epoch + 1}/{self._settings_train.num_epochs} '
f'[{(epoch + 1) / self._settings_train.num_epochs * 100:.2f} %]: '
f'train_loss = {loss_train:.5f},'
f'\tvalid_loss = {loss_valid:.5f},'
f'\tdelta_loss = {loss_train - loss_valid:.6f}')
# Tracking the best performance and saving the model
if loss_valid < best_loss[1]:
best_loss = [loss_train, loss_valid]
path2model = self._path2temp / f'model_ae_fold{fold:03d}_epoch{epoch:04d}.pt'
save(self._model, path2model)
patience_counter = self._settings_train.patience
else:
patience_counter -= 1
# Early Stopping
if patience_counter <= 0:
self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!")
break
copy(path2model, self._path2save)
self._save_train_results(best_loss[0], best_loss[1], 'Loss')
# --- Saving results
metric_fold = {
'loss_train': epoch_loss_train,
'loss_valid': epoch_loss_valid
}
metric_fold.update(self.__metric_result)
metric_out.update({f"fold_{fold:03d}": metric_fold})
# --- Ending of all trainings phases
self._end_training_routine(timestamp_start)
return self._converting_tensor_to_numpy(metric_out)
[docs]
def do_post_training_validation(self, do_ptq: bool=False) -> DataValidation:
"""Performing the post-training validation with the best model
:param do_ptq: Boolean for activating post training quantization during post-training validation
:return: Dataclass with results from validation phase
"""
if cuda.is_available():
cuda.empty_cache()
# --- Do the Inference with Best Model
overview_models = self.get_best_model('ae')
if len(overview_models) == 0:
raise RuntimeError(f"No models found on {str(self._path2save)} - Please start training!")
path2model = overview_models[0]
if do_ptq:
model_test = quantize_model_fxp(
model=load(path2model, weights_only=False),
total_bits=self._ptq_level[0],
frac_bits=self._ptq_level[1]
)
else:
model_test = load(path2model, weights_only=False)
self._logger.info("=================================================================")
self._logger.info(f"Do Validation with best model: {path2model}")
pred_model = randn(32, 1)
feat_model = randn(32, 1)
clus_orig_list = randn(32, 1)
data_orig_list = randn(32, 1)
first_cycle = True
model_test.eval()
for vdata in self._valid_loader[-1]:
feat, pred = model_test(vdata['in'].to(self._used_hw_dev))
if first_cycle:
feat_model = feat.detach().cpu()
pred_model = pred.detach().cpu()
clus_orig_list = vdata['class']
data_orig_list = vdata['in']
else:
feat_model = cat((feat_model, feat.detach().cpu()), dim=0)
pred_model = cat((pred_model, pred.detach().cpu()), dim=0)
clus_orig_list = cat((clus_orig_list, vdata['class']), dim=0)
data_orig_list = cat((data_orig_list, vdata['in']), dim=0)
first_cycle = False
# --- Preparing output
data_out = self._getting_data_for_plotting(
valid_input=data_orig_list.numpy(),
valid_label=clus_orig_list.numpy(),
addon='ae'
)
data_out.output = pred_model.numpy()
data_out.feat = feat_model.numpy()
data_out.mean = self._mean_data
return data_out