Source code for denspp.offline.dnn.models.spike_classifier
from torch import nn, Tensor, argmax, flatten
[docs]
class synthetic_cl_v1(nn.Module):
def __init__(self, input_size=32, output_size=5):
"""DL model for classifying neural spike activity (MLP)"""
super().__init__()
self.model_shape = (1, input_size)
# --- Settings of model
do_train_bias = True
do_train_batch = True
config_network = [input_size, 40, 32, 20, 12, output_size]
# --- Model Deployment
self.model = nn.Sequential()
for idx, layer_size in enumerate(config_network[1:], start=1):
self.model.add_module(f"linear_{idx:02d}",
nn.Linear(in_features=config_network[idx - 1], out_features=layer_size,
bias=do_train_bias))
self.model.add_module(f"batch1d_{idx:02d}",
nn.BatchNorm1d(num_features=layer_size, affine=do_train_batch))
if not idx == len(config_network) - 1:
self.model.add_module(f"act_{idx:02d}", nn.ReLU())
else:
# self.model.add_module(f"soft", nn.Softmax(dim=1))
pass
[docs]
def forward(self, x: Tensor) -> [Tensor, Tensor]:
x = flatten(x, start_dim=1)
prob = self.model(x)
return prob, argmax(prob, 1)