基于查表法实现 Softmax 函数
1 简介
在深度学习领域,Softmax 函数是一种广泛应用的激活函数,尤其在多分类问题中表现突出。它能够将原始的得分转换为概率分布,使每个类别的概率值都处于 0 到 1 之间且总和为 1。Softmax 内含有大量的指数运算,这使得它在嵌入式端(例如 RV1106)上计算较慢。
针对量化模型,模型的输出一般为一个 8/16/64 位数据类型的原始数据 data、一个 float 类型的 scale 和一个 int 类型的 zero_point,我们可以用查表法对这种特殊情况做优化。以 int8 数据类型为例,它可以表示的范围是从 -128 到 127,共 256 个不同的数据,反量化后的 float 类型数据也是 256 个。针对小批量有限数据,用查表法做优化是比较合适的。
2 代码实现
简介中已经提到,Softmax 的计算瓶颈主要在指数函数的运算上,因此优化 Softmax 函数的核心在于优化指数运算。
2.1 实现反量化函数
首先,模型的输出是一个 int8 类型的量化值,但是 softmax 的输入应该是一个 float 类型的反量化值,因此我们需要先实现一个反量化函数,这里我参考了 Rockchip 团队的实现方法。
template <typename T>
inline float DequantizeToFP32(T data, int32_t zero_point, float scale) {
static_assert(
(std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value ||
std::is_same<T, int16_t>::value || std::is_same<T, uint16_t>::value ||
std::is_same<T, int32_t>::value || std::is_same<T, uint32_t>::value),
"DequantizeToFP32 only support key type is "
"uint8_t/int8_t/int16_t/uint16_t/int32_t/uint32_t");
return ((float)data - (float)zero_point) * scale;
}
2.2 实现指数运算表
接下来我们建立一个数据表,表的索引是 int8 类型的数据,表的值是先对 int8 类型的数据做反量化,再求指数后的值,代码实现如下:
class ExpTable {
public:
ExpTable(int32_t zero_point, float scale) {
static_assert((min_key <= max_key),
"ExpTable need min_key less than max_key");
for (int32_t key = min_key; key <= max_key; ++key) {
exp_table_[key] = std::exp(DequantizeToFP32(key, zero_point, scale));
}
}
float GetExp(int32_t index) const {
if (index < min_key || index > max_key) {
LOGGER_ERROR("index < min_key || index > max_key. [%d, %d, %d]", index, min_key, max_key);
return 0;
}
return exp_table_.at(index);
}
private:
std::unordered_map<int32_t, float> exp_table_;
};
2.3 实现 Softmax 函数
接下来将 Softmax 的函数中的指数函数运算全部替换为指数表
template <typename T>
class Softmax {
public:
Softmax(int32_t zero_point, float scale) {
static_assert(
(std::is_same<T, signed char>::value ||
std::is_same<T, unsigned char>::value ||
std::is_same<T, int8_t>::value || std::is_same<T, int16_t>::value ||
std::is_same<T, uint8_t>::value || std::is_same<T, uint16_t>::value),
"ExpTable only support key type is uint8_t/int8_t/uint16_t/int16_t");
exp_table_ = new ExpTable<min_, max_>(zero_point, scale);
}
virtual ~Softmax() { delete exp_table_; }
void Run(T *input, float *output, size_t size) {
for (size_t i = 0; i < size; ++i) {
output[i] = exp_table_->GetExp(input[i]);
}
float sum = std::accumulate(output, output + size, 0.0);
for (size_t i = 0; i < size; ++i) {
output[i] /= sum;
}
}
protected:
constexpr static T min_ = std::numeric_limits<T>::min();
constexpr static T max_ = std::numeric_limits<T>::max();
ExpTable<min_, max_> *exp_table_ = nullptr;
};

浙公网安备 33010602011771号