public static class ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model> extends Object
| Modifier and Type | Field and Description |
|---|---|
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator |
accumulator |
protected boolean |
averageUpdaters |
protected int |
averagingFrequency |
protected boolean |
isMQ |
protected boolean |
legacyAveraging |
protected T |
model |
protected int |
prefetchSize |
protected boolean |
reportScore |
protected TrainerContext |
trainerContext |
protected Object[] |
trainerContextArgs |
protected ParallelWrapper.TrainingMode |
trainingMode |
protected int |
workers |
protected org.deeplearning4j.nn.conf.WorkspaceMode |
workspaceMode |
| Constructor and Description |
|---|
Builder(T model)
Build ParallelWrapper for MultiLayerNetwork
|
| Modifier and Type | Method and Description |
|---|---|
ParallelWrapper.Builder |
averageUpdaters(boolean reallyAverage)
This method enables/disables updaters averaging.
|
ParallelWrapper.Builder |
averagingFrequency(int freq)
Model averaging frequency.
|
ParallelWrapper |
build()
This method returns ParallelWrapper instance
|
ParallelWrapper.Builder |
gradientsAccumulator(org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)
This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance
|
ParallelWrapper.Builder |
prefetchBuffer(int size)
Size of prefetch buffer that will be used for background data prefetching.
|
ParallelWrapper.Builder |
reportScoreAfterAveraging(boolean reallyReport)
This method enables/disables averaged model score reporting
|
ParallelWrapper.Builder |
trainerContextArgs(Object... trainerContextArgs)
Transer context args are for calling a
TrainerContext init method
when ParallelWrapper starts training |
ParallelWrapper.Builder |
trainerFactory(TrainerContext trainerContext)
Specify a
TrainerContext
for the given ParallelWrapper
instance. |
ParallelWrapper.Builder |
trainingMode(ParallelWrapper.TrainingMode mode) |
ParallelWrapper.Builder |
workers(int num)
This method allows to configure number of workers that'll be used for parallel training
|
ParallelWrapper.Builder |
workspaceMode(org.deeplearning4j.nn.conf.WorkspaceMode mode) |
protected ParallelWrapper.TrainingMode trainingMode
protected T extends org.deeplearning4j.nn.api.Model model
protected int workers
protected int prefetchSize
protected int averagingFrequency
protected boolean reportScore
protected boolean averageUpdaters
protected boolean legacyAveraging
protected boolean isMQ
protected TrainerContext trainerContext
protected Object[] trainerContextArgs
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator
public Builder(@NonNull
T model)
model - public ParallelWrapper.Builder trainerContextArgs(Object... trainerContextArgs)
TrainerContext init method
when ParallelWrapper starts trainingtrainerContextArgs - the args to use (maybe null)public ParallelWrapper.Builder trainerFactory(@NonNull TrainerContext trainerContext)
TrainerContext
for the given ParallelWrapper
instance.
Defaults to DefaultTrainerContext
otherwisetrainerContext - the trainer factory to usepublic ParallelWrapper.Builder workspaceMode(@NonNull org.deeplearning4j.nn.conf.WorkspaceMode mode)
public ParallelWrapper.Builder workers(int num)
num - public ParallelWrapper.Builder averagingFrequency(int freq)
freq - number of iterations between averagingpublic ParallelWrapper.Builder averageUpdaters(boolean reallyAverage)
reallyAverage - public ParallelWrapper.Builder prefetchBuffer(int size)
size - 0 to disable prefetching, any positive numberpublic ParallelWrapper.Builder trainingMode(@NonNull ParallelWrapper.TrainingMode mode)
mode - public ParallelWrapper.Builder gradientsAccumulator(@NonNull org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator accumulator)
accumulator - public ParallelWrapper.Builder reportScoreAfterAveraging(boolean reallyReport)
reallyReport - public ParallelWrapper build()
Copyright © 2017. All rights reserved.