import numpy as np
from logging import getLogger, Logger
from os.path import join
from shutil import copy
from datetime import datetime
from torch import Tensor, load, save, inference_mode, flatten, cuda, cat, concatenate, randn
from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.ptq_help import quantize_model_fxp
from denspp.offline.dnn.pytorch_handler import ConfigPytorch, SettingsDataset, PyTorchHandler
from denspp.offline.metric.snr import calculate_snr_tensor_waveform, calculate_dsnr_tensor_waveform
[docs]
class TrainAutoencoder(PyTorchHandler):
_logger: Logger
def __init__(self, config_train: ConfigPytorch, config_data: SettingsDataset, do_train: bool=True) -> None:
"""Class for Handling Training of Autoencoders
:param config_data: Settings for handling and loading the dataset (just for saving)
:param config_train: Settings for handling the PyTorch Trainings Routine
:param do_train: Do training of model otherwise only inference
:return: None
"""
PyTorchHandler.__init__(self, config_train, config_data, do_train)
self._logger = getLogger(__name__)
self.__metric_buffer = dict()
self.__metric_result = dict()
self._metric_methods = {'snr_in': self.__determine_snr_input, 'snr_in_cl': self.__determine_snr_input_class,
'snr_out': self.__determine_snr_output, 'snr_out_cl': self.__determine_snr_output_class,
'dsnr_all': self.__determine_dsnr_all, 'dsnr_cl': self.__determine_dsnr_class,
'ptq_loss': self.__determine_ptq_loss}
def __do_training_epoch(self) -> float:
"""Do training during epoch of training
Return:
Floating value with training loss value
"""
train_loss = 0.0
total_batches = 0
self.model.train(True)
for tdata in self.train_loader[self._run_kfold]:
data_x = tdata['in'].to(self.used_hw_dev)
data_y = tdata['out'].to(self.used_hw_dev)
data_p = self.model(data_x)[1]
self.optimizer.zero_grad()
if len(data_y) > 2:
loss = self.loss_fn(flatten(data_p, 1), flatten(data_y, 1))
else:
loss = self.loss_fn(data_p, data_y)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
total_batches += 1
return float(train_loss / total_batches)
def __do_valid_epoch(self, epoch_custom_metrics: list) -> float:
"""Do validation during epoch of training
Args:
epoch_custom_metrics: List with entries of custom-made metric calculations
Return:
Floating value with validation loss value
"""
self.total_batches_valid = 0
valid_loss = 0.0
self.model.eval()
with inference_mode():
for vdata in self.valid_loader[self._run_kfold]:
data_x = vdata['in'].to(self.used_hw_dev)
data_y = vdata['out'].to(self.used_hw_dev)
data_m = vdata['mean'].to(self.used_hw_dev)
data_id = vdata['class'].to(self.used_hw_dev)
data_p = self.model(data_x)[1]
self.total_batches_valid += 1
if len(data_y) > 2:
valid_loss += self.loss_fn(flatten(data_p, 1), flatten(data_y, 1)).item()
else:
valid_loss += self.loss_fn(data_p, data_y).item()
# --- Calculating custom made metrics
for metric_used in epoch_custom_metrics:
self._determine_epoch_metrics(metric_used)(data_x, data_p, data_m, metric=metric_used, id=data_id)
return float(valid_loss / self.total_batches_valid)
def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None:
"""Function for preparing the custom-made metric calculation
Args:
init_phase: Boolean decision if processing part is in init (True) or in training phase (False)
custom_made_metrics:List with custom metrics for calculation during validation phase
Return:
None
"""
assert check_keylist_elements_any(
keylist=custom_made_metrics,
elements=self.get_epoch_metric_custom_methods()
), f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods()} - Please adapt in settings!"
# --- Init phase for generating empty data structure
if init_phase:
for key0 in custom_made_metrics:
self.__metric_result.update({key0: list()})
self.__metric_buffer.update({key0: list()})
# --- Processing results
else:
for key0 in self.__metric_buffer.keys():
self.__metric_result[key0].append(self.__metric_buffer[key0])
self.__metric_buffer.update({key0: list()})
def __determine_snr_input(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_snr_tensor_waveform(input_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_snr_input_class(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = self._separate_classes_from_label(
pred=calculate_snr_tensor_waveform(input_waveform, mean_waveform),
true=kwargs['id'], label=kwargs['metric']
)
if len(self.__metric_buffer[kwargs['metric']]) == 0:
self.__metric_buffer[kwargs['metric']] = out[0]
else:
for idx, snr_class in enumerate(out[0]):
old = self.__metric_buffer[kwargs['metric']][idx]
self.__metric_buffer[kwargs['metric']][idx] = concatenate((old, snr_class), dim=0)
def __determine_snr_output(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_snr_tensor_waveform(pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_snr_output_class(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = self._separate_classes_from_label(
pred=calculate_snr_tensor_waveform(pred_waveform, mean_waveform),
true=kwargs['id'], label=kwargs['metric']
)
if len(self.__metric_buffer[kwargs['metric']]) == 0:
self.__metric_buffer[kwargs['metric']] = out[0]
else:
for idx, snr_class in enumerate(out[0]):
old = self.__metric_buffer[kwargs['metric']][idx][0]
self.__metric_buffer[kwargs['metric']][idx] = concatenate((old, snr_class), dim=0)
def __determine_dsnr_all(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_dsnr_tensor_waveform(input_waveform, pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_dsnr_class(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
out = calculate_dsnr_tensor_waveform(input_waveform, pred_waveform, mean_waveform)
if isinstance(self.__metric_buffer[kwargs['metric']], list):
self.__metric_buffer[kwargs['metric']] = out
else:
self.__metric_buffer[kwargs['metric']] = concatenate((self.__metric_buffer[kwargs['metric']], out), dim=0)
def __determine_ptq_loss(self, input_waveform: Tensor, pred_waveform: Tensor, mean_waveform: Tensor, **kwargs) -> None:
"""if not hasattr(self.model, 'bit_config'):
raise NotImplementedError('PTQ Test is only available with elasticAI.creator Models or '
'model includes variable \"bit_config\" = [total_bitwidth, frac_bitwidth]')
else:"""
# --- Load model and make inference
model_ptq = quantize_model_fxp(self.model, self._ptq_level[0], self._ptq_level[1])
model_ptq.eval()
pred_waveform_ptq = model_ptq(input_waveform)[1]
# --- Calculate loss
if len(input_waveform) > 2:
loss = self.loss_fn(flatten(pred_waveform_ptq, 1), flatten(input_waveform, 1)).item() / self.total_batches_valid
else:
loss = self.loss_fn(pred_waveform_ptq, input_waveform).item() / self.total_batches_valid
# --- Saving results
if len(self.__metric_buffer[kwargs['metric']]):
self.__metric_buffer[kwargs['metric']][0] = self.__metric_buffer[kwargs['metric']][0] + loss
else:
self.__metric_buffer[kwargs['metric']].append(loss)
[docs]
def do_training(self, path2save='', metrics=()) -> dict:
"""Start model training incl. validation and custom-own metric calculation
Args:
path2save: Path for saving the results [Default: '' --> generate new folder]
metrics: List with strings of used metric [Default: empty]
Returns:
Dictionary with metrics from training (loss_train, loss_valid, own_metrics)
"""
self._init_train(path2save=path2save, addon='_AE')
if self._kfold_do:
self._logger.info(f"Starting Kfold cross validation training in {self.settings_train.num_kfold} steps")
metric_out = dict()
path2model = str()
path2model_init = join(self._path2save, f'model_ae_reset.pt')
save(self.model.state_dict(), path2model_init)
timestamp_start = datetime.now()
timestamp_string = timestamp_start.strftime('%H:%M:%S')
self._logger.info(f'\nTraining starts on {timestamp_string}')
self._logger.info("=====================================================================================")
self.__process_epoch_metrics_calculation(True, metrics)
for fold in np.arange(self.settings_train.num_kfold):
# --- Init fold
best_loss = np.array((1_000_000., 1_000_000.), dtype=float)
patience_counter = self.settings_train.patience
metric_fold = dict()
epoch_loss_train = list()
epoch_loss_valid = list()
self.model.load_state_dict(load(path2model_init, weights_only=False))
self._run_kfold = fold
if self._kfold_do:
self._logger.info(f'\nStarting with Fold #{fold}')
for epoch in range(0, self.settings_train.num_epochs):
if self.settings_train.deterministic_do:
self.deterministic_generator.manual_seed(self.settings_train.deterministic_seed + epoch)
loss_train = self.__do_training_epoch()
loss_valid = self.__do_valid_epoch(metrics)
epoch_loss_train.append(loss_train)
epoch_loss_valid.append(loss_valid)
self.__process_epoch_metrics_calculation(False, metrics)
self._logger.info(f'... results of epoch {epoch + 1}/{self.settings_train.num_epochs} '
f'[{(epoch + 1) / self.settings_train.num_epochs * 100:.2f} %]: '
f'train_loss = {loss_train:.5f},'
f'\tvalid_loss = {loss_valid:.5f},'
f'\tdelta_loss = {loss_train - loss_valid:.6f}')
# Tracking the best performance and saving the model
if loss_valid < best_loss[1]:
best_loss = [loss_train, loss_valid]
path2model = join(self._path2temp, f'model_ae_fold{fold:03d}_epoch{epoch:04d}.pt')
save(self.model, path2model)
patience_counter = self.settings_train.patience
else:
patience_counter -= 1
# Early Stopping
if patience_counter <= 0:
self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!")
break
copy(path2model, self._path2save)
self._save_train_results(best_loss[0], best_loss[1], 'Loss')
# --- Saving results
metric_fold.update({'loss_train': epoch_loss_train, 'loss_valid': epoch_loss_valid})
metric_fold.update(self.__metric_result)
metric_out.update({f"fold_{fold:03d}": metric_fold})
# --- Ending of all trainings phases
self._end_training_routine(timestamp_start)
metric_save = self._converting_tensor_to_numpy(metric_out)
np.save(f"{self._path2save}/metric_ae", metric_save, allow_pickle=True)
return metric_out
[docs]
def do_validation_after_training(self, do_ptq_valid: bool=False) -> dict:
"""Performing the validation with the best model after training for plotting and saving results"""
if cuda.is_available():
cuda.empty_cache()
# --- Do the Inference with Best Model
path2model = self.get_best_model('ae')[0]
if do_ptq_valid:
model_test = quantize_model_fxp(
model=load(path2model, weights_only=False),
total_bits=self._ptq_level[0],
frac_bits=self._ptq_level[1]
)
else:
model_test = load(path2model, weights_only=False)
self._logger.info("=================================================================")
self._logger.info(f"Do Validation with best model: {path2model}")
pred_model = randn(32, 1)
feat_model = randn(32, 1)
clus_orig_list = randn(32, 1)
data_orig_list = randn(32, 1)
first_cycle = True
model_test.eval()
for ite_cycle, vdata in enumerate(self.valid_loader[-1]):
feat, pred = model_test(vdata['in'].to(self.used_hw_dev))
if first_cycle:
feat_model = feat.detach().cpu()
pred_model = pred.detach().cpu()
clus_orig_list = vdata['class']
data_orig_list = vdata['in']
else:
feat_model = cat((feat_model, feat.detach().cpu()), dim=0)
pred_model = cat((pred_model, pred.detach().cpu()), dim=0)
clus_orig_list = cat((clus_orig_list, vdata['class']), dim=0)
data_orig_list = cat((data_orig_list, vdata['in']), dim=0)
first_cycle = False
# --- Preparing output
result_feat = feat_model.numpy()
result_pred = pred_model.numpy()
return self._getting_data_for_plotting(data_orig_list.numpy(), clus_orig_list.numpy(),
{'feat': result_feat, 'pred': result_pred}, addon='ae')