iunets.layers module

class iunets.layers.OrthogonalResamplingLayer(low_channel_number, stride, method='cayley', init='haar', learnable=True, init_kwargs=None, **kwargs)

Bases: torch.nn.modules.module.Module

Base class for orthogonal up- and downsampling operators.

Parameters
  • low_channel_number (int) – Lower number of channels. These are the input channels in the case of downsampling ops, and the output channels in the case of upsampling ops.

  • stride (Union[int, Tuple[int, …]]) – The downsampling / upsampling factor for each dimension.

  • channel_multiplier – The channel multiplier, i.e. the number by which the number of channels are multiplied (downsampling) or divided (upsampling).

  • method (str) – Which method to use for parametrizing orthogonal matrices which are used as convolutional kernels.

property kernel_matrix

The orthogonal matrix created by the chosen parametrisation method.

property kernel

The kernel associated with the invertible up-/downsampling.

training: bool
class iunets.layers.InvertibleDownsampling1D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleUpsampling1D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleDownsampling2D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleUpsampling2D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleDownsampling3D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleUpsampling3D(in_channels, stride=2, method='cayley', init='haar', learnable=True, *args, **kwargs)

Bases: iunets.layers.OrthogonalResamplingLayer

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.SplitChannels(split_location)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x, y)
training: bool
class iunets.layers.ConcatenateChannels(split_location)

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.AdditiveCoupling(F, channel_split_pos)

Bases: torch.nn.modules.module.Module

Additive coupling layer, a basic invertible layer.

By splitting the input activation \(x\) and output activation \(y\) into two groups of channels (i.e. \((x_1, x_2) \cong x\) and \((y_1, y_2) \cong y\)), additive coupling layers define an invertible mapping \(x \mapsto y\) via

\[ \begin{align}\begin{aligned}y_1 &= x_2\\y_2 &= x_1 + F(x_2),\end{aligned}\end{align} \]

where the coupling function \(F\) is an (almost) arbitrary mapping. \(F\) just has to map from the space of \(x_2\) to the space of \(x_1\). In practice, this can for instance be a sequence of convolutional layers with batch normalization.

The inverse of the above mapping is computed algebraically via

\[ \begin{align}\begin{aligned}x_1 &= y_2 - F(y_1)\\x_2 &= y_1.\end{aligned}\end{align} \]

Warning: Note that this is different from the definition of additive coupling layers in MemCNN. Those are equivalent to two consecutive instances of the above-defined additive coupling layers. Hence, the variant implemented here is twice as memory-efficient as the variant from MemCNN.

Parameters
  • F (Module) – The coupling function of the additive coupling layer, typically a sequence of neural network layers.

  • channel_split_pos (int) – The index of the channel at which the input and output activations are split.

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(y)
training: bool
class iunets.layers.StandardBlock(dim, num_in_channels, num_out_channels, depth=2, zero_init=False, normalization='instance', **kwargs)

Bases: torch.nn.modules.module.Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
iunets.layers.create_standard_module(in_channels, **kwargs)
class iunets.layers.OrthogonalChannelMixing(in_channels, method='cayley', learnable=True, **kwargs)

Bases: torch.nn.modules.module.Module

Base class for all orthogonal channel mixing layers.

property kernel_matrix

The orthogonal matrix created by the chosen parametrisation method.

property kernel_matrix_transposed

The orthogonal matrix created by the chosen parametrisation method.

training: bool
class iunets.layers.InvertibleChannelMixing1D(in_channels, method='cayley', learnable=True, **kwargs)

Bases: iunets.layers.OrthogonalChannelMixing

Orthogonal (and hence invertible) channel mixing layer for 1D data.

This layer linearly combines the input channels to each output channel. Here, the number of output channels is the same as the number of input channels, and the matrix specifying the connectivity between the channels is orthogonal.

Parameters
  • in_channels (int) – The number of input (and output) channels.

  • method (str) – The chosen method for parametrizing the orthogonal matrix which determines the orthogonal channel mixing. Either "exp", "cayley" or "householder".

property kernel
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleChannelMixing2D(in_channels, method='cayley', learnable=True, **kwargs)

Bases: iunets.layers.OrthogonalChannelMixing

property kernel
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)
training: bool
class iunets.layers.InvertibleChannelMixing3D(in_channels, method='cayley', learnable=True, **kwargs)

Bases: iunets.layers.OrthogonalChannelMixing

training: bool
property kernel
forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(x)