Source code for denspp.offline.dnn.pytorch_pipeline

import matplotlib.pyplot as plt
from denspp.offline.dnn.plots.plot_metric import plot_confusion, plot_loss, plot_statistic_data
from denspp.offline.dnn.dnn_handler import ConfigMLPipeline
from denspp.offline.dnn.pytorch_config_data import SettingsDataset
from denspp.offline.dnn.pytorch_config_model import ConfigPytorch
from denspp.offline.dnn.pytorch.classifier import TrainClassifier
from denspp.offline.dnn.pytorch.autoencoder import TrainAutoencoder


[docs] def train_classifier_template(config_ml: ConfigMLPipeline, config_data: SettingsDataset, config_train: ConfigPytorch, used_dataset, used_model, path2save: str='', ptq_quant_lvl: list = (12, 11)) -> [dict, dict, str]: """Template for training DL classifiers using PyTorch (incl. plotting) Args: config_ml: Settings for handling the ML Pipeline config_data: Settings for handling and loading the dataset (just for saving) config_train: Settings for handling the PyTorch Trainings Routine used_dataset: Used custom-made DataLoader with data set used_model: Used custom-made PyTorch DL model path2save: Path for saving the results [Default: '' --> generate new subfolder in runs ptq_quant_lvl: Quantization level for PTQ [total bitwidth, frac bitwidth] Returns: Dictionaries with results from training [metrics, validation data] + String to path for saving plots """ # ---Processing Step #1: Preparing Trainings Handler, Build Model train_handler = TrainClassifier(config_train=config_train, config_data=config_data, do_train=True) train_handler.load_model(model=used_model) train_handler.load_data(data_set=used_dataset) train_handler.define_ptq_level(ptq_quant_lvl[0], ptq_quant_lvl[1]) # --- Processing Step #2: Do Training and Validation metrics = train_handler.do_training(path2save=path2save, metrics=config_train.custom_metrics) path2folder = train_handler.get_saving_path() data_result = train_handler.do_validation_after_training() # --- Processing Step #3: Plotting if config_ml.do_plot: plt.close('all') used_first_fold = [key for key in metrics.keys()][0] plot_loss(metrics[used_first_fold]['acc_train'], metrics[used_first_fold]['acc_valid'], type='Acc.', path2save=path2folder) plot_loss(metrics[used_first_fold]['loss_train'], metrics[used_first_fold]['loss_valid'], type=f'{config_train.loss} (CL)', path2save=path2folder) plot_confusion(data_result['valid_clus'], data_result['yclus'], path2save=path2folder, cl_dict=used_dataset.get_dictionary) plot_statistic_data(data_result['train_clus'], data_result['valid_clus'], path2save=path2folder, cl_dict=used_dataset.get_dictionary, show_plot=config_ml.do_block) # --- Output return metrics, data_result, path2folder
[docs] def train_autoencoder_template(config_ml: ConfigMLPipeline, config_data: SettingsDataset, config_train: ConfigPytorch, used_dataset, used_model, path2save: str='', ptq_quant_lvl: list = (12, 8)) -> [dict, dict, str]: """Template for training DL classifiers using PyTorch (incl. plotting) Args: config_ml: Settings for handling the ML Pipeline config_data: Settings for handling and loading the dataset (just for saving) config_train: Settings for handling the PyTorch Trainings Routine used_dataset: Used custom-made DataLoader with data set used_model: Used custom-made PyTorch DL model path2save: Path for saving the results [Default: '' --> generate new subfolder in runs] ptq_quant_lvl: Quantization level for PTQ [total bitwidth, frac bitwidth] Returns: Dictionaries with results from training [metrics, validation data] + String to path for saving plots """ # ---Processing Step #1: Preparing Trainings Handler, Build Model train_handler = TrainAutoencoder(config_train=config_train, config_data=config_data, do_train=True) train_handler.load_model(model=used_model) train_handler.load_data(data_set=used_dataset) train_handler.define_ptq_level(ptq_quant_lvl[0], ptq_quant_lvl[1]) # --- Processing Step #2: Do Training and Validation metrics = train_handler.do_training(path2save=path2save, metrics=config_train.custom_metrics) path2folder = train_handler.get_saving_path() data_result = train_handler.do_validation_after_training() # --- Processing Step #3: Plotting if config_ml.do_plot: plt.close('all') used_first_fold = [key for key in metrics.keys()][0] plot_loss(loss_train=metrics[used_first_fold]['loss_train'], loss_valid=metrics[used_first_fold]['loss_valid'], type=config_train.loss, path2save=path2folder) plot_statistic_data(train_cl=data_result['train_clus'], valid_cl=data_result['valid_clus'], path2save=path2folder, cl_dict=used_dataset.get_dictionary) # --- Output return metrics, data_result, path2folder