"""Implementation of AUROC metric based on TorchMetrics."""# Copyright (C) 2022 Intel Corporation# SPDX-License-Identifier: Apache-2.0fromtypingimportTupleimporttorchfrommatplotlib.figureimportFigurefromtorchimportTensorfromtorchmetricsimportROCfromtorchmetrics.functionalimportaucfrom.plotting_utilsimportplot_figure
[docs]classAUROC(ROC):"""Area under the ROC curve."""
[docs]defcompute(self)->Tensor:"""First compute ROC curve, then compute area under the curve. Returns: Tensor: Value of the AUROC metric """tpr:Tensorfpr:Tensorfpr,tpr=self._compute()# TODO: use stable sort after upgrading to pytorch 1.9.x (https://github.com/openvinotoolkit/anomalib/issues/92)ifnot(torch.all(fpr.diff()<=0)ortorch.all(fpr.diff()>=0)):returnauc(fpr,tpr,reorder=True)# only reorder if fpr is not increasing or decreasingreturnauc(fpr,tpr)
[docs]defupdate(self,preds:Tensor,target:Tensor)->None:# type: ignore"""Update state with new values. Need to flatten new values as ROC expects them in this format for binary classification. Args: preds (Tensor): predictions of the model target (Tensor): ground truth targets """super().update(preds.flatten(),target.flatten())
[docs]def_compute(self)->Tuple[Tensor,Tensor]:"""Compute fpr/tpr value pairs. Returns: Tuple containing Tensors for fpr and tpr """tpr:Tensorfpr:Tensorfpr,tpr,_thresholds=super().compute()return(fpr,tpr)
[docs]defgenerate_figure(self)->Tuple[Figure,str]:"""Generate a figure containing the ROC curve, the baseline and the AUROC. Returns: Tuple[Figure, str]: Tuple containing both the figure and the figure title to be used for logging """fpr,tpr=self._compute()auroc=self.compute()xlim=(0.0,1.0)ylim=(0.0,1.0)xlabel="False Positive Rate"ylabel="True Positive Rate"loc="lower right"title="ROC"fig,axis=plot_figure(fpr,tpr,auroc,xlim,ylim,xlabel,ylabel,loc,title)axis.plot([0,1],[0,1],color="navy",lw=2,linestyle="--",figure=fig,)returnfig,title