Source code for denspp.offline.dnn.pytorch_config_model

from dataclasses import dataclass
from typing import Any
from torch import optim, nn
from copy import deepcopy
from denspp.offline.dnn.model_library import ModelLibrary


[docs] @dataclass class ConfigPytorch: """Class for handling the PyTorch training/inference pipeline Attributes: model_name: String with the model name patience: Integer value with number of epochs before early stopping optimizer: String with PyTorch optimizer name loss: String with method name for the loss function deterministic_do: Boolean if deterministic training should be done deterministic_seed: Integer with the seed for deterministic training num_kfold: Integer value with applying k-fold cross validation num_epochs: Integer value with number of epochs batch_size: Integer value with batch size data_split_ratio: Float value for splitting the input dataset between training and validation data_do_shuffle: Boolean if data should be shuffled before training custom_metrics: List with string of custom metrics to calculate during training """ model_name: str patience: int optimizer: str loss: str deterministic_do: bool deterministic_seed: int num_kfold: int num_epochs: int batch_size: int data_split_ratio: float data_do_shuffle: bool custom_metrics: list
[docs] @staticmethod def get_model_overview(print_overview: bool=False, index: str='') -> None: """Function for getting an overview of existing models inside library""" models_bib = ModelLibrary().get_registry() models_bib.get_library_overview(index, do_print=print_overview)
[docs] def get_loss_func(self) -> Any: """Getting the loss function""" match self.loss: case 'L1': loss_func = nn.L1Loss case 'MSE': loss_func = nn.MSELoss() case 'Cross Entropy': loss_func = nn.CrossEntropyLoss() case 'Cosine Similarity': loss_func = nn.CosineSimilarity() case _: raise NotImplementedError("Loss function unknown! - Please implement or check!") return loss_func
[docs] def load_optimizer(self, model, learn_rate: float=0.1) -> Any: """Loading the optimizer function""" match self.optimizer: case 'Adam': optim_func = optim.Adam(model.parameters()) case 'SGD': optim_func = optim.SGD(model.parameters(), lr=learn_rate) case _: raise NotImplementedError("Optimizer function unknown! - Please implement or check!") return optim_func
[docs] def get_model(self, *args, **kwargs): """Function for loading the model to train""" models_bib = ModelLibrary().get_registry() if not self.model_name: models_bib.get_library_overview(do_print=True) raise NotImplementedError("Please select one model above and type-in the name into yaml file") else: if models_bib.check_module_available(self.model_name): used_model = deepcopy(models_bib.build(self.model_name, *args, **kwargs)) return used_model else: ovr = models_bib.get_library_overview(do_print=True) raise NotImplementedError(f"Model is not available - Please check again!")
DefaultSettingsTrainMSE = ConfigPytorch( model_name='', patience=20, optimizer='Adam', loss='MSE', deterministic_do=False, deterministic_seed=42, num_kfold=1, num_epochs=10, batch_size=256, data_do_shuffle=True, data_split_ratio=0.2, custom_metrics=[] ) DefaultSettingsTrainCE = ConfigPytorch( model_name='', patience=20, optimizer='Adam', loss='Cross Entropy', num_kfold=1, num_epochs=10, batch_size=256, data_do_shuffle=True, data_split_ratio=0.2, deterministic_do=False, deterministic_seed=42, custom_metrics=[] )