public class SymmetricTrainerContext extends Object implements TrainerContext
DefaultTrainer
instances for use with ParallelWrapper| Constructor and Description |
|---|
SymmetricTrainerContext() |
| Modifier and Type | Method and Description |
|---|---|
Trainer |
create(int threadId,
org.deeplearning4j.nn.api.Model model,
int rootDevice,
boolean useMDS,
ParallelWrapper wrapper,
org.deeplearning4j.nn.conf.WorkspaceMode mode,
int averagingFrequency)
Create a
Trainer
based on the given parameters |
void |
finalizeRound(org.deeplearning4j.nn.api.Model originalModel,
org.deeplearning4j.nn.api.Model... models)
This method is called at averagingFrequency
|
void |
finalizeTraining(org.deeplearning4j.nn.api.Model originalModel,
org.deeplearning4j.nn.api.Model... models)
This method is called
|
void |
init(org.deeplearning4j.nn.api.Model model,
Object... args)
Initialize the context
|
public void init(org.deeplearning4j.nn.api.Model model,
Object... args)
init in interface TrainerContextmodel - args - the arguments to initialize with (maybe null)public Trainer create(int threadId, org.deeplearning4j.nn.api.Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, org.deeplearning4j.nn.conf.WorkspaceMode mode, int averagingFrequency)
Trainer
based on the given parameterscreate in interface TrainerContextthreadId - the thread id to use for this workermodel - the model to start the trainer withrootDevice - the root device iduseMDS - whether to use the MagicQueue
or notwrapper - the wrapper instance to use with this trainer (this refernece is needed
for coordination with the ParallelWrapper 's IterationListenerpublic void finalizeRound(org.deeplearning4j.nn.api.Model originalModel,
org.deeplearning4j.nn.api.Model... models)
TrainerContextfinalizeRound in interface TrainerContextpublic void finalizeTraining(org.deeplearning4j.nn.api.Model originalModel,
org.deeplearning4j.nn.api.Model... models)
TrainerContextfinalizeTraining in interface TrainerContextCopyright © 2017. All rights reserved.