梳理caffe代码blob(三)

梳理caffe代码blob(三)

贯穿整个caffe的就是数据blob:

 

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
  1. #ifndef CAFFE_BLOB_HPP_  
  2. #define CAFFE_BLOB_HPP_  
  3.   
  4. #include <algorithm>  
  5. #include <string>  
  6. #include <vector>  
  7.   
  8. #include "caffe/common.hpp"  
  9. #include "caffe/proto/caffe.pb.h"  
  10. #include "caffe/syncedmem.hpp"  
  11. #include "caffe/util/math_functions.hpp"  
  12.   
  13. const int kMaxBlobAxes = INT_MAX;  
  14.   
  15. namespace caffe {  
  16.   
  17. /** 
  18.  * @brief A wrapper around SyncedMemory holders serving as the basic 
  19.  *        computational unit through which Layer%s, Net%s, and Solver%s 
  20.  *        interact. 
  21.  * 
  22.  * TODO(dox): more thorough description. 
  23.  */  
  24.   
  25.   
  26. template <typename Dtype>  
  27. class Blob {  
  28.  public:  
  29.   Blob()  
  30.        : data_(), diff_(), count_(0), capacity_(0) {}  
  31.   
  32.   /// @brief Deprecated; use <code>Blob(const vector<int>& shape)</code>.  
  33.   //explicit关键字的作用是禁止单参数构造函数的隐式转换  
  34.   explicit Blob(const int num, const int channels, const int height,  
  35.       const int width);  
  36.   explicit Blob(const vector<int>& shape);  
  37.   
  38.   /// @brief Deprecated; use <code>Reshape(const vector<int>& shape)</code>.  
  39. /* 
  40. Reshape函数将num,channels,height,width传递给vector shape_  
  41. */  
  42.   void Reshape(const int num, const int channels, const int height,  
  43.       const int width);  
  44.  /** 
  45.  *Blob作为一个最基础的类,其中构造函数开辟一个内存空间来存储数据,Reshape函数在Layer中的 
  46.  *reshape或者forward 操作中来adjust the dimensions of a top blob。同时在改变Blob大小时, 
  47.  *内存将会被重新分配如果内存大小不够了,并且额外的内存将不会被释放。对input的blob进行reshape, 
  48.  *如果立马调用Net::Backward是会出错的,因为reshape之后,要么Net::forward或者Net::Reshape就会 
  49.  *被调用来将新的input shape 传播到高层 
  50.  */  
  51.   //根据shape来初始化shape_和shape_data_,以及为data_ 和diff_ 分配空间。   
  52.   void Reshape(const vector<int>& shape);  
  53.   void Reshape(const BlobShape& shape);  
  54.   void ReshapeLike(const Blob& other);  
  55.   //iniline主要是将代码进行复制,扩充,会使代码总量上升,好处就是可以节省调用的开销,以string形式获取shape_  
  56.   inline string shape_string() const {  
  57.     ostringstream stream;  
  58.     for (int i = 0; i < shape_.size(); ++i) {  
  59.       stream << shape_[i] << " ";  
  60.     }  
  61.     stream << "(" << count_ << ")";  
  62.     return stream.str();  
  63.   }  
  64. //获取shape_  
  65.   inline const vector<int>& shape() const { return shape_; }  
  66.   /** 
  67.    * @brief Returns the dimension of the index-th axis (or the negative index-th 
  68.    *        axis from the end, if index is negative). 
  69.    * 
  70.    * @param index the axis index, which may be negative as it will be 
  71.    *        "canonicalized" using CanonicalAxisIndex. 
  72.    *        Dies on out of range index. 
  73.    */  
  74. //获取index维的大小  
  75.   inline int shape(int index) const {  
  76.     return shape_[CanonicalAxisIndex(index)];  
  77.   }  
  78. //获取维的个数  
  79.   inline int num_axes() const { return shape_.size(); }  
  80. //获取当前data的大小  
  81.   inline int count() const { return count_; }  
  82.   
  83.   /** 
  84.    * @brief Compute the volume of a slice; i.e., the product of dimensions 
  85.    *        among a range of axes. 
  86.    * 
  87.    * @param start_axis The first axis to include in the slice. 
  88.    * 
  89.    * @param end_axis The first axis to exclude from the slice. 
  90.    */  
  91. /*多个count()函数,主要还是为了统计Blob的容量(volume),或者是某一片(slice), 
  92. 从某个axis到具体某个axis的shape乘积。 
  93. */  
  94. //获取某几维数据的大小  
  95.   inline int count(int start_axis, int end_axis) const {  
  96.     CHECK_LE(start_axis, end_axis);  
  97.     CHECK_GE(start_axis, 0);  
  98.     CHECK_GE(end_axis, 0);  
  99.     CHECK_LE(start_axis, num_axes());  
  100.     CHECK_LE(end_axis, num_axes());  
  101.     int count = 1;  
  102.     for (int i = start_axis; i < end_axis; ++i) {  
  103.       count *= shape(i);  
  104.     }  
  105.     return count;  
  106.   }  
  107.   /** 
  108.    * @brief Compute the volume of a slice spanning from a particular first 
  109.    *        axis to the final axis. 
  110.    * 
  111.    * @param start_axis The first axis to include in the slice. 
  112.    */  
  113. //获取某一维到结束数据的大小  
  114.   inline int count(int start_axis) const {  
  115.     return count(start_axis, num_axes());  
  116.   }  
  117.   
  118.   /** 
  119.    * @brief Returns the 'canonical' version of a (usually) user-specified axis, 
  120.    *        allowing for negative indexing (e.g., -1 for the last axis). 
  121.    * 
  122.    * @param index the axis index. 
  123.    *        If 0 <= index < num_axes(), return index. 
  124.    *        If -num_axes <= index <= -1, return (num_axes() - (-index)), 
  125.    *        e.g., the last axis index (num_axes() - 1) if index == -1, 
  126.    *        the second to last if index == -2, etc. 
  127.    *        Dies on out of range index. 
  128.    */  
  129.   //Blob的Index是可以从负坐标开始读的,标准化索引,主要是对参数索引进行标准化,以满足要求  
  130.   inline int CanonicalAxisIndex(int axis_index) const {  
  131.     CHECK_GE(axis_index, -num_axes())  
  132.         << "axis " << axis_index << " out of range for " << num_axes()  
  133.         << "-D Blob with shape " << shape_string();  
  134.     CHECK_LT(axis_index, num_axes())  
  135.         << "axis " << axis_index << " out of range for " << num_axes()  
  136.         << "-D Blob with shape " << shape_string();  
  137.     if (axis_index < 0) {  
  138.       return axis_index + num_axes();  
  139.     }  
  140.     return axis_index;  
  141.   }  
  142.   //Blob中的4个基本变量num,channel,height,width可以直接通过shape(0),shape(1),shape(2),shape(3)来访问  
  143.   /// @brief Deprecated legacy shape accessor num: use shape(0) instead.  
  144.   inline int num() const { return LegacyShape(0); }  
  145.   /// @brief Deprecated legacy shape accessor channels: use shape(1) instead.  
  146.   inline int channels() const { return LegacyShape(1); }  
  147.   /// @brief Deprecated legacy shape accessor height: use shape(2) instead.  
  148.   inline int height() const { return LegacyShape(2); }  
  149.   /// @brief Deprecated legacy shape accessor width: use shape(3) instead.  
  150.   inline int width() const { return LegacyShape(3); }  
  151. //data_维数不大于4时才能使用,功能同shape()类似。  
  152.   inline int LegacyShape(int index) const {  
  153.     CHECK_LE(num_axes(), 4)  
  154.         << "Cannot use legacy accessors on Blobs with > 4 axes.";  
  155.     CHECK_LT(index, 4);  
  156.     CHECK_GE(index, -4);  
  157.     if (index >= num_axes() || index < -num_axes()) {  
  158.       // Axis is out of range, but still in [0, 3] (or [-4, -1] for reverse  
  159.       // indexing) -- this special case simulates the one-padding used to fill  
  160.       // extraneous axes of legacy blobs.  
  161.       return 1;  
  162.     }  
  163.     return shape(index);  
  164.   }  
  165.   //计算offset,offset计算的方式也支持两种方式,一种直接指定n,c,h,w或者放到一个vector中进行计算,  
  166.   //偏差是根据对应的n,c,h,w,返回的offset是((n*channels()+c)*height()+h)*width()+w  
  167.   inline int offset(const int n, const int c = 0, const int h = 0,  
  168.       const int w = 0) const {  
  169.     CHECK_GE(n, 0);  
  170.     CHECK_LE(n, num());  
  171.     CHECK_GE(channels(), 0);  
  172.     CHECK_LE(c, channels());  
  173.     CHECK_GE(height(), 0);  
  174.     CHECK_LE(h, height());  
  175.     CHECK_GE(width(), 0);  
  176.     CHECK_LE(w, width());  
  177.     return ((n * channels() + c) * height() + h) * width() + w;  
  178.   }  
  179.   
  180.   inline int offset(const vector<int>& indices) const {  
  181.     CHECK_LE(indices.size(), num_axes());  
  182.     int offset = 0;  
  183.     for (int i = 0; i < num_axes(); ++i) {  
  184.       offset *= shape(i);  
  185.       if (indices.size() > i) {  
  186.         CHECK_GE(indices[i], 0);  
  187.         CHECK_LT(indices[i], shape(i));  
  188.         offset += indices[i];  
  189.       }  
  190.     }  
  191.     return offset;  
  192.   }  
  193.   /** 
  194.    * @brief Copy from a source Blob. 
  195.    * 
  196.    * @param source the Blob to copy from 
  197.    * @param copy_diff if false, copy the data; if true, copy the diff 
  198.    * @param reshape if false, require this Blob to be pre-shaped to the shape 
  199.    *        of other (and die otherwise); if true, Reshape this Blob to other's 
  200.    *        shape if necessary 
  201.    */  
  202.   //一个blob中copy数据 ,通过开关控制是否copy_diff,如果是False则copy data。reshape控制是否需要reshape  
  203.   void CopyFrom(const Blob<Dtype>& source, bool copy_diff = false,  
  204.       bool reshape = false);  
  205. /*这一部分函数主要通过给定的位置访问数据,根据位置计算与数据起始 
  206.   的偏差offset,在通过cpu_data*指针获得地址 
  207. */  
  208. //获取某位置的data_数据  
  209.   inline Dtype data_at(const int n, const int c, const int h,  
  210.       const int w) const {  
  211.     return cpu_data()[offset(n, c, h, w)];  
  212.   }  
  213. //获取某位置的diff_数据  
  214.   inline Dtype diff_at(const int n, const int c, const int h,  
  215.       const int w) const {  
  216.     return cpu_diff()[offset(n, c, h, w)];  
  217.   }  
  218.   
  219.   inline Dtype data_at(const vector<int>& index) const {  
  220.     return cpu_data()[offset(index)];  
  221.   }  
  222.   
  223.   inline Dtype diff_at(const vector<int>& index) const {  
  224.     return cpu_diff()[offset(index)];  
  225.   }  
  226. //获取data_  
  227.   inline const shared_ptr<SyncedMemory>& data() const {  
  228.     CHECK(data_);  
  229.     return data_;  
  230.   }  
  231. //获取diff_  
  232.   inline const shared_ptr<SyncedMemory>& diff() const {  
  233.     CHECK(diff_);  
  234.     return diff_;  
  235.   }  
  236.   //这里有data和diff两类数据,而这个diff就是我们所熟知的偏差,前者主要存储  
  237.   //前向传递的数据,而后者存储的是反向传播中的梯度  
  238.   const Dtype* cpu_data() const;//获取data_ cpu指针  
  239.   void set_cpu_data(Dtype* data);//设置data_的cpu指针,只是修改了指针  
  240.   const Dtype* gpu_data() const;//获取data_的gpu指针  
  241.   const Dtype* cpu_diff() const;//获取diff_的cpu指针  
  242.   const Dtype* gpu_diff() const;//获取diff_的gpu指针  
  243.   Dtype* mutable_cpu_data();//见SyncedMemory的mutable_cpu_data();  
  244.   Dtype* mutable_gpu_data();//见SyncedMemory的mutable_gpu_data();  
  245.   Dtype* mutable_cpu_diff();//见SyncedMemory的mutable_cpu_data();  
  246.   Dtype* mutable_gpu_diff();//见SyncedMemory的mutable_gpu_data();  
  247.   //更新data_的数据,减去diff_的数据  
  248.   void Update();  
  249. /* 
  250. 其中用到math_functions.hpp中的函数caffe_axpy(),该函数封装了cblas_saxpy,实现的是Y=alpha*X+Y。 
  251. 由此,知该函数的功能是data_=(data_-diff_)。另外,该函数只实现了对double和float型数据, 
  252. 对于unsigned int和int由于该函数主要是在Net中被调用,只有Blob<float>和Blob<double>型式, 
  253. 因此没有定义unsigned int和int。 
  254. */  
  255.   void FromProto(const BlobProto& proto, bool reshape = true);  
  256. /* 
  257. 由BlobProto对Blob进行赋值操作。reshape代表是否允许修改shape_的大小。 
  258. 需要注意的是再这里有double和float两种类型的数据 ,在代码中可以看到具体的体现 
  259. */  
  260.   void ToProto(BlobProto* proto, bool write_diff = false) const;  
  261.   
  262.   /// @brief Compute the sum of absolute values (L1 norm) of the data.  
  263. /* 
  264. 功能:计算L1范数 
  265. 说明:其中用到了math_function.hpp中的函数caffe_cpu_asum()和caffe_gpu_asum,实现的功能是对向量X求其每个元素绝对值的和,不同的是X分别在cpu和gpu中。 
  266. */  
  267.   Dtype asum_data() const;  
  268.   /// @brief Compute the sum of absolute values (L1 norm) of the diff.  
  269.   Dtype asum_diff() const;  
  270.   /// @brief Compute the sum of squares (L2 norm squared) of the data.  
  271. /* 
  272. 功能:计算L2范数。 
  273. 说明:用到了math_function.hpp中的caffe_cpu_dot(),caffe_cpu_strided_dot(),caffe_gpu_dot(), caffe_gpu_strided_dot()。具体就是就向量X的平方和。 
  274. */  
  275.   Dtype sumsq_data() const;  
  276.   /// @brief Compute the sum of squares (L2 norm squared) of the diff.  
  277.   Dtype sumsq_diff() const;  
  278.   
  279.   /// @brief Scale the blob data by a constant factor.  
  280. /* 
  281. 功能:正规化data_。 
  282. 说明:用到math_function.hpp中的caffe_scal()和caffe_gpu_scal()函数,就是对向量X乘上一个因子。 
  283. */  
  284.   void scale_data(Dtype scale_factor);  
  285.   /// @brief Scale the blob diff by a constant factor.  
  286.   void scale_diff(Dtype scale_factor);  
  287.   
  288.   /** 
  289.    * @brief Set the data_ shared_ptr to point to the SyncedMemory holding the 
  290.    *        data_ of Blob other -- useful in Layer%s which simply perform a copy 
  291.    *        in their Forward pass. 
  292.    * 
  293.    * This deallocates the SyncedMemory holding this Blob's data_, as 
  294.    * shared_ptr calls its destructor when reset with the "=" operator. 
  295.    */  
  296.   void ShareData(const Blob& other);//本Blob共享other的data_  
  297.   /** 
  298.    * @brief Set the diff_ shared_ptr to point to the SyncedMemory holding the 
  299.    *        diff_ of Blob other -- useful in Layer%s which simply perform a copy 
  300.    *        in their Forward pass. 
  301.    * 
  302.    * This deallocates the SyncedMemory holding this Blob's diff_, as 
  303.    * shared_ptr calls its destructor when reset with the "=" operator. 
  304.    */  
  305.   void ShareDiff(const Blob& other);//本Blob共享other的diff_  
  306.   
  307.   bool ShapeEquals(const BlobProto& other);//判断other与本Blob形状是否相同。  
  308.   
  309.  protected:  
  310. //data_指针,指针类型是shared_ptr,属于boost库的一个智能指针,这一部分主要用来申请内存存储data,data主要是正向传播的时候用的  
  311.   shared_ptr<SyncedMemory> data_;  
  312. //diff_主要用来存储偏差,update data  
  313.   shared_ptr<SyncedMemory> diff_;  
  314. //shape_存储Blob的形状  
  315.   vector<int> shape_;  
  316. //count_表示Blob中的元素个数,也就是个数*通道数*高度*宽度  
  317.   int count_;  
  318. //capacity表示当前的元素个数,因为Blob可能会reshape  
  319.   int capacity_;  
  320.   
  321.   DISABLE_COPY_AND_ASSIGN(Blob);  
  322. };  // class Blob  
  323.   
  324. }  // namespace caffe  
  325.   
  326. #endif  // CAFFE_BLOB_HPP_  

顺便将实现部分也贴出来,方便对照:

[cpp] view plain copy
 
 在CODE上查看代码片派生到我的代码片
    1. #include <climits>  
    2. #include <vector>  
    3.   
    4. #include "caffe/blob.hpp"  
    5. #include "caffe/common.hpp"  
    6. #include "caffe/syncedmem.hpp"  
    7. #include "caffe/util/math_functions.hpp"  
    8.   
    9. namespace caffe {  
    10.   
    11. template <typename Dtype>  
    12. //该函数将num,channels,height,width传递给vector shape_   
    13. void Blob<Dtype>::Reshape(const int num, const int channels, const int height,  
    14.     const int width) {  
    15.   vector<int> shape(4);  
    16.   shape[0] = num;  
    17.   shape[1] = channels;  
    18.   shape[2] = height;  
    19.   shape[3] = width;  
    20.   Reshape(shape);  
    21. }  
    22.   
    23. template <typename Dtype>  
    24. void Blob<Dtype>::Reshape(const vector<int>& shape) {  
    25.   CHECK_LE(shape.size(), kMaxBlobAxes);  
    26.   count_ = 1;  
    27.   shape_.resize(shape.size());//重新定义vector shape_ 的size  
    28.   for (int i = 0; i < shape.size(); ++i) {  
    29.     CHECK_GE(shape[i], 0);//确保shape 每个元素为正数  
    30.     CHECK_LE(shape[i], INT_MAX / count_) << "blob size exceeds INT_MAX";  
    31.     count_ *= shape[i];  
    32.     shape_[i] = shape[i];  
    33.   }  
    34.   //由于count_超过了当前capacity_ 因此需要重新分配内存空间  
    35.   if (count_ > capacity_) {  
    36.     capacity_ = count_;  
    37.     data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));  
    38.     diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));  
    39.   }  
    40. }  
    41.   
    42. template <typename Dtype>// BlobShape 在caffe.proto 中定义  
    43. void Blob<Dtype>::Reshape(const BlobShape& shape) {  
    44.   CHECK_LE(shape.dim_size(), kMaxBlobAxes);  
    45.   vector<int> shape_vec(shape.dim_size());  
    46.   for (int i = 0; i < shape.dim_size(); ++i) {  
    47.     shape_vec[i] = shape.dim(i);//dim 包含num,channels,height, width  
    48.   }  
    49.   Reshape(shape_vec);//用protobuf传递来dim 对shape_ 进行reshape  
    50. }  
    51. //用已知的Blob的shape来对shape_ 进行reshape  
    52. template <typename Dtype>  
    53. void Blob<Dtype>::ReshapeLike(const Blob<Dtype>& other) {  
    54.   Reshape(other.shape());  
    55. }  
    56. //用num,channels,height, width 初始化  
    57. template <typename Dtype>  
    58. Blob<Dtype>::Blob(const int num, const int channels, const int height,  
    59.     const int width)  
    60.   // capacity_ must be initialized before calling Reshape  
    61.   : capacity_(0) {  
    62.   Reshape(num, channels, height, width);  
    63. }  
    64. //用shape 初始化  
    65. template <typename Dtype>  
    66. Blob<Dtype>::Blob(const vector<int>& shape)  
    67.   // capacity_ must be initialized before calling Reshape  
    68.   : capacity_(0) {  
    69.   Reshape(shape);  
    70. }  
    71. //返回cpu 中的数据  
    72. template <typename Dtype>  
    73. const Dtype* Blob<Dtype>::cpu_data() const {  
    74.   CHECK(data_);  
    75.   return (const Dtype*)data_->cpu_data();  
    76. }  
    77. // 清空cpu 数据  
    78. template <typename Dtype>  
    79. void Blob<Dtype>::set_cpu_data(Dtype* data) {  
    80.   CHECK(data);  
    81.   data_->set_cpu_data(data);  
    82. }  
    83. //返回gpu 中的数据  
    84. template <typename Dtype>  
    85. const Dtype* Blob<Dtype>::gpu_data() const {  
    86.   CHECK(data_);  
    87.   return (const Dtype*)data_->gpu_data();  
    88. }  
    89. //反向传播导数diff_ 操作函数,返回cpu 中的数据  
    90. template <typename Dtype>  
    91. const Dtype* Blob<Dtype>::cpu_diff() const {  
    92.   CHECK(diff_);  
    93.   return (const Dtype*)diff_->cpu_data();  
    94. }  
    95. //返回gpu 中的数据  
    96. template <typename Dtype>  
    97. const Dtype* Blob<Dtype>::gpu_diff() const {  
    98.   CHECK(diff_);  
    99.   return (const Dtype*)diff_->gpu_data();  
    100. }  
    101.   
    102. template <typename Dtype>  
    103. Dtype* Blob<Dtype>::mutable_cpu_data() {  
    104.   CHECK(data_);  
    105.   return static_cast<Dtype*>(data_->mutable_cpu_data());  
    106. }  
    107.   
    108. template <typename Dtype>  
    109. Dtype* Blob<Dtype>::mutable_gpu_data() {  
    110.   CHECK(data_);  
    111.   return static_cast<Dtype*>(data_->mutable_gpu_data());  
    112. }  
    113.   
    114. template <typename Dtype>  
    115. Dtype* Blob<Dtype>::mutable_cpu_diff() {  
    116.   CHECK(diff_);  
    117.   return static_cast<Dtype*>(diff_->mutable_cpu_data());  
    118. }  
    119.   
    120. template <typename Dtype>  
    121. Dtype* Blob<Dtype>::mutable_gpu_diff() {  
    122.   CHECK(diff_);  
    123.   return static_cast<Dtype*>(diff_->mutable_gpu_data());  
    124. }  
    125. //当前的blob 的data_ 指向已知blob的数据  
    126. template <typename Dtype>  
    127. void Blob<Dtype>::ShareData(const Blob& other) {  
    128.   CHECK_EQ(count_, other.count());  
    129.   data_ = other.data();  
    130. }  
    131. //当前的blob 的diff_ 指向已知blob的反向传播导数  
    132. template <typename Dtype>  
    133. void Blob<Dtype>::ShareDiff(const Blob& other) {  
    134.   CHECK_EQ(count_, other.count());  
    135.   diff_ = other.diff();  
    136. }  
    137.   
    138. // The "update" method is used for parameter blobs in a Net, which are stored  
    139. // as Blob<float> or Blob<double> -- hence we do not define it for  
    140. // Blob<int> or Blob<unsigned int>.  
    141. template <> void Blob<unsigned int>::Update() { NOT_IMPLEMENTED; }  
    142. template <> void Blob<int>::Update() { NOT_IMPLEMENTED; }  
    143. //Updata函数用于参数blob的更新(weight,bias 等减去对应的导数)  
    144. template <typename Dtype>  
    145. void Blob<Dtype>::Update() {  
    146.   // We will perform update based on where the data is located.  
    147.   switch (data_->head()) {  
    148.   case SyncedMemory::HEAD_AT_CPU://数据在cpu上,则在cpu上进行计算  
    149.     // perform computation on CPU  
    150.     caffe_axpy<Dtype>(count_, Dtype(-1),  
    151.         static_cast<const Dtype*>(diff_->cpu_data()),  
    152.         static_cast<Dtype*>(data_->mutable_cpu_data()));  
    153.     break;  
    154.   case SyncedMemory::HEAD_AT_GPU:  
    155.   case SyncedMemory::SYNCED:  
    156. #ifndef CPU_ONLY//如果没有定义CPU_ONLY,且数据在gpu上,则在gpu上进行计算  
    157.     // perform computation on GPU  
    158.     caffe_gpu_axpy<Dtype>(count_, Dtype(-1),  
    159.         static_cast<const Dtype*>(diff_->gpu_data()),  
    160.         static_cast<Dtype*>(data_->mutable_gpu_data()));  
    161. #else  
    162.     NO_GPU;  
    163. #endif  
    164.     break;  
    165.   default:  
    166.     LOG(FATAL) << "Syncedmem not initialized.";  
    167.   }  
    168. }  
    169.   
    170. template <> unsigned int Blob<unsigned int>::asum_data() const {  
    171.   NOT_IMPLEMENTED;  
    172.   return 0;  
    173. }  
    174.   
    175. template <> int Blob<int>::asum_data() const {  
    176.   NOT_IMPLEMENTED;  
    177.   return 0;  
    178. }  
    179. //返回data_ 中所有 element 的绝对值之和  
    180. template <typename Dtype>  
    181. Dtype Blob<Dtype>::asum_data() const {  
    182.   if (!data_) { return 0; }  
    183.   switch (data_->head()) {  
    184.   case SyncedMemory::HEAD_AT_CPU:  
    185.     return caffe_cpu_asum(count_, cpu_data());  
    186.   case SyncedMemory::HEAD_AT_GPU:  
    187.   case SyncedMemory::SYNCED:  
    188. #ifndef CPU_ONLY  
    189.   {  
    190.     Dtype asum;  
    191.     caffe_gpu_asum(count_, gpu_data(), &asum);  
    192.     return asum;  
    193.   }  
    194. #else  
    195.     NO_GPU;  
    196. #endif  
    197.   case SyncedMemory::UNINITIALIZED:  
    198.     return 0;  
    199.   default:  
    200.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
    201.   }  
    202.   return 0;  
    203. }  
    204.   
    205. template <> unsigned int Blob<unsigned int>::asum_diff() const {  
    206.   NOT_IMPLEMENTED;  
    207.   return 0;  
    208. }  
    209.   
    210. template <> int Blob<int>::asum_diff() const {  
    211.   NOT_IMPLEMENTED;  
    212.   return 0;  
    213. }  
    214. //返回diff_ 中所有 element 的绝对值之和  
    215. template <typename Dtype>  
    216. Dtype Blob<Dtype>::asum_diff() const {  
    217.   if (!diff_) { return 0; }  
    218.   switch (diff_->head()) {  
    219.   case SyncedMemory::HEAD_AT_CPU:  
    220.     return caffe_cpu_asum(count_, cpu_diff());  
    221.   case SyncedMemory::HEAD_AT_GPU:  
    222.   case SyncedMemory::SYNCED:  
    223. #ifndef CPU_ONLY  
    224.   {  
    225.     Dtype asum;  
    226.     caffe_gpu_asum(count_, gpu_diff(), &asum);  
    227.     return asum;  
    228.   }  
    229. #else  
    230.     NO_GPU;  
    231. #endif  
    232.   case SyncedMemory::UNINITIALIZED:  
    233.     return 0;  
    234.   default:  
    235.     LOG(FATAL) << "Unknown SyncedMemory head state: " << diff_->head();  
    236.   }  
    237.   return 0;  
    238. }  
    239.   
    240. template <> unsigned int Blob<unsigned int>::sumsq_data() const {  
    241.   NOT_IMPLEMENTED;  
    242.   return 0;  
    243. }  
    244.   
    245. template <> int Blob<int>::sumsq_data() const {  
    246.   NOT_IMPLEMENTED;  
    247.   return 0;  
    248. }  
    249. //返回 data_ 中所有 element 的平方和  
    250. template <typename Dtype>  
    251. Dtype Blob<Dtype>::sumsq_data() const {  
    252.   Dtype sumsq;  
    253.   const Dtype* data;  
    254.   if (!data_) { return 0; }  
    255.   switch (data_->head()) {  
    256.   case SyncedMemory::HEAD_AT_CPU:  
    257.     data = cpu_data();  
    258.     sumsq = caffe_cpu_dot(count_, data, data);  
    259.     break;  
    260.   case SyncedMemory::HEAD_AT_GPU:  
    261.   case SyncedMemory::SYNCED:  
    262. #ifndef CPU_ONLY  
    263.     data = gpu_data();  
    264.     caffe_gpu_dot(count_, data, data, &sumsq);  
    265. #else  
    266.     NO_GPU;  
    267. #endif  
    268.     break;  
    269.   case SyncedMemory::UNINITIALIZED:  
    270.     return 0;  
    271.   default:  
    272.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
    273.   }  
    274.   return sumsq;  
    275. }  
    276.   
    277. template <> unsigned int Blob<unsigned int>::sumsq_diff() const {  
    278.   NOT_IMPLEMENTED;  
    279.   return 0;  
    280. }  
    281.   
    282. template <> int Blob<int>::sumsq_diff() const {  
    283.   NOT_IMPLEMENTED;  
    284.   return 0;  
    285. }  
    286. //返回 diff_ 中所有 element 的平方和  
    287. template <typename Dtype>  
    288. Dtype Blob<Dtype>::sumsq_diff() const {  
    289.   Dtype sumsq;  
    290.   const Dtype* diff;  
    291.   if (!diff_) { return 0; }  
    292.   switch (diff_->head()) {  
    293.   case SyncedMemory::HEAD_AT_CPU:  
    294.     diff = cpu_diff();  
    295.     sumsq = caffe_cpu_dot(count_, diff, diff);  
    296.     break;  
    297.   case SyncedMemory::HEAD_AT_GPU:  
    298.   case SyncedMemory::SYNCED:  
    299. #ifndef CPU_ONLY  
    300.     diff = gpu_diff();  
    301.     caffe_gpu_dot(count_, diff, diff, &sumsq);  
    302.     break;  
    303. #else  
    304.     NO_GPU;  
    305. #endif  
    306.   case SyncedMemory::UNINITIALIZED:  
    307.     return 0;  
    308.   default:  
    309.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
    310.   }  
    311.   return sumsq;  
    312. }  
    313.   
    314. template <> void Blob<unsigned int>::scale_data(unsigned int scale_factor) {  
    315.   NOT_IMPLEMENTED;  
    316. }  
    317.   
    318. template <> void Blob<int>::scale_data(int scale_factor) {  
    319.   NOT_IMPLEMENTED;  
    320. }  
    321. // 给data乘以scale_factor  
    322. template <typename Dtype>  
    323. void Blob<Dtype>::scale_data(Dtype scale_factor) {  
    324.   Dtype* data;  
    325.   if (!data_) { return; }  
    326.   switch (data_->head()) {  
    327.   case SyncedMemory::HEAD_AT_CPU:  
    328.     data = mutable_cpu_data();  
    329.     caffe_scal(count_, scale_factor, data);  
    330.     return;  
    331.   case SyncedMemory::HEAD_AT_GPU:  
    332.   case SyncedMemory::SYNCED:  
    333. #ifndef CPU_ONLY  
    334.     data = mutable_gpu_data();  
    335.     caffe_gpu_scal(count_, scale_factor, data);  
    336.     return;  
    337. #else  
    338.     NO_GPU;  
    339. #endif  
    340.   case SyncedMemory::UNINITIALIZED:  
    341.     return;  
    342.   default:  
    343.     LOG(FATAL) << "Unknown SyncedMemory head state: " << data_->head();  
    344.   }  
    345. }  
    346.   
    347. template <> void Blob<unsigned int>::scale_diff(unsigned int scale_factor) {  
    348.   NOT_IMPLEMENTED;  
    349. }  
    350.   
    351. template <> void Blob<int>::scale_diff(int scale_factor) {  
    352.   NOT_IMPLEMENTED;  
    353. }  
    354. // 给diff乘以scale_factor  
    355. template <typename Dtype>  
    356. void Blob<Dtype>::scale_diff(Dtype scale_factor) {  
    357.   Dtype* diff;  
    358.   if (!diff_) { return; }  
    359.   switch (diff_->head()) {  
    360.   case SyncedMemory::HEAD_AT_CPU:  
    361.     diff = mutable_cpu_diff();  
    362.     caffe_scal(count_, scale_factor, diff);  
    363.     return;  
    364.   case SyncedMemory::HEAD_AT_GPU:  
    365.   case SyncedMemory::SYNCED:  
    366. #ifndef CPU_ONLY  
    367.     diff = mutable_gpu_diff();  
    368.     caffe_gpu_scal(count_, scale_factor, diff);  
    369.     return;  
    370. #else  
    371.     NO_GPU;  
    372. #endif  
    373.   case SyncedMemory::UNINITIALIZED:  
    374.     return;  
    375.   default:  
    376.     LOG(FATAL) << "Unknown SyncedMemory head state: " << diff_->head();  
    377.   }  
    378. }  
    379. //BlobProto 是定义在caffe.proto 中的一个message,其字段有 data,diff,shape,num,channels,height,width  
    380. template <typename Dtype>  
    381. bool Blob<Dtype>::ShapeEquals(const BlobProto& other) {  
    382.   if (other.has_num() || other.has_channels() ||  
    383.       other.has_height() || other.has_width()) {  
    384.     // Using deprecated 4D Blob dimensions --  
    385.     // shape is (num, channels, height, width).  
    386.     // Note: we do not use the normal Blob::num(), Blob::channels(), etc.  
    387.     // methods as these index from the beginning of the blob shape, where legacy  
    388.     // parameter blobs were indexed from the end of the blob shape (e.g., bias  
    389.     // Blob shape (1 x 1 x 1 x N), IP layer weight Blob shape (1 x 1 x M x N)).  
    390.     return shape_.size() <= 4 &&  
    391.            LegacyShape(-4) == other.num() &&  
    392.            LegacyShape(-3) == other.channels() &&  
    393.            LegacyShape(-2) == other.height() &&  
    394.            LegacyShape(-1) == other.width();  
    395.   }  
    396.   vector<int> other_shape(other.shape().dim_size());  
    397.   for (int i = 0; i < other.shape().dim_size(); ++i) {  
    398.     other_shape[i] = other.shape().dim(i);  
    399.   }  
    400.   return shape_ == other_shape;  
    401. }//检查当前的blob和已知的 other 的 shape 是否相同,相同返回true  
    402.   
    403. template <typename Dtype>  
    404. void Blob<Dtype>::CopyFrom(const Blob& source, bool copy_diff, bool reshape) {  
    405.   if (source.count() != count_ || source.shape() != shape_) {  
    406.     if (reshape) {  
    407.       ReshapeLike(source);  
    408.     } else {  
    409.       LOG(FATAL) << "Trying to copy blobs of different sizes.";  
    410.     }  
    411.   }  
    412.   switch (Caffe::mode()) {  
    413.   case Caffe::GPU:  
    414.     if (copy_diff) {  
    415.       caffe_copy(count_, source.gpu_diff(),  
    416.           static_cast<Dtype*>(diff_->mutable_gpu_data()));  
    417.     } else {  
    418.       caffe_copy(count_, source.gpu_data(),  
    419.           static_cast<Dtype*>(data_->mutable_gpu_data()));  
    420.     }  
    421.     break;  
    422.   case Caffe::CPU:  
    423.     if (copy_diff) {  
    424.       caffe_copy(count_, source.cpu_diff(),  
    425.           static_cast<Dtype*>(diff_->mutable_cpu_data()));  
    426.     } else {  
    427.       caffe_copy(count_, source.cpu_data(),  
    428.           static_cast<Dtype*>(data_->mutable_cpu_data()));  
    429.     }  
    430.     break;  
    431.   default:  
    432.     LOG(FATAL) << "Unknown caffe mode.";  
    433.   }  
    434. }//从source 拷贝数据,copy_diff控制是拷贝diff还是data  
    435.   
    436. template <typename Dtype>  
    437. void Blob<Dtype>::FromProto(const BlobProto& proto, bool reshape) {  
    438.   if (reshape) {  
    439.     vector<int> shape;  
    440.     if (proto.has_num() || proto.has_channels() ||  
    441.         proto.has_height() || proto.has_width()) {  
    442.       // Using deprecated 4D Blob dimensions --  
    443.       // shape is (num, channels, height, width).  
    444.       shape.resize(4);  
    445.       shape[0] = proto.num();  
    446.       shape[1] = proto.channels();  
    447.       shape[2] = proto.height();  
    448.       shape[3] = proto.width();  
    449.     } else {  
    450.       shape.resize(proto.shape().dim_size());  
    451.       for (int i = 0; i < proto.shape().dim_size(); ++i) {  
    452.         shape[i] = proto.shape().dim(i);  
    453.       }  
    454.     }  
    455.     Reshape(shape);  
    456.   } else {//如果不做reshape要求当前的blob的shape和proto传入的shape相同  
    457.     CHECK(ShapeEquals(proto)) << "shape mismatch (reshape not set)";  
    458.   }  
    459.   // copy data  
    460.   Dtype* data_vec = mutable_cpu_data();  
    461.   for (int i = 0; i < count_; ++i) {  
    462.     data_vec[i] = proto.data(i);  
    463.   }//将proto传入的data拷贝到cpu数据  
    464.   if (proto.diff_size() > 0) {  
    465.     Dtype* diff_vec = mutable_cpu_diff();  
    466.     for (int i = 0; i < count_; ++i) {  
    467.       diff_vec[i] = proto.diff(i);  
    468.     }//将proto传入的diff 拷贝到cpu数据  
    469.   }  
    470. }  
    471.   
    472. template <typename Dtype>  
    473. void Blob<Dtype>::ToProto(BlobProto* proto, bool write_diff) const {  
    474.   proto->clear_shape();  
    475.   for (int i = 0; i < shape_.size(); ++i) {  
    476.     proto->mutable_shape()->add_dim(shape_[i]);  
    477.   }  
    478.   proto->clear_data();  
    479.   proto->clear_diff();  
    480.   const Dtype* data_vec = cpu_data();  
    481.   for (int i = 0; i < count_; ++i) {  
    482.     proto->add_data(data_vec[i]);//将data写入proto  
    483.   }  
    484.   if (write_diff) {  
    485.     const Dtype* diff_vec = cpu_diff();  
    486.     for (int i = 0; i < count_; ++i) {  
    487.       proto->add_diff(diff_vec[i]);//将diff写入proto  
    488.     }  
    489.   }  
    490. }  
    491.   
    492. INSTANTIATE_CLASS(Blob);  
    493. template class Blob<int>;  
    494. template class Blob<unsigned int>;  
    495.   
    496. }  // namespace caffe  
posted @ 2016-03-31 15:26  菜鸡一枚  阅读(5187)  评论(0编辑  收藏  举报