diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 13ef7b3347e7..d526ec00af4c 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -17,6 +17,34 @@ value is then computed using the output of the engine's ``process_function``: metric = Accuracy() metric.attach(engine, "accuracy") +If the engine's prediction output ``y_pred`` represents probability estimates, it can be binarized using the +``Mode.PROBABILITIES``, with default threshold of 0.5: +.. code-block:: python + + def process_function(engine, batch): + # ... + y = torch.from_numpy(np.array([0, 0, 1])) + y_pred = torch.from_numpy(np.array([0.1, 0.2, 0.7])) + return y_pred, y + + engine = Engine(process_function) + metric = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + metric.attach(engine, "accuracy") + +If the engine's prediction output ``y_pred`` represents logits, it can be binarized using the +``Mode.LOGITS``, with default threshold of 0.5: +.. code-block:: python + + def process_function(engine, batch): + # ... + y = torch.from_numpy(np.array([0, 0, 1])) + y_pred = torch.from_numpy(np.array([-2.1, 0.6, 1.7])) + return y_pred, y + + engine = Engine(process_function) + metric = Accuracy(mode=Accuracy.Mode.LOGITS) + metric.attach(engine, "accuracy") + If the engine's output is not in the format ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``, the user can use the ``output_transform`` argument to transform it: @@ -41,13 +69,13 @@ use the ``output_transform`` argument to transform it: .. warning:: Please, be careful when using ``lambda`` functions to setup multiple ``output_transform`` for multiple metrics - + .. code-block:: python # Wrong # metrics_group = [Accuracy(output_transform=lambda output: output[name]) for name in names] # As lambda can not store `name` and all `output_transform` will use the last `name` - + # A correct way. For example, using functools.partial from functools import partial @@ -55,7 +83,7 @@ use the ``output_transform`` argument to transform it: return output[name] metrics_group = [Accuracy(output_transform=partial(ot_func, name=name)) for name in names] - + For more details, see `here `_ .. Note :: diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 7d6c939e4b53..8e0f50572bcf 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -1,3 +1,4 @@ +import enum from typing import Callable, Sequence, Union import torch @@ -103,19 +104,10 @@ class Accuracy(_BaseClassification): - `y` and `y_pred` must be in the following shape of (batch_size, num_categories, ...) and num_categories must be greater than 1 for multilabel cases. - In binary and multilabel cases, the elements of `y` and `y_pred` should have 0 or 1 values. Thresholding of + In binary and multilabel cases, the elements of `y` should have 0 or 1 values, while `y_pred` can represent + probabilities using PROBABILITIES Mode or logits using LOGITS Mode. predictions can be done as below: - .. code-block:: python - - def thresholded_output_transform(output): - y_pred, y = output - y_pred = torch.round(y_pred) - return y_pred, y - - binary_accuracy = Accuracy(thresholded_output_transform) - - Args: output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the @@ -125,17 +117,30 @@ def thresholded_output_transform(output): device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + mode (Mode): specifies in which form input will be passed. This can be useful to directly compute + accuracy on the output of a neural network, which ofter return probabilities. By default, UNCHANGED. + threshold (float): threshold for binarization of the input, in case a Mode that uses + binarization is used. """ + class Mode(enum.Enum): + UNCHANGED = 0 + PROBABILITIES = 1 + LOGITS = 2 + def __init__( self, output_transform: Callable = lambda x: x, is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), + mode: Mode = Mode.UNCHANGED, + threshold: float = 0.5, ): self._num_correct = None self._num_examples = None + self._mode = mode + self._threshold = threshold super(Accuracy, self).__init__(output_transform=output_transform, is_multilabel=is_multilabel, device=device) @reinit__is_reduced @@ -147,9 +152,15 @@ def reset(self) -> None: @reinit__is_reduced def update(self, output: Sequence[torch.Tensor]) -> None: self._check_shape(output) - self._check_type(output) y_pred, y = output[0].detach(), output[1].detach() + if self._mode == self.Mode.PROBABILITIES: + y_pred = (y_pred >= self._threshold).int() + if self._mode == self.Mode.LOGITS: + y_pred = (torch.sigmoid(y_pred) >= self._threshold).int() + + self._check_type([y_pred, y]) + if self._type == "binary": correct = torch.eq(y_pred.view(-1).to(y), y.view(-1)) elif self._type == "multiclass": diff --git a/tests/ignite/metrics/test_accuracy.py b/tests/ignite/metrics/test_accuracy.py index 3960a09ec7f7..90ee9c31ab9e 100644 --- a/tests/ignite/metrics/test_accuracy.py +++ b/tests/ignite/metrics/test_accuracy.py @@ -91,6 +91,156 @@ def _test(): _test() +def test_binary_input_N_probabilities(): + # Binary accuracy on probabilities input of shape (N, 1) or (N, ) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + + y_pred = torch.rand(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + +def test_binary_input_N_probabilities_threshold(): + # Binary accuracy on probabilities input of shape (N, 1) or (N, ), + # with custom binarization threshold. + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES, threshold=0.75) + + y_pred = torch.rand(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = (y_pred.numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + +def test_binary_input_N_logits(): + # Binary accuracy on logits input of shape (N, 1) or (N, ) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.LOGITS) + + y_pred = torch.randn(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.randn(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = torch.sigmoid(y_pred).numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + +def test_binary_input_N_logits_threshold(): + # Binary accuracy on logits input of shape (N, 1) or (N, ), + # with custom binarization threshold. + def _test(): + acc = Accuracy(mode=Accuracy.Mode.LOGITS, threshold=0.75) + + y_pred = torch.randn(size=(10,)) + y = torch.randint(0, 2, size=(10,)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.randn(size=(100,)) + y = torch.randint(0, 2, size=(100,)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = (torch.sigmoid(y_pred).numpy() >= 0.75).astype(int).ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NL(): # Binary accuracy on input of shape (N, L) def _test(): @@ -138,6 +288,53 @@ def _test(): _test() +def test_binary_input_NL_probabilities(): + # Binary accuracy on probabilities input of shape (N, L) + def _test(): + acc = Accuracy(mode=Accuracy.Mode.PROBABILITIES) + + y_pred = torch.rand(size=(10, 5)) + y = torch.randint(0, 2, size=(10, 5)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + acc.reset() + y_pred = torch.rand(size=(10, 1, 5)) + y = torch.randint(0, 2, size=(10, 1, 5)).long() + acc.update((y_pred, y)) + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # Batched Updates + acc.reset() + y_pred = torch.rand(size=(100, 8)) + y = torch.randint(0, 2, size=(100, 8)).long() + + n_iters = 16 + batch_size = y.shape[0] // n_iters + 1 + + for i in range(n_iters): + idx = i * batch_size + acc.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) + + np_y = y.numpy().ravel() + np_y_pred = y_pred.numpy().round().ravel() + assert acc._type == "binary" + assert isinstance(acc.compute(), float) + assert accuracy_score(np_y, np_y_pred) == pytest.approx(acc.compute()) + + # check multiple random inputs as random exact occurencies are rare + for _ in range(10): + _test() + + def test_binary_input_NHW(): # Binary accuracy on input of shape (N, H, W, ...) def _test():