Source code for denspp.offline.dnn.handler.train_cl

from copy import deepcopy
from denspp.offline.yaml_handler import YamlHandler
from denspp.offline.dnn.dnn_handler import ConfigMLPipeline
from denspp.offline.dnn.pytorch_config_data import SettingsDataset, DefaultSettingsDataset
from denspp.offline.dnn.pytorch_config_model import ConfigPytorch, DefaultSettingsTrainCE
from denspp.offline.dnn.pytorch_pipeline import train_classifier_template
from denspp.offline.dnn.dataset.classifier import prepare_training


[docs] def do_train_classifiers(class_dataset, settings: ConfigMLPipeline, yaml_name_index: str='Config_Neural', used_dataset_name: str='quiroga', used_model_name: str='') -> str: """Training routine for Classification DL models :param class_dataset: Class of custom-made SettingsDataset from src_dnn/call_dataset.py :param settings: Handler for configuring the routine selection for train deep neural networks :param yaml_name_index: Index of yaml file name :param used_dataset_name: Used dataset name :param used_model_name: Used model for DNN training :return: String with path to folder in which results are saved """ # --- Loading the YAML file: Dataset default_data = deepcopy(DefaultSettingsDataset) default_data.data_file_name = used_dataset_name config_data = YamlHandler( template=default_data, path=settings.get_path2config, file_name=f'{yaml_name_index}_Dataset' ).get_class(SettingsDataset) # --- Loading the YAML file: Model training default_train = deepcopy(DefaultSettingsTrainCE) default_train.model_name = used_model_name config_train = YamlHandler( template=default_train, path=settings.get_path2config, file_name=f'{yaml_name_index}_TrainCL' ).get_class(ConfigPytorch) del default_train, default_data # --- Loading Data, Build Model and Do Inference dataset = prepare_training( rawdata=class_dataset(settings=config_data).load_dataset() ) used_model = config_train.get_model(input_size=dataset[0]['in'].size, output_size=dataset.get_cluster_num) _, _, path2folder = train_classifier_template( config_ml=settings, config_data=config_data, config_train=config_train, used_dataset=dataset, used_model=used_model ) return path2folder