#!/usr/bin/env python3
import time
import curses
import subprocess
import re

def EnableTelemetry():
    paths = []
    devices = []

    result = subprocess.run('find /sys/devices/ -name "telemetry"', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    if result.returncode != 0:
        print(f"Error executing 'find' command: {result.stderr}")
        return devices, paths

    original_array = result.stdout.split()
    output_telem = sorted(original_array, key=lambda x: (x.split(':')[1], 16))

    result = subprocess.run("adf_ctl status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    if result.returncode != 0:
        print(f"Error executing 'adf_ctl status': {result.stderr}")
        return devices, paths

    device_info_pattern = re.compile(r"(\w+)\s+-\s+type:\s+(\w+),\s+.*\s+bsf:\s+([\da-fA-F:.]+),\s+.*\s+state:\s+(\w+)")
    for match in device_info_pattern.finditer(result.stdout):
        name, device_type, bsf, state = match.groups()
        if device_type == "4xxx" and state == "up":
            bus = bsf.split(':')[1]
            devices.append((name, bus))

    paths = output_telem
    set_paths = [path for path in paths if any(device[1] in path for device in devices)]

    if not set_paths:
        print("No telemetry supported QAT endpoints found... exiting.")
        quit()

    for path in set_paths:
        control_file_name = f"{path}/control"
        try:
            with open(control_file_name, 'r') as control_file:
                output = control_file.read().strip()
                if output == "off":
                    with open(control_file_name, 'w') as control_file:
                        control_file.write("1")
        except Exception as e:
            print(f"Failed to enable telemetry for {path}: {e}")
            break

    return devices, paths

def calculate_utilization(data, prefix, num_slices):
    keys = [f"{prefix}{i}" for i in range(num_slices)]
    values = [int(data.get(key, 0)) for key in keys]
    return round(sum(values) / len(values)) if values else 0

def get_slice_counts(first_path):
    device_data_path = first_path + "/device_data"
    result = subprocess.run(f"cat {device_data_path}", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    output = result.stdout.split()
    data = {output[i]: output[i + 1] for i in range(0, len(output), 2)}

    num_cpr_slices = len([key for key in data if key.startswith('util_cpr')])
    num_dcpr_slices = len([key for key in data if key.startswith('util_dcpr')])
    num_pke_slices = len([key for key in data if key.startswith('util_pke')])
    num_cph_slices = len([key for key in data if key.startswith('util_cph')])
    num_ath_slices = len([key for key in data if key.startswith('util_ath')])
    num_ucs_slices = len([key for key in data if key.startswith('util_ucs')])

    slice_counts = {
        'cpr': num_cpr_slices,
        'dcpr': num_dcpr_slices,
        'pke': num_pke_slices,
        'cph': num_cph_slices,
        'ath': num_ath_slices,
        'ucs': num_ucs_slices
    }
    
    return slice_counts

def pbar(window, devices, set_paths, slice_counts):
    refresh_counter = 0

    while True:
        try:
            refresh_counter += 1

            window.addstr(0, 10, "Intel(R) QuickAssist Device Utilization")
            window.addstr(2, 10, "Device\t%Comp\t%Decomp\t%PKE\t%Cipher\t%Auth\t%UCS\tLatency(ns)")
            window.addstr(3, 10, "=" * 73)

            for count, device in enumerate(devices):
                device_data_path = set_paths[count] + "/device_data"
                result = subprocess.run(f"cat {device_data_path}", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
                output = result.stdout.split()

                data = {output[i]: output[i + 1] for i in range(0, len(output), 2)}

                decompress_utilization = calculate_utilization(data, 'util_dcpr', slice_counts['dcpr'])
                pke_utilization = calculate_utilization(data, 'util_pke', slice_counts['pke'])
                cph_utilization = calculate_utilization(data, 'util_cph', slice_counts['cph'])
                ath_utilization = calculate_utilization(data, 'util_ath', slice_counts['ath'])
                usc_utilization = calculate_utilization(data, 'util_ucs', slice_counts['ucs'])
                latency = data.get("lat_acc_avg", '0')

                window.addstr(4 + count, 10, f"{device[0]}\t{data.get('util_cpr0', '0')}\t{decompress_utilization}\t{pke_utilization}\t{cph_utilization}\t{ath_utilization}\t{usc_utilization}\t{latency}    ")

            window.addstr(4 + len(devices), 10, "=" * 73)
            window.refresh()
            time.sleep(2)

            if refresh_counter % 5 == 0:
                window.clear()
                devices, set_paths = EnableTelemetry()
                slice_counts = get_slice_counts(set_paths[0])

        except KeyboardInterrupt:
            break
        except Exception as e:
            print(f"An error occurred: {e}")
            break

if __name__ == "__main__":
    devices, paths = EnableTelemetry()
    if not devices:
        print("No QAT devices found... exiting.")
        exit()

    slice_counts = get_slice_counts(paths[0])
    curses.wrapper(pbar, devices, paths, slice_counts)