iunets.networks module¶
-
class
iunets.networks.
iUNet
(channels, architecture, dim, create_module_fn=<function create_standard_module>, module_kwargs=None, learnable_resampling=True, resampling_stride=2, resampling_method='cayley', resampling_init='haar', resampling_kwargs=None, learnable_channel_mixing=True, channel_mixing_freq=-1, channel_mixing_method='cayley', channel_mixing_kwargs=None, padding_mode='constant', padding_value=0, revert_input_padding=True, disable_custom_gradient=False, verbose=1, **kwargs)¶ Bases:
torch.nn.modules.module.Module
Fully-invertible U-Net (iUNet).
This model can be used for memory-efficient backpropagation, e.g. in high-dimensional (such as 3D) segmentation tasks.
- Parameters
channels (
Tuple
[int
, …]) – The number of channels at each resolution. For example: If one wants 5 resolution levels (i.e. 3 up-/downsampling operations), it should be a tuple of 4 numbers, e.g.(32,64,128,256,384)
.architecture (
Tuple
[int
, …]) –Determines the number of invertible layers at each resolution (both left and right), e.g.
[2,3,4]
results in the following structure:2-----2 3---3 4-4
Must be the same length as
channels
.dim (
int
) – Either1
,2
or3
, signifying whether a 1D, 2D or 3D invertible U-Net should be created.create_module_fn (
Callable
[[int
,Optional
[dict
]],Module
]) – Function which outputs an invertible layer. This layer should be atorch.nn.Module
with a methodforward(*x)
and a methodinverse(*x)
.create_module_fn
should have the signaturecreate_module_fn(in_channels, **kwargs)
. Additional keyword arguments passed on viakwargs
aredim
(whether this is a 1D, 2D or 3D iUNet), the coordinates of the specific module within the iUNet (branch
,level
andmodule_index
) as well asarchitecture
. By default, this creates an additive coupling layer, whose block consists of a number of convolutional layers, followed by an instance normalization layer and a leaky ReLU activation function. The number of blocks can be controlled by setting"depth"
inmodule_kwargs
, whose default value is2
.module_kwargs (
Optional
[dict
]) –dict
of optional, additional keyword arguments that are passed on tocreate_module_fn
.learnable_resampling (
bool
) – Whether to train the invertible learnable up- and downsampling or to leave it at the initialized values. Defaults toTrue
.resampling_stride (
int
) – Controls the stride of the invertible up- and downsampling. The format can be either a single integer, a single tuple (where the length corresponds to the spatial dimensions of the data), or a list containing either of the last two options (where the length of the list has to be equal to the number of downsampling operations), For example:2
would result in a up-/downsampling with a factor of 2 along each dimension;(2,1,4)
would apply (at every resampling) a factor of 2, 1 and 4 for the height, width and depth dimensions respectively, whereas for a 3D iUNet with 3 up-/downsampling stages,[(2,1,3), (2,2,2), (4,3,1)]
would result in different strides at different up-/downsampling stages.resampling_method (
str
) – Chooses the method for parametrizing orthogonal matrices for invertible up- and downsampling. Can be either"exp"
(i.e. exponentiation of skew-symmetric matrices) or"cayley"
(i.e. the Cayley transform, acting on skew-symmetric matrices). Defaults to"cayley"
.resampling_init (
Union
[str
,ndarray
,Tensor
]) – Sets the initialization for the learnable up- and downsampling operators. Can be"haar"
,"pixel_shuffle"
(aliases:"squeeze"
,"zeros"
), a specifictorch.Tensor
or anumpy.ndarray
. Defaults to"haar"
, i.e. the Haar transform.resampling_kwargs (
Optional
[dict
]) –dict
of optional, additional keyword arguments that are passed on to the invertible up- and downsampling modules.channel_mixing_freq (
int
) – How often an invertible channel mixing is applied, which is (in 2D) is an orthogonal 1x1-convolution.-1
means that this will only be applied before the channel splitting and before the recombination in the decoder branch. For any othern
, this means that everyn
-th module is followed by an invertible channel mixing. In particular,``0`` deactivates the usage of invertible channel mixing. Defaults to-1
.channel_mixing_method (
str
) – How the orthogonal matrix for invertible channel mixing is parametrized. Same hasresampling_method
. Defaults to"cayley"
.channel_mixing_kwargs (
Optional
[dict
]) –dict
of optional, additional keyword arguments that are passed on to the invertible channel mixing modules.padding_mode (
Optional
[str
]) – If downsampling is not possible without residue (e.g. when halving spatial odd-valued resolutions), the input gets padded to allow for invertibility of the padded input. padding_mode takes the same keywords astorch.nn.functional.pad
formode
. If set toNone
, this behavior is deactivated. Defaults to"constant"
.padding_value (
int
) – Ifpadding_mode
is set to constant, this is the value that the input is padded with, e.g. 0. Defaults to0
.revert_input_padding (
bool
) – Whether to revert the input padding in the output, such that the input resolution is preserved, even when padding is required. Defaults toTrue
.disable_custom_gradient (
bool
) – If set toTrue
, normal backpropagation (i.e. storing activations instead of reconstructing activations) is used. Defaults toFalse
.verbose (
int
) – Level of verbosity. Currently only 0 (no warnings) or 1, which includes warnings. Defaults to1
.
-
get_padding
(x)¶ Calculates the required padding for the input.
- Parameters
x (torch.Tensor) –
-
revert_padding
(x, padding)¶ Reverses a given padding.
- Parameters
x (
Tensor
) – The image that was originally padded.padding (
List
[int
]) – The padding that is removed fromx
.
-
pad
(x, padded_shape=None, padding=None)¶ Applies the chosen padding to the input, if required.
-
encode
(x, use_padding=False)¶ Encodes x, i.e. applies the contractive part of the iUNet.
-
decode
(*codes)¶ Applies the expansive, i.e. decoding, portion of the iUNet.
-
forward
(x)¶ Applies the forward mapping of the iUNet to
x
.- Parameters
x (torch.Tensor) –
-
decoder_inverse
(x, use_padding=False)¶ Applies the inverse of the decoder portion of the iUNet.
-
encoder_inverse
(*codes)¶ Applies the inverse of the encoder portion of the iUNet.
-
training
: bool¶
-
inverse
(x)¶ Applies the inverse of the iUNet to
x
.- Parameters
x (torch.Tensor) –
-
print_layout
()¶ Prints the layout of the iUNet.