import torch
import triton
import numpy as np
import os
import time

data_type = torch.bfloat16
device = 'cuda'

M = 15360
N = 18176
K = 8192



A = torch.randn((M, K), dtype=data_type, device=device)
B = torch.randn((K, N), dtype=data_type, device=device)

# 将 bfloat16 转换为 float8_e4m3fn
A_fp8 = A.to(dtype=torch.float8_e4m3fn)
B_fp8 = B.to(dtype=torch.float8_e4m3fn).t().contiguous().t()

def do_fp8_matmul():
    scale_a = torch.tensor(1.0, device=device)
    scale_b = torch.tensor(1.0, device=device)
    # 实测表明：out_dtype=float16或bfloat16性能区别不大，use_fast_accum=True
    torch._scaled_mm(A_fp8, B_fp8, scale_a, scale_b, out_dtype=data_type)

# 使用 triton.testing.do_bench 进行基准测试
avg_time_ms = triton.testing.do_bench(do_fp8_matmul, warmup=5, rep=10)




# 计算 FLOPs 和 TFLOPS
flops = 2 * M * N * K
tflops = flops / (avg_time_ms * 1e9)

print(f"Average time: {avg_time_ms:.6f} ms")
print(f"Performance: {tflops:.2f} TFLOPS")

