public class ParallelWrapper extends Object implements AutoCloseable
| Modifier and Type | Class and Description |
|---|---|
static class |
ParallelWrapper.Builder<T extends org.deeplearning4j.nn.api.Model> |
static class |
ParallelWrapper.TrainingMode |
| Modifier and Type | Field and Description |
|---|---|
protected boolean |
averageUpdaters |
protected int |
averagingFrequency |
protected boolean |
debug |
protected ThreadPoolExecutor |
executorService |
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator |
gradientsAccumulator |
protected boolean |
isMQ |
protected AtomicLong |
iterationsCounter |
protected boolean |
legacyAveraging |
protected List<org.deeplearning4j.optimize.api.IterationListener> |
listeners |
protected org.deeplearning4j.nn.api.Model |
model |
protected int |
prefetchSize |
protected boolean |
reportScore |
protected AtomicBoolean |
stopFit |
protected StatsStorageRouter |
storageRouter |
protected TrainerContext |
trainerContext |
protected Object[] |
trainerContextArgs |
protected boolean |
wasAveraged |
protected AtomicInteger |
workerCounter |
protected int |
workers |
protected org.deeplearning4j.nn.conf.WorkspaceMode |
workspaceMode |
protected Trainer[] |
zoo |
| Modifier | Constructor and Description |
|---|---|
protected |
ParallelWrapper(org.deeplearning4j.nn.api.Model model,
int workers,
int prefetchSize) |
| Modifier and Type | Method and Description |
|---|---|
void |
broadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
This method will propagate gradients across all workers
|
void |
close() |
void |
fit(org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)
This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
|
void |
fit(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source) |
protected void |
init() |
void |
setListeners(Collection<org.deeplearning4j.optimize.api.IterationListener> listeners)
This method allows you to specify IterationListeners for this model.
|
void |
setListeners(org.deeplearning4j.optimize.api.IterationListener... listeners)
This method allows you to specify IterationListeners for this model.
|
void |
setListeners(StatsStorageRouter statsStorage,
Collection<? extends org.deeplearning4j.optimize.api.IterationListener> listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
that implement the
RoutingIterationListener interface) |
void |
setListeners(StatsStorageRouter statsStorage,
org.deeplearning4j.optimize.api.IterationListener... listeners)
Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
that implement the
RoutingIterationListener interface) |
void |
shutdown()
This method causes all threads used for parallel training to stop
|
void |
stopFit()
Will stop a fit operation from continuing to iterate.
|
protected org.deeplearning4j.nn.api.Model model
protected int workers
protected int prefetchSize
protected int averagingFrequency
protected Trainer[] zoo
protected TrainerContext trainerContext
protected AtomicLong iterationsCounter
protected boolean reportScore
protected boolean averageUpdaters
protected boolean legacyAveraging
protected boolean wasAveraged
protected AtomicBoolean stopFit
protected List<org.deeplearning4j.optimize.api.IterationListener> listeners
protected StatsStorageRouter storageRouter
protected boolean isMQ
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
protected Object[] trainerContextArgs
protected boolean debug
protected ThreadPoolExecutor executorService
protected final AtomicInteger workerCounter
protected org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator gradientsAccumulator
protected ParallelWrapper(org.deeplearning4j.nn.api.Model model,
int workers,
int prefetchSize)
protected void init()
public void close()
throws Exception
close in interface AutoCloseableExceptionpublic void shutdown()
public void stopFit()
public void fit(@NonNull
org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator source)
source - public void setListeners(@NonNull
Collection<org.deeplearning4j.optimize.api.IterationListener> listeners)
setListeners(StatsStorageRouter, Collection)listeners - Listeners to setpublic void setListeners(@NonNull
org.deeplearning4j.optimize.api.IterationListener... listeners)
setListeners(StatsStorageRouter, Collection)listeners - Listeners to setpublic void setListeners(StatsStorageRouter statsStorage, org.deeplearning4j.optimize.api.IterationListener... listeners)
RoutingIterationListener interface)statsStorage - Stats storage router to place the results intolisteners - Listeners to setpublic void setListeners(StatsStorageRouter statsStorage, Collection<? extends org.deeplearning4j.optimize.api.IterationListener> listeners)
RoutingIterationListener interface)statsStorage - Stats storage router to place the results intolisteners - Listeners to setpublic void broadcastGradients(org.deeplearning4j.optimize.listeners.SharedGradient gradients)
gradients - public void fit(@NonNull
org.nd4j.linalg.dataset.api.iterator.DataSetIterator source)
source - Copyright © 2017. All rights reserved.