from dataclasses import dataclass
from datetime import datetime
from logging import Logger, getLogger
from pathlib import Path
from shutil import copy
import numpy as np
from torch import (
Tensor,
add,
cat,
concatenate,
cuda,
div,
inference_mode,
load,
randn,
save,
tensor,
zeros,
)
from denspp.offline import check_keylist_elements_any
from denspp.offline.dnn.data_config import DatasetFromFile, SettingsDataset
from denspp.offline.dnn.training.classifier_dataset import DatasetClassifier
from denspp.offline.dnn.training.ptq_help import quantize_model_fxp
from denspp.offline.metric.data_torch import (
calculate_fbeta,
calculate_number_true_predictions,
calculate_precision,
calculate_recall,
)
from .common_train import DataValidation, PyTorchHandler, SettingsPytorch
[docs]
@dataclass
class SettingsClassifier(SettingsPytorch):
"""Class for handling the PyTorch training/inference pipeline
Attributes:
model_name: String with the model name
patience: Integer value with number of epochs before early stopping
optimizer: String with PyTorch optimizer name
loss: String with method name for the loss function
deterministic_do: Boolean if deterministic training should be done
deterministic_seed: Integer with the seed for deterministic training
num_kfold: Integer value with applying k-fold cross validation
num_epochs: Integer value with number of epochs
batch_size: Integer value with batch size
data_split_ratio: Float value for splitting the input dataset between training and validation
data_do_shuffle: Boolean if data should be shuffled before training
custom_metrics: List with string of custom metrics to calculate during training
"""
pass
DefaultSettingsTrainingCE = SettingsClassifier(
model_name="",
patience=20,
optimizer="Adam",
loss="Cross Entropy",
num_kfold=1,
num_epochs=10,
batch_size=256,
data_do_shuffle=True,
data_split_ratio=0.2,
deterministic_do=False,
deterministic_seed=42,
custom_metrics=[],
)
[docs]
class TrainClassifier(PyTorchHandler):
_logger: Logger
_settings_train: SettingsClassifier
_train_loader: list
_valid_loader: list
def __init__(
self,
config_train: SettingsClassifier,
config_data: SettingsDataset,
do_train: bool = True,
) -> None:
"""Class for Handling Training of Classifiers
:param config_data: Settings for handling and loading the dataset (just for saving)
:param config_train: Settings for handling the PyTorch Trainings Routine of a Classifier
:param do_train: Do training of model otherwise only inference
:return: None
"""
PyTorchHandler.__init__(self, config_train, config_data, do_train)
self._logger = getLogger(__name__)
self.__metric_buffer = dict()
self.__metric_result = dict()
self._metric_methods = {
"accuracy": self.__determine_accuracy_per_class,
"precision": self.__determine_buffering_metric_calculation,
"recall": self.__determine_buffering_metric_calculation,
"fbeta": self.__determine_buffering_metric_calculation,
"ptq_acc": self.__determine_ptq_acc,
}
[docs]
def load_dataset(self, dataset: DatasetFromFile) -> None:
"""Loading the loaded dataset and transform it into right dataloader
:param dataset: Dataclass with dataset loaded from extern
:return: None
"""
dataset0 = DatasetClassifier(
dataset=dataset,
)
self._prepare_dataset_for_training(data_set=dataset0, num_workers=0)
def __do_training_epoch(self) -> tuple[float, float]:
"""Do training during epoch of training
Return:
Floating value of training loss and accuracy of used epoch
"""
train_loss = 0.0
total_batches = 0
total_correct = 0
total_samples = 0
self._model.train(True)
for tdata in self._train_loader[self._run_kfold]:
self._optimizer.zero_grad()
tdata_out = tdata["out"].to(self._used_hw_dev)
pred_cl, dec_cl = self._model(tdata["in"].to(self._used_hw_dev))
loss = self._loss_fn(pred_cl, tdata_out)
loss.backward()
self._optimizer.step()
train_loss += loss.item()
total_batches += 1
total_correct += calculate_number_true_predictions(dec_cl, tdata_out)
total_samples += len(tdata["in"])
train_acc = float(int(total_correct) / total_samples)
train_loss = float(train_loss / total_batches)
return train_loss, train_acc
def __do_valid_epoch(self, epoch_custom_metrics: list) -> tuple[float, float]:
"""Do validation during epoch of training
Args:
epoch_custom_metrics: List with entries of custom-made metric calculations
Return:
Floating value of validation loss and validation accuracy of used epoch
"""
valid_loss = 0.0
total_batches = 0
total_correct = 0
total_samples = 0
self._model.eval()
with inference_mode():
for vdata in self._valid_loader[self._run_kfold]:
# --- Validation phase of model
pred_cl, dec_cl = self._model(vdata["in"].to(self._used_hw_dev))
true_cl = vdata["out"].to(self._used_hw_dev)
valid_loss += self._loss_fn(pred_cl, true_cl).item()
total_batches += 1
total_correct += calculate_number_true_predictions(dec_cl, true_cl)
total_samples += len(vdata["in"])
# --- Calculating custom made metrics
for metric_used in epoch_custom_metrics:
self._determine_epoch_metrics(metric_used)(
dec_cl,
true_cl,
metric=metric_used,
frame=vdata["in"].to(self._used_hw_dev),
)
valid_acc = float(int(total_correct) / total_samples)
valid_loss = float(valid_loss / total_batches)
return valid_loss, valid_acc
def __process_epoch_metrics_calculation(self, init_phase: bool, custom_made_metrics: list) -> None:
"""Function for preparing the custom-made metric calculation
Args:
init_phase: Boolean decision if processing part is in init (True) or in post-training phase (False)
custom_made_metrics:List with custom metrics for calculation during validation phase
Return:
None
"""
assert check_keylist_elements_any(
keylist=custom_made_metrics, elements=self.get_epoch_metric_custom_methods
), (
f"Used custom made metrics not found in: {self.get_epoch_metric_custom_methods} - Please adapt in settings!"
)
# --- Init phase for generating empty data structure
if init_phase:
for key0 in custom_made_metrics:
self.__metric_result.update({key0: list()})
match key0:
case "accuracy":
self.__metric_buffer.update(
{
key0: [
zeros((len(self._cell_classes),)),
zeros((len(self._cell_classes),)),
]
}
)
case "precision":
self.__metric_buffer.update({key0: [[], []]})
case "recall":
self.__metric_buffer.update({key0: [[], []]})
case "fbeta":
self.__metric_buffer.update({key0: [[], []]})
case "ptq_acc":
self.__metric_buffer.update(
{
key0: [
zeros(
1,
),
zeros(
1,
),
]
}
)
# --- Processing results
else:
for key0 in self.__metric_buffer.keys():
match key0:
case "accuracy":
self.__metric_result[key0].append(
div(
self.__metric_buffer[key0][0],
self.__metric_buffer[key0][1],
)
)
self.__metric_buffer.update(
{
key0: [
zeros((len(self._cell_classes),)),
zeros((len(self._cell_classes),)),
]
}
)
case "precision":
out = self._separate_classes_from_label(
self.__metric_buffer[key0][0],
self.__metric_buffer[key0][1],
key0,
calculate_precision,
)
self.__metric_result[key0].append(out[0])
self.__metric_buffer.update({key0: [[], []]})
case "recall":
out = self._separate_classes_from_label(
self.__metric_buffer[key0][0],
self.__metric_buffer[key0][1],
key0,
calculate_recall,
)
self.__metric_result[key0].append(out[0])
self.__metric_buffer.update({key0: [[], []]})
case "fbeta":
out = self._separate_classes_from_label(
self.__metric_buffer[key0][0],
self.__metric_buffer[key0][1],
key0,
calculate_fbeta,
)
self.__metric_result[key0].append(out[0])
self.__metric_buffer.update({key0: [[], []]})
case "ptq_acc":
self.__metric_result[key0].append(
div(
self.__metric_buffer[key0][0],
self.__metric_buffer[key0][1],
)
)
self.__metric_buffer.update(
{
key0: [
zeros(
1,
),
zeros(
1,
),
]
}
)
def __determine_accuracy_per_class(self, pred: Tensor, true: Tensor, **kwargs) -> None:
out = self._separate_classes_from_label(
pred, true, kwargs["metric"], calculate_number_true_predictions
)
self.__metric_buffer[kwargs["metric"]][0] = add(self.__metric_buffer[kwargs["metric"]][0], out[0])
self.__metric_buffer[kwargs["metric"]][1] = add(self.__metric_buffer[kwargs["metric"]][1], out[1])
def __determine_buffering_metric_calculation(self, pred: Tensor, true: Tensor, **kwargs) -> None:
if len(self.__metric_buffer[kwargs["metric"]][0]) == 0:
self.__metric_buffer[kwargs["metric"]][0] = true
self.__metric_buffer[kwargs["metric"]][1] = pred
else:
self.__metric_buffer[kwargs["metric"]][0] = concatenate(
(self.__metric_buffer[kwargs["metric"]][0], true), dim=0
)
self.__metric_buffer[kwargs["metric"]][1] = concatenate(
(self.__metric_buffer[kwargs["metric"]][1], pred), dim=0
)
def __determine_ptq_acc(self, pred: Tensor, true: Tensor, **kwargs) -> None:
model_ptq = quantize_model_fxp(
model=self._model,
total_bits=self._ptq_level[0],
frac_bits=self._ptq_level[1],
)
model_ptq.eval()
pred_cl, dec_cl = model_ptq(kwargs["frame"])
num_true = calculate_number_true_predictions(dec_cl, true)
a = tensor([num_true])
b = tensor([kwargs["frame"].shape[0]])
if self.__metric_buffer[kwargs["metric"]][0].size == 1:
self.__metric_buffer[kwargs["metric"]][0] = a
self.__metric_buffer[kwargs["metric"]][1] = b
else:
self.__metric_buffer[kwargs["metric"]][0] = concatenate(
(self.__metric_buffer[kwargs["metric"]][0], a), dim=0
)
self.__metric_buffer[kwargs["metric"]][1] = concatenate(
(self.__metric_buffer[kwargs["metric"]][1], b), dim=0
)
[docs]
def do_training(self, path2save=Path(".")) -> dict:
"""Start model training incl. validation and custom-own metric calculation
Args:
path2save: Path for saving the results [Default: '' --> generate new folder]
Returns:
Dictionary with metrics from training (loss_train, loss_valid, own_metrics)
"""
metrics = self._settings_train.custom_metrics
self._init_train(path2save=path2save, addon="_cl")
if self._kfold_do:
self._logger.info(
f"Starting Kfold cross validation training in {self._settings_train.num_kfold} steps"
)
path2model = str()
path2model_init = self._path2save / "model_cl_reset.pt"
save(self._model.state_dict(), path2model_init)
timestamp_start = datetime.now()
timestamp_string = timestamp_start.strftime("%H:%M:%S")
self._logger.info(f"Training starts on {timestamp_string}")
self._logger.info(
"====================================================================================="
)
metric_out = dict()
self.__process_epoch_metrics_calculation(True, metrics)
for fold in np.arange(self._settings_train.num_kfold):
# --- Init fold
best_loss = [1e6, 1e6]
best_acc = [0.0, 0.0]
patience_counter = self._settings_train.patience
epoch_train_acc = list()
epoch_valid_acc = list()
epoch_train_loss = list()
epoch_valid_loss = list()
self._model.load_state_dict(load(path2model_init, weights_only=False))
self._run_kfold = fold
if self._kfold_do:
self._logger.info(f"Starting with Fold #{fold}")
for epoch in range(0, self._settings_train.num_epochs):
if self._settings_train.deterministic_do:
self._deterministic_generator.manual_seed(
self._settings_train.deterministic_seed + epoch
)
train_loss, train_acc = self.__do_training_epoch()
valid_loss, valid_acc = self.__do_valid_epoch(metrics)
self._logger.info(
f"... results of epoch {epoch + 1}/{self._settings_train.num_epochs} "
f"[{(epoch + 1) / self._settings_train.num_epochs * 100:.2f} %]: "
f"train_loss = {train_loss:.5f}, delta_loss = {train_loss - valid_loss:.5f}, "
f"train_acc = {100 * train_acc:.4f} %, delta_acc = {100 * (train_acc - valid_acc):.4f} %"
)
# Saving metrics after each epoch
epoch_train_acc.append(train_acc)
epoch_train_loss.append(train_loss)
epoch_valid_acc.append(valid_acc)
epoch_valid_loss.append(valid_loss)
self.__process_epoch_metrics_calculation(False, metrics)
# Tracking the best performance and saving the model
if valid_loss < best_loss[1]:
best_loss = [train_loss, valid_loss]
best_acc = [train_acc, valid_acc]
path2model = self._path2temp / f"model_cl_fold{fold:03d}_epoch{epoch:04d}.pt"
save(self._model, path2model)
patience_counter = self._settings_train.patience
else:
patience_counter -= 1
# Early Stopping
if patience_counter <= 0:
self._logger.info(f"... training stopped due to no change after {epoch + 1} epochs!")
break
copy(path2model, self._path2save)
self._save_train_results(best_loss[0], best_loss[1], "Loss")
self._save_train_results(best_acc[0], best_acc[1], "Acc.")
# --- Saving metrics after each fold
metric_fold = {
"acc_train": epoch_train_acc,
"acc_valid": epoch_valid_acc,
"loss_train": epoch_train_loss,
"loss_valid": epoch_valid_loss,
}
metric_fold.update(self.__metric_result)
metric_out.update({f"fold_{fold:03d}": metric_fold})
# --- Ending of all trainings phases
self._end_training_routine(timestamp_start)
return self._converting_tensor_to_numpy(metric_out)
[docs]
def do_post_training_validation(self, do_ptq: bool = False) -> DataValidation:
"""Performing the post-training validation with the best model
:param do_ptq: Boolean for activating post training quantization during post-training validation
:return: Dataclass with results from validation phase
"""
if cuda.is_available():
cuda.empty_cache()
# --- Do the Inference with Best Model
overview_models = self.get_best_model("cl")
if len(overview_models) == 0:
raise RuntimeError(f"No models found on {self._path2save} - Please start training!")
path2model = overview_models[0]
if do_ptq:
model_test = quantize_model_fxp(
model=load(path2model, weights_only=False),
total_bits=self._ptq_level[0],
frac_bits=self._ptq_level[1],
)
else:
model_test = load(path2model, weights_only=False)
self._logger.info("=================================================================")
self._logger.info(f"Do Validation with best model: {path2model}")
clus_pred_list = randn(32, 1)
clus_orig_list = randn(32, 1)
data_orig_list = randn(32, 1)
first_cycle = True
model_test.eval()
for vdata in self._valid_loader[-1]:
_, clus_pred = model_test(vdata["in"].to(self._used_hw_dev))
if first_cycle:
clus_pred_list = clus_pred.detach().cpu()
clus_orig_list = vdata["out"]
data_orig_list = vdata["in"]
else:
clus_pred_list = cat((clus_pred_list, clus_pred.detach().cpu()), dim=0)
clus_orig_list = cat((clus_orig_list, vdata["out"]), dim=0)
data_orig_list = cat((data_orig_list, vdata["in"]), dim=0)
first_cycle = False
# --- Preparing output
data_out = self._getting_data_for_plotting(
valid_input=data_orig_list.numpy(),
valid_label=clus_orig_list.numpy(),
addon="cl",
)
data_out.output = clus_pred_list.numpy()
return data_out