import os
import torch
import torch.nn as nn
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.compile_fx import compile_fx
cache_dir = "/home/xytpai/workspace/work/temp"
envs = {
"TORCHINDUCTOR_CACHE_DIR": os.path.join(cache_dir, "inductor"),
}
for k,v in envs.items():
os.environ[k] = v
@torch.library.custom_op("myops::add", mutates_args=["result"])
def myops_add(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:
torch.add(x, y, out=result)
@torch.library.custom_op("myops::relu", mutates_args=["result"])
def myops_relu(result: torch.Tensor, x: torch.Tensor) -> None:
result.copy_(x)
torch.relu_(result)
@torch.library.custom_op("myops::add_relu", mutates_args=["result"])
def myops_add_relu(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:
z = x + y
result.copy_(z)
torch.relu_(result)
def pattern(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
at1 = auto_functionalized(
torch.ops.myops.add.default,
result=result_add,
x=x,
y=y)
at2 = auto_functionalized(
torch.ops.myops.relu.default,
result=result,
x=at1[1])
return at2[1]
def replacement(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):
at = auto_functionalized(
torch.ops.myops.add_relu.default,
result=result,
x=x,
y=y)
return at[1]
inputs = [
torch.empty(5, 4, dtype=torch.float), # result
torch.empty(5, 4, dtype=torch.float), # result_add
torch.empty(5, 4, dtype=torch.float), # x
torch.empty(5, 4, dtype=torch.float), # y
]
pm_pass = pm.PatternMatcherPass(pass_name="fusion_pass")
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
print(graph)
_count = pm_pass.apply(graph)
print(_count)
print(graph)
graph.eliminate_dead_code()
return graph
def custom_backend(graph: torch.fx.GraphModule, example_inputs):
from torch._inductor import config
current_config = config.get_config_copy()
current_config["post_grad_custom_post_pass"] = custom_pass
return compile_fx(graph, example_inputs, config_patches=current_config)
# def fw_add(x, y):
# out = torch.empty_like(x)
# torch.ops.myops.add(out, x, y)
# return out
# def fw_relu(x):
# out = torch.empty_like(x)
# torch.ops.myops.relu(out, x)
# return out
@torch.compile(backend=custom_backend)
class SimpleModel(nn.Module):
@torch._inductor.config.patch(enable_auto_functionalized_v2=False)
# def forward(self, x, y):
# x = fw_add(x, y)
# x = fw_relu(x)
# return x
def forward(self, x, y):
out = torch.empty_like(x)
out2 = torch.empty_like(x)
torch.ops.myops.add(out, x, y)
torch.ops.myops.relu(out2, out)
return out2
model = SimpleModel()
x = torch.rand(10, 10)
y = torch.rand(10, 10)
z = model(x, y)