iunets package¶
Submodules¶
Module contents¶
-
class
iunets.iUNet(channels, architecture, dim, create_module_fn=<function create_standard_module>, module_kwargs=None, learnable_resampling=True, resampling_stride=2, resampling_method='cayley', resampling_init='haar', resampling_kwargs=None, learnable_channel_mixing=True, channel_mixing_freq=-1, channel_mixing_method='cayley', channel_mixing_kwargs=None, padding_mode='constant', padding_value=0, revert_input_padding=True, disable_custom_gradient=False, verbose=1, **kwargs)¶ Bases:
torch.nn.modules.module.ModuleFully-invertible U-Net (iUNet).
This model can be used for memory-efficient backpropagation, e.g. in high-dimensional (such as 3D) segmentation tasks.
- Parameters
channels (
Tuple[int, …]) – The number of channels at each resolution. For example: If one wants 5 resolution levels (i.e. 3 up-/downsampling operations), it should be a tuple of 4 numbers, e.g.(32,64,128,256,384).architecture (
Tuple[int, …]) –Determines the number of invertible layers at each resolution (both left and right), e.g.
[2,3,4]results in the following structure:2-----2 3---3 4-4
Must be the same length as
channels.dim (
int) – Either1,2or3, signifying whether a 1D, 2D or 3D invertible U-Net should be created.create_module_fn (
Callable[[int,Optional[dict]],Module]) – Function which outputs an invertible layer. This layer should be atorch.nn.Modulewith a methodforward(*x)and a methodinverse(*x).create_module_fnshould have the signaturecreate_module_fn(in_channels, **kwargs). Additional keyword arguments passed on viakwargsaredim(whether this is a 1D, 2D or 3D iUNet), the coordinates of the specific module within the iUNet (branch,levelandmodule_index) as well asarchitecture. By default, this creates an additive coupling layer, whose block consists of a number of convolutional layers, followed by an instance normalization layer and a leaky ReLU activation function. The number of blocks can be controlled by setting"depth"inmodule_kwargs, whose default value is2.module_kwargs (
Optional[dict]) –dictof optional, additional keyword arguments that are passed on tocreate_module_fn.learnable_resampling (
bool) – Whether to train the invertible learnable up- and downsampling or to leave it at the initialized values. Defaults toTrue.resampling_stride (
int) – Controls the stride of the invertible up- and downsampling. The format can be either a single integer, a single tuple (where the length corresponds to the spatial dimensions of the data), or a list containing either of the last two options (where the length of the list has to be equal to the number of downsampling operations), For example:2would result in a up-/downsampling with a factor of 2 along each dimension;(2,1,4)would apply (at every resampling) a factor of 2, 1 and 4 for the height, width and depth dimensions respectively, whereas for a 3D iUNet with 3 up-/downsampling stages,[(2,1,3), (2,2,2), (4,3,1)]would result in different strides at different up-/downsampling stages.resampling_method (
str) – Chooses the method for parametrizing orthogonal matrices for invertible up- and downsampling. Can be either"exp"(i.e. exponentiation of skew-symmetric matrices) or"cayley"(i.e. the Cayley transform, acting on skew-symmetric matrices). Defaults to"cayley".resampling_init (
Union[str,ndarray,Tensor]) – Sets the initialization for the learnable up- and downsampling operators. Can be"haar","pixel_shuffle"(aliases:"squeeze","zeros"), a specifictorch.Tensoror anumpy.ndarray. Defaults to"haar", i.e. the Haar transform.resampling_kwargs (
Optional[dict]) –dictof optional, additional keyword arguments that are passed on to the invertible up- and downsampling modules.channel_mixing_freq (
int) – How often an invertible channel mixing is applied, which is (in 2D) is an orthogonal 1x1-convolution.-1means that this will only be applied before the channel splitting and before the recombination in the decoder branch. For any othern, this means that everyn-th module is followed by an invertible channel mixing. In particular,``0`` deactivates the usage of invertible channel mixing. Defaults to-1.channel_mixing_method (
str) – How the orthogonal matrix for invertible channel mixing is parametrized. Same hasresampling_method. Defaults to"cayley".channel_mixing_kwargs (
Optional[dict]) –dictof optional, additional keyword arguments that are passed on to the invertible channel mixing modules.padding_mode (
Optional[str]) – If downsampling is not possible without residue (e.g. when halving spatial odd-valued resolutions), the input gets padded to allow for invertibility of the padded input. padding_mode takes the same keywords astorch.nn.functional.padformode. If set toNone, this behavior is deactivated. Defaults to"constant".padding_value (
int) – Ifpadding_modeis set to constant, this is the value that the input is padded with, e.g. 0. Defaults to0.revert_input_padding (
bool) – Whether to revert the input padding in the output, such that the input resolution is preserved, even when padding is required. Defaults toTrue.disable_custom_gradient (
bool) – If set toTrue, normal backpropagation (i.e. storing activations instead of reconstructing activations) is used. Defaults toFalse.verbose (
int) – Level of verbosity. Currently only 0 (no warnings) or 1, which includes warnings. Defaults to1.
-
get_padding(x)¶ Calculates the required padding for the input.
- Parameters
x (torch.Tensor) –
-
revert_padding(x, padding)¶ Reverses a given padding.
- Parameters
x (
Tensor) – The image that was originally padded.padding (
List[int]) – The padding that is removed fromx.
-
pad(x, padded_shape=None, padding=None)¶ Applies the chosen padding to the input, if required.
-
encode(x, use_padding=False)¶ Encodes x, i.e. applies the contractive part of the iUNet.
-
decode(*codes)¶ Applies the expansive, i.e. decoding, portion of the iUNet.
-
forward(x)¶ Applies the forward mapping of the iUNet to
x.- Parameters
x (torch.Tensor) –
-
decoder_inverse(x, use_padding=False)¶ Applies the inverse of the decoder portion of the iUNet.
-
encoder_inverse(*codes)¶ Applies the inverse of the encoder portion of the iUNet.
-
training: bool¶
-
inverse(x)¶ Applies the inverse of the iUNet to
x.- Parameters
x (torch.Tensor) –
-
print_layout()¶ Prints the layout of the iUNet.