Source code for denspp.offline.dnn.plots.plot_classifier

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, precision_recall_fscore_support

from denspp.offline.plot_helper import (
    save_figure,
    cm_to_inch
)


[docs] def plot_confusion(true_labels: list | np.ndarray, pred_labels: list | np.ndarray, plotting: str="class", show_accuracy: bool=False, cl_dict=None, path2save: str="", name_addon: str="", timestamps_result: list=(), timestamps_f1: list=(), timestamps_accuracy: list=(), show_plots: bool=False) -> None: """This function is designed to generate and display confusion matrices for classification results and timestamp-based comparisons. The function takes various parameters, including true and predicted labels, as well as additional information such as timestamps and plotting preferences. The confusion matrix for classification is displayed using the ConfusionMatrixDisplay class from scikit-learn. It also calculates and prints precision, recall, and F-beta score for the classification case. The timestamp-based comparison is visualized using a heatmap. The resulting plots can be saved to a specified path if provided. Args: true_labels: List or numpy array containing true class labels. pred_labels: List or numpy array containing predicted class labels. timestamps_result: Resulting array for timestamp comparison (for plotting timestamps). timestamps_f1: F1-score for timestamp comparison. timestamps_accuracy: Accuracy for timestamp comparison. plotting: Specifies the type of plotting to perform ("class", "timestamps", or "both"). show_accuracy: Boolean indicating whether to display accuracy in timestamp plots. cl_dict: Dictionary mapping class indices to labels. path2save: Path to save the generated plots. name_addon: Additional name for saved plots. show_plots: Command for showing plots in the end [Default: False] Returns: The function generates and displays confusion matrices and timestamp-based plots. If path2save is provided, the plots are saved to the specified path. """ dict_available = False if plotting == "class" or plotting == "both": """Plotting the Confusion Matrix""" if isinstance(cl_dict, np.ndarray): cl_used = cl_dict.tolist() else: cl_used = cl_dict if isinstance(cl_dict, list): dict_available = not len(cl_dict) == 0 elif isinstance(cl_dict, dict): dict_available = not len(cl_dict) == 0 else: dict_available = False max_key_length = 0 precision, recall, fbeta, _ = precision_recall_fscore_support(true_labels, pred_labels, average='weighted') if plotting == "timestamps" or plotting == "both": # --- Plotting the results for the timestamp comparison plt.imshow(timestamps_result, cmap=plt.cm.Blues, interpolation='nearest') for i in range(timestamps_result.shape[0]): for j in range(timestamps_result.shape[1]): plt.text(j, i, f'{timestamps_result[i, j]:.2f}', ha='center', va='center', color='white') xtick_labels = ['true', 'false'] plt.xticks(np.arange(2), xtick_labels) ytick_labels = ['positive', 'negative'] plt.yticks(np.arange(2), ytick_labels) if show_accuracy: plt.title(f'F1-Score = {timestamps_f1:.4f} - Accuracy = {timestamps_accuracy:.4f}') else: plt.title(f'F1-Score = {timestamps_f1:.4f}') elif dict_available: for keys in cl_used: max_key_length = len(keys) if len(keys) > max_key_length else max_key_length do_xticks_vertical = bool(max_key_length > 5) and np.unique(true_labels).size > 3 use_cl_dict = list() if isinstance(cl_dict, dict): for key in cl_dict.keys(): use_cl_dict.append(key) else: for idx in np.unique(true_labels): use_cl_dict.append(cl_used[int(idx)]) cmp = ConfusionMatrixDisplay.from_predictions( y_true=true_labels, y_pred=pred_labels, normalize='pred', display_labels=use_cl_dict ) else: do_xticks_vertical = False cmp = ConfusionMatrixDisplay.from_predictions( y_true=true_labels, y_pred=pred_labels, normalize='pred', ) # --- Plotting the results of the class confusion matrix ax = plt.subplots(figsize=(cm_to_inch(12), cm_to_inch(12.5)))[1] cmp.plot(ax=ax, colorbar=False, values_format='.3f', text_kw={'fontsize': 8}, cmap=plt.cm.Blues, xticks_rotation=('vertical' if do_xticks_vertical else 'horizontal')) cmp.ax_.set_title(f'Precision = {100*precision:.2f}%, Recall = {100*recall:.2f}%') plt.tight_layout() # --- saving if path2save: save_figure(plt, path2save, f"confusion_matrix{name_addon}") if show_plots: plt.show(block=True)