Skip to content

Commit 2f55f42

Browse files
Rok Mandeljcrokm
authored andcommitted
matcaffe: allow destruction of individual networks and solvers
1 parent 25422de commit 2f55f42

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

matlab/+caffe/Net.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
self.layer_names = self.attributes.layer_names;
6969
self.blob_names = self.attributes.blob_names;
7070
end
71+
function delete (self)
72+
caffe_('delete_net', self.hNet_self);
73+
end
7174
function layer = layers(self, layer_name)
7275
CHECK(ischar(layer_name), 'layer_name must be a string');
7376
layer = self.layer_vec(self.name2layer_index(layer_name));

matlab/+caffe/Solver.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
self.test_nets(n) = caffe.Net(self.attributes.hNet_test_nets(n));
3737
end
3838
end
39+
function delete (self)
40+
caffe_('delete_solver', self.hSolver_self);
41+
end
3942
function iter = iter(self)
4043
iter = caffe_('solver_get_iter', self.hSolver_self);
4144
end

matlab/+caffe/private/caffe_.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,17 @@ static void get_solver(MEX_ARGS) {
197197
mxFree(solver_file);
198198
}
199199

200+
// Usage: caffe_('delete_solver', hSolver)
201+
static void delete_solver(MEX_ARGS) {
202+
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
203+
"Usage: caffe_('delete_solver', hSolver)");
204+
Solver<float>* solver = handle_to_ptr<Solver<float> >(prhs[0]);
205+
solvers_.erase(std::remove_if(solvers_.begin(), solvers_.end(),
206+
[solver] (const shared_ptr< Solver<float> > &solverPtr) {
207+
return solverPtr.get() == solver;
208+
}), solvers_.end());
209+
}
210+
200211
// Usage: caffe_('solver_get_attr', hSolver)
201212
static void solver_get_attr(MEX_ARGS) {
202213
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -271,6 +282,17 @@ static void get_net(MEX_ARGS) {
271282
mxFree(phase_name);
272283
}
273284

285+
// Usage: caffe_('delete_solver', hSolver)
286+
static void delete_net(MEX_ARGS) {
287+
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
288+
"Usage: caffe_('delete_solver', hNet)");
289+
Net<float>* net = handle_to_ptr<Net<float> >(prhs[0]);
290+
nets_.erase(std::remove_if(nets_.begin(), nets_.end(),
291+
[net] (const shared_ptr< Net<float> > &netPtr) {
292+
return netPtr.get() == net;
293+
}), nets_.end());
294+
}
295+
274296
// Usage: caffe_('net_get_attr', hNet)
275297
static void net_get_attr(MEX_ARGS) {
276298
mxCHECK(nrhs == 1 && mxIsStruct(prhs[0]),
@@ -522,12 +544,14 @@ struct handler_registry {
522544
static handler_registry handlers[] = {
523545
// Public API functions
524546
{ "get_solver", get_solver },
547+
{ "delete_solver", delete_solver },
525548
{ "solver_get_attr", solver_get_attr },
526549
{ "solver_get_iter", solver_get_iter },
527550
{ "solver_restore", solver_restore },
528551
{ "solver_solve", solver_solve },
529552
{ "solver_step", solver_step },
530553
{ "get_net", get_net },
554+
{ "delete_net", delete_net },
531555
{ "net_get_attr", net_get_attr },
532556
{ "net_forward", net_forward },
533557
{ "net_backward", net_backward },

0 commit comments

Comments
 (0)