Source code for denspp.offline.metric.data_torch
from sklearn.metrics import precision_recall_fscore_support
from torch import Tensor, tensor, sum, eq, ne
[docs]
def calculate_number_true_predictions(pred: Tensor, true: Tensor) -> Tensor:
"""Function for determining the true predicted values
Args:
pred: Tensor with predicted values from model
true: Tensor with true labels from dataset
Return
Tensor with metric
"""
assert pred.shape == true.shape, "Dimension / shape mismatch"
return sum(eq(pred, true))
[docs]
def calculate_number_false_predictions(pred: Tensor, true: Tensor) -> Tensor:
"""Function for determining the true predicted values
Args:
pred: Tensor with predicted values from model
true: Tensor with true labels from dataset
Return
Tensor with metric
"""
assert pred.shape == true.shape, "Dimension / shape mismatch"
return sum(ne(pred, true))
[docs]
def calculate_precision(pred: Tensor, true: Tensor) -> Tensor:
"""Function for determining the precision metric
Args:
pred: Tensor with predicted values from model
true: Tensor with true labels from dataset
Return
Tensor with metrics [precision]
"""
assert pred.shape == true.shape, "Dimension / shape mismatch"
return tensor(precision_recall_fscore_support(true, pred, average="micro", warn_for=tuple())[0])
[docs]
def calculate_recall(pred: Tensor, true: Tensor) -> Tensor:
"""Function for determining the precision metric
Args:
pred: Tensor with predicted values from model
true: Tensor with true labels from dataset
Return
Tensor with metrics [precision]
"""
assert pred.shape == true.shape, "Dimension / shape mismatch"
return tensor(precision_recall_fscore_support(true, pred, average="micro", warn_for=tuple())[1])
[docs]
def calculate_fbeta(pred: Tensor, true: Tensor, beta: float=1.0) -> Tensor:
"""Function for determining the precision metric
Args:
pred: Tensor with predicted values from model
true: Tensor with true labels from dataset
beta: Beta value for getting Fbeta metric
Return
Tensor with metrics [precision]
"""
assert pred.shape == true.shape, "Dimension / shape mismatch"
return tensor(precision_recall_fscore_support(true, pred, beta=beta, average="micro", warn_for=tuple())[2])