hidet使用rule based调度

定义computation

整体流程类似于tvm的计算描述

定义输入、输出tensor,指定名称、数据类型和shape

a = tensor_input('a', dtype='float32', shape=[10])
b = tensor_input('b', dtype='float32', shape=[])
b = tensor_input('data', dtype='float16', shape=[1, 3, 224, 224])

使用compute定义计算,指定名称、shape、计算表达式

b = compute('copy', shape=[10], fcompute=lambda i: a[i])

语义等价于

for i1 in range(10):
	b[i1] = a[i1]

此外还有一些参数,使用这些参数可以指定reduce等计算。

封装Task

将计算封装为Task,然后使用hidet提供的rule-based调度器自动将计算生成代码

一个task由名称、tensor输入和输出组成,对应于计算的输入输出

from typing import List
import hidet
from hidet.ir.task import Task


def run_task(task: Task, inputs: List[hidet.Tensor]):
    """Run given task and print inputs and outputs"""
    from hidet.runtime import CompiledTask

    # build the task
    func: CompiledTask = hidet.drivers.build_task(task, target='cpu')

    # run the compiled task
    outputs = func.run_async(inputs)

    print('Task:', task.name)
    print('Inputs:')
    for tensor in inputs:
        print(tensor)
    print('Output:')
    for tensor in outputs:
        print(tensor)
    print()

使用build_task将task lowering到可执行函数,主要包含以下流程

  1. 根据device等信息调度task到scheduler
  2. scheduler将task下降到IRModule
  3. 优化并继续下降IRModule
  4. 根据device进行代码生成
from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else

class AbsTask(Task):
    def __init__(self, input):
        out = compute(
            name='out',
            shape=input.shape,
            fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
        )

        super().__init__(
            name='abs',
            inputs=[input],
            outputs=[out]
        )

封装算子类,输入是hidet Tensor


from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like

class AbsOp(Operator):
    def __init__(self, input: Tensor):
        super().__init__(
            inputs=[input],
            attributes={},
            task=AbsTask(
                input_like(input, 'input')
            )
        )

def abs(input: Tensor) -> Tensor:
    return AbsOp(input).outputs[0]

输出结果在outputs中,通过在torch compile时指定后端为hidet,就能实现runtime的算子替换。

import hidet

from hidet.ir.compute import TensorNode, compute, reduce
from hidet.ir.task import Task
from hidet.ir.expr import if_then_else

hidet.option.cache_dir('./outs/cache')

op_device = "cuda"

class AbsTask(Task):
    def __init__(self, input):
        out = compute(
            name='out',
            shape=input.shape,
            fcompute=lambda *indices: if_then_else(input[indices] < 0, -input[indices], input[indices])
        )

        super().__init__(
            name='abs',
            inputs=[input],
            outputs=[out]
        )

from hidet.graph import Operator, Tensor
from hidet.graph.ops.utils import input_like

class AbsOp(Operator):
    def __init__(self, input: Tensor):
        super().__init__(
            inputs=[input],
            attributes={},
            task=AbsTask(
                input_like(input, 'input')
            )
        )

import torch
from torch import nn
from hidet.graph.frontend.torch.interpreter import (
    register_function,
)

@register_function(torch.abs)
def abs_demo(input: Tensor) -> Tensor:
    return AbsOp(input).outputs[0]

class Model(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = torch.abs(x)
        return x

def run_demo():
    input = hidet.randn([2, 3, 3])
    print(input)

    torch_ref = torch.randn([2, 3, 3], device=op_device)
    print("input: ", torch_ref)
    m = Model()
    m = torch.compile(m, backend='hidet')
    print("output: ", m(torch_ref))
    
run_demo()
input:  tensor([[[ 0.3670,  0.3044, -0.8114],
         [-0.3331,  0.3331, -0.1969],
         [-1.4178,  0.3059,  0.4833]],

        [[-0.6634,  2.7176, -0.4525],
         [ 0.9879, -0.0581,  0.9540],
         [ 0.1877, -0.2522,  0.0652]]], device='cuda:0')
output:  tensor([[[0.3670, 0.3044, 0.8114],
         [0.3331, 0.3331, 0.1969],
         [1.4178, 0.3059, 0.4833]],

        [[0.6634, 2.7176, 0.4525],
         [0.9879, 0.0581, 0.9540],
         [0.1877, 0.2522, 0.0652]]], device='cuda:0')
posted @ 2024-05-22 17:40  Jareth  阅读(47)  评论(0)    收藏  举报