iunets.networks module

class iunets.networks.iUNet(channels, architecture, dim, create_module_fn=<function create_standard_module>, module_kwargs=None, learnable_resampling=True, resampling_stride=2, resampling_method='cayley', resampling_init='haar', resampling_kwargs=None, learnable_channel_mixing=True, channel_mixing_freq=-1, channel_mixing_method='cayley', channel_mixing_kwargs=None, padding_mode='constant', padding_value=0, revert_input_padding=True, disable_custom_gradient=False, verbose=1, **kwargs)

Bases: torch.nn.modules.module.Module

Fully-invertible U-Net (iUNet).

This model can be used for memory-efficient backpropagation, e.g. in high-dimensional (such as 3D) segmentation tasks.

Parameters
  • channels (Tuple[int, …]) – The number of channels at each resolution. For example: If one wants 5 resolution levels (i.e. 3 up-/downsampling operations), it should be a tuple of 4 numbers, e.g. (32,64,128,256,384).

  • architecture (Tuple[int, …]) –

    Determines the number of invertible layers at each resolution (both left and right), e.g. [2,3,4] results in the following structure:

    2-----2
     3---3
      4-4
    

    Must be the same length as channels.

  • dim (int) – Either 1, 2 or 3, signifying whether a 1D, 2D or 3D invertible U-Net should be created.

  • create_module_fn (Callable[[int, Optional[dict]], Module]) – Function which outputs an invertible layer. This layer should be a torch.nn.Module with a method forward(*x) and a method inverse(*x). create_module_fn should have the signature create_module_fn(in_channels, **kwargs). Additional keyword arguments passed on via kwargs are dim (whether this is a 1D, 2D or 3D iUNet), the coordinates of the specific module within the iUNet (branch, level and module_index) as well as architecture. By default, this creates an additive coupling layer, whose block consists of a number of convolutional layers, followed by an instance normalization layer and a leaky ReLU activation function. The number of blocks can be controlled by setting "depth" in module_kwargs, whose default value is 2.

  • module_kwargs (Optional[dict]) – dict of optional, additional keyword arguments that are passed on to create_module_fn.

  • learnable_resampling (bool) – Whether to train the invertible learnable up- and downsampling or to leave it at the initialized values. Defaults to True.

  • resampling_stride (int) – Controls the stride of the invertible up- and downsampling. The format can be either a single integer, a single tuple (where the length corresponds to the spatial dimensions of the data), or a list containing either of the last two options (where the length of the list has to be equal to the number of downsampling operations), For example: 2 would result in a up-/downsampling with a factor of 2 along each dimension; (2,1,4) would apply (at every resampling) a factor of 2, 1 and 4 for the height, width and depth dimensions respectively, whereas for a 3D iUNet with 3 up-/downsampling stages, [(2,1,3), (2,2,2), (4,3,1)] would result in different strides at different up-/downsampling stages.

  • resampling_method (str) – Chooses the method for parametrizing orthogonal matrices for invertible up- and downsampling. Can be either "exp" (i.e. exponentiation of skew-symmetric matrices) or "cayley" (i.e. the Cayley transform, acting on skew-symmetric matrices). Defaults to "cayley".

  • resampling_init (Union[str, ndarray, Tensor]) – Sets the initialization for the learnable up- and downsampling operators. Can be "haar", "pixel_shuffle" (aliases: "squeeze", "zeros"), a specific torch.Tensor or a numpy.ndarray. Defaults to "haar", i.e. the Haar transform.

  • resampling_kwargs (Optional[dict]) – dict of optional, additional keyword arguments that are passed on to the invertible up- and downsampling modules.

  • channel_mixing_freq (int) – How often an invertible channel mixing is applied, which is (in 2D) is an orthogonal 1x1-convolution. -1 means that this will only be applied before the channel splitting and before the recombination in the decoder branch. For any other n, this means that every n-th module is followed by an invertible channel mixing. In particular,``0`` deactivates the usage of invertible channel mixing. Defaults to -1.

  • channel_mixing_method (str) – How the orthogonal matrix for invertible channel mixing is parametrized. Same has resampling_method. Defaults to "cayley".

  • channel_mixing_kwargs (Optional[dict]) – dict of optional, additional keyword arguments that are passed on to the invertible channel mixing modules.

  • padding_mode (Optional[str]) – If downsampling is not possible without residue (e.g. when halving spatial odd-valued resolutions), the input gets padded to allow for invertibility of the padded input. padding_mode takes the same keywords as torch.nn.functional.pad for mode. If set to None, this behavior is deactivated. Defaults to "constant".

  • padding_value (int) – If padding_mode is set to constant, this is the value that the input is padded with, e.g. 0. Defaults to 0.

  • revert_input_padding (bool) – Whether to revert the input padding in the output, such that the input resolution is preserved, even when padding is required. Defaults to True.

  • disable_custom_gradient (bool) – If set to True, normal backpropagation (i.e. storing activations instead of reconstructing activations) is used. Defaults to False.

  • verbose (int) – Level of verbosity. Currently only 0 (no warnings) or 1, which includes warnings. Defaults to 1.

get_padding(x)

Calculates the required padding for the input.

Parameters

x (torch.Tensor) –

revert_padding(x, padding)

Reverses a given padding.

Parameters
  • x (Tensor) – The image that was originally padded.

  • padding (List[int]) – The padding that is removed from x.

pad(x, padded_shape=None, padding=None)

Applies the chosen padding to the input, if required.

encode(x, use_padding=False)

Encodes x, i.e. applies the contractive part of the iUNet.

decode(*codes)

Applies the expansive, i.e. decoding, portion of the iUNet.

forward(x)

Applies the forward mapping of the iUNet to x.

Parameters

x (torch.Tensor) –

decoder_inverse(x, use_padding=False)

Applies the inverse of the decoder portion of the iUNet.

encoder_inverse(*codes)

Applies the inverse of the encoder portion of the iUNet.

training: bool
inverse(x)

Applies the inverse of the iUNet to x.

Parameters

x (torch.Tensor) –

print_layout()

Prints the layout of the iUNet.