基于查表法实现 Softmax 函数

1 简介

在深度学习领域,Softmax 函数是一种广泛应用的激活函数,尤其在多分类问题中表现突出。它能够将原始的得分转换为概率分布,使每个类别的概率值都处于 0 到 1 之间且总和为 1。Softmax 内含有大量的指数运算,这使得它在嵌入式端(例如 RV1106)上计算较慢。

针对量化模型,模型的输出一般为一个 8/16/64 位数据类型的原始数据 data、一个 float 类型的 scale 和一个 int 类型的 zero_point,我们可以用查表法对这种特殊情况做优化。以 int8 数据类型为例,它可以表示的范围是从 -128 到 127,共 256 个不同的数据,反量化后的 float 类型数据也是 256 个。针对小批量有限数据,用查表法做优化是比较合适的。

2 代码实现

简介中已经提到,Softmax 的计算瓶颈主要在指数函数的运算上,因此优化 Softmax 函数的核心在于优化指数运算。

2.1 实现反量化函数

首先,模型的输出是一个 int8 类型的量化值,但是 softmax 的输入应该是一个 float 类型的反量化值,因此我们需要先实现一个反量化函数,这里我参考了 Rockchip 团队的实现方法。

template <typename T>
inline float DequantizeToFP32(T data, int32_t zero_point, float scale) {
  static_assert(
      (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value ||
       std::is_same<T, int16_t>::value || std::is_same<T, uint16_t>::value ||
       std::is_same<T, int32_t>::value || std::is_same<T, uint32_t>::value),
      "DequantizeToFP32 only support key type is "
      "uint8_t/int8_t/int16_t/uint16_t/int32_t/uint32_t");

  return ((float)data - (float)zero_point) * scale;
}

2.2 实现指数运算表

接下来我们建立一个数据表,表的索引是 int8 类型的数据,表的值是先对 int8 类型的数据做反量化,再求指数后的值,代码实现如下:

class ExpTable {
 public:
  ExpTable(int32_t zero_point, float scale) {
    static_assert((min_key <= max_key),
                  "ExpTable need min_key less than max_key");

    for (int32_t key = min_key; key <= max_key; ++key) {
      exp_table_[key] = std::exp(DequantizeToFP32(key, zero_point, scale));
    }
  }

  float GetExp(int32_t index) const {
    if (index < min_key || index > max_key) {
      LOGGER_ERROR("index < min_key || index > max_key. [%d, %d, %d]", index, min_key, max_key);
      return 0;
    }
    return exp_table_.at(index);
  }

 private:
  std::unordered_map<int32_t, float> exp_table_;
};

2.3 实现 Softmax 函数

接下来将 Softmax 的函数中的指数函数运算全部替换为指数表

template <typename T>
class Softmax {
 public:
  Softmax(int32_t zero_point, float scale) {
    static_assert(
        (std::is_same<T, signed char>::value ||
         std::is_same<T, unsigned char>::value ||
         std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
         std::is_same<T, uint8_t>::value || std::is_same<T, uint16_t>::value),
        "ExpTable only support key type is uint8_t/int8_t/uint16_t/int16_t");

    exp_table_ = new ExpTable<min_, max_>(zero_point, scale);
  }

  virtual ~Softmax() { delete exp_table_; }

  void Run(T *input, float *output, size_t size) {
    for (size_t i = 0; i < size; ++i) {
      output[i] = exp_table_->GetExp(input[i]);
    }

    float sum = std::accumulate(output, output + size, 0.0);
    for (size_t i = 0; i < size; ++i) {
      output[i] /= sum;
    }
  }

 protected:
  constexpr static T min_ = std::numeric_limits<T>::min();
  constexpr static T max_ = std::numeric_limits<T>::max();
  ExpTable<min_, max_> *exp_table_ = nullptr;
};

3 参考文档

posted @ 2024-11-30 17:25  Zheng-Bicheng  阅读(224)  评论(0)    收藏  举报