import contextlib
import time
import torch
import torch.nn.functional as F
import os
import logging

# 设置日志记录
log_dir = "/root/presurelog"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, "train_op.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format='%(message)s')

CUR_RANK = 0


class Data:
    a_list = []
    b_list = []
    c_list = []
    c_host_list = []

    @classmethod
    def reset(cls, len):
        cls.a_list = [None] * len
        cls.b_list = [None] * len
        cls.c_list = [None] * len
        cls.c_host_list = [None] * len


@contextlib.contextmanager
def add_nvtx_event(event_name):
    torch.cuda.nvtx.range_push(event_name)
    yield
    torch.cuda.nvtx.range_pop()


def bench_with_loop(loop_iters, func, *args, enable_print=True, **kwargs):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_time = time.time()
    start_event.record()
    with add_nvtx_event(func.__name__):
        for i in range(loop_iters):
            func(i, *args, **kwargs)
    end_event.record()
    end_time = time.time()
    torch.cuda.synchronize()

    elapsed_time = start_event.elapsed_time(end_event) * 1000 / loop_iters
    elapsed_time_cpu = (end_time - start_time) * 1000 * 1000 / loop_iters

    if enable_print:
        logging.info(
            f"|{func.__name__.rjust(20)}|{(elapsed_time_cpu):>16.5f} us|{(elapsed_time):>16.5f} us|"
        )


DEFAULT_DIM = [512, 1024, 2048, 4096, 8192, 16384]


def op_bench(warmup_iters, loop_iters, dtype=torch.bfloat16):
    def init_a(i, M, N):
        Data.a_list[i] = torch.rand(
            [M, N], dtype=dtype, device="cpu", pin_memory=True
        )

    def init_b(i, N, K):
        Data.b_list[i] = torch.rand(
            [N, K], dtype=dtype, device=f"cuda:{CUR_RANK}"
        )

    def empty_c(i, M, K):
        Data.c_host_list[i] = torch.empty(
            [M, K], dtype=dtype, device="cpu", pin_memory=True
        )

    def H2D(i):
        Data.a_list[i] = Data.a_list[i].to(f"cuda:{CUR_RANK}")

    def silu(i):
        Data.b_list[i] = F.silu(Data.b_list[i])

    def sigmoid(i):
        Data.b_list[i] = F.sigmoid(Data.b_list[i])

    def matmul(i):
        Data.c_list[i] = torch.matmul(Data.a_list[i], Data.b_list[i])

    def norm(i):
        Data.c_list[i] = F.normalize(Data.c_list[i], p=2)

    def D2H(i):
        Data.c_host_list[i].copy_(Data.c_list[i])

    for dim in DEFAULT_DIM:
        # Warming up
        Data.reset(warmup_iters)
        bench_with_loop(warmup_iters, init_a, dim, dim, enable_print=False)
        bench_with_loop(warmup_iters, init_b, dim, dim, enable_print=False)
        bench_with_loop(warmup_iters, empty_c, dim, dim, enable_print=False)
        bench_with_loop(warmup_iters, H2D, enable_print=False)
        bench_with_loop(warmup_iters, silu, enable_print=False)
        bench_with_loop(warmup_iters, sigmoid, enable_print=False)
        bench_with_loop(warmup_iters, matmul, enable_print=False)
        bench_with_loop(warmup_iters, norm, enable_print=False)
        bench_with_loop(warmup_iters, D2H, enable_print=False)
        torch.cuda.synchronize()

        # Loop
        Data.reset(loop_iters)
        torch.cuda.empty_cache()
        title = "Avg Cost for dim " + str(dim)
        logging.info(f"+{'-' * 60}+")
        logging.info(f"|{title.center(60)}|")
        logging.info(f"+{'-' * 60}+")
        logging.info(
            f"|{'function name'.rjust(20)}|{'CPU cost'.rjust(19)}|{'GPU cost'.rjust(19)}|"
        )
        logging.info(f"+{'-' * 60}+")
        bench_with_loop(loop_iters, init_a, dim, dim)
        bench_with_loop(loop_iters, init_b, dim, dim)
        bench_with_loop(loop_iters, empty_c, dim, dim)
        bench_with_loop(loop_iters, H2D)
        bench_with_loop(loop_iters, silu)
        bench_with_loop(loop_iters, sigmoid)
        bench_with_loop(loop_iters, matmul)
        bench_with_loop(loop_iters, norm)
        bench_with_loop(loop_iters, D2H)
        torch.cuda.synchronize()
        logging.info(f"+{'-' * 60}+")
        logging.info("")


op_bench(5, 20)
