Source code for denspp.offline.dnn.handler.train_ae
from copy import deepcopy
from denspp.offline.yaml_handler import YamlHandler
from denspp.offline.dnn.dnn_handler import ConfigMLPipeline
from denspp.offline.dnn.pytorch_config_data import SettingsDataset, DefaultSettingsDataset
from denspp.offline.dnn.pytorch_config_model import ConfigPytorch, DefaultSettingsTrainMSE
from denspp.offline.dnn.pytorch_pipeline import train_autoencoder_template
from denspp.offline.dnn.plots.plot_dnn import results_training
from denspp.offline.dnn.dataset.autoencoder import prepare_training
[docs]
def do_train_autoencoder(class_dataset, settings: ConfigMLPipeline, yaml_name_index: str= 'Config_AE',
used_dataset_name: str='quiroga', model_default_name: str='') -> [dict, dict]:
"""Training routine for Autoencoders (e.g. in neural Applications for Spike Frames)
Args:
class_dataset: Class of custom-made SettingsDataset from src_dnn/call_dataset.py
settings: Handler for configuring the routine selection for train deep neural networks
yaml_name_index: Index of yaml file name
model_default_name: Optional name for the model to load
used_dataset_name: Default name of the dataset for training [default: quiroga]
Returns:
Dictionaries with results from training [metrics, validation data]
"""
# --- Loading the YAML file: Dataset
default_data = deepcopy(DefaultSettingsDataset)
default_data.data_file_name = used_dataset_name
config_data = YamlHandler(
template=default_data,
path=settings.get_path2config,
file_name=f'{yaml_name_index}_Dataset'
).get_class(SettingsDataset)
# --- Loading the YAML file: Model training
default_train = deepcopy(DefaultSettingsTrainMSE)
default_train.model_name = model_default_name
default_train.custom_metrics = ['dsnr_all']
config_train = YamlHandler(
template=default_train,
path=settings.get_path2config,
file_name=f'{yaml_name_index}_TrainAE'
).get_class(ConfigPytorch)
del default_train, default_data
# --- Loading Data, Build Model and Do Training
dataset = prepare_training(
rawdata=class_dataset(settings=config_data).load_dataset(),
do_classification=False,
mode_train_ae=settings.autoencoder_mode,
noise_std=settings.autoencoder_noise_std,
print_state=True
)
if settings.autoencoder_feat_size:
used_model = config_train.get_model(input_size=dataset[0]['in'].size, output_size=settings.autoencoder_feat_size)
else:
used_model = config_train.get_model(input_size=dataset[0]['in'].size, output_size=dataset[0]['in'].size)
metrics, data_result, path2folder = train_autoencoder_template(
config_ml=settings, config_data=config_data, config_train=config_train,
used_dataset=dataset, used_model=used_model
)
# --- Plotting
if settings.do_plot:
used_first_fold = [key for key in metrics.keys()][0]
results_training(
path=path2folder, cl_dict=data_result['cl_dict'], feat=data_result['feat'],
yin=data_result['input'], ypred=data_result['pred'], ymean=dataset.get_mean_waveforms,
yclus=data_result['valid_clus'], snr=metrics[used_first_fold]['dsnr_all'],
show_plot=settings.do_block
)
return metrics, data_result