Skip to content

Commit ca81667

Browse files
cypofshelhamer
authored andcommitted
Refactor solvers regularization and logging code
1 parent c255709 commit ca81667

File tree

2 files changed

+214
-298
lines changed

2 files changed

+214
-298
lines changed

include/caffe/solver.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ class Solver {
3939
int iter() { return iter_; }
4040

4141
protected:
42-
// Get the update value for the current iteration.
43-
virtual void ComputeUpdateValue() = 0;
42+
// Get and apply the update value for the current iteration.
43+
virtual void MakeUpdate() = 0;
4444
// The Solver::Snapshot function implements the basic snapshotting utility
4545
// that stores the learned net. You should implement the SnapshotSolverState()
4646
// function that produces a SolverState protocol buffer that needs to be
@@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> {
8080
protected:
8181
void PreSolve();
8282
Dtype GetLearningRate();
83-
virtual void ComputeUpdateValue();
83+
virtual void MakeUpdate();
84+
virtual void Regularize(int param_id);
85+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
8486
virtual void ClipGradients();
8587
virtual void SnapshotSolverState(SolverState * state);
8688
virtual void RestoreSolverState(const SolverState& state);
@@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
102104
: SGDSolver<Dtype>(param_file) {}
103105

104106
protected:
105-
virtual void ComputeUpdateValue();
107+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
106108

107109
DISABLE_COPY_AND_ASSIGN(NesterovSolver);
108110
};
@@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
116118
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
117119

118120
protected:
119-
virtual void ComputeUpdateValue();
121+
virtual void ComputeUpdateValue(int param_id, Dtype rate);
120122
void constructor_sanity_check() {
121123
CHECK_EQ(0, this->param_.momentum())
122124
<< "Momentum cannot be used with AdaGrad.";

0 commit comments

Comments
 (0)