import torch
import triton
import numpy as np
import os
import time

# 修改测试对应的数据类型
#data_type = torch.bfloat16

# 启用 TF32 精度（适用于 Ampere 架构及以上的 NVIDIA GPU）
#torch.backends.cuda.matmul.allow_tf32 = True
#torch.backends.cudnn.allow_tf32 = True  # 如果使用 cuDNN 卷积，也可启用

# 设置数据类型为 float32
data_type = torch.float32

M = 15360
N = 18176
K = 8192

A = torch.randn((M, K), dtype=data_type, device='cuda')
B = torch.randn((K, N), dtype=data_type, device='cuda')

def compute_kernel():
    torch.matmul(A, B)

# 使用 triton.testing.do_bench 进行基准测试
avg_time_ms = triton.testing.do_bench(compute_kernel, warmup=5, rep=10)

flops = 2 * M * N * K

# 计算 TFLOPS
tflops = flops / (avg_time_ms * 1e9)
print(f"Average time: {avg_time_ms:.6f} ms")
print(f"Performance: {tflops:.2f} TFLOPS")

