From 49981070165d4a5640bcf713911c8cf2d7538219 Mon Sep 17 00:00:00 2001 From: Shane Li Date: Thu, 18 Oct 2018 19:19:33 +0800 Subject: [PATCH] Enable distribution validation run. --- src/caffe/solver.cpp | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index a9892716c..a2216a1f2 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -347,8 +347,8 @@ void Solver::Step(int iters) { ApplyUpdate(); PERFORMANCE_MEASUREMENT_END_STATIC("weights_update"); }else{ - //While using multinodes mode, force to print current lr to logs - PrintLearningRate(); + //While using multinodes mode, force to print current lr to logs + PrintLearningRate(); } iter_time += iter_timer.MilliSeconds(); @@ -358,7 +358,7 @@ void Solver::Step(int iters) { net_->ResetTimers(); #ifdef USE_MLSL - if (mn::is_root() == true) + if (mn::is_root()) #endif LOG(INFO) << "iter " << iter_ << ", forward_backward_update_time: " << iter_time << " ms"; @@ -479,7 +479,16 @@ void Solver::TestClassification(const int test_net_id) { vector test_score_output_id; const shared_ptr >& test_net = test_nets_[test_net_id]; Dtype loss = 0; - for (int i = 0; i < param_.test_iter(test_net_id); ++i) { + int global_test_iter = param_.test_iter(test_net_id); +#ifdef USE_MLSL + int local_test_iter = global_test_iter / mn::get_nodes_count(); + int left_test_iter = global_test_iter % mn::get_nodes_count(); + if (mn::get_node_id() < left_test_iter) + local_test_iter += 1; +#else + int local_test_iter = global_test_iter; +#endif + for (int i = 0; i < local_test_iter; ++i) { SolverAction::Enum request = GetRequestedAction(); // Check to see if stoppage of testing/training has been requested. while (request != SolverAction::NONE) { @@ -526,35 +535,26 @@ void Solver::TestClassification(const int test_net_id) { if (param_.test_compute_loss()) { #ifdef USE_MLSL mn::allreduce(&loss, 1); - loss /= (param_.test_iter(test_net_id) * mn::get_group_size()); - if (mn::is_root() == true) - LOG(INFO) << "Test loss: " << loss; -#else /* !USE_MLSL */ - loss /= param_.test_iter(test_net_id); + if (mn::is_root()) { +#endif /* USE_MLSL */ + loss /= global_test_iter; LOG(INFO) << "Test loss: " << loss; +#ifdef USE_MLSL + } #endif /* USE_MLSL */ } #ifdef USE_MLSL mn::allreduce(test_score.data(), test_score.size()); - if (mn::is_root() == true) + if (mn::is_root()) #endif /* USE_MLSL */ for (int i = 0; i < test_score.size(); ++i) { const int output_blob_index = test_net->output_blob_indices()[test_score_output_id[i]]; const string& output_name = test_net->blob_names()[output_blob_index]; - const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index] -#ifdef USE_MLSL - * mn::get_distrib()->get_data_parts() -#endif - ; + const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index]; ostringstream loss_msg_stream; -#ifdef USE_MLSL - const Dtype mean_score = - test_score[i] / (param_.test_iter(test_net_id) * mn::get_group_size()); -#else /* !USE_MLSL */ - const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id); -#endif /* USE_MLSL */ + const Dtype mean_score = test_score[i] / global_test_iter; if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * mean_score << " loss)";