Models¶
The model API provides convenient high-level interface to do training and predicting on a network described using the symbolic API.
-
class
AbstractModel
¶ The abstract super type of all models in MXNet.jl.
-
class
FeedForward
¶ The feedforward model provides convenient interface to train and predict on feedforward architectures like multi-layer MLP, ConvNets, etc. There is no explicitly handling of time index, but it is relatively easy to implement unrolled RNN / LSTM under this framework (TODO: add example). For models that handles sequential data explicitly, please use TODO…
-
FeedForward
(arch :: SymbolicNode, ctx) Parameters: - arch – the architecture of the network constructed using the symbolic API.
- ctx – the devices on which this model should do computation. It could be a single
Context
or a list ofContext
objects. In the latter case, data parallelization will be used for training. If no context is provided, the default contextcpu()
will be used.
-
init_model
(self, initializer; overwrite=false, input_shapes...)¶ Initialize the weights in the model.
This method will be called automatically when training a model. So there is usually no need to call this method unless one needs to inspect a model with only randomly initialized weights.
Parameters: - self (FeedForward) – the model to be initialized.
- initializer (AbstractInitializer) – an initializer describing how the weights should be initialized.
- overwrite (Bool) – keyword argument, force initialization even when weights already exists.
- input_shapes – the shape of all data and label inputs to this model, given as keyword arguments.
For example,
data=(28,28,1,100), label=(100,)
.
-
predict
(self, data; overwrite=false, callback=nothing)¶ Predict using an existing model. The model should be already initialized, or trained or loaded from a checkpoint. There is an overloaded function that allows to pass the callback as the first argument, so it is possible to do
predict(model, data) do batch_output # consume or write batch_output to file end
Parameters: - self (FeedForward) – the model.
- data (AbstractDataProvider) – the data to perform prediction on.
- overwrite (Bool) – an
Executor
is initialized the first time predict is called. The memory allocation of theExecutor
depends on the mini-batch size of the test data provider. If you call predict twice with data provider of the same batch-size, then the executor can be potentially be re-used. So, ifoverwrite
is false, we will try to re-use, and raise an error if batch-size changed. Ifoverwrite
is true (the default), a newExecutor
will be created to replace the old one.
Note
Prediction is computationally much less costly than training, so the bottleneck sometimes becomes the IO for copying mini-batches of data. Since there is no concern about convergence in prediction, it is better to set the mini-batch size as large as possible (limited by your device memory) if prediction speed is a concern.
For the same reason, currently prediction will only use the first device even if multiple devices are provided to construct the model.
Note
If you perform further after prediction. The weights are not automatically synchronized if
overwrite
is set to false and the old predictor is re-used. In this case settingoverwrite
to true (the default) will re-initialize the predictor the next time you call predict and synchronize the weights again.Seealso: train()
,fit()
,init_model()
,load_checkpoint()
-
train
(model :: FeedForward, ...)¶ Alias to
fit()
.
-
fit
(model :: FeedForward, optimizer, data; kwargs...)¶ Train the
model
ondata
with theoptimizer
.Parameters: - model (FeedForward) – the model to be trained.
- optimizer (AbstractOptimizer) – the optimization algorithm to use.
- data (AbstractDataProvider) – the training data provider.
- n_epoch (Int) – default 10, the number of full data-passes to run.
- eval_data (AbstractDataProvider) – keyword argument, default
nothing
. The data provider for the validation set. - eval_metric (AbstractEvalMetric) – keyword argument, default
Accuracy()
. The metric used to evaluate the training performance. Ifeval_data
is provided, the same metric is also calculated on the validation set. - kvstore (
KVStore
orBase.Symbol
) – keyword argument, default:local
. The key-value store used to synchronize gradients and parameters when multiple devices are used for training. - initializer (AbstractInitializer) – keyword argument, default
UniformInitializer(0.01)
. - force_init (Bool) – keyword argument, default false. By default, the random initialization using the
provided
initializer
will be skipped if the model weights already exists, maybe from a previous call totrain()
or an explicit call toinit_model()
orload_checkpoint()
. When this option is set, it will always do random initialization at the begining of training. - callbacks (
Vector{AbstractCallback}
) – keyword argument, default[]
. Callbacks to be invoked at each epoch or mini-batch, seeAbstractCallback
.