import os
import json
import numpy as np

def read_combined_json_files(directory):
    """从指定目录中的所有合并JSON文件中读取数据，并汇总。"""
    combined_data = {}
    for json_file in os.listdir(directory):
        if json_file.endswith('_combined_numbers.json'):
            filepath = os.path.join(directory, json_file)
            with open(filepath, 'r') as file:
                data = json.load(file)
                for key, values in data.items():
                    if key not in combined_data:
                        combined_data[key] = []
                    combined_data[key].append(values)
    return combined_data

def filter_data_by_grid_length(data, expected_length):
    """过滤出长度正确的数据列表。"""
    filtered_data = [values for values in data if len(values) == expected_length]
    if len(filtered_data) < len(data):
        print(f"Filtered out {len(data) - len(filtered_data)} entries due to incorrect length.")
    return filtered_data

def calculate_p95(numbers):
    """计算数字列表的第95百分位数。"""
    if not numbers:
        return None
    return np.percentile(numbers, 95)

def calculate_p95_per_grid(data, expected_length):
    """计算每个格子的p95值。"""
    filtered_data = filter_data_by_grid_length(data, expected_length)
    grid_p95_values = []

    for grid_index in range(expected_length):
        grid_values = [values[grid_index] for values in filtered_data]
        p95_value = calculate_p95(grid_values)
        grid_p95_values.append(p95_value)
    
    return grid_p95_values

def main():
    # 假设收集到的文件在当前目录下的 logs_summary 目录中
    logs_dir = './logs_summary/'
    combined_data = read_combined_json_files(logs_dir)
    
    grid_lengths = {
        "test_cpu": 1,
        "test_gpu": 16,  # 请根据实际数据调整
        "test_mlc_latency_matrix": 4,
        "test_mlc_max_bandwidth": 4,
        "test_pcie_p2p": 7  # 请根据实际数据调整
    }

    for key, data in combined_data.items():
        expected_length = grid_lengths.get(key)
        if expected_length:
            grid_p95_values = calculate_p95_per_grid(data, expected_length)
            print(f"{key} p95 values per grid: {grid_p95_values}")
        else:
            print(f"Warning: No expected length defined for {key}, skipping p95 calculation.")

if __name__ == "__main__":
    main()
