Source code for denspp.offline.dnn.models.autoencoder_class
from torch import nn, Tensor, argmax
[docs]
class synthetic_ae_cl_v1(nn.Module):
"""Classification model of autoencoder output"""
def __init__(self, input_size=6, output_size=5):
super().__init__()
self.model_shape = (1, input_size)
lin_size = [input_size, 16, 12, output_size]
lin_drop = [0.0, 0.0]
do_train_bias = True
self.classifier = nn.Sequential(
nn.Dropout(0.0),
nn.Linear(lin_size[0], lin_size[1]),
nn.BatchNorm1d(lin_size[1], affine=do_train_bias),
nn.ReLU(),
nn.Dropout(lin_drop[0]),
nn.Linear(lin_size[1], lin_size[2]),
nn.BatchNorm1d(lin_size[2], affine=do_train_bias),
nn.ReLU(),
nn.Dropout(lin_drop[1]),
nn.Linear(lin_size[2], lin_size[3]),
nn.BatchNorm1d(lin_size[3], affine=do_train_bias),
# nn.Softmax(dim=1)
)
[docs]
def forward(self, x: Tensor) -> [Tensor, Tensor]:
val = self.classifier(x)
return val, argmax(val, dim=1)