LLVM official Tutorial:Chap4
一般的常量折叠优化
第三章中的IR Builder在编译代码时会给出简单的常量优化
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}
如果是通过解析输入所构建的AST的文字转录,不进行常量折叠,上面这段代码就会是:
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 2.000000e+00, 1.000000e+00
%addtmp1 = fadd double %addtmp, %x
ret double %addtmp1
}
如上所述,常量折叠是一种非常常见且非常重要的优化方法:很多语言实现者在AST表示中都实现了常量折叠支持。
因为所有构建LLVM IR的调用都要经过LLVM IR Builder,所以当你调用LLVM IR Builder时,Builder本身会检查是否存在需要常量折叠的地方。如果是这样,它只是进行常数折叠并返回常数,而不是创建指令。
这很简单:)。
实际上,我们建议在生成这样的代码时始终使用IR Builder。它的使用没有“语法开销”(你不必通过到处进行常量检查来丑化你的编译器) ,它可以显著地减少在某些情况下生成的LLVM IR 的数量(特别是对于使用宏预处理器或使用大量常量的语言)。
另一方面,IR Builder的局限性在于,它在构建代码时将所有的分析都内联在一起。比如有一个稍微复杂一点的例子:
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1
ret double %multmp
}
在这种情况下,乘法的LHS
和RHS
值相同。我们希望看到这个生成”tmp = x+3; Result = tmp * tmp;”而不是两次计算"x+3"。
不幸的是,没有任何局部分析能够检测和纠正这一点。这需要两个转换:表达式的重新关联(使添加的词汇相同)和公共子表达式消除(CSE),以删除冗余的添加指令。但是,LLVM以“passes”的形式提供了广泛的优化。
LLVM Optimization Passes
LLVM提供了很多 optimization passes,LLVM与其他系统不同的是,不要求一组优化适用于任何语言和任何情况。LLVM允许编译器的实现者完全自主决定在不同情况和不同需求下使用不同的优化。
作为一个具体的例子。LLVM支持两种passes:
- ”whole module“ passes,通常是整个文件,如果是在链接时运行,也可以是整个程序的某个重要部分
- ”per-function” passes,每次只对单个函数操作,不考虑其他函数
- 更多内容请参考如何写一个Pass和LLVM Passes列表
回到我们当前的工作,对我们所做的Kaleidoscope来说,我们正在动态的生成函数,每次一个,所以我们接下来要使用per-function优化。如果想要创建一个完整的”static Kaleidoscope Compiler“,我们会在现有代码的基础上,推迟运行优化器,在整个文件被parsed之后再优化,而不是像现在这样每次用户输入一个expression就优化。
start optimize
为了进行per-function的优化,我们需要设置一个FunctionPassManager来保存和组织我们想要运行的LLVM优化。
一旦我们有了它,我们就可以添加一组要运行的优化功能。我们需要为每个想要优化的模块创建一个新的FunctionPassManager,所以我们将编写一个函数来创建和初始化module和pass manager:
// 基于上一章的代码添加
void InitializeModuleAndPassManager(void) {
// Open a new module.
// 打开一个新的module
TheModule = std::make_unique<Module>("my cool jit", TheContext);
// Create a new pass manager attached to it.
// 创建一个新的pass manager并链接到刚刚打开的module上
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
// Do simple "peephole" optimizations and bit-twiddling optzns.
// 简单的"peephole"和 bit-twiddling优化
TheFPM->add(createInstructionCombiningPass());
// Reassociate expressions.
// 重关联表达式
TheFPM->add(createReassociatePass());
// Eliminate Common SubExpressions.
// 消除公共子表达式
TheFPM->add(createGVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
// 简化控制流程图(比如消除不可达块)
TheFPM->add(createCFGSimplificationPass());
TheFPM->doInitialization();
}
上面这段代码初始化全局模块:TheModule 和pass manager:TheFPM,并将TheFPM和TheModule关联起来。
一旦pass manager被设置,就可以向其中添加一系列LLVM passes。
在本例中,我们选择添加了四个优化passes。我们在这里选择的passes是一组非常标准的“cleanup”优化,对各种各样的代码都很有用。我不会深入研究它们是怎么做的,但是相信我,它们是一个很好的起点:)。
当PassManager被初始化和设置完毕之后,就可以使用它。它被调用的时机在函数被创建之后,被返回之前:也就是FunctionAST::codegen()
被调用,但是还未返回之前:
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency.
// 验证生成的代码,检查一致性
verifyFunction(*TheFunction);
// Optimize the function.
TheFPM->run(*TheFunction);
return TheFunction;
}
FunctionPassManager在适当的地方优化和更新LLVM Function*
,改进(我们希望如此)它的主体。有了这个,我们可以再次尝试上面的测试:
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp
ret double %multmp
}
正如预期的那样,我们现在得到了经过精心优化的代码,从该函数的每次执行中保存了一条浮点添加指令。
LLVM 提供了在不同情况下可以使用的各种优化。一些关于各种passes的文档是可用的,但是并不是非常完整。另一个好的方法是查看Clang
运行的passes。“opt”工具允许您从命令行试验通行证,因此您可以看到他们是否做了什么。
现在我们已经有了来自前端的合理的代码,接下来,让我们讨论一下如何执行它!
Adding a JIT Compiler
LLVM IR 中可用的代码可以使用各种各样的工具。例如,您可以对其运行优化(如前文所述) ,也可以将其转储为文本形式或二进制形式,可以将代码编译为汇编文件(.s文件)。或者可以通过JIT编译它。LLVM IR表示的优点是,它是编译器不同部分之间的“通用货币”。
在本节中,我们将向解释器添加 JIT 编译器支持。我们希望 Kaleidoscope 的基本思想是让用户像现在一样输入函数体,但是立即计算他们输入的顶级表达式。例如,如果他们输入“1 + 2;”,我们应该计算并打印出3。如果他们定义了一个函数,他们应该能够从命令行调用它。
为此,我们首先准备环境,为当前本机目标创建代码,并声明和初始化 JIT。这是通过调用一些 InitializeNativeTarget\*
函数,添加一个全局变量 TheJIT,并在 main 中初始化它来实现的:
// 声明
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
// 新增部分,初始化
TheJIT = std::make_unique<KaleidoscopeJIT>();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}
我们还需要设置JIT的数据布局:
void InitializeModuleAndPassManager(void) {
// Open a new module.
TheModule = std::make_unique<Module>("my cool jit", TheContext);
TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
// Create a new pass manager attached to it.
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
...
KaleidoscopeJIT类是为了本教程而建立的一个简单的JIT,可以在llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h中查看。在之后的章节我们会探讨它是如何工作的并添加一些新的特性。它的API非常简单:
addModule
:添加一个LLVM IR模块到JIT中,使它的函数可以执行;removeModule
:移除一个module,释放该Module中所有相关的代码的内存;findSymbol
:允许我们查找编译后的代码的指针
我们可以修改top-level expression的parse代码来使用这些API:
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
// 将top-level表达式计算为匿名函数
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// JIT the module containing the anonymous expression, keeping a handle so we can free it later.
// 对保存了匿名表达式的module进行JIT,维护一个句柄H,以便在之后进行释放工作
auto H = TheJIT->addModule(std::move(TheModule));
// 还记得Initial函数中做了哪些工作;
//1. 初始化TheModule;2. 添加TheFPM
InitializeModuleAndPassManager();
// Search the JIT for the __anon_expr symbol.
// 在JIT中搜索__anon_expr符号
auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
assert(ExprSymbol && "Function not found");
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
// 找到符号的地址,并将其转换为正确的函数type(本例中是返回值double,无参数)。
// 将其作为native function调用
double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
// 释放H
TheJIT->removeModule(H);
}
如果parse并且codegen成功,下一步就是将包含顶级表达式的模块添加到JIT中。
我们通过调用addModule
来做到这一点,它会触发模块中所有函数的代码生成,并返回一个句柄,该句柄可用于稍后从JIT中删除该模块。
一旦这个模块被添加到JIT中,它就不能再被修改了,所以我们还通过调用InitializeModuleAndPassManager()
打开一个新模块来保存后续的代码。
将模块添加到JIT之后,我们需要获得一个指向最终生成代码的指针。为此,我们调用JIT的findSymbol方法,并传递顶级表达式函数的名称:__anon_expr
。因为我们刚刚添加了这个函数,所以我们断言findSymbol返回了一个结果。
接下来,我们通过对符号调用getAddress()
来获得__anon_expr
函数的内存地址。回想一下,我们将顶级表达式编译成一个自包含的LLVM函数,该函数不接受参数,并返回计算后的double值。因为LLVM JIT编译器匹配本机平台ABI,这意味着您可以直接将结果指针转换为该类型的函数指针并直接调用它。这意味着,JIT编译的代码和静态链接到应用程序中的本地机器码之间没有区别。
最后,由于我们不支持顶级表达式的重新求值,所以在释放相关内存时,我们将从JIT中删除该模块。但是,请记住,我们在前面(通过InitializeModuleAndPassManager
)创建的模块仍然是打开的,并等待添加新代码。
只有这两个变化,让我们看看kaleidoscope现在是如何工作的:
ready> 4+5;
Read top-level expression:
define double @0() {
entry:
ret double 9.000000e+00
}
Evaluated to 9.000000
好吧,这看起来基本上是有效的。函数的转储显示了我们为输入的每个顶级表达式合成的“总是返回double的无参数函数”。这演示了非常基本的功能,但是我们能做更多吗?
ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
%multmp = fmul double %y, 2.000000e+00
%addtmp = fadd double %multmp, %x
ret double %addtmp
}
ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
%calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
ret double %calltmp
}
Evaluated to 24.000000
ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!
函数定义和调用也可以工作,但是最后一行出了很大的问题。这个调用看起来是有效的,那么发生了什么?正如你可能已经从API中猜到的那样,Module是JIT的一个分配单元,而testfunc
是该Mudole中包含匿名表达式的一部分。当我们从JIT中删除该模块以释放用于匿名表达式的内存时,我们删除了testfunc
的定义。然后,当我们尝试第二次调用testfunc
时,JIT再也找不到它了。
解决这个问题的最简单方法是将匿名表达式与其他函数定义放在单独的模块中。只要每个被调用的函数都有一个原型,并且在调用之前将其添加到JIT中,JIT就会很高兴地跨模块边界解析函数调用。通过将匿名表达式放在不同的模块中,我们可以在不影响其他函数的情况下删除它。
事实上,我们将更进一步,把每个函数放在它自己的模块中。这样做允许我们利用KaleidoscopeJIT的一个有用属性,它将使我们的环境更像REPL:函数可以多次添加到JIT中(不像每个函数都必须有唯一定义的模块)。当你在KaleidoscopeJIT中查找一个符号时,它总是返回最近的定义:
ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 1.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 2.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 4.000000
为了让每个函数驻留在它自己的模块中,我们需要一种方法,在我们打开的每个新模块中重新生成之前的函数声明。
首先明确一件事,每次我们在终端输入一个top-level表达式,比如一个函数定义:def foo(x) x+1;
,或者一个函数调用foo(2);
,都会经历lexer→parse→codegen的过程,之前提到过codegen的过程由IR Builder管理,codegen()
产生的代码与其当前所在TheMudole
相关联。
同时在上文JIT的过程中,我们通过std::move
将当前的变量TheModule
(比如叫module1)移动到JIT中,并重新创建新的TheModule
(比如叫module2,通过那个initialize...()
函数创建),用于之后的终端输入,JIT完成后,删除与其绑定的TheModule
,也就是module1,那么我们接下来在终端所输入的内容都会在module2中处理,也就无法调用之前在module1中定义生成的代码。
那么如何处理这个问题,请先大致看一下代码:
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
// 首先查看当前Mudole中有无对应的Proto,如果有的话就直接返回,此处的auto类型应该是Function*
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing prototype.
// 如果当前Module中没有,则在全局变量FunctionProtos中查找,如果找到了就说明之前定义过同名函数,直接进行codegen()
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
// 如果之前从未定义过对应的Proto,则返回nullptr
return nullptr;
}
...
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
// getFunction()函数实现了在当前Module和全局中同时查找
Function *CalleeF = getFunction(Callee);
...
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a reference to it for use below.
// 将FunctionAST::Proto的管理权交给FunctionProtos这个map,不过对其复制一个引用用来管理
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
首先需要一个全局变量:FunctionProtos
,保存了所有的函数的最近一次的定义。其类型应该为std::map<string, unique_ptr<PrototypeAST>>
(在完整代码中可以看到)。这个全局变量的作用在于让函数定义脱离Module,独立管理。
并且添加辅助一个方法:getFunction()
。
用来替换函数调用,也就是CallExprAST::codegen()
方法中第一行的TheModule->getFunction()
。
从最顶层,也就是终端的输入来看,所以我们要处理的有两点,第一是函数定义,第二是函数调用,也就是说我们需要改动的代码在CallExprAST::codegen()
和FunctionAST::codegen()
中。
- 对于
CallExprAST::codegen()
改动在于脱离TheModule来查找函数定义 - 对于
FunctionAST::codegen()
的改动在于,函数定义时同时,添加到全局变量FunctionProtos
中去。 - 具体细节请参考代码。
接下来,还需要更新一下HandleDefinition
and HandleExtern
:
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
TheJIT->addModule(std::move(TheModule));
InitializeModuleAndPassManager();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
在HandleDefinition
中,我们添加了两行代码,将新定义的函数传递给JIT,并打开一个新模块。在HandleExtern
中,我们只需要添加一行代码来将原型添加到FunctionProtos
中。
做了这些更改后,让我们再次尝试我们的REPL(这次我删除了匿名函数的转储,你现在应该明白了:):
ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000
即使有了这段简单的代码,我们还是得到了一些令人惊讶的强大功能——来看看:
ready> extern sin(x);
Read extern:
declare double @sin(double)
ready> extern cos(x);
Read extern:
declare double @cos(double)
ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
ret double 0x3FEAED548F090CEE
}
Evaluated to 0.841471
ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
%calltmp = call double @sin(double %x)
%multmp = fmul double %calltmp, %calltmp
%calltmp2 = call double @cos(double %x)
%multmp4 = fmul double %calltmp2, %calltmp2
%addtmp = fadd double %multmp, %multmp4
ret double %addtmp
}
ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
%calltmp = call double @foo(double 4.000000e+00)
ret double %calltmp
}
Evaluated to 1.000000
JIT是怎么知道sin和cos的?答案非常简单:KaleidoscopeJIT有十分简单的符号解析规则,可以用来查找在所有Module中都没有定义的符号:首先,它搜索的所有已经添加到JIT的模块,从最新的到最久的,并取最新的定义。如果在JIT中没有找到定义,它就会返回到Kaleidoscope进程本身上调用“dlsym("sin")
”。因为“sin
”是在JIT的地址空间中定义的,所以它只是简单地修补模块中的调用,以直接调用sin的libm版本。但在某些情况下,这甚至更进一步:因为sin和cos是标准数学函数的名称,当使用常量调用函数时(如上面的“sin(1.0)
”),常数文件夹将直接求出正确的结果。
在将来,我们将看到如何通过调整这个符号解析规则来启用各种有用的特性,从安全性(限制JIT代码可用的符号集)到基于符号名的动态代码生成,甚至是延迟编译。
符号解析规则的一个直接好处是,我们现在可以通过编写任意c++代码来实现操作来扩展语言。例如,如果我们添加:
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
/// putchard - putchar接受一个double值并返回0。
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
注意,对于Windows,我们需要实际导出函数,因为动态符号加载器将使用GetProcAddress来查找符号。
现在,我们可以通过使用类似于“extern putchard(x);
",它在控制台上打印小写的'x
'(120是'x'
的ASCII码)。类似的代码可以用于实现文件I/O、控制台输入和Kaleidoscope中的许多其他功能。
这就完成了Kaleidoscope教程的JIT和优化器章节。此时,我们可以编译一种非图灵完备的编程语言,并以用户驱动的方式对其进行优化和JIT编译。接下来,我们将研究如何用控制流结构扩展该语言,并在此过程中解决一些有趣的LLVM IR问题。
Full Code Listing
编译
# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy
如果在Linux上编译,请确保还添加了“-rdynamic”选项。这确保了外部函数在运行时被正确解析。
完整代码
第一行的”../include/KaleidoscopeJIT.h”,请使用sudo find / -name "KaleidoscopeJIT.h"
查找绝对路径并替换
```cpp
#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
tok_eof = -1,
// commands
tok_def = -2,
tok_extern = -3,
// primary
tok_identifier = -4,
tok_number = -5
};
static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal; // Filled in if tok_number
/// gettok - Return the next token from standard input.
static int gettok() {
static int LastChar = ' ';
// Skip any whitespace.
while (isspace(LastChar))
LastChar = getchar();
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
IdentifierStr = LastChar;
while (isalnum((LastChar = getchar())))
IdentifierStr += LastChar;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "extern")
return tok_extern;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = getchar();
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = getchar();
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return gettok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
int ThisChar = LastChar;
LastChar = getchar();
return ThisChar;
}
//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//
namespace {
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string Name;
public:
VariableExprAST(const std::string &Name) : Name(Name) {}
Value *codegen() override;
};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
Value *codegen() override;
};
/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: Callee(Callee), Args(std::move(Args)) {}
Value *codegen() override;
};
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
public:
PrototypeAST(const std::string &Name, std::vector<std::string> Args)
: Name(Name), Args(std::move(Args)) {}
Function *codegen();
const std::string &getName() const { return Name; }
};
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprAST> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprAST> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
Function *codegen();
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
/// token the parser is looking at. getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }
/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;
/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
if (!isascii(CurTok))
return -1;
// Make sure it's a declared binop.
int TokPrec = BinopPrecedence[CurTok];
if (TokPrec <= 0)
return -1;
return TokPrec;
}
/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
fprintf(stderr, "Error: %s\n", Str);
return nullptr;
}
std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
LogError(Str);
return nullptr;
}
static std::unique_ptr<ExprAST> ParseExpression();
/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
auto Result = std::make_unique<NumberExprAST>(NumVal);
getNextToken(); // consume the number
return std::move(Result);
}
/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (CurTok != ')')
return LogError("expected ')'");
getNextToken(); // eat ).
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string IdName = IdentifierStr;
getNextToken(); // eat identifier.
if (CurTok != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(IdName);
// Call.
getNextToken(); // eat (
std::vector<std::unique_ptr<ExprAST>> Args;
if (CurTok != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (CurTok == ')')
break;
if (CurTok != ',')
return LogError("Expected ')' or ',' in argument list");
getNextToken();
}
}
// Eat the ')'.
getNextToken();
return std::make_unique<CallExprAST>(IdName, std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
switch (CurTok) {
default:
return LogError("unknown token when expecting an expression");
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
}
}
/// binoprhs
/// ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = CurTok;
getNextToken(); // eat binop
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return nullptr;
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS =
std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
}
}
/// expression
/// ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// prototype
/// ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
if (CurTok != tok_identifier)
return LogErrorP("Expected function name in prototype");
std::string FnName = IdentifierStr;
getNextToken();
if (CurTok != '(')
return LogErrorP("Expected '(' in prototype");
std::vector<std::string> ArgNames;
while (getNextToken() == tok_identifier)
ArgNames.push_back(IdentifierStr);
if (CurTok != ')')
return LogErrorP("Expected ')' in prototype");
// success.
getNextToken(); // eat ')'.
return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}
/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
getNextToken(); // eat def.
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto E = ParseExpression())
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
return nullptr;
}
/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
if (auto E = ParseExpression()) {
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}
/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
getNextToken(); // eat extern.
return ParsePrototype();
}
//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
return LogErrorV("Unknown variable name");
return V;
}
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
}
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Run the optimizer on the function.
TheFPM->run(*TheFunction);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//
static void InitializeModuleAndPassManager() {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create a new pass manager attached to it.
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->add(createInstructionCombiningPass());
// Reassociate expressions.
TheFPM->add(createReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->add(createGVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->add(createCFGSimplificationPass());
TheFPM->doInitialization();
}
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndPassManager();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndPassManager();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
/// top ::= definition | external | expression | ';'
static void MainLoop() {
while (true) {
fprintf(stderr, "ready> ");
switch (CurTok) {
case tok_eof:
return;
case ';': // ignore top-level semicolons.
getNextToken();
break;
case tok_def:
HandleDefinition();
break;
case tok_extern:
HandleExtern();
break;
default:
HandleTopLevelExpression();
break;
}
}
}
//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
fprintf(stderr, "%f\n", X);
return 0;
}
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
InitializeModuleAndPassManager();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}
```