PatternMatcher-Pytorch

import os
import torch
import torch.nn as nn
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.compile_fx import compile_fx


cache_dir = "/home/xytpai/workspace/work/temp"
envs = {
    "TORCHINDUCTOR_CACHE_DIR": os.path.join(cache_dir, "inductor"),
}
for k,v in envs.items():
    os.environ[k] = v


@torch.library.custom_op("myops::add", mutates_args=["result"])
def myops_add(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:
    torch.add(x, y, out=result)


@torch.library.custom_op("myops::relu", mutates_args=["result"])
def myops_relu(result: torch.Tensor, x: torch.Tensor) -> None:
    result.copy_(x)
    torch.relu_(result)


@torch.library.custom_op("myops::add_relu", mutates_args=["result"])
def myops_add_relu(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:
    z = x + y
    result.copy_(z)
    torch.relu_(result)


def pattern(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
    at1 = auto_functionalized(
        torch.ops.myops.add.default,
        result=result_add,
        x=x, 
        y=y)
    at2 = auto_functionalized(
        torch.ops.myops.relu.default,
        result=result,
        x=at1[1])
    return at2[1]


def replacement(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
    at = auto_functionalized(
        torch.ops.myops.add_relu.default,
        result=result,
        x=x, 
        y=y)
    return at[1]


inputs = [
    torch.empty(5, 4, dtype=torch.float),  # result
    torch.empty(5, 4, dtype=torch.float),  # result_add
    torch.empty(5, 4, dtype=torch.float),  # x
    torch.empty(5, 4, dtype=torch.float),  # y
]


pm_pass = pm.PatternMatcherPass(pass_name="fusion_pass")
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)


def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
    print(graph)
    _count = pm_pass.apply(graph)
    print(_count)
    print(graph)
    graph.eliminate_dead_code()
    return graph


def custom_backend(graph: torch.fx.GraphModule, example_inputs):
    from torch._inductor import config
    current_config = config.get_config_copy()
    current_config["post_grad_custom_post_pass"] = custom_pass
    return compile_fx(graph, example_inputs, config_patches=current_config)


# def fw_add(x, y):
#     out = torch.empty_like(x)
#     torch.ops.myops.add(out, x, y)
#     return out


# def fw_relu(x):
#     out = torch.empty_like(x)
#     torch.ops.myops.relu(out, x)
#     return out


@torch.compile(backend=custom_backend)
class SimpleModel(nn.Module):
    @torch._inductor.config.patch(enable_auto_functionalized_v2=False)
    # def forward(self, x, y):
    #     x = fw_add(x, y)
    #     x = fw_relu(x)
    #     return x
    def forward(self, x, y):
        out = torch.empty_like(x)
        out2 = torch.empty_like(x)
        torch.ops.myops.add(out, x, y)
        torch.ops.myops.relu(out2, out)
        return out2


model = SimpleModel()
x = torch.rand(10, 10)
y = torch.rand(10, 10)
z = model(x, y)
posted @ 2025-09-22 01:02  xytpai  阅读(17)  评论(0)    收藏  举报