import torch
import triton
import numpy as np
import os
import time

# 修改测试对应的数据类型
data_type = torch.float16

M = 15360
N = 18176
K = 8192

A = torch.randn((M, K), dtype=data_type, device='cuda')
B = torch.randn((K, N), dtype=data_type, device='cuda')

def compute_kernel():
    torch.matmul(A, B)

# 使用 triton.testing.do_bench 进行基准测试
avg_time_ms = triton.testing.do_bench(compute_kernel, warmup=5, rep=10)

flops = 2 * M * N * K

# 计算 TFLOPS
tflops = flops / (avg_time_ms * 1e9)
print(f"Average time: {avg_time_ms:.6f} ms")
print(f"Performance: {tflops:.2f} TFLOPS")

