scico.flax.train.input_pipeline¶
Generalized data handling for training script.
Includes construction of data iterator and instantiation for parallel processing.
Functions
|
Reshape input batch for parallel training. |
Classes
|
Class to load data for training and testing. |
- class scico.flax.train.input_pipeline.IterateData(dt, batch_size, train=True, key=None)[source]¶
Bases:
object
Class to load data for training and testing.
It uses the generator pattern to obtain an iterable object.
Initialize a
IterateData
object.- Parameters:
dt (
DataSetDict
) – Dictionary of data for supervised training including images and labels.batch_size (
int
) – Size of batch for iterating through the data.train (
bool
) – Flag indicating use of iterator for training. Iterator for training is infinite, iterator for testing passes once through the data. Default:True
.key (
Optional
[Array
]) – A PRNGKey used as the random key. Default:None
.