public class DefaultTrainer extends Thread implements Trainer
| Modifier and Type | Class and Description |
|---|---|
static class |
DefaultTrainer.DefaultTrainerBuilder |
Thread.State, Thread.UncaughtExceptionHandler| Modifier and Type | Field and Description |
|---|---|
protected int |
averagingFrequency |
protected AtomicBoolean |
isStopped |
protected AtomicLong |
lastEtlTime |
protected org.nd4j.linalg.dataset.api.DataSet |
nullDataSet |
protected AtomicBoolean |
nullMode |
protected boolean |
onRootModel |
protected org.deeplearning4j.nn.api.Model |
originalModel |
protected ParallelWrapper |
parallelWrapper |
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> |
queue |
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> |
queueMDS |
protected org.deeplearning4j.nn.api.Model |
replicatedModel |
protected AtomicInteger |
running |
protected AtomicBoolean |
shouldStop |
protected AtomicBoolean |
shouldUpdate |
protected int |
threadId |
protected Exception |
thrownException |
protected boolean |
useMDS |
protected String |
uuid |
protected org.deeplearning4j.nn.conf.WorkspaceMode |
workspaceMode |
MAX_PRIORITY, MIN_PRIORITY, NORM_PRIORITY| Constructor and Description |
|---|
DefaultTrainer() |
| Modifier and Type | Method and Description |
|---|---|
boolean |
averagingRequired()
This method returns TRUE if this Trainer implementation assumes periodic aver
|
protected static org.deeplearning4j.optimize.api.IterationListener |
cloneListener(org.deeplearning4j.optimize.api.IterationListener original) |
protected void |
configureListeners(String workerUUID,
Collection<org.deeplearning4j.optimize.api.IterationListener> oldListeners,
Collection<org.deeplearning4j.optimize.api.IterationListener> replicatedListeners) |
void |
feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet,
long etlTime)
Train on a
DataSet |
void |
feedMultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
long etlTime)
Train on a
MultiDataSet |
protected void |
fit(org.nd4j.linalg.dataset.api.DataSet dataSet) |
protected void |
fit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet) |
org.deeplearning4j.nn.api.Model |
getModel()
THe current model for the trainer
|
boolean |
isRunning() |
protected void |
postInit()
This method does post-initialization configuration of Model.
|
void |
run() |
protected void |
setupIfNeccessary() |
void |
shutdown()
Shutdown this worker
|
void |
updateModel(org.deeplearning4j.nn.api.Model model)
Update the current
Model
for the worker |
void |
waitTillRunning()
Block the main thread
till the trainer is up and running.
|
activeCount, checkAccess, clone, countStackFrames, currentThread, destroy, dumpStack, enumerate, getAllStackTraces, getContextClassLoader, getDefaultUncaughtExceptionHandler, getId, getName, getPriority, getStackTrace, getState, getThreadGroup, getUncaughtExceptionHandler, holdsLock, interrupt, interrupted, isAlive, isDaemon, isInterrupted, join, join, join, resume, setContextClassLoader, setDaemon, setDefaultUncaughtExceptionHandler, setName, setPriority, setUncaughtExceptionHandler, sleep, sleep, start, stop, stop, suspend, toString, yieldequals, finalize, getClass, hashCode, notify, notifyAll, wait, wait, waitsetUncaughtExceptionHandler, startprotected org.deeplearning4j.nn.api.Model replicatedModel
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.DataSet> queue
protected LinkedBlockingQueue<org.nd4j.linalg.dataset.api.MultiDataSet> queueMDS
protected AtomicInteger running
protected AtomicBoolean shouldUpdate
protected AtomicBoolean shouldStop
protected Exception thrownException
protected volatile boolean useMDS
protected final String uuid
protected boolean onRootModel
protected volatile AtomicLong lastEtlTime
protected AtomicBoolean nullMode
protected org.nd4j.linalg.dataset.api.DataSet nullDataSet
protected AtomicBoolean isStopped
protected ParallelWrapper parallelWrapper
protected org.deeplearning4j.nn.conf.WorkspaceMode workspaceMode
protected int averagingFrequency
protected int threadId
protected org.deeplearning4j.nn.api.Model originalModel
public void feedMultiDataSet(@NonNull
org.nd4j.linalg.dataset.api.MultiDataSet dataSet,
long etlTime)
TrainerMultiDataSetfeedMultiDataSet in interface TrainerdataSet - the data set to train onpublic void feedDataSet(org.nd4j.linalg.dataset.api.DataSet dataSet,
long etlTime)
TrainerDataSetfeedDataSet in interface TrainerdataSet - the data set to train onpublic org.deeplearning4j.nn.api.Model getModel()
Trainerpublic void updateModel(@NonNull
org.deeplearning4j.nn.api.Model model)
TrainerModel
for the workerupdateModel in interface Trainermodel - the new model for this workerprotected void setupIfNeccessary()
public void shutdown()
Trainerprotected void fit(org.nd4j.linalg.dataset.api.DataSet dataSet)
protected void fit(org.nd4j.linalg.dataset.api.MultiDataSet dataSet)
protected void postInit()
public void waitTillRunning()
TrainerwaitTillRunning in interface Trainerpublic boolean averagingRequired()
TraineraveragingRequired in interface Trainerprotected static org.deeplearning4j.optimize.api.IterationListener cloneListener(org.deeplearning4j.optimize.api.IterationListener original)
protected void configureListeners(String workerUUID, Collection<org.deeplearning4j.optimize.api.IterationListener> oldListeners, Collection<org.deeplearning4j.optimize.api.IterationListener> replicatedListeners)
Copyright © 2017. All rights reserved.