@@ -39,8 +39,8 @@ class Solver {
39
39
int iter () { return iter_; }
40
40
41
41
protected:
42
- // Get the update value for the current iteration.
43
- virtual void ComputeUpdateValue () = 0;
42
+ // Get and apply the update value for the current iteration.
43
+ virtual void MakeUpdate () = 0;
44
44
// The Solver::Snapshot function implements the basic snapshotting utility
45
45
// that stores the learned net. You should implement the SnapshotSolverState()
46
46
// function that produces a SolverState protocol buffer that needs to be
@@ -80,7 +80,9 @@ class SGDSolver : public Solver<Dtype> {
80
80
protected:
81
81
void PreSolve ();
82
82
Dtype GetLearningRate ();
83
- virtual void ComputeUpdateValue ();
83
+ virtual void MakeUpdate ();
84
+ virtual void Regularize (int param_id);
85
+ virtual void ComputeUpdateValue (int param_id, Dtype rate);
84
86
virtual void ClipGradients ();
85
87
virtual void SnapshotSolverState (SolverState * state);
86
88
virtual void RestoreSolverState (const SolverState& state);
@@ -102,7 +104,7 @@ class NesterovSolver : public SGDSolver<Dtype> {
102
104
: SGDSolver<Dtype>(param_file) {}
103
105
104
106
protected:
105
- virtual void ComputeUpdateValue ();
107
+ virtual void ComputeUpdateValue (int param_id, Dtype rate );
106
108
107
109
DISABLE_COPY_AND_ASSIGN (NesterovSolver);
108
110
};
@@ -116,7 +118,7 @@ class AdaGradSolver : public SGDSolver<Dtype> {
116
118
: SGDSolver<Dtype>(param_file) { constructor_sanity_check (); }
117
119
118
120
protected:
119
- virtual void ComputeUpdateValue ();
121
+ virtual void ComputeUpdateValue (int param_id, Dtype rate );
120
122
void constructor_sanity_check () {
121
123
CHECK_EQ (0 , this ->param_ .momentum ())
122
124
<< " Momentum cannot be used with AdaGrad." ;
0 commit comments