scico.flax.blocks¶
Flax implementation of different convolutional blocks.
Functions
|
Nearest neighbor upscale for image batches of shape (N, H, W, C). |
Classes
|
Define convolution and batch normalization Flax block. |
|
Block constructed from sucessive applications of |
|
Define convolution, batch normalization and pooling Flax block. |
|
Define convolution, batch normalization and upsample Flax block. |
|
Define Flax convolution block. |
- class scico.flax.blocks.ConvBNBlock(num_filters, conv, norm, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Define convolution and batch normalization Flax block.
- Parameters:
num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.conv (
Any
) – Flax module implementing the convolution layer to apply.norm (
Any
) – Flax module implementing the batch normalization layer to apply.act (
Callable
[...
,Array
]) – Flax function defining the activation operation to apply.kernel_size (
Tuple
[int
,int
]) – A shape tuple defining the size of the convolution filters.strides (
Tuple
[int
,int
]) – A shape tuple defining the size of strides in convolution.
- class scico.flax.blocks.ConvBlock(num_filters, conv, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Define Flax convolution block.
- Parameters:
num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.conv (
Any
) – Flax module implementing the convolution layer to apply.act (
Callable
[...
,Array
]) – Flax function defining the activation operation to apply.kernel_size (
Tuple
[int
,int
]) – A shape tuple defining the size of the convolution filters.strides (
Tuple
[int
,int
]) – A shape tuple defining the size of strides in convolution.
- class scico.flax.blocks.ConvBNPoolBlock(num_filters, conv, norm, act, pool, kernel_size, strides, window_shape, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Define convolution, batch normalization and pooling Flax block.
- Parameters:
num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.conv (
Any
) – Flax module implementing the convolution layer to apply.norm (
Any
) – Flax module implementing the batch normalization layer to apply.act (
Callable
[...
,Array
]) – Flax function defining the activation operation to apply.pool (
Callable
[...
,Array
]) – Flax function defining the pooling operation to apply.kernel_size (
Tuple
[int
,int
]) – A shape tuple defining the size of the convolution filters.strides (
Tuple
[int
,int
]) – A shape tuple defining the size of strides in convolution.window_shape (
Tuple
[int
,int
]) – A shape tuple defining the window to reduce over in the pooling operation.
- class scico.flax.blocks.ConvBNUpsampleBlock(num_filters, conv, norm, act, upfn, kernel_size, strides, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Define convolution, batch normalization and upsample Flax block.
- Parameters:
num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.conv (
Any
) – Flax module implementing the convolution layer to apply.norm (
Any
) – Flax module implementing the batch normalization layer to apply.act (
Callable
[...
,Array
]) – Flax function defining the activation operation to apply.upfn (
Callable
[...
,Array
]) – Flax function defining the upsampling operation to apply.kernel_size (
Tuple
[int
,int
]) – A shape tuple defining the size of the convolution filters.strides (
Tuple
[int
,int
]) – A shape tuple defining the size of strides in convolution.
- class scico.flax.blocks.ConvBNMultiBlock(num_blocks, num_filters, conv, norm, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Block constructed from sucessive applications of
ConvBNBlock
.- Parameters:
num_blocks (
int
) – Number of convolutional batch normalization blocks to apply. Each block has its own parameters for convolution and batch normalization.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.conv (
Any
) – Flax module implementing the convolution layer to apply.norm (
Any
) – Flax module implementing the batch normalization layer to apply.act (
Callable
[...
,Array
]) – Flax function defining the activation operation to apply.kernel_size (
Tuple
[int
,int
]) – A shape tuple defining the size of the convolution filters.strides (
Tuple
[int
,int
]) – A shape tuple defining the size of strides in convolution.