diff --git a/ignite/metrics/matthews_corrcoef.py b/ignite/metrics/matthews_corrcoef.py index eddc9ca82bae..6c3903886203 100644 --- a/ignite/metrics/matthews_corrcoef.py +++ b/ignite/metrics/matthews_corrcoef.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from collections.abc import Callable import torch @@ -80,7 +80,7 @@ def __init__( self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, - device: Union[str, torch.device] = torch.device("cpu"), + device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ): try: diff --git a/ignite/metrics/mean_average_precision.py b/ignite/metrics/mean_average_precision.py index ab04f29eacb6..3d331a3f9176 100644 --- a/ignite/metrics/mean_average_precision.py +++ b/ignite/metrics/mean_average_precision.py @@ -1,5 +1,6 @@ import warnings -from typing import Callable, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Callable, Sequence +from typing import cast import torch from typing_extensions import Literal @@ -14,8 +15,8 @@ class _BaseAveragePrecision: def __init__( self, - rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - class_mean: Optional[Literal["micro", "macro", "weighted"]] = "macro", + rec_thresholds: Sequence[float] | torch.Tensor | None = None, + class_mean: Literal["micro", "macro", "weighted"] | None = "macro", ) -> None: r"""Base class for Average Precision metric. @@ -56,7 +57,7 @@ def __init__( computes macro precision which is unweighted mean of AP computed across classes/labels. Default. """ if rec_thresholds is not None: - self.rec_thresholds: Optional[torch.Tensor] = self._setup_thresholds(rec_thresholds, "rec_thresholds") + self.rec_thresholds: torch.Tensor | None = self._setup_thresholds(rec_thresholds, "rec_thresholds") else: self.rec_thresholds = None @@ -64,7 +65,7 @@ def __init__( raise ValueError(f"Wrong `class_mean` parameter, given {class_mean}") self.class_mean = class_mean - def _setup_thresholds(self, thresholds: Union[Sequence[float], torch.Tensor], threshold_type: str) -> torch.Tensor: + def _setup_thresholds(self, thresholds: Sequence[float] | torch.Tensor, threshold_type: str) -> torch.Tensor: if isinstance(thresholds, Sequence): thresholds = torch.tensor(thresholds, dtype=torch.double) @@ -108,10 +109,10 @@ def _compute_average_precision(self, recall: torch.Tensor, precision: torch.Tens def _cat_and_agg_tensors( - tensors: List[torch.Tensor], - tensor_shape_except_last_dim: Tuple[int], + tensors: list[torch.Tensor], + tensor_shape_except_last_dim: tuple[int], dtype: torch.dtype, - device: Union[str, torch.device], + device: str | torch.device, ) -> torch.Tensor: """ Concatenate tensors in ``tensors`` at their last dimension and gather all tensors from across all processes. @@ -139,16 +140,16 @@ def _cat_and_agg_tensors( class MeanAveragePrecision(_BaseClassification, _BaseAveragePrecision): - _y_pred: List[torch.Tensor] - _y_true: List[torch.Tensor] + _y_pred: list[torch.Tensor] + _y_true: list[torch.Tensor] def __init__( self, - rec_thresholds: Optional[Union[Sequence[float], torch.Tensor]] = None, - class_mean: Optional["Literal['micro', 'macro', 'weighted']"] = "macro", + rec_thresholds: Sequence[float] | torch.Tensor | None = None, + class_mean: "Literal['micro', 'macro', 'weighted']" | None = "macro", is_multilabel: bool = False, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: r"""Calculate the mean average precision metric i.e. mean of the averaged-over-recall precision for @@ -297,7 +298,7 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens return yp, yt @reinit__is_reduced - def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: """Metric update function using prediction and target. Args: @@ -315,7 +316,7 @@ def update(self, output: Tuple[torch.Tensor, torch.Tensor]) -> None: def _compute_recall_and_precision( self, y_true: torch.Tensor, y_pred: torch.Tensor, y_true_positive_count: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: r"""Measuring recall & precision. Shape of function inputs and return values follow the table below. @@ -357,7 +358,7 @@ def _compute_recall_and_precision( precision = tp_summation / torch.where(predicted_positive == 0, 1, predicted_positive) return recall, precision - def compute(self) -> Union[torch.Tensor, float]: + def compute(self) -> torch.Tensor | float: """ Compute method of the metric """ @@ -367,7 +368,7 @@ def compute(self) -> Union[torch.Tensor, float]: y_true = _cat_and_agg_tensors( self._y_true, - cast(Tuple[int], ()) if self._type == "multiclass" else (num_classes,), + cast(tuple[int], ()) if self._type == "multiclass" else (num_classes,), torch.long if self._type == "multiclass" else torch.uint8, self._device, ) diff --git a/ignite/metrics/multilabel_confusion_matrix.py b/ignite/metrics/multilabel_confusion_matrix.py index 624bbbe4c02f..3dfc6a0d764a 100644 --- a/ignite/metrics/multilabel_confusion_matrix.py +++ b/ignite/metrics/multilabel_confusion_matrix.py @@ -1,4 +1,4 @@ -from typing import Callable, Sequence, Union +from collections.abc import Callable, Sequence import torch @@ -92,7 +92,7 @@ def __init__( self, num_classes: int, output_transform: Callable = lambda x: x, - device: Union[str, torch.device] = torch.device("cpu"), + device: str | torch.device = torch.device("cpu"), normalized: bool = False, skip_unrolling: bool = False, ): diff --git a/ignite/metrics/precision_recall_curve.py b/ignite/metrics/precision_recall_curve.py index d77f7a9160a7..a2a70c396504 100644 --- a/ignite/metrics/precision_recall_curve.py +++ b/ignite/metrics/precision_recall_curve.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, cast, Tuple, Union +from collections.abc import Callable +from typing import Any, cast import torch @@ -7,7 +8,7 @@ from ignite.metrics.epoch_metric import EpochMetric -def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> Tuple[Any, Any, Any]: +def precision_recall_curve_compute_fn(y_preds: torch.Tensor, y_targets: torch.Tensor) -> tuple[Any, Any, Any]: from sklearn.metrics import precision_recall_curve y_true = y_targets.cpu().numpy() @@ -77,7 +78,7 @@ def __init__( self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False, - device: Union[str, torch.device] = torch.device("cpu"), + device: str | torch.device = torch.device("cpu"), skip_unrolling: bool = False, ) -> None: try: @@ -93,7 +94,7 @@ def __init__( skip_unrolling=skip_unrolling, ) - def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override] + def compute(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore[override] if len(self._predictions) < 1 or len(self._targets) < 1: raise NotComputableError("PrecisionRecallCurve must have at least one example before it can be computed.") @@ -109,7 +110,7 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i if idist.get_rank() == 0: # Run compute_fn on zero rank only - precision, recall, thresholds = cast(Tuple, self.compute_fn(_prediction_tensor, _target_tensor)) + precision, recall, thresholds = cast(tuple, self.compute_fn(_prediction_tensor, _target_tensor)) precision = torch.tensor(precision, device=_prediction_tensor.device, dtype=self._double_dtype) recall = torch.tensor(recall, device=_prediction_tensor.device, dtype=self._double_dtype) # thresholds can have negative strides, not compatible with torch tensors @@ -126,4 +127,4 @@ def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: i self._result = (precision, recall, thresholds) # type: ignore[assignment] - return cast(Tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result) + return cast(tuple[torch.Tensor, torch.Tensor, torch.Tensor], self._result)