AI 编译器CINN中的OpLowering优化Pass
一、Lower 主逻辑
在 OpLower::Lower()
接口中,主要分为两大类:
- Elementwise类,主要涉及的
OpPattern
包括:kElementwise
、kBroadcast
、kInjective
- Reduce 类,主要涉及的OpPattern包括:
kReduction
std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) {
VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind;
group->input_names.clear();
group->output_names.clear();
if (FLAGS_cinn_ir_schedule) {
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return IRLowerOp(&OpLowerer::IRElementwiseCompute, &OpLowerer::IRElementwiseSchedule, group); // << --- 第一大类 Elementwise相关
case framework::kReduction:
return IRLowerOp(&OpLowerer::IRReduceCompute, &OpLowerer::IRReduceSchedule, group); // << --- 第二大类 Reduce 相关
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
} else {
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
}
}
二、Optimize 逻辑
在 op_lowering.cc
中的 IRLowerOp
的最后,会创建一个 LoweredFunc
对象,并对其调用 optim::Optimize()
函数。
std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(compute, schedule, group){
// .... 省略
auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body);
auto func =
ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers);
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); // <----在函数最后会调用 Optimizer 对函数表达式进行优化,注意与 target 相关
return {func};
}
其接口实现是在 optimize.cc
文件中,主要是对 LoweredFunc
对应的Expr应用各种优化 pass:
// 中层是 optimize.cc 中的:
Expr Optimize(Expr e, Target target, bool runtime_debug_info, bool remove_gpu_for_loops) {
auto copied = IRCopy(e);
// 与 target 无关的通用优化
FoldCINNCallArguments(&copied);
TransformPolyForToFor(&copied);
ReplaceConstParamToInteger(&copied);
CastSimplify(&copied);
Simplify(&copied);
UnrollLoop(&copied);
// 与 target 有关的优化
VectorizeLoops(&copied, target);
MapExternCall(&copied, target); // <---- 此处是这里要关注和讨论的 MapExternCall 优化
// 仅在 CUDA 上的优化
ir::SetCudaAxisInfo(&copied);
RemoveGpuForloopsAxis(&copied);
CudaSyncThreadsDropIfThenElse(&copied);
// 又是与 target 无关的通用优化
RemoveNestedBlock(&copied);
ExternCallMultiOutputShallowStore(&copied);
CastSimplify(&copied);
Simplify(&copied);
IfSimplify(&copied);
// 与调试相关通用优化
InsertDebugLogCallee(&copied);
}
三、各个优化Pass
接下来,我们逐个来研究每个 pass 的角色和作用。
3.1 FoldCINNCallArguments
此 Pass 的功能是通过 FoldCINNCallArgumentsMutator
来实现的:
void FoldCINNCallArguments(Expr* expr) { FoldCINNCallArgumentsMutator()(expr); }
此 Mutator 只关心ir::Block和ir::Store两种类型节点:
struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const ir::Block* op, Expr* expr); // <----- Block
void Visit(const ir::Store* op, Expr* expr); // <----- Store
void MutateCall(ir::Call* call);
private:
// To avoid the same call triggered duplicately.
std::unordered_set<std::string> visited_call_;
};
其中对于 ir::Block
类型的节点,找到所有的 CallType::CINN
的 Expr,然后判断是否已经存在过,若是,则删除 Block 中的此 statement 语句。
这里补充下 ir::Call
节点中的 CallType
枚举类型的值都有哪些:CINN
、Intrinsic
、Extern
、ISL
void Visit(const ir::Block* op, Expr* expr) override {
auto* node = expr->As<ir::Block>();
for (auto it = node->stmts.begin(); it != node->stmts.end();) {
if (it->As<ir::Store>()) {
auto* call = it->As<ir::Store>()->value.As<ir::Call>(); // <---- 针对 x = cinn_call_func(args) 场景?
if (call && call->is_cinn_call()) {
// remove the duplicate calls.
std::string key = utils::GetStreamCnt(Expr(call));
if (visited_call_.count(key)) {
it = node->stmts.erase(it);
continue;
}
ir::IRMutator<>::Visit(&(*it), &(*it)); // <--- 这里会触发下面的 ir::Store 的处理逻辑
visited_call_.insert(key);
continue;
}
}
ir::IRMutator<>::Visit(&(*it), &(*it));
++it;
}
}
对于 ir::Store
类型的节点,仅针对 CallType::CINN
类型的节点调用 MutateCall
函数进行修改和替换;
void Visit(const ir::Store* op, Expr* expr) override {
auto* node = expr->As<ir::Store>();
if (node->value.As<ir::Call>()) {
auto* call = node->value.As<ir::Call>();
switch (call->call_type) {
case ir::CallType::CINN:
MutateCall(call);
*expr = node->value;
break;
case ir::CallType::Intrinsic:
break;
case ir::CallType::Extern:
break;
default:
CINN_NOT_IMPLEMENTED
}
}
}
MuteCall
函数是此 Pass 的最核心逻辑,其作用是 call 节点中所有的输入、输出 args 中的 Tensor 类型,确认其都 defined 了 buffer ,并将 buffer 作为真正的 args 替换原来的 read_args 和 write_args 。
思考:为什么要单独对CINN类型的CallType多做这样一件事情呢?背景是什么?
void MutateCall(ir::Call* call) {
if (call->call_type == ir::CallType::Extern) return;
std::vector<Expr> read_args;
std::vector<Expr> write_args;
for (auto& arg : call->read_args) {
if (arg.as_tensor()) {
CHECK(arg.as_tensor()->buffer.defined()) << "arg tensor [" << arg.as_tensor()->name << "] not has buffer";
read_args.push_back(arg.as_tensor()->buffer);
} else {
read_args.push_back(arg);
}
}
for (auto& arg : call->write_args) {
if (arg.as_tensor()) {
write_args.push_back(arg.as_tensor()->buffer);
} else {
write_args.push_back(arg);
}
}
call->read_args = read_args;
call->write_args = write_args;
}
3.2 ReplaceConstParamToInteger
这个 Pass 比较简单,只针对 ir::Var
类型的节点,如果其 name 是以 _const_
开头的,则取其具体的值,转为Expr
(如 Intmm)
static const char* kIslParamConstPrefix = "_const_";
struct Mutator : public ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::_Var_* op, Expr* expr) override {
if (utils::Startswith(op->name, poly::kIslParamConstPrefix)) {
std::string value = op->name.substr(strlen(poly::kIslParamConstPrefix));
*expr = Expr(std::stoi(value)); // <----- 这里强转为 int 类型,是只存在类似 _const_12 这种情况
}
}
};
} // namespace
void ReplaceConstParamToInteger(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
那这个const
前缀字符串拼接是在哪里做的呢?是在 cinn::poly::ast_gen
中做的,相关逻辑代码如下:
isl::set TransIdentityExtentToContextId(isl::set set) {
std::vector<std::tuple<int, int>> iden_dim_offsets;
for (int i = 0; i < isl_set_dim(set.get(), isl_dim_set); i++) {
if (isl_set_axis_has_noparam_constant_bound(set.get(), i)) {
auto range = isl_set_get_axis_range(set.get(), i);
auto& minv = std::get<0>(range);
auto& maxv = std::get<1>(range);
int min_iv = minv.get_num_si();
int max_iv = maxv.get_num_si();
if (max_iv == min_iv) {
iden_dim_offsets.emplace_back(i, max_iv);
}
}
}
isl::set res_set = set;
for (auto offset_val : iden_dim_offsets) {
auto& offset = std::get<0>(offset_val);
auto& val = std::get<1>(offset_val); // <---- 是个 int 类型
res_set = isl::manage(isl_set_drop_constraints_involving_dims(res_set.copy(), isl_dim_set, offset, 1));
std::string const_param_name = llvm::formatv("{0}{1}", kIslParamConstPrefix, val); //<---- 在此处进行拼接的
std::string cond_str = llvm::formatv(
"{0} <= {1} <= {2}", val, isl_set_get_dim_name(res_set.get(), isl_dim_set, offset), const_param_name);
std::string param_cond_str = llvm::formatv("{0} <= {1} < {2}", val, const_param_name, val + 2);
std::string set_repr = llvm::formatv("[{0}] -> { {1}[{2}]: {3} and {4} }",
const_param_name,
isl_set_get_tuple_name(res_set.get()),
utils::Join(isl_get_dim_names(res_set.get()), ","),
cond_str,
param_cond_str);
VLOG(4) << "repr: " << set_repr;
isl::set new_set(res_set.ctx(), set_repr);
res_set = res_set.intersect(new_set);
}
return res_set;
}
注:通过检索Bert 模型中的GLOG_v=10 的日志,并没有发现
ReplaceConstParamToInteger
生效的地方。
如下是一个 CINN 框架里的单测,可以辅助帮助理解上面这个函数的作用效果,样例中 j=0
,其中会把 0 这个常量值先创建一个 _const_0
,然后做了变换?
TEST(TransIdentityExtentToContextId, basic) {
isl_ctx* ctx = isl_ctx_alloc();
isl::set set(ctx, "{ s[i,j=0,k] : 0<=i<12 and 12<k<32 }");
auto new_set = TransIdentityExtentToContextId(set);
LOG(INFO) << new_set;
ASSERT_EQ(utils::GetStreamCnt(new_set),
"[_const_0] -> { s[i, j, k] : _const_0 <= 1 and 0 <= i <= 11 and 0 <= j <= _const_0 and 13 <= k <= 31 }");
}
3.3 CastSimplify
此Pass 仅会对 constant 的Expr进行处理,比如 IntImm、FloatImm、UIntImm,作用是将其持有的value值按照 ir::Cast.type()
进行数值类型转换,然后包裹一个Expr重新返回。
void CastSimplify(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
struct Mutator : ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
Visit(&node->v(), &node->v()); // <<--- 类似 AST 的 generic_visit,深度优先递归处理 node->v() 节点
if (op->type() == op->v().type()) {
*expr = op->v(); // Caset 1: 如果 value 类型已经与dst_type 一致,则直接返回 node->v() 以替换当前节点。
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \ // <<----- 注意:这里对Float类型进行了特殊处理,因为存在转Float16的场景
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) { // <----- 注意:此pass仅支持 ir::Cast->v()为常量类型的场景
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
在上面流程代码中,我们可以看出对于 FloatImm
类型的处理额外借助了 NormCastValue
这个函数,原因是对于 Float32 到 Float16 的转写,要考虑上溢、下溢、Nan
、Inf
的场景:
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
3.4 Simplify
这个 Pass 包括的子逻辑比较多,单测文件 ir_simplify_test.cc
里可以帮助理解效果:
void Simplify(Expr* expr) {
optim::CastSimplify(expr); // 先调用了 CastsSimplify,这个似乎会比较多余?在递归调用时更会导致效率低下
SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr);
SimplifyIfThenElseMutator()(expr);
common::cas_intervals_t var_intervals;
SimplifyButStoreLoadMutator mutator(var_intervals); // 又额外来了一遍,这里似乎也比较低效?
mutator(expr);
ReplaceFracWithDivMutator()(expr); // 这里将 ir::Frac 替换为了 ir::Div,似乎也不是必要的,没有看到哪里构造了 ir::Frac
}
效果:
// case 1:
C = 1. //shape = [100, 20]
B = C[i, 0] + 1 * 0 + 100 + 24.5
// 经过此 Pass 后:
B = C[i, 0] + 124.5
// case 2:
{
serial for (i, 0, 100)
{
serial for (j, 0, 20)
{
B[i, j] = (X[i + 0, j + 0] + Y[i, j * 0] * 1.f + 0.f * X[i, j] + 25.f + 100.f - 0.f +
9.f * 10000.f * 1.f * 1.f * 0.f
)
}
}
}
// 经过此 Pass 后:
{
serial for (i, 0, 100)
{
serial for (j, 0, 20)
{
B[i, j] = (125.000000f + (X[i, j] + y[i, 0]))
}
}
}
首先看 SimplifyRampMutator
的角色作用,从源码上来看,只关心两种节点:ir::Ramp
和 ir::Add
。
- 对于
ir::Add
节点,如果两个操作数都是ir::Ramp
类型,且其 lanes 属性值是一样的话,则会构建一个ir::Ramp
节点来替换掉ir::Add
节点 - 对于
ir::Ramp
节点,则递归对其base
和stride
属性调用Simplify
函数。
struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }
void Visit(const Ramp* op, Expr* expr) override {
auto* node = expr->As<ir::Ramp>();
CHECK(common::IsPureMath(node->base)) << node->base << "is not a pure math!";
CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!";
;
Simplify(&node->base);
Simplify(&node->stride);
}
// ramp + ramp
void Visit(const Add* op, Expr* expr) override {
auto* node = expr->As<ir::Add>();
Expr a = node->a();
Expr b = node->b();
auto a_ramp = a.As<ir::Ramp>();
auto b_ramp = b.As<ir::Ramp>();
if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) {
Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base); // 这里会做CAS
Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride);
*expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes);
}
}
};
我们这里瞅一眼 ir::Ramp
节点是什么样子的:
//! A linear ramp node.
struct Ramp : public ExprNode<Ramp> {
Expr base, stride;
int lanes;
static Expr Make(Expr base, Expr stride, int lanes);
void Verify() const override;
static const IrNodeTy _node_type_ = IrNodeTy::Ramp;
};
接下来看第二个 SimplifyLoadMutator
的角色,简单理解就是对 X[i+0, j+0]
以及 Y[i, j*0]
进行优化,得到 X[i, j]
,Y[i, 0]
:
struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }
void Visit(const Load* expr, Expr* op) override {
auto* node = op->As<Load>();
for (auto& idx : node->indices) {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_); // << 也是借助了CAS了
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_); // 根据节点类型,分发调用 PartialSimplify 函数
mutator(&idx);
}
}
}
void Visit(const For* op, Expr* expr) override {
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value) {
var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1});
}
auto* node = expr->As<For>();
operator()(&node->body);
operator()(&node->extent);
if (min_i && extent_i) {
var_intervals_.erase(op->loop_var->name);
}
}
common::cas_intervals_t var_intervals_;
};
第三个 SimplifyStoreMutator
的代码逻辑基本与 SimplifyLoadMutator
一致,这里我们不再赘述。
第四个 SimplifyIfThenElseMutator
,这个也很好理解,对 condition
调用 CAS 逻辑:
struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
using ir::IRMutator<>::Visit;
void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<ir::IfThenElse>();
node->condition = common::AutoSimplify(node->condition); // 核心点
if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); // 访问者模式分发
if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); // 访问者模式分发
}
};
第五个 SimplifyButStoreLoadMutator
本来在第二、三个子逻辑会局部触发,这里为何单独触发了一遍?从函数实现了是对其他节点都遍历一遍进行简化处理,唯独除了 Store
和 Load
节点(因为这两个节点主要出现在 ir::For
节点中)
第六个 ReplaceFracWithDivMutator
,这个很有意思,是把所有的 ir::FracOp
替换为 ir::Div
,这两个不一样么?仔细看了下,在一些 CodeGen
模块里,如 codegen_llvm.cc
中,是没有实现 ir::FracOp
里的代码生成逻辑的,只有 ir::Div
实现了。那为什么不直接把 ir::FracOp
节点删除呢?
struct ReplaceFracWithDivMutator : public ir::IRMutator<> {
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
void Visit(const FracOp* op, Expr* expr) override {
auto* node = expr->As<ir::FracOp>();
ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0));
ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1));
*expr = ir::Div::Make(node->operand(0), node->operand(1));
}
};
llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); }
在CINN里我检索了 lang/pe
等模块源码,没有看到在 IR 层面直接使用或构造 ir::Frac
节点的代码,只有单测和不相关模块:
3.5 MapExternCall 逻辑
从调用栈来看, 底层是 map_extern_call.cc
中具体的 MapExternCall
的实现
void MapExternCall(Expr *e, Target target) {
Mutator m(target);
m(e);
}
所有的工作都是交给基于 Ast 的 Mutator 来做的,原理:借助「访问者模式」仅识别和处理 ir::Call
对象:
void Visit(const ir::Call *op, Expr *expr) override {
auto *node = expr->As<ir::Call>();
CHECK(node);
OptimizeConstantPow(node);
if (target.arch == Target::Arch::NVGPU) {
DealWithNvGpuintrinsics(node, expr);
} else {
DealWithCpuintrinsics(node, expr);
}
}
我们比较关心 CUDA 上的变换,进一步看 DealWithNvGpuintrinsics
函数:
void DealWithNvGpuintrinsics(ir::Call *node, Expr *expr) {
auto arg_size = node->read_args.size();
if (arg_size == 0UL) {
// some node like __syncthreads hasn't arguments
return;
}
const auto &dtype = node->read_args.front().type();
const auto &name = node->name;
bool node_in_extern_fp32 = kExternFp32CallsGPU.count(name);
bool node_in_extern_int32 = kExternInt32CallsGPU.count(name);
if (!node_in_extern_fp32 && !node_in_extern_int32) {
return;
}
std::string suffix;
if (dtype.is_int() && node_in_extern_int32) {
if (dtype.is_int(32)) {
suffix = "_int32";
} else if (dtype.is_int(64)) {
suffix = "_int64";
}
} else if (dtype.is_float() && node_in_extern_fp32) {
if (dtype.is_float(64)) {
suffix = "_fp64";
} else if (dtype.is_float(32)) {
suffix = "_fp32";
} else if (dtype.is_float(16)) {
suffix = "_fp16";
}
}
CHECK(!suffix.empty()) << name << " not support data type " << dtype;
std::string extern_func = "cinn_nvgpu_" + name + suffix; // <------ 主要是按照OpNode白名单+dtype拼接要替换的外部 API (其实也是在CINN层里wrapper注册了一层)
*expr = lang::CallExtern(extern_func, node->read_args); // 直接替换 ir::Call 对应的 Expr 对象
}