
.. DO NOT EDIT.
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
.. "auto_examples/filters/plot_j_invariant_tutorial.py"
.. LINE NUMBERS ARE GIVEN BELOW.

.. only:: html

    .. note::
        :class: sphx-glr-download-link-note

        :ref:`Go to the end <sphx_glr_download_auto_examples_filters_plot_j_invariant_tutorial.py>`
        to download the full example code.

.. rst-class:: sphx-glr-example-title

.. _sphx_glr_auto_examples_filters_plot_j_invariant_tutorial.py:


=========================================================
Full tutorial on calibrating Denoisers Using J-Invariance
=========================================================

In this example, we show how to find an optimally calibrated
version of any denoising algorithm.

The calibration method is based on the `noise2self` algorithm of [1]_.

.. [1] J. Batson & L. Royer. Noise2Self: Blind Denoising by Self-Supervision,
       International Conference on Machine Learning, p. 524-533 (2019).

.. seealso::
   A simple example of the method is given in
   :ref:`sphx_glr_auto_examples_filters_plot_j_invariant.py`.

.. GENERATED FROM PYTHON SOURCE LINES 20-21

Calibrating a wavelet denoiser

.. GENERATED FROM PYTHON SOURCE LINES 21-77

.. code-block:: Python


    import numpy as np
    from matplotlib import pyplot as plt
    from matplotlib import gridspec

    from skimage.data import chelsea, hubble_deep_field
    from skimage.metrics import mean_squared_error as mse
    from skimage.metrics import peak_signal_noise_ratio as psnr
    from skimage.restoration import (
        calibrate_denoiser,
        denoise_wavelet,
        denoise_tv_chambolle,
        denoise_nl_means,
        estimate_sigma,
    )
    from skimage.util import img_as_float, random_noise
    from skimage.color import rgb2gray
    from functools import partial

    _denoise_wavelet = partial(denoise_wavelet, rescale_sigma=True)

    image = img_as_float(chelsea())
    sigma = 0.2
    noisy = random_noise(image, var=sigma**2)

    # Parameters to test when calibrating the denoising algorithm
    parameter_ranges = {
        'sigma': np.arange(0.1, 0.3, 0.02),
        'wavelet': ['db1', 'db2'],
        'convert2ycbcr': [True, False],
        'channel_axis': [-1],
    }

    # Denoised image using default parameters of `denoise_wavelet`
    default_output = denoise_wavelet(noisy, channel_axis=-1, rescale_sigma=True)

    # Calibrate denoiser
    calibrated_denoiser = calibrate_denoiser(
        noisy, _denoise_wavelet, denoise_parameters=parameter_ranges
    )

    # Denoised image using calibrated denoiser
    calibrated_output = calibrated_denoiser(noisy)

    fig, axes = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(15, 5))

    for ax, img, title in zip(
        axes,
        [noisy, default_output, calibrated_output],
        ['Noisy Image', 'Denoised (Default)', 'Denoised (Calibrated)'],
    ):
        ax.imshow(img)
        ax.set_title(title)
        ax.set_yticks([])
        ax.set_xticks([])




.. image-sg:: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_001.png
   :alt: Noisy Image, Denoised (Default), Denoised (Calibrated)
   :srcset: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_001.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07628155468205111..0.9396137159825669].




.. GENERATED FROM PYTHON SOURCE LINES 78-102

The Self-Supervised Loss and J-Invariance
=========================================
The key to this calibration method is the notion of J-invariance. A denoising
function is J-invariant if the prediction it makes for each pixel does
not depend on the value of that pixel in the original image. The prediction
for each pixel may instead use all the relevant information contained in the
rest of the image, which is typically quite significant. Any function
can be converted into a J-invariant one using a simple masking procedure,
as described in [1].

The pixel-wise error of a J-invariant denoiser is uncorrelated
to the noise, so long as the noise in each pixel is independent.
Consequently, the average difference between the denoised image and the
noisy image, the *self-supervised loss*, is the same as the
difference between the denoised image and the original clean image, the
*ground-truth loss* (up to a constant).

This means that the best J-invariant denoiser for a given image can
be found using the noisy data alone, by selecting the denoiser minimizing
the self-supervised loss. Below, we demonstrate this
for a family of wavelet denoisers with varying `sigma` parameter. The
self-supervised loss (solid blue line) and the ground-truth loss (dashed
blue line) have the same shape and the same minimizer.


.. GENERATED FROM PYTHON SOURCE LINES 102-179

.. code-block:: Python


    from skimage.restoration import denoise_invariant

    sigma_range = np.arange(sigma / 2, 1.5 * sigma, 0.025)

    parameters_tested = [
        {'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2', 'channel_axis': -1}
        for sigma in sigma_range
    ]

    denoised_invariant = [
        denoise_invariant(noisy, _denoise_wavelet, denoiser_kwargs=params)
        for params in parameters_tested
    ]

    self_supervised_loss = [mse(img, noisy) for img in denoised_invariant]
    ground_truth_loss = [mse(img, image) for img in denoised_invariant]

    opt_idx = np.argmin(self_supervised_loss)
    plot_idx = [0, opt_idx, len(sigma_range) - 1]


    def get_inset(x):
        return x[25:225, 100:300]


    plt.figure(figsize=(10, 12))

    gs = gridspec.GridSpec(3, 3)
    ax1 = plt.subplot(gs[0, :])
    ax2 = plt.subplot(gs[1, :])
    ax_image = [plt.subplot(gs[2, i]) for i in range(3)]

    ax1.plot(sigma_range, self_supervised_loss, color='C0', label='Self-Supervised Loss')
    ax1.scatter(
        sigma_range[opt_idx],
        self_supervised_loss[opt_idx] + 0.0003,
        marker='v',
        color='red',
        label='optimal sigma',
    )

    ax1.set_ylabel('MSE')
    ax1.set_xticks([])
    ax1.legend()
    ax1.set_title('Self-Supervised Loss')

    ax2.plot(
        sigma_range,
        ground_truth_loss,
        color='C0',
        linestyle='--',
        label='Ground Truth Loss',
    )
    ax2.scatter(
        sigma_range[opt_idx],
        ground_truth_loss[opt_idx] + 0.0003,
        marker='v',
        color='red',
        label='optimal sigma',
    )
    ax2.set_ylabel('MSE')
    ax2.legend()
    ax2.set_xlabel('sigma')
    ax2.set_title('Ground-Truth Loss')

    for i in range(3):
        ax = ax_image[i]
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(get_inset(denoised_invariant[plot_idx[i]]))
        ax.set_xlabel('sigma = ' + str(np.round(sigma_range[plot_idx[i]], 2)))

    for spine in ax_image[1].spines.values():
        spine.set_edgecolor('red')
        spine.set_linewidth(5)




.. image-sg:: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_002.png
   :alt: Self-Supervised Loss, Ground-Truth Loss
   :srcset: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_002.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.054405485278038096..1.0024969736511142].
    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.07538645784206764..0.9358830138623191].
    Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). Got range [-0.06355978060341062..0.9504132575392662].




.. GENERATED FROM PYTHON SOURCE LINES 180-195

Conversion to J-invariance
=========================================
The function `_invariant_denoise` acts as a J-invariant version of a
given denoiser. It works by masking a fraction of the pixels, interpolating
them, running the original denoiser, and extracting the values returned in
the masked pixels. Iterating over the image results in a fully J-invariant
output.

For any given set of parameters, the J-invariant version of a denoiser
is different from the original denoiser, but it is not necessarily better
or worse. In the plot below, we see that, for the test image of a cat,
the J-invariant version of a wavelet denoiser is significantly better
than the original at small values of variance-reduction `sigma` and
imperceptibly worse at larger values.


.. GENERATED FROM PYTHON SOURCE LINES 195-232

.. code-block:: Python


    parameters_tested = [
        {'sigma': sigma, 'convert2ycbcr': True, 'wavelet': 'db2', 'channel_axis': -1}
        for sigma in sigma_range
    ]

    denoised_original = [_denoise_wavelet(noisy, **params) for params in parameters_tested]

    ground_truth_loss_invariant = [mse(img, image) for img in denoised_invariant]
    ground_truth_loss_original = [mse(img, image) for img in denoised_original]

    fig, ax = plt.subplots(figsize=(10, 4))

    ax.plot(
        sigma_range,
        ground_truth_loss_invariant,
        color='C0',
        linestyle='--',
        label='J-invariant',
    )
    ax.plot(
        sigma_range,
        ground_truth_loss_original,
        color='C1',
        linestyle='--',
        label='Original',
    )
    ax.scatter(
        sigma_range[opt_idx], ground_truth_loss[opt_idx] + 0.001, marker='v', color='red'
    )
    ax.legend()
    ax.set_title(
        'J-Invariant Denoiser Has Comparable Or ' 'Better Performance At Same Parameters'
    )
    ax.set_ylabel('MSE')
    ax.set_xlabel('sigma')




.. image-sg:: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_003.png
   :alt: J-Invariant Denoiser Has Comparable Or Better Performance At Same Parameters
   :srcset: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_003.png
   :class: sphx-glr-single-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none


    Text(0.5, 14.722222222222216, 'sigma')



.. GENERATED FROM PYTHON SOURCE LINES 233-246

Comparing Different Classes of Denoiser
=========================================
The self-supervised loss can be used to compare different classes of
denoiser in addition to choosing parameters for a single class.
This allows the user to, in an unbiased way, choose the best parameters
for the best class of denoiser for a given image.

Below, we show this for an image of the hubble deep field with significant
speckle noise added. In this case, the J-invariant calibrated denoiser is
better than the default denoiser in each of three families of denoisers --
Non-local means, wavelet, and TV norm. Additionally, the self-supervised
loss shows that the TV norm denoiser is the best for this noisy image.


.. GENERATED FROM PYTHON SOURCE LINES 246-350

.. code-block:: Python


    image = rgb2gray(img_as_float(hubble_deep_field()[100:250, 50:300]))

    sigma = 0.4
    noisy = random_noise(image, mode='speckle', var=sigma**2)

    parameter_ranges_tv = {'weight': np.arange(0.01, 0.3, 0.02)}
    _, (parameters_tested_tv, losses_tv) = calibrate_denoiser(
        noisy,
        denoise_tv_chambolle,
        denoise_parameters=parameter_ranges_tv,
        extra_output=True,
    )
    print(f'Minimum self-supervised loss TV: {np.min(losses_tv):.4f}')

    best_parameters_tv = parameters_tested_tv[np.argmin(losses_tv)]
    denoised_calibrated_tv = denoise_invariant(
        noisy, denoise_tv_chambolle, denoiser_kwargs=best_parameters_tv
    )
    denoised_default_tv = denoise_tv_chambolle(noisy, **best_parameters_tv)

    psnr_calibrated_tv = psnr(image, denoised_calibrated_tv)
    psnr_default_tv = psnr(image, denoised_default_tv)

    parameter_ranges_wavelet = {'sigma': np.arange(0.01, 0.3, 0.03)}
    _, (parameters_tested_wavelet, losses_wavelet) = calibrate_denoiser(
        noisy, _denoise_wavelet, parameter_ranges_wavelet, extra_output=True
    )
    print(f'Minimum self-supervised loss wavelet: {np.min(losses_wavelet):.4f}')

    best_parameters_wavelet = parameters_tested_wavelet[np.argmin(losses_wavelet)]
    denoised_calibrated_wavelet = denoise_invariant(
        noisy, _denoise_wavelet, denoiser_kwargs=best_parameters_wavelet
    )
    denoised_default_wavelet = _denoise_wavelet(noisy, **best_parameters_wavelet)

    psnr_calibrated_wavelet = psnr(image, denoised_calibrated_wavelet)
    psnr_default_wavelet = psnr(image, denoised_default_wavelet)

    sigma_est = estimate_sigma(noisy)

    parameter_ranges_nl = {
        'sigma': np.arange(0.6, 1.4, 0.2) * sigma_est,
        'h': np.arange(0.6, 1.2, 0.2) * sigma_est,
    }

    parameter_ranges_nl = {'sigma': np.arange(0.01, 0.3, 0.03)}
    _, (parameters_tested_nl, losses_nl) = calibrate_denoiser(
        noisy, denoise_nl_means, parameter_ranges_nl, extra_output=True
    )
    print(f'Minimum self-supervised loss NL means: {np.min(losses_nl):.4f}')

    best_parameters_nl = parameters_tested_nl[np.argmin(losses_nl)]
    denoised_calibrated_nl = denoise_invariant(
        noisy, denoise_nl_means, denoiser_kwargs=best_parameters_nl
    )
    denoised_default_nl = denoise_nl_means(noisy, **best_parameters_nl)

    psnr_calibrated_nl = psnr(image, denoised_calibrated_nl)
    psnr_default_nl = psnr(image, denoised_default_nl)

    print('                       PSNR')
    print(f'NL means (Default)   : {psnr_default_nl:.1f}')
    print(f'NL means (Calibrated): {psnr_calibrated_nl:.1f}')
    print(f'Wavelet  (Default)   : {psnr_default_wavelet:.1f}')
    print(f'Wavelet  (Calibrated): {psnr_calibrated_wavelet:.1f}')
    print(f'TV norm  (Default)   : {psnr_default_tv:.1f}')
    print(f'TV norm  (Calibrated): {psnr_calibrated_tv:.1f}')

    plt.subplots(figsize=(10, 12))
    plt.imshow(noisy, cmap='Greys_r')
    plt.xticks([])
    plt.yticks([])
    plt.title('Noisy Image')


    def get_inset(x):
        return x[0:100, -140:]


    fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(15, 8))

    for ax in axes.ravel():
        ax.set_xticks([])
        ax.set_yticks([])

    axes[0, 0].imshow(get_inset(denoised_default_nl), cmap='Greys_r')
    axes[0, 0].set_title('NL Means Default')
    axes[1, 0].imshow(get_inset(denoised_calibrated_nl), cmap='Greys_r')
    axes[1, 0].set_title('NL Means Calibrated')
    axes[0, 1].imshow(get_inset(denoised_default_wavelet), cmap='Greys_r')
    axes[0, 1].set_title('Wavelet Default')
    axes[1, 1].imshow(get_inset(denoised_calibrated_wavelet), cmap='Greys_r')
    axes[1, 1].set_title('Wavelet Calibrated')
    axes[0, 2].imshow(get_inset(denoised_default_tv), cmap='Greys_r')
    axes[0, 2].set_title('TV Norm Default')
    axes[1, 2].imshow(get_inset(denoised_calibrated_tv), cmap='Greys_r')
    axes[1, 2].set_title('TV Norm Calibrated')

    for spine in axes[1, 2].spines.values():
        spine.set_edgecolor('red')
        spine.set_linewidth(5)

    plt.show()



.. rst-class:: sphx-glr-horizontal


    *

      .. image-sg:: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_004.png
         :alt: Noisy Image
         :srcset: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_004.png
         :class: sphx-glr-multi-img

    *

      .. image-sg:: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_005.png
         :alt: NL Means Default, Wavelet Default, TV Norm Default, NL Means Calibrated, Wavelet Calibrated, TV Norm Calibrated
         :srcset: /auto_examples/filters/images/sphx_glr_plot_j_invariant_tutorial_005.png
         :class: sphx-glr-multi-img


.. rst-class:: sphx-glr-script-out

 .. code-block:: none

    Minimum self-supervised loss TV: 0.0036
    Minimum self-supervised loss wavelet: 0.0037
    Minimum self-supervised loss NL means: 0.0042
                           PSNR
    NL means (Default)   : 25.2
    NL means (Calibrated): 26.9
    Wavelet  (Default)   : 25.6
    Wavelet  (Calibrated): 28.7
    TV norm  (Default)   : 27.5
    TV norm  (Calibrated): 29.0





.. rst-class:: sphx-glr-timing

   **Total running time of the script:** (0 minutes 11.857 seconds)


.. _sphx_glr_download_auto_examples_filters_plot_j_invariant_tutorial.py:

.. only:: html

  .. container:: sphx-glr-footer sphx-glr-footer-example

    .. container:: sphx-glr-download sphx-glr-download-jupyter

      :download:`Download Jupyter notebook: plot_j_invariant_tutorial.ipynb <plot_j_invariant_tutorial.ipynb>`

    .. container:: sphx-glr-download sphx-glr-download-python

      :download:`Download Python source code: plot_j_invariant_tutorial.py <plot_j_invariant_tutorial.py>`

    .. container:: sphx-glr-download sphx-glr-download-zip

      :download:`Download zipped: plot_j_invariant_tutorial.zip <plot_j_invariant_tutorial.zip>`


.. only:: html

 .. rst-class:: sphx-glr-signature

    `Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
