extract_triton_kernels.py

import sys

filename = sys.argv[1]
with open(filename, 'r') as f:
    lines = f.readlines()

def extract_info(line):
    line = line.split()
    name = line[0].strip()
    self_gpu_time = line[6].strip()
    num_of_calls = int(line[10].strip())
    if self_gpu_time.endswith('ms'):
        self_gpu_time = float(self_gpu_time.replace('ms', ''))
    elif self_gpu_time.endswith('us'):
        self_gpu_time = float(self_gpu_time.replace('us', '')) / 1000.0
    elif self_gpu_time.endswith('s'):
        self_gpu_time = float(self_gpu_time.replace('s', '')) * 1000.0
    return [name, self_gpu_time, num_of_calls]

total_timems = 0
triton_kernel_info = []
aten_kernel_info = []
for line in lines:
    line = line.strip().replace('XPU Triton kernel:', '')
    if 'triton_' in line:
        triton_kernel_info.append(extract_info(line))
    elif 'aten::' in line:
        aten_kernel_info.append(extract_info(line))

print('opname, self_gpu_timems, calls')
for item in triton_kernel_info:
    name, self_gpu_time, num_of_calls = item
    print(f'{name}, {self_gpu_time}, {num_of_calls}')
print('\nopname, self_gpu_timems, calls')
for item in aten_kernel_info:
    name, self_gpu_time, num_of_calls = item
    print(f'{name}, {self_gpu_time}, {num_of_calls}')
posted @ 2023-12-15 10:05  xytpai  阅读(3)  评论(0编辑  收藏  举报