from pathlib import Path

from .file_operations import read_file
from .logging import logger


def get_pci_devices() -> list[str]:
    return _get_pci_usb_devices("pci", "vendor", "device")


def get_usb_devices() -> list[str]:
    return _get_pci_usb_devices("usb", "idVendor", "idProduct")


def _get_pci_usb_devices(bus: str, vendor_file_name: str, device_file_name: str) -> list[str]:
    """
    Returns a list of vendor:device strings for PCI or USB.
    Example: ["8086:1911", "10ec:8168"]
    """

    if bus not in ["pci", "usb"]:
        logger.error(f"Invalid bus type: {bus}")
        return []

    base = Path(f"/sys/bus/{bus}/devices/")
    logger.debug(f"Inspecting bus path: {base}")

    if not base.exists():
        logger.warning(f"Bus path does not exist: {base}")
        return []

    devices: list[str] = []

    for dev in base.iterdir():
        if not dev.is_dir():
            continue

        vendor_file = dev / vendor_file_name
        device_file = dev / device_file_name

        vendor = read_file(vendor_file)
        device = read_file(device_file)

        if vendor and device:
            vendor = vendor.lower()
            device = device.lower()

            if vendor.startswith("0x"):
                vendor = vendor[2:]
            if device.startswith("0x"):
                device = device[2:]

            # ensure 4-digit hex formatting
            vendor = vendor.zfill(4)
            device = device.zfill(4)

            id_str = f"{vendor}:{device}"
            devices.append(id_str)
            logger.debug(f"Found device: {id_str}")

    logger.debug(f"Collected {len(devices)} {bus.upper()} devices")

    return devices


def count_gpu_devices() -> int:
    """
    Counts GPUs by looking at PCI class files beginning with 0x0300.
    """
    base = Path("/sys/bus/pci/devices")
    if not base.exists():
        logger.warning("PCI devices path missing")
        return 0

    count = 0

    for dev in base.iterdir():
        class_file = dev / "class"
        class_hex = read_file(class_file)

        if class_hex and class_hex.startswith("0x0300"):
            count += 1
            logger.debug(f"GPU detected: {dev.name}")

    logger.info(f"Detected {count} GPU(s)")
    return count


def get_hardware_specs() -> dict:
    """
    Returns a hardware summary dictionary:
      {
         'PCI': [...],
         'USB': [...],
         'GPU_COUNT': 2
      }
    """

    pci = get_pci_devices()
    usb = get_usb_devices()
    gpus = count_gpu_devices()

    result = {
        "PCI": pci,
        "USB": usb,
        "GPU_COUNT": gpus,
    }

    logger.debug("Hardware specs collected")
    logger.debug(result)

    return result
