Source code for denspp.offline.dnn.handler.train_ae_cl

from copy import deepcopy
from denspp.offline.yaml_handler import YamlHandler
from denspp.offline.dnn.dnn_handler import ConfigMLPipeline
from denspp.offline.dnn.pytorch_config_data import SettingsDataset, DefaultSettingsDataset
from denspp.offline.dnn.pytorch_config_model import ConfigPytorch, DefaultSettingsTrainMSE, DefaultSettingsTrainCE
from denspp.offline.dnn.pytorch_pipeline import train_autoencoder_template, train_classifier_template
from denspp.offline.dnn.plots.plot_dnn import results_training
from denspp.offline.dnn.dataset.autoencoder import prepare_training as get_dataset_ae
from denspp.offline.dnn.dataset.autoencoder_class import prepare_training as get_dataset_cl


[docs] def do_train_ae_classifier(class_dataset, settings: ConfigMLPipeline, yaml_name_index: str= 'Config_ACL', model_ae_default_name: str='', model_cl_default_name: str='', used_dataset_name:str='quiroga') -> dict: """Training routine for Autoencoders and Classifier with Encoder after Autoencoder-Training Args: class_dataset: Class of custom-made SettingsDataset from src_dnn/call_dataset.py settings: Handler for configuring the routine selection for train deep neural networks yaml_name_index: Index of yaml file name model_ae_default_name: Default name of the autoencoder model model_cl_default_name: Default name of the classifier model used_dataset_name: Default name of the dataset for training [default: quiroga] Returns: Dictionary with metric results from Autoencoder and Classifier Training """ # --- Processing Step #0: Loading YAML files # --- Loading the YAML file: Dataset default_data = deepcopy(DefaultSettingsDataset) default_data.data_file_name = used_dataset_name config_data = YamlHandler( template=default_data, path=settings.get_path2config, file_name=f'{yaml_name_index}_Dataset' ).get_class(SettingsDataset) # --- Loading the YAML file: Autoencoder Model training default_train_ae = deepcopy(DefaultSettingsTrainMSE) default_train_ae.model_name = model_ae_default_name default_train_ae.custom_metrics = ['dsnr_all'] config_train_ae = YamlHandler( template=default_train_ae, path=settings.get_path2config, file_name=f'{yaml_name_index}_TrainAE' ).get_class(ConfigPytorch) # --- Loading the YAML file: Classifier Model training default_train_cl = deepcopy(DefaultSettingsTrainCE) default_train_cl.model_name = model_cl_default_name config_train_cl = YamlHandler( template=default_train_cl, path=settings.get_path2config, file_name=f'{yaml_name_index}_TrainCL' ).get_class(ConfigPytorch) del default_train_cl, default_train_ae, default_data # --- Processing Step #1.1: Loading dataset and Build Model dataset_ae = get_dataset_ae( rawdata=class_dataset(settings=config_data).load_dataset(), do_classification=False, mode_train_ae=settings.autoencoder_mode, noise_std=settings.autoencoder_noise_std ) if settings.autoencoder_feat_size: used_model_ae = config_train_ae.get_model(input_size=dataset_ae[0]['in'].size, output_size=settings.autoencoder_feat_size) else: used_model_ae = config_train_ae.get_model(input_size=dataset_ae[0]['in'].size) print("\n# ----------- Step #1: TRAINING AUTOENCODER") # --- Processing Step #1.2: Train Autoencoder and Plot Results metrics_ae, valid_data_ae, path2folder = train_autoencoder_template( config_ml=settings, config_data=config_data, config_train=config_train_ae, used_dataset=dataset_ae, used_model=used_model_ae, path2save='' ) if settings.do_plot: used_first_fold = [key for key in metrics_ae.keys()][0] results_training( path=path2folder, cl_dict=dataset_ae.get_dictionary, feat=valid_data_ae['feat'], yin=valid_data_ae['input'], ypred=valid_data_ae['pred'], ymean=dataset_ae.get_mean_waveforms, yclus=valid_data_ae['valid_clus'], snr=metrics_ae[used_first_fold]['dsnr_all'] ) del dataset_ae print("\n# ----------- Step #2: TRAINING CLASSIFIER") # --- Processing Step #2.1: Loading dataset and Build Model dataset_cl = get_dataset_cl( rawdata=class_dataset(settings=config_data).load_dataset(), path2model=path2folder, print_state=True ) num_feat = dataset_cl[0]['in'].shape[0] if not settings.autoencoder_feat_size else settings.autoencoder_feat_size used_model_cl = config_train_cl.get_model(input_size=num_feat, output_size=dataset_cl.get_cluster_num) # --- Processing Step #1.2: Train Classifier metrics_cl, valid_data_cl, _ = train_classifier_template( config_ml=settings, config_data=config_data, config_train=config_train_cl, used_dataset=dataset_cl, used_model=used_model_cl, path2save=path2folder ) del dataset_cl return {"path2model_ae": path2folder, 'metric_ae': metrics_ae, 'metrics_cl': metrics_cl}