pytorch jit script的学习

pytorch jit的学习

TorchScript:

TorchScript是一个静态类型的Python子集,可以直接编写(使用@torch.jit。
脚本装饰器)或通过跟踪从Python代码自动生成。
在使用跟踪时,通过只记录张量上的实际操作符,并简单地执行和丢弃周围的其他Python代码,代码会自动转换为Python的这个子集。
当使用@torch.jit直接编写TorchScript时。
脚本装饰器,程序员必须只使用TorchScript中支持的Python子集。
本节记录了TorchScript支持的内容,就好像它是一种独立语言的语言参考。
本参考中未提及的Python特性都不是TorchScript的一部分。
有关可用Pytorch张量方法、模块和函数的完整参考,请参阅内置函数。
作为Python的子集,任何有效的TorchScript函数也是一个有效的Python函数。
这使得禁用TorchScript和使用标准Python工具(如pdb)调试该函数成为可能。
反之则不然:有许多有效的Python程序不是有效的TorchScript程序。
相反,TorchScript专门关注Python的一些特性,这些特性需要在PyTorch中表示神经网络模型.
以上节选自pytorch官网介绍

简而言之:pytorch script 以一种特定语言描述从python导出模型,并可在任意非python环境中导入使用

简单案例

import torch
import torchvision

class MyScriptModule(torch.nn.Module):
    def __init__(self):
        super(MyScriptModule, self).__init__()
        self.means = torch.nn.Parameter(torch.tensor([103.939, 116.779, 123.68])
                                        .resize_(1, 3, 1, 1))
        self.resnet = torch.jit.trace(torchvision.models.resnet18(),
                                      torch.rand(1, 3, 224, 224))

    def forward(self, input):
        return self.resnet(input - self.means)

my_script_module = torch.jit.script(MyScriptModule())

利用装饰器:

import torch
# 跟踪函数
@torch.jit.script
def foo(x, tup):
    # type: (int, Tuple[Tensor, Tensor]) -> Tensor
    t0, t1 = tup
    return t0 + t1 + x

print(foo(3, (torch.rand(3), torch.rand(3))))

常用操作可见官网:
TorchScript

script, trace差异

import torch
def foo(x, y):
   	return 2*x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3),torch.rand(3)))

trace仅记录张量上的操作,因此它不会记录任何控制流操作,如if语句或循环。
当你的模型涉及复杂控制流操作,得用script

@torch.jit.script
def foo(len):
  # type: (int) -> torch.Tensor
  rv = torch.zeros(3, 4)
  for i in range(len):
    if i < 10:
        rv = rv - 1.0
    else:
        rv = rv + 1.0
  return rv

torch.script 保存读取

Script中的核心数据结构是ScriptModule。 它是Torch的nn.Module的类似物,代表整个模型作为子模块树。 与普通模块一样,ScriptModule中的每个单独模块都可以包含子模块,参数和方法
对于sciptmodule的保存跟普通moudle类似

cpu_model = gpu_model.cpu()
sample_input_cpu = sample_input_gpu.cpu()
traced_cpu = torch.jit.trace(cpu_model, sample_input_cpu)
# 利用torch.jit.save保存模型
torch.jit.save(traced_cpu, "cpu.pt")

traced_gpu = torch.jit.trace(gpu_model, sample_input_gpu)
torch.jit.save(traced_gpu, "gpu.pt")

# ... later, when using the model:

if use_gpu:
  # 对应利用jit.load读取模型
  model = torch.jit.load("gpu.pt")
else:
  model = torch.jit.load("cpu.pt")

model(input)

Script Module的解释图表

TorchScript使用静态单一赋值(SSA)中间表示(IR)来表示计算。 这种格式的指令包括ATen(PyTorch的C ++后端)运算符和其他原始运算符,包括循环和条件的控制流运算符。 举个例子:

code属性:

graph属性:

打印子层

idx = 0
for name, cr in m.named_children():
    print(f"{idx} layer: {name}")
    print(cr)
    idx+=1

额外小知识

利用torch.jit.save保存的.pth文件可以通过压缩软件打开,可以直接看到里面的code

pytorch jit中的一些优化

torch._C._jit_set_profiling_mode()

torch.jit.optimized_execution()

参考:

https://blog.csdn.net/xxradon/article/details/86504906

posted @ 2022-04-04 15:17  HiIcy  阅读(1859)  评论(0)    收藏  举报