代码阅读
http://alanse7en.github.io/caffedai-ma-jie-xi-4/
三.
從一個比較宏觀的層面上去了解caffe怎么去完成一些初始化的工作和使用Solver的接口函數,本文將主要分為四部分的內容:
- Google Flags的使用
- Register Brew Function的宏的定義和使用
- train()函數的具體實現
- SolverParameter的具體解析過程
Google Flags的使用
從Caffe官網中可以看到,caffe的Command Line Interfaces一共提供了四個功能:train/test/time/device_query,而Interfaces的輸入除了這四種功能還可以輸入諸如-solver/-weights/-snapshot/-gpu等參數。這些參數的解析是通過Google Flags這個工具來完成的。
在caffe.cpp(位于/CAFFE_ROOT/tools/caffe.cpp)的開頭,我們可以看到很多這樣的宏:
DEFINE_string(gpu, "","Optional; run in GPU mode on given device IDs separated by ','.""Use '-gpu all' to run on all available GPUs. The effective training ""batch size is multiplied by the number of devices.");這個宏的使用方式為DEFINE_xxx(name, default_value, instruction);,這樣就定義了一個xxx類型名為FLAGS_name的標志,如果用戶沒有在Command Line中提供其值,那么會默認為default_value,instruction是這個標志含義的說明。因此,上面的代碼定義了一個string類型的名為FLAGS_gpu的標志,如果在Command Line中用戶沒有提供值,那么會默認為空字符串,根據說明可以得知這個標志是提供給用戶來指定caffe將使用的GPU的。其余的定義也是類似的理解方式就不一一列舉了。
解析這些標志的代碼在caffe.cpp中的main()中調用了/CAFFE_ROOT/src/common.cpp中的GlobalInit(&argc, &argv)函數:
1 void GlobalInit(int* pargc, char*** pargv) { 2 // Google flags. 3 ::gflags::ParseCommandLineFlags(pargc, pargv, true); 4 // Google logging. 5 ::google::InitGoogleLogging(*(pargv)[0]); 6 // Provide a backtrace on segfault. 7 ::google::InstallFailureSignalHandler(); 8 }第三行的函數就是Google Flags用來解析輸入的參數的,前兩個參數分別是指向main()的argc和argv的指針,第三個參數為true,表示在解析完所有的標志之后將這些標志從argv中清除,因此在解析完成之后,argc的值為2,argv[0]為main,argv[1]為train/test/time/device_query中的一個。
Register Brew Function的宏的定義和使用
Caffe在Command Line Interfaces中一共提供了4種功能:train/test/time/device_query,分別對應著四個函數,這四個函數的調用是通過一個叫做g_brew_map的全局變量來完成的:
1 // A simple registry for caffe commands. 2 typedef int (*BrewFunction)(); 3 typedef std::map<caffe::string, BrewFunction> BrewMap; 4 BrewMap g_brew_map;g_brew_map是一個key為string類型,value為BrewFunction類型的一個map類型的全局變量,BrewFunction是一個函數指針類型,指向的是參數為空,返回值為int的函數,也就是train/test/time/device_query這四個函數的類型。在train等四個函數實現的后面都緊跟著這樣一句宏的調用:RegisterBrewFunction(train);
其中使用的宏的具體定義為:
1 \#define RegisterBrewFunction(func) \ 2 namespace { \ 3 class __Registerer_##func { \ 4 public: /* NOLINT */ \ 5 __Registerer_##func() { \ 6 g_brew_map[#func] = &func; \ 7 } \ 8 }; \ 9 __Registerer_##func g_registerer_##func; \ 10 }以train函數為例子,RegisterBrewFunction(train)這個宏的作用是定義了一個名為__Register_train的類,在定義完這個類之后,定義了一個這個類的變量,會調用構造函數,這個類的構造函數在前面提到的g_brew_map中添加了key為”train”,value為指向train函數的指針的一個元素。
然后函數的調用在main()函數中是通過下面的這段代碼實現的,在完成初始化(GlobalInit)之后,有這樣一句代碼:
1 // main()中的調用代碼 2 return GetBrewFunction(caffe::string(argv[1]))(); 3 // BrewFunction的具體實現 4 static BrewFunction GetBrewFunction(const caffe::string& name) { 5 if (g_brew_map.count(name)) { 6 return g_brew_map[name]; 7 } else { 8 LOG(ERROR) << "Available caffe actions:"; 9 for (BrewMap::iterator it = g_brew_map.begin(); 10 it != g_brew_map.end(); ++it) { 11 LOG(ERROR) << "\t" << it->first; 12 } 13 LOG(FATAL) << "Unknown action: " << name; 14 return NULL; // not reachable, just to suppress old compiler warnings. 15 } 16 }還是以train函數為例子,如果我們在Command Line中輸入了caffe train <args>,經過Google Flags的解析argv[1]=train,因此,在GetBrewFunction中會通過g_brew_map返回一個指向train函數的函數指針,最后在main函數中就通過這個返回的函數指針完成了對train函數的調用。
總結一下:RegisterBrewFunction這個宏在每一個實現主要功能的函數之后將這個函數的名字和其對應的函數指針添加到了g_brew_map中,然后在main函數中,通過GetBrewFunction得到了我們需要調用的那個函數的函數指針,并完成了調用。
train()函數的具體實現
接下來我們仔細地分析一下在train()的具體實現。
首先是這樣的一段代碼:
1 CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; 2 CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size()) 3 << "Give a snapshot to resume training or weights to finetune " 4 "but not both.";這段代碼的第一行使用了glog的CHECK_GT宏(含義為check greater than),檢查FLAGS_solver的size是否大于0,如果小于或等于0則輸出提示:”Need a solver definition to train”。FLAGS_solver是最開始通過DEFINE_string定義的標志,如果我們希望訓練一個模型,那么自然應該應該提供對應的solver定義文件的路徑,這一句話正是在確保我們提供了這樣的路徑。這樣的檢查語句在后續的代碼中會經常出現,將不再一一詳細解釋,如果有不清楚含義的glog宏可以去看看文檔。 與第一行代碼類似,第二行代碼是確保用戶沒有同時提供snapshot和weights參數,這兩個參數都是繼續之前的訓練或者進行fine-tuning的,如果同時指明了這兩個標志,則不知道到底應該從哪個路徑的文件去讀入模型的相關參數更為合適。
然后出現了SolverParameter solver_param的聲明和解析的代碼:
1 caffe::SolverParameter solver_param; 2 caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);SolverParameter是通過Google Protocol Buffer自動生成的一個類,如果有不清楚的可以參考上一篇文章。而具體的解析函數將在下一部分具體解釋。
接下來這一部分的代碼是根據用戶的設置來選擇caffe工作的模式(GPU或CPU)以及使用哪些GPU(caffe已經支持了多GPU同時工作!具體參考:官網tutorial的Parallelism部分):
1 // If the gpus flag is not provided, allow the mode and device to be set 2 // in the solver prototxt. 3 if (FLAGS_gpu.size() == 0 4 && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { 5 if (solver_param.has_device_id()) { 6 FLAGS_gpu = "" + 7 boost::lexical_cast<string>(solver_param.device_id()); 8 } else { // Set default GPU if unspecified 9 FLAGS_gpu = "" + boost::lexical_cast<string>(0); 10 } 11 } 12 vector<int> gpus; 13 get_gpus(&gpus); 14 if (gpus.size() == 0) { 15 LOG(INFO) << "Use CPU."; 16 Caffe::set_mode(Caffe::CPU); 17 } else { 18 ostringstream s; 19 for (int i = 0; i < gpus.size(); ++i) { 20 s << (i ? ", " : "") << gpus[i]; 21 } 22 LOG(INFO) << "Using GPUs " << s.str(); 23 24 solver_param.set_device_id(gpus[0]); 25 Caffe::SetDevice(gpus[0]); 26 Caffe::set_mode(Caffe::GPU); 27 Caffe::set_solver_count(gpus.size()); 28 }首先是判斷用戶在Command Line中是否輸入了gpu相關的參數,如果沒有(FLAGS_gpu.size()==0)但是用戶在solver的prototxt定義中提供了相關的參數,那就把相關的參數放到FLAGS_gpu中,如果用戶僅僅是選擇了在solver的prototxt定義中選擇了GPU模式,但是沒有指明具體的gpu_id,那么就默認設置為0。
接下來的代碼則通過一個get_gpus的函數,將存放在FLAGS_gpu中的string轉成了一個vector,并完成了具體的設置。
下面的代碼聲明并通過SolverRegistry初始化了一個指向Solver類型的shared_ptr。并通過這個shared_ptr指明了在遇到系統信號(用戶按了ctrl+c或者關閉了當前的terminal)時的處理方式。
1 caffe::SignalHandler signal_handler( 2 GetRequestedAction(FLAGS_sigint_effect), 3 GetRequestedAction(FLAGS_sighup_effect)); 4 5 shared_ptr<caffe::Solver<float> > 6 solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); 7 8 solver->SetActionFunction(signal_handler.GetActionFunction());接下來判斷了一下用戶是否定義了snapshot或者weights這兩個參數中的一個,如果定義了則需要通過Solver提供的接口從snapshot或者weights文件中去讀取已經訓練好的網絡的參數:
1 if (FLAGS_snapshot.size()) { 2 LOG(INFO) << "Resuming from " << FLAGS_snapshot; 3 solver->Restore(FLAGS_snapshot.c_str()); 4 } else if (FLAGS_weights.size()) { 5 CopyLayers(solver.get(), FLAGS_weights); 6 }最后,如果用戶設置了要使用多個gpu,那么要聲明一個P2PSync類型的對象,并通過這個對象來完成多gpu的計算,這一部分的代碼,這一系列的文章會暫時先不涉及。而如果是只使用單個gpu,那么就通過Solver的Solve()開始具體的優化過程。在優化結束之后,函數將0值返回給main函數,整個train過程到這里也就結束了:
1 if (gpus.size() > 1) { 2 caffe::P2PSync<float> sync(solver, NULL, solver->param()); 3 sync.run(gpus); 4 } else { 5 LOG(INFO) << "Starting Optimization"; 6 solver->Solve(); 7 } 8 LOG(INFO) << "Optimization Done."; 9 return 0;上面的代碼中涉及了很多Solver這個類的接口,這些內容都將在下一篇文章中進行具體的分析。
SolverParameter的具體解析過程
前面提到了SolverParameter是通過ReadSolverParamsFromTextFileOrDie來完成解析的,這個函數的實現在/CAFFE_ROOT/src/caffe/util/upgrade_proto.cpp里,我們來看一下具體的過程:
1 // Read parameters from a file into a SolverParameter proto message. 2 void ReadSolverParamsFromTextFileOrDie(const string& param_file, 3 SolverParameter* param) { 4 CHECK(ReadProtoFromTextFile(param_file, param)) 5 << "Failed to parse SolverParameter file: " << param_file; 6 UpgradeSolverAsNeeded(param_file, param); 7 }這里調用了先后調用了兩個函數,首先是ReadProtoFromTextFile,這個函數的作用是從param_file這個路徑去讀取solver的定義,并將文件中的內容解析存到param這個指針指向的對象,具體的實現在/CAFFE_ROOT/src/caffe/util/io.cpp的開始:
1 bool ReadProtoFromTextFile(const char* filename, Message* proto) { 2 int fd = open(filename, O_RDONLY); 3 CHECK_NE(fd, -1) << "File not found: " << filename; 4 FileInputStream* input = new FileInputStream(fd); 5 bool success = google::protobuf::TextFormat::Parse(input, proto); 6 delete input; 7 close(fd); 8 return success; 9 }這段代碼首先是打開了文件,并且讀取到了一個FileInputStream的指針中,然后通過protobuf的TextFormat::Parse函數完成了解析。
然后UpgradeSolverAsNeeded完成了新老版本caffe.proto的兼容處理:
1 // Check for deprecations and upgrade the SolverParameter as needed. 2 bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) { 3 bool success = true; 4 // Try to upgrade old style solver_type enum fields into new string type 5 if (SolverNeedsTypeUpgrade(*param)) { 6 LOG(INFO) << "Attempting to upgrade input file specified using deprecated " 7 << "'solver_type' field (enum)': " << param_file; 8 if (!UpgradeSolverType(param)) { 9 success = false; 10 LOG(ERROR) << "Warning: had one or more problems upgrading " 11 << "SolverType (see above)."; 12 } else { 13 LOG(INFO) << "Successfully upgraded file specified using deprecated " 14 << "'solver_type' field (enum) to 'type' field (string)."; 15 LOG(WARNING) << "Note that future Caffe releases will only support " 16 << "'type' field (string) for a solver's type."; 17 } 18 } 19 return success; 20 }主要的問題就是在舊版本中Solver的type是enum類型,而新版本的變為了string。
總結
本文從主要分析了caffe.cpp中實現各種具體功能的函數的調用的機制,以及在Command Line中用戶輸入的各種參數是怎么解析的,以及最常用的train函數的具體代碼。通過這些分析,我們對Solver類型的接口有了一個初步的認識和了解,在下一篇文章中,我們將去具體地分析Solver的實現。
四.
在上文對Command Line Interfaces進行了簡單的介紹之后,本文將對caffe的Solver相關的代碼進行分析。
本文將主要分為四部分的內容:
- Solver的初始化(Register宏和構造函數)
- SIGINT和SIGHUP信號的處理
- Solver::Solve()具體實現
- SGDSolver::ApplyUpdate具體實現
Solver的初始化(Register宏和構造函數)
shared_ptr<caffe::Solver<float> >solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));caffe.cpp中的train函數中通過上面的代碼定義了一個指向Solver<float>的shared_ptr。其中主要是通過調用SolverRegistry這個類的靜態成員函數CreateSolver得到一個指向Solver的指針來構造shared_ptr類型的solver。而且由于C++多態的特性,盡管solver是一個指向基類Solver類型的指針,通過solver這個智能指針來調用各個成員函數會調用到各個子類(SGDSolver等)的函數。具體的過程如下面的流程圖所示:
Create solver下面我們就來具體看一下SolverRegistry這個類的代碼,以便理解是如何通過同一個函數得到不同類型的Solver:
1 class SolverRegistry { 2 public: 3 typedef Solver<Dtype>* (*Creator)(const SolverParameter&); 4 typedef std::map<string, Creator> CreatorRegistry; 5 static CreatorRegistry& Registry() { 6 static CreatorRegistry* g_registry_ = new CreatorRegistry(); 7 return *g_registry_; 8 } 9 static void AddCreator(const string& type, Creator creator) { 10 CreatorRegistry& registry = Registry(); 11 CHECK_EQ(registry.count(type), 0) 12 << "Solver type " << type << " already registered."; 13 registry[type] = creator; 14 } 15 static Solver<Dtype>* CreateSolver(const SolverParameter& param) { 16 const string& type = param.type(); 17 CreatorRegistry& registry = Registry(); 18 CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type 19 << " (known types: " << SolverTypeListString() << ")"; 20 return registry[type](param); 21 } 22 static vector<string> SolverTypeList() { 23 CreatorRegistry& registry = Registry(); 24 vector<string> solver_types; 25 for (typename CreatorRegistry::iterator iter = registry.begin(); 26 iter != registry.end(); ++iter) { 27 solver_types.push_back(iter->first); 28 } 29 return solver_types; 30 } 31 private: 32 SolverRegistry() {} 33 static string SolverTypeListString() { 34 vector<string> solver_types = SolverTypeList(); 35 string solver_types_str; 36 for (vector<string>::iterator iter = solver_types.begin(); 37 iter != solver_types.end(); ++iter) { 38 if (iter != solver_types.begin()) { 39 solver_types_str += ", "; 40 } 41 solver_types_str += *iter; 42 } 43 return solver_types_str; 44 } 45 };首先需要注意的是這個類的構造函數是private的,也就是用我們沒有辦法去構造一個這個類型的變量,這個類也沒有數據成員,所有的成員函數也都是static的,可以直接調用。
我們首先從CreateSolver函數(第15行)入手,這個函數先定義了string類型的變量type,表示Solver的類型(‘SGD’/’Nestrov’等),然后定義了一個key類型為string,value類型為Creator的map:registry,其中Creator是一個函數指針類型,指向的函數的參數為SolverParameter類型,返回類型為Solver<Dtype>*(見第2行和第3行)。如果是一個已經register過的Solver類型,那么registry.count(type)應該為1,然后通過registry這個map返回了我們需要類型的Solver的creator,并調用這個creator函數,將creator返回的Solver<Dtype>*返回。
上面的代碼中,Registry這個函數(第5行)中定義了一個static的變量g_registry,這個變量是一個指向CreatorRegistry這個map類型的指針,然后直接返回,因為這個變量是static的,所以即使多次調用這個函數,也只會定義一個g_registry,而且在其他地方修改這個map里的內容,是存儲在這個map中的。事實上各個Solver的register的過程正是往g_registry指向的那個map里添加以Solver的type為key,對應的Creator函數指針為value的內容。Register的過程如流程圖所示:
Register Solver下面我們具體來看一下Solver的register的過程:
1 template <typename Dtype> 2 class SolverRegisterer { 3 public: 4 SolverRegisterer(const string& type, 5 Solver<Dtype>* (*creator)(const SolverParameter&)) { 6 // LOG(INFO) << "Registering solver type: " << type; 7 SolverRegistry<Dtype>::AddCreator(type, creator); 8 } 9 }; 10 #define REGISTER_SOLVER_CREATOR(type, creator) \ 11 static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \ 12 static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \ 13 14 #define REGISTER_SOLVER_CLASS(type) \ 15 template <typename Dtype> \ 16 Solver<Dtype>* Creator_##type##Solver( \ 17 const SolverParameter& param) \ 18 { \ 19 return new type##Solver<Dtype>(param); \ 20 } \ 21 REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) 22 } 23 // register SGD Solver 24 REGISTER_SOLVER_CLASS(SGD);在sgd_solver.cpp(SGD Solver對應的cpp文件)末尾有上面第24行的代碼,使用了REGISTER_SOLVER_CLASS這個宏,這個宏會定義一個名為Creator_SGDSolver的函數,這個函數即為Creator類型的指針指向的函數,在這個函數中調用了SGDSolver的構造函數,并將構造的這個變量得到的指針返回,這也就是Creator類型函數的作用:構造一個對應類型的Solver對象,將其指針返回。然后在這個宏里又調用了REGISTER_SOLVER_CREATOR這個宏,這里分別定義了SolverRegisterer這個模板類的float和double類型的static變量,這會去調用各自的構造函數,而在SolverRegisterer的構造函數中調用了之前提到的SolverRegistry類的AddCreator函數,這個函數就是將剛才定義的Creator_SGDSolver這個函數的指針存到g_registry指向的map里面。類似地,所有的Solver對應的cpp文件的末尾都調用了這個宏來完成注冊,在所有的Solver都注冊之后,我們就可以通過之前描述的方式,通過g_registry得到對應的Creator函數的指針,并通過調用這個Creator函數來構造對應的Solver。Register和Create對應的流程圖如下所示:
SIGINT和SIGHUP信號的處理
Caffe在train或者test的過程中都有可能會遇到系統信號(用戶按下ctrl+c或者關掉了控制的terminal),我們可以通過對sigint_effect和sighup_effect來設置遇到系統信號的時候希望進行的處理方式:
caffe train –solver=/path/to/solver.prototxt –sigint_effect=EFFECT –sighup_effect=EFFECT
在caffe.cpp中定義了一個GetRequesedAction函數來將設置的string類型的標志轉變為枚舉類型的變量:
1 caffe::SolverAction::Enum GetRequestedAction( 2 const std::string& flag_value) { 3 if (flag_value == "stop") { 4 return caffe::SolverAction::STOP; 5 } 6 if (flag_value == "snapshot") { 7 return caffe::SolverAction::SNAPSHOT; 8 } 9 if (flag_value == "none") { 10 return caffe::SolverAction::NONE; 11 } 12 LOG(FATAL) << "Invalid signal effect \""<< flag_value << "\" was specified"; 13 } 14 // SolverAction::Enum的定義 15 namespace SolverAction { 16 enum Enum { 17 NONE = 0, // Take no special action. 18 STOP = 1, // Stop training. snapshot_after_train controls whether a 19 // snapshot is created. 20 SNAPSHOT = 2 // Take a snapshot, and keep training. 21 }; 22 }其中SolverAction::Enum的定義在solver.hpp中,這是一個定義為枚舉類型的數據類型,只有三個可能的值,分別對應了三種處理系統信號的方式:NONE(忽略信號什么都不做)/STOP(停止訓練)/SNAPSHOT(保存當前的訓練狀態,繼續訓練)。在caffe.cpp中的train函數里Solver設置如何處理系統信號的代碼為:
1 caffe::SignalHandler signal_handler( 2 GetRequestedAction(FLAGS_sigint_effect), 3 GetRequestedAction(FLAGS_sighup_effect)); 4 5 solver->SetActionFunction(signal_handler.GetActionFunction());FLAGS_sigint_effect和FLAGS_sighup_effect是通過gflags定義和解析的兩個Command Line Interface的輸入參數,分別對應遇到sigint和sighup信號的處理方式,如果用戶不設定(大部分時候我自己就沒設定),sigint的默認值為”stop”,sighup的默認值為”snapshot”。GetRequestedAction函數會將string類型的FLAGS_xx轉為SolverAction::Enum類型,并用來定義一個SignalHandler類型的對象signal_handler。我們可以看到這部分代碼都依賴于SignalHandler這個類的接口,我們先來看看這個類都做了些什么:
1 // header file 2 class SignalHandler { 3 public: 4 // Contructor. Specify what action to take when a signal is received. 5 SignalHandler(SolverAction::Enum SIGINT_action, 6 SolverAction::Enum SIGHUP_action); 7 ~SignalHandler(); 8 ActionCallback GetActionFunction(); 9 private: 10 SolverAction::Enum CheckForSignals() const; 11 SolverAction::Enum SIGINT_action_; 12 SolverAction::Enum SIGHUP_action_; 13 }; 14 // source file 15 SignalHandler::SignalHandler(SolverAction::Enum SIGINT_action, 16 SolverAction::Enum SIGHUP_action): 17 SIGINT_action_(SIGINT_action), 18 SIGHUP_action_(SIGHUP_action) { 19 HookupHandler(); 20 } 21 void HookupHandler() { 22 if (already_hooked_up) { 23 LOG(FATAL) << "Tried to hookup signal handlers more than once."; 24 } 25 already_hooked_up = true; 26 struct sigaction sa; 27 sa.sa_handler = &handle_signal; 28 // ... 29 } 30 static volatile sig_atomic_t got_sigint = false; 31 static volatile sig_atomic_t got_sighup = false; 32 void handle_signal(int signal) { 33 switch (signal) { 34 case SIGHUP: 35 got_sighup = true; 36 break; 37 case SIGINT: 38 got_sigint = true; 39 break; 40 } 41 } 42 ActionCallback SignalHandler::GetActionFunction() { 43 return boost::bind(&SignalHandler::CheckForSignals, this); 44 } 45 SolverAction::Enum SignalHandler::CheckForSignals() const { 46 if (GotSIGHUP()) { 47 return SIGHUP_action_; 48 } 49 if (GotSIGINT()) { 50 return SIGINT_action_; 51 } 52 return SolverAction::NONE; 53 } 54 bool GotSIGINT() { 55 bool result = got_sigint; 56 got_sigint = false; 57 return result; 58 } 59 bool GotSIGHUP() { 60 bool result = got_sighup; 61 got_sighup = false; 62 return result; 63 } 64 // ActionCallback的含義 65 typedef boost::function<SolverAction::Enum()> ActionCallback;SignalHandler這個類有兩個數據成員,都是SolverAction::Enum類型的,分別對應sigint和sighup信號,在構造函數中,用解析FLAGS_xx得到的結果分別給兩個成員賦值,然后調用了HookupHandler函數,這個函數的主要作用是定義了一個sigaction類型(應該是系統級別的代碼)的對象sa,然后通過sa.sa_handler = &handle_signal來設置,當有遇到系統信號時,調用handle_signal函數來處理,而我們可以看到這個函數的處理很簡單,就是判斷一下當前的信號是什么類型,如果是sigint就將全局的static變量got_sigint變為true,sighup的處理類似。
在根據用戶設置(或者默認值)的參數定義了signal_handler之后,solver通過SetActionFunction來設置了如何處理系統信號。這個函數的輸入為signal_handler的GetActionFunction的返回值,根據上面的代碼我們可以看到,GetActionFunction會返回signal_handler這個對象的CheckForSignals函數的地址(boost::bind的具體使用請參考boost官方文檔)。而在Solver的SetActionFunction函數中只是簡單的把Solver的一個成員action_request_function_賦值為輸入參數的值,以當前的例子來說就是,solver對象的action_request_function_指向了signal_handler對象的CheckForSignals函數的地址。其中的ActionCallback是一個函數指針類型,指向了參數為空,返回值為SolverAction::Enum類型的函數(boost::function具體用法參考官方文檔)。
總結起來,我們通過定義一個SignalHandler類型的對象,告知系統在遇到系統信號的時候回調handle_signal函數來改變全局變量got_sigint和got_sighup的值,然后通過Solver的接口設置了其遇到系統函數將調用signal_handler的Check函數,這個函數實際上就是去判斷當前是否遇到了系統信號,如果遇到某個類型的信號,就返回我們之前設置的處理方式(SolverAction::Enum類型)。剩余的具體處理再交給Solver的其它函數,后面會具體分析。
Solver::Solve()具體實現
Solve函數實現了具體的網絡的優化過程,下面我們來具體分析一下這部分的代碼,分析見注釋:
1 void Solver<Dtype>::Solve(const char* resume_file) { 2 // 檢查當前是否是root_solver(多GPU模式下,只有root_solver才運行這一部分的代碼) 3 CHECK(Caffe::root_solver()); 4 // 然后輸出learning policy(更新學習率的策略) 5 LOG(INFO) << "Solving " << net_->name(); 6 LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); 7 // requested_early_exit_`一開始被賦值為false,也就是現在沒有要求在優化結束前退出 8 requested_early_exit_ = false; 9 // 判斷`resume_file`這個指針是否NULL,如果不是則需要從resume_file存儲的路徑里讀取之前訓練的狀態 10 if (resume_file) { 11 LOG(INFO) << "Restoring previous solver status from " << resume_file; 12 Restore(resume_file); 13 } 14 // 然后調用了'Step'函數,這個函數執行了實際的逐步的迭代過程 15 Step(param_.max_iter() - iter_); 16 // 迭代結束或者遇到系統信號提前結束后,判斷是否需要在訓練結束之后snapshot 17 // 這個可以在solver.prototxt里設置 18 if (param_.snapshot_after_train() 19 && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { 20 Snapshot(); 21 } 22 // 如果在`Step`函數的迭代過程中遇到了系統信號,且我們的處理方式設置為`STOP`, 23 // 那么`requested_early_exit_`會被修改為true,迭代提前結束,輸出相關信息 24 if (requested_early_exit_) { 25 LOG(INFO) << "Optimization stopped early."; 26 return; 27 } 28 // 判斷是否需要輸出最后的loss 29 if (param_.display() && iter_ % param_.display() == 0) { 30 Dtype loss; 31 net_->ForwardPrefilled(&loss); 32 LOG(INFO) << "Iteration " << iter_ << ", loss = " << loss; 33 } 34 // 判斷是否需要最后Test 35 if (param_.test_interval() && iter_ % param_.test_interval() == 0) { 36 TestAll(); 37 } 38 LOG(INFO) << "Optimization Done."; 39 }下面繼續分析具體的迭代過程發生的Step函數:
1 template <typename Dtype> 2 void Solver<Dtype>::Step(int iters) { 3 vector<Blob<Dtype>*> bottom_vec; 4 // 設置開始的迭代次數(如果是從之前的snapshot恢復的,那iter_等于snapshot時的迭代次數)和結束的迭代次數 5 const int start_iter = iter_; 6 const int stop_iter = iter_ + iters; 7 // 輸出的loss為前average_loss次loss的平均值,在solver.prototxt里設置,默認為1, 8 // losses存儲之前的average_loss個loss,smoothed_loss為最后要輸出的均值 9 int average_loss = this->param_.average_loss(); 10 vector<Dtype> losses; 11 Dtype smoothed_loss = 0; 12 // 迭代 13 while (iter_ < stop_iter) { 14 // 清空上一次所有參數的梯度 15 net_->ClearParamDiffs(); 16 // 判斷是否需要測試 17 if (param_.test_interval() && iter_ % param_.test_interval() == 0 18 && (iter_ > 0 || param_.test_initialization()) 19 && Caffe::root_solver()) { 20 TestAll(); 21 // 判斷是否需要提前結束迭代 22 if (requested_early_exit_) { 23 break; 24 } 25 } 26 for (int i = 0; i < callbacks_.size(); ++i) { 27 callbacks_[i]->on_start(); 28 } 29 // 判斷當前迭代次數是否需要顯示loss等信息 30 const bool display = param_.display() && iter_ % param_.display() == 0; 31 net_->set_debug_info(display && param_.debug_info()); 32 Dtype loss = 0; 33 // iter_size也是在solver.prototxt里設置,實際上的batch_size=iter_size*網絡定義里的batch_size, 34 // 因此每一次迭代的loss是iter_size次迭代的和,再除以iter_size,這個loss是通過調用`Net::ForwardBackward`函數得到的 35 // 這個設置我的理解是在GPU的顯存不夠的時候使用,比如我本來想把batch_size設置為128,但是會out_of_memory, 36 // 借助這個方法,可以設置batch_size=32,iter_size=4,那實際上每次迭代還是處理了128個數據 37 for (int i = 0; i < param_.iter_size(); ++i) { 38 loss += net_->ForwardBackward(bottom_vec); 39 } 40 loss /= param_.iter_size(); 41 // 計算要輸出的smoothed_loss,如果losses里還沒有存夠average_loss個loss則將當前的loss插入,如果已經存夠了,則將之前的替換掉 42 if (losses.size() < average_loss) { 43 losses.push_back(loss); 44 int size = losses.size(); 45 smoothed_loss = (smoothed_loss * (size - 1) + loss) / size; 46 } else { 47 int idx = (iter_ - start_iter) % average_loss; 48 smoothed_loss += (loss - losses[idx]) / average_loss; 49 losses[idx] = loss; 50 } 51 // 輸出當前迭代的信息 52 if (display) { 53 LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ 54 << ", loss = " << smoothed_loss; 55 const vector<Blob<Dtype>*>& result = net_->output_blobs(); 56 int score_index = 0; 57 for (int j = 0; j < result.size(); ++j) { 58 const Dtype* result_vec = result[j]->cpu_data(); 59 const string& output_name = 60 net_->blob_names()[net_->output_blob_indices()[j]]; 61 const Dtype loss_weight = 62 net_->blob_loss_weights()[net_->output_blob_indices()[j]]; 63 for (int k = 0; k < result[j]->count(); ++k) { 64 ostringstream loss_msg_stream; 65 if (loss_weight) { 66 loss_msg_stream << " (* " << loss_weight 67 << " = " << loss_weight * result_vec[k] << " loss)"; 68 } 69 LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" 70 << score_index++ << ": " << output_name << " = " 71 << result_vec[k] << loss_msg_stream.str(); 72 } 73 } 74 } 75 for (int i = 0; i < callbacks_.size(); ++i) { 76 callbacks_[i]->on_gradients_ready(); 77 } 78 // 執行梯度的更新,這個函數在基類`Solver`中沒有實現,會調用每個子類自己的實現,后面具體分析`SGDSolver`的實現 79 ApplyUpdate(); 80 // 迭代次數加1 81 ++iter_; 82 // 調用GetRequestedAction,實際是通過action_request_function_函數指針調用之前設置好(通過`SetRequestedAction`)的 83 // signal_handler的`CheckForSignals`函數,這個函數的作用是 84 // 會根據之前是否遇到系統信號以及信號的類型和我們設置(或者默認)的方式返回處理的方式 85 SolverAction::Enum request = GetRequestedAction(); 86 // 判斷當前迭代是否需要snapshot,如果request等于`SNAPSHOT`則也需要 87 if ((param_.snapshot() 88 && iter_ % param_.snapshot() == 0 89 && Caffe::root_solver()) || 90 (request == SolverAction::SNAPSHOT)) { 91 Snapshot(); 92 } 93 // 如果request為`STOP`則修改`requested_early_exit_`為true,之后就會提前結束迭代 94 if (SolverAction::STOP == request) { 95 requested_early_exit_ = true; 96 break; 97 } 98 } 99 }SGDSolver::ApplyUpdate具體實現
每一組網絡中的參數的更新都是在不同類型的Solver自己實現的ApplyUpdate函數中完成的,下面我們就以最常用的SGD為例子來分析這個函數具體的功能:
1 template <typename Dtype> 2 void SGDSolver<Dtype>::ApplyUpdate() { 3 CHECK(Caffe::root_solver()); 4 // GetLearningRate根據設置的lr_policy來計算當前迭代的learning rate的值 5 Dtype rate = GetLearningRate(); 6 // 判斷是否需要輸出當前的learning rate 7 if (this->param_.display() && this->iter_ % this->param_.display() == 0) { 8 LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; 9 } 10 // 避免梯度爆炸,如果梯度的二范數超過了某個數值則進行scale操作,將梯度減小 11 ClipGradients(); 12 // 對所有可更新的網絡參數進行操作 13 for (int param_id = 0; param_id < this->net_->learnable_params().size(); 14 ++param_id) { 15 // 將第param_id個參數的梯度除以iter_size,這一步的作用是保證實際的batch_size=iter_size*設置的batch_size 16 Normalize(param_id); 17 // 將正則化部分的梯度降入到每個參數的梯度中 18 Regularize(param_id); 19 // 計算SGD算法的梯度(momentum等) 20 ComputeUpdateValue(param_id, rate); 21 } 22 // 調用`Net::Update`更新所有的參數 23 this->net_->Update(); 24 }下面我們繼續具體分析一下Normalize/Regularize/ComputeUpdateValue的實現,我們均以CPU的代碼為例子,GPU部分的處理原理是一樣的:
Normalize
1 template <typename Dtype> 2 void SGDSolver<Dtype>::Normalize(int param_id) { 3 // 如果iter_size的值為1,則不需要任何處理直接return 4 if (this->param_.iter_size() == 1) { return; } 5 // 通過net_返回所有可以學習的參數,是一個vector<shared_ptr<Blob<Dtype> > > 6 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 7 // 要乘以的系數等于1/iter_size 8 const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); 9 switch (Caffe::mode()) { 10 case Caffe::CPU: { 11 // caffe_scal在/CAFFE_ROOT/src/caffe/util/math_functions.cpp中 12 // 是blas的scale函數的一個封裝,第一個參數是數據的個數,第二個參數是乘以的系數, 13 // 第三個參數是數據的指針 14 caffe_scal(net_params[param_id]->count(), accum_normalization, 15 net_params[param_id]->mutable_cpu_diff()); 16 break; 17 } 18 case Caffe::GPU: { 19 // GPU代碼略 20 } 21 }Regularize
1 template <typename Dtype> 2 void SGDSolver<Dtype>::Regularize(int param_id) { 3 // 獲取所有可以學習的參數的vector 4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 5 // 獲取所有的參數對應的weight_decay的vector 6 const vector<float>& net_params_weight_decay = 7 this->net_->params_weight_decay(); 8 // 模型整體的weight_decay數值 9 Dtype weight_decay = this->param_.weight_decay(); 10 // 獲取正則化的類型:L1 或 L2 11 string regularization_type = this->param_.regularization_type(); 12 // 實際的weight_decay等于整體模型的數值乘以具體每個參數的數值 13 Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; 14 switch (Caffe::mode()) { 15 case Caffe::CPU: { 16 // 如果weight_decay不為0,則計算 17 if (local_decay) { 18 if (regularization_type == "L2") { 19 // L2的梯度為diff_ = weight_decay*data_ + diff_ 20 // caffe_axpy的功能是 y = a*x + y 21 // 第一個參數是數據的個數,第二個是上式的a,第三個是x的指針,第四個是y的指針 22 caffe_axpy(net_params[param_id]->count(), 23 local_decay, 24 net_params[param_id]->cpu_data(), 25 net_params[param_id]->mutable_cpu_diff()); 26 } else if (regularization_type == "L1") { 27 // L1的梯度為diff_ = diff_ + sign(data_) 28 // temp_ = sign(data_) 29 caffe_cpu_sign(net_params[param_id]->count(), 30 net_params[param_id]->cpu_data(), 31 temp_[param_id]->mutable_cpu_data()); 32 // 將temp_加到diff_中 diff_ = weight_decay*temp_ + diff_ 33 caffe_axpy(net_params[param_id]->count(), 34 local_decay, 35 temp_[param_id]->cpu_data(), 36 net_params[param_id]->mutable_cpu_diff()); 37 } else { 38 LOG(FATAL) << "Unknown regularization type: " << regularization_type; 39 } 40 } 41 break; 42 } 43 // GPU代碼略 44 }ComputeUpdatedValue
1 template <typename Dtype> 2 void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) { 3 // 獲取所有可以更新的參數的vector 4 const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params(); 5 // 獲取所有參數對應的learning_rate的vector 6 const vector<float>& net_params_lr = this->net_->params_lr(); 7 // 獲取momentum數值 8 Dtype momentum = this->param_.momentum(); 9 // 實際的learning_rate為全局的learning_rate乘以每個參數對應的learning_rate 10 Dtype local_rate = rate * net_params_lr[param_id]; 11 switch (Caffe::mode()) { 12 case Caffe::CPU: { 13 // 關于SGD的公式參考caffe官網tutorial的Solver部分 14 // history_存儲了上一次的梯度,下面這個函數: 15 // history_ = learning_rate*diff_ + momentum*history 16 caffe_cpu_axpby(net_params[param_id]->count(), local_rate, 17 net_params[param_id]->cpu_diff(), momentum, 18 history_[param_id]->mutable_cpu_data()); 19 // 把當前的梯度拷貝給參數Blob的diff_ 20 caffe_copy(net_params[param_id]->count(), 21 history_[param_id]->cpu_data(), 22 net_params[param_id]->mutable_cpu_diff()); 23 break; 24 } 25 case Caffe::GPU: { 26 // GPU代碼略 27 } 28 }至此Solver主要的代碼都已經分析完了,總結起來主要有:(1)solver_factory的register和create不同類型Solver的機制,(2)通過signal_handler來獲取系統信號,并根據用戶或默認的設置進行相應的處理,(3)Solver::Solve函數的具體實現的分析,(4)SGDSolver::ApplyUpdate函數的具體實現。前面三個部分都屬于基類的,最后一個是SGDSolver這個子類的,如果用戶想要實現自己的Solver類,也應該類似地去繼承基類,并實現自己的ApplyUpdate函數,在代碼的末尾通過register宏完成注冊,便可以被成功的調用。
在train()中的solver->Solve()中的Step()中的ForwardBackward()中進行的各個layers的計算
總結
- 上一篇: 刀剑物语1.41攻略
- 下一篇: 手机闹钟怎么设置?怎么设置手机闹铃