Source code for denspp.offline.dnn.training.classifier_dataset
import numpy as np
from torch import is_tensor
from torch.utils.data import Dataset
from denspp.offline.dnn import DatasetFromFile
[docs]
class DatasetClassifier(Dataset):
def __init__(self, dataset: DatasetFromFile):
"""Dataset Loader for Classification Tasks
:param dataset: Dataclass DatasetFromFile with ['data', 'label', 'names', 'mean']
:return: Dataclass Dataset used in PyTorch Training Routine
"""
self.__data = np.array(dataset.data, dtype=np.float32)
self.__label = np.array(dataset.label, dtype=np.uint8)
self.__name = dataset.dict if isinstance(dataset.dict, list) else []
[docs]
def __len__(self):
return self.__data.shape[0]
[docs]
def __getitem__(self, idx):
if is_tensor(idx):
idx = idx.tolist()
return {"in": self.__data[idx, :], "out": self.__label[idx]}
@property
def get_dictionary(self) -> list:
"""Getting the dictionary of labeled dataset"""
return self.__name
@property
def get_topology_type(self) -> str:
"""Getting the information of used deep learning topology"""
return "Classifier"
@property
def get_cluster_num(self) -> int:
"""Getting the number of clusters"""
return int(np.unique(self.__label).size)