pytorch jit script的学习
pytorch jit的学习
TorchScript:
TorchScript是一个静态类型的Python子集,可以直接编写(使用@torch.jit。
脚本装饰器)或通过跟踪从Python代码自动生成。
在使用跟踪时,通过只记录张量上的实际操作符,并简单地执行和丢弃周围的其他Python代码,代码会自动转换为Python的这个子集。
当使用@torch.jit直接编写TorchScript时。
脚本装饰器,程序员必须只使用TorchScript中支持的Python子集。
本节记录了TorchScript支持的内容,就好像它是一种独立语言的语言参考。
本参考中未提及的Python特性都不是TorchScript的一部分。
有关可用Pytorch张量方法、模块和函数的完整参考,请参阅内置函数。
作为Python的子集,任何有效的TorchScript函数也是一个有效的Python函数。
这使得禁用TorchScript和使用标准Python工具(如pdb)调试该函数成为可能。
反之则不然:有许多有效的Python程序不是有效的TorchScript程序。
相反,TorchScript专门关注Python的一些特性,这些特性需要在PyTorch中表示神经网络模型.
以上节选自pytorch官网介绍
简而言之:pytorch script 以一种特定语言描述从python导出模型,并可在任意非python环境中导入使用
简单案例
import torch
import torchvision
class MyScriptModule(torch.nn.Module):
def __init__(self):
super(MyScriptModule, self).__init__()
self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
.resize_(1, 3, 1, 1))
self.resnet = torch.jit.trace(torchvision.models.resnet18(),
torch.rand(1, 3, 224, 224))
def forward(self, input):
return self.resnet(input - self.means)
my_script_module = torch.jit.script(MyScriptModule())
利用装饰器:
import torch
# 跟踪函数
@torch.jit.script
def foo(x, tup):
# type: (int, Tuple[Tensor, Tensor]) -> Tensor
t0, t1 = tup
return t0 + t1 + x
print(foo(3, (torch.rand(3), torch.rand(3))))
常用操作可见官网:
TorchScript
script, trace差异
import torch
def foo(x, y):
return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3),torch.rand(3)))
trace仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。
当你的模型涉及复杂控制流操作,得用script
@torch.jit.script
def foo(len):
# type: (int) -> torch.Tensor
rv = torch.zeros(3, 4)
for i in range(len):
if i < 10:
rv = rv - 1.0
else:
rv = rv + 1.0
return rv
torch.script 保存读取
Script中的核心数据结构是ScriptModule。 它是Torch的nn.Module的类似物,代表整个模型作为子模块树。 与普通模块一样,ScriptModule中的每个单独模块都可以包含子模块,参数和方法
对于sciptmodule的保存跟普通moudle类似
cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
# 利用torch.jit.save保存模型
torch.jit.save(traced_cpu, "cpu.pt")
traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")
# ... later, when using the model:
if use_gpu:
# 对应利用jit.load读取模型
model = torch.jit.load("gpu.pt")
else:
model = torch.jit.load("cpu.pt")
model(input)
Script Module的解释图表
TorchScript使用静态单一赋值(SSA)中间表示(IR)来表示计算。 这种格式的指令包括ATen(PyTorch的C ++后端)运算符和其他原始运算符,包括循环和条件的控制流运算符。 举个例子:

code属性:

graph属性:

打印子层
idx = 0
for name, cr in m.named_children():
print(f"{idx} layer: {name}")
print(cr)
idx+=1
额外小知识
利用torch.jit.save保存的.pth文件可以通过压缩软件打开,可以直接看到里面的code

pytorch jit中的一些优化
torch._C._jit_set_profiling_mode()
torch.jit.optimized_execution()
参考:

浙公网安备 33010602011771号