Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ignite/metrics/matthews_corrcoef.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from collections.abc import Callable

import torch

Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
):
try:
Expand Down
35 changes: 18 additions & 17 deletions ignite/metrics/mean_average_precision.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from typing import Callable, List, Optional, Sequence, Tuple, Union, cast
from collections.abc import Callable, Sequence
from typing import cast

import torch
from typing_extensions import Literal
Expand All @@ -14,8 +15,8 @@
class _BaseAveragePrecision:
def __init__(
self,
rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None,
class_mean: Optional[Literal["micro", "macro", "weighted"]] = "macro",
rec_thresholds: Optional[Sequence[float] | torch.Tensor] = None,
class_mean: Literal["micro", "macro", "weighted"] | None = "macro",
) -> None:
r"""Base class for Average Precision metric.

Expand Down Expand Up @@ -56,15 +57,15 @@ def __init__(
computes macro precision which is unweighted mean of AP computed across classes/labels. Default.
"""
if rec_thresholds is not None:
self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds")
self.rec_thresholds: torch.Tensor | None = self._setup_thresholds(rec_thresholds, "rec_thresholds")
else:
self.rec_thresholds = None

if class_mean is not None and class_mean not in ("micro", "macro", "weighted"):
raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}")
self.class_mean = class_mean

def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor:
def _setup_thresholds(self, thresholds: Sequence[float] | torch.Tensor, threshold_type: str) -> torch.Tensor:
if isinstance(thresholds, Sequence):
thresholds = torch.tensor(thresholds, dtype=torch.double)

Expand Down Expand Up @@ -108,10 +109,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens


def _cat_and_agg_tensors(
tensors: List[torch.Tensor],
tensor_shape_except_last_dim: Tuple[int],
tensors: list[torch.Tensor],
tensor_shape_except_last_dim: tuple[int],
dtype: torch.dtype,
device: Union[str, torch.device],
device: str | torch.device,
) -> torch.Tensor:
"""
Concatenate tensors in ``tensors`` at their last dimension and gather all tensors from across all processes.
Expand Down Expand Up @@ -139,16 +140,16 @@ def _cat_and_agg_tensors(


class MeanAveragePrecision(_BaseClassification, _BaseAveragePrecision):
_y_pred: List[torch.Tensor]
_y_true: List[torch.Tensor]
_y_pred: list[torch.Tensor]
_y_true: list[torch.Tensor]

def __init__(
self,
rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None,
class_mean: Optional["Literal['micro', 'macro', 'weighted']"] = "macro",
rec_thresholds: Optional[Sequence[float] | torch.Tensor] = None,
class_mean: "Literal['micro', 'macro', 'weighted']" | None = "macro",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class_mean: "Literal['micro', 'macro', 'weighted']" | None = "macro",
class_mean: Literal['micro', 'macro', 'weighted'] | None = "macro",

is_multilabel: bool = False,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for
Expand Down Expand Up @@ -297,7 +298,7 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
return yp, yt

@reinit__is_reduced
def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
"""Metric update function using prediction and target.

Args:
Expand All @@ -315,7 +316,7 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None:

def _compute_recall_and_precision(
self, y_true: torch.Tensor, y_pred: torch.Tensor, y_true_positive_count: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
r"""Measuring recall & precision.

Shape of function inputs and return values follow the table below.
Expand Down Expand Up @@ -357,7 +358,7 @@ def _compute_recall_and_precision(
precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive)
return recall, precision

def compute(self) -> Union[torch.Tensor, float]:
def compute(self) -> torch.Tensor | float:
"""
Compute method of the metric
"""
Expand All @@ -367,7 +368,7 @@ def compute(self) -> Union[torch.Tensor, float]:

y_true = _cat_and_agg_tensors(
self._y_true,
cast(Tuple[int], ()) if self._type == "multiclass" else (num_classes,),
cast(tuple[int], ()) if self._type == "multiclass" else (num_classes,),
torch.long if self._type == "multiclass" else torch.uint8,
self._device,
)
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Sequence, Union
from collections.abc import Callable, Sequence

import torch

Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
self,
num_classes: int,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
normalized: bool = False,
skip_unrolling: bool = False,
):
Expand Down
11 changes: 6 additions & 5 deletions ignite/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, cast, Tuple, Union
from collections.abc import Callable
from typing import Any, cast

import torch

Expand All @@ -7,7 +8,7 @@
from ignite.metrics.epoch_metric import EpochMetric


def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]:
def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> tuple[Any, Any, Any]:
from sklearn.metrics import precision_recall_curve

y_true = y_targets.cpu().numpy()
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
device: str | torch.device = torch.device("cpu"),
skip_unrolling: bool = False,
) -> None:
try:
Expand All @@ -93,7 +94,7 @@ def __init__(
skip_unrolling=skip_unrolling,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override]
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.")

Expand Down Expand Up @@ -126,4 +127,4 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i

self._result = (precision, recall, thresholds) # type: ignore[assignment]

return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)
return cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)