public class CheckpointListener extends BaseTrainingListener implements Serializable
loadCheckpointMLN(Checkpoint) and loadCheckpointCG(Checkpoint).
Model files can be obtained using getFileForCheckpoint(Checkpoint)lastCheckpoint() and availableCheckpoints()
CheckpointListener l = new CheckpointListener.Builder("/save/directory")
.keepAll() //Don't delete any models
.saveEveryNEpochs(2)
.build()
CheckpointListener l = new CheckpointListener.Builder(new File("/save/directory"))
.keepLast(3)
.saveEveryNIterations(1000)
.build();
CheckpointListener l = new CheckpointListener.Builder(new File("/save/directory"))
.keepLastAndEvery(3, 4)
.saveEvery(15, TimeUnit.MINUTES)
.build();
.saveEveryEpoch().saveEvery(15, TimeUnit.MINUTES).saveEveryEpoch().saveEvery(15, TimeUnit.MINUTES, true)| Modifier and Type | Class and Description |
|---|---|
static class |
CheckpointListener.Builder |
| Modifier and Type | Method and Description |
|---|---|
List<Checkpoint> |
availableCheckpoints()
List all available checkpoints.
|
protected static int |
getEpoch(Model model) |
File |
getFileForCheckpoint(Checkpoint checkpoint)
Get the model file for the given checkpoint.
|
File |
getFileForCheckpoint(int checkpointNum)
Get the model file for the given checkpoint number.
|
protected static int |
getIter(Model model) |
protected static String |
getModelType(Model model) |
void |
iterationDone(Model model,
int iteration,
int epoch)
Event listener for each iteration.
|
Checkpoint |
lastCheckpoint()
Return the most recent checkpoint, if one exists - otherwise returns null
|
ComputationGraph |
loadCheckpointCG(Checkpoint checkpoint)
Load a ComputationGraph for the given checkpoint
|
ComputationGraph |
loadCheckpointCG(int checkpointNum)
Load a ComputationGraph for the given checkpoint
|
MultiLayerNetwork |
loadCheckpointMLN(Checkpoint checkpoint)
Load a MultiLayerNetwork for the given checkpoint
|
MultiLayerNetwork |
loadCheckpointMLN(int checkpointNum)
Load a MultiLayerNetwork for the given checkpoint number
|
void |
onEpochEnd(Model model)
Called once at the end of each epoch, when using methods such as
MultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator) |
onBackwardPass, onEpochStart, onForwardPass, onForwardPass, onGradientCalculationpublic void onEpochEnd(Model model)
TrainingListenerMultiLayerNetwork.fit(DataSetIterator),
ComputationGraph.fit(DataSetIterator) or ComputationGraph.fit(MultiDataSetIterator)onEpochEnd in interface TrainingListeneronEpochEnd in class BaseTrainingListenerpublic void iterationDone(Model model, int iteration, int epoch)
TrainingListeneriterationDone in interface TrainingListeneriterationDone in class BaseTrainingListenermodel - the model iteratingiteration - the iterationprotected static int getIter(Model model)
protected static int getEpoch(Model model)
public List<Checkpoint> availableCheckpoints()
public Checkpoint lastCheckpoint()
public File getFileForCheckpoint(Checkpoint checkpoint)
checkpoint - Checkpoint to get the model file forpublic File getFileForCheckpoint(int checkpointNum)
checkpointNum - Checkpoint number to get the model file forpublic MultiLayerNetwork loadCheckpointMLN(Checkpoint checkpoint)
checkpoint - Checkpoint model to loadpublic MultiLayerNetwork loadCheckpointMLN(int checkpointNum)
checkpointNum - Checkpoint model to loadpublic ComputationGraph loadCheckpointCG(Checkpoint checkpoint)
checkpoint - Checkpoint model to loadpublic ComputationGraph loadCheckpointCG(int checkpointNum)
checkpointNum - Checkpoint model number to loadCopyright © 2018. All rights reserved.