#!/usr/bin/env python3

import argparse
import hashlib
import json
import os
import re
import shutil
import subprocess
import sys
import pathlib


def run_gcc_command(cmd):
    """Run a GCC command and handle errors properly"""
    try:
        result = subprocess.run(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            check=False,  # Don't raise exception, we'll handle errors ourselves
        )

        if result.returncode != 0:
            print(f"Error running GCC command: {' '.join(cmd)}")
            print(f"Return code: {result.returncode}")
            print(f"Error output: {result.stderr}")
            return None

        return result.stdout
    except Exception as e:
        print(f"Exception while running GCC: {e}")
        return None


def get_kernel_headers_from_deps(prefix: pathlib.Path, deps_output: str):
    """Extract kernel header paths from GCC dependency output"""
    headers = set()

    # Common kernel header patterns
    escaped_prefix = re.escape(str(prefix))
    patterns = [
        rf"{escaped_prefix}/(linux/[^ :\\\n]+)",
        rf"{escaped_prefix}/(asm/[^ :\\\n]+)",
        rf"{escaped_prefix}/(asm-generic/[^ :\\\n]+)",
    ]

    for pattern in patterns:
        headers.update(re.findall(pattern, deps_output))

    return {h.rstrip("\\:") for h in headers}


def get_file_hash(path: pathlib.Path) -> str:
    return hashlib.md5(open(path, "rb").read()).hexdigest()


def copy_header(header_path, khdr_dir, dest_dir):
    """Copy a header and its dependencies recursively"""

    src_abspath = khdr_dir.absolute() / header_path
    dst_abspath = dest_dir.absolute() / header_path
    copy_reason = None

    if not dst_abspath.exists():
        copy_reason = "new"
    elif get_file_hash(src_abspath) != get_file_hash(dst_abspath):
        copy_reason = "changed"
    else:
        return

    # Create target directory
    dst_absdir = dst_abspath.parent
    os.makedirs(dst_absdir, exist_ok=True)

    # Copy the header
    shutil.copy2(src_abspath, dst_abspath)
    print(f"Copied ({copy_reason}): {src_abspath} -> {dst_abspath}")


def main():
    parser = argparse.ArgumentParser(
        description="Copy required Linux kernel headers using compile_commands.json"
    )
    parser.add_argument(
        "srcdir", type=pathlib.Path, help="Path to the sources directory"
    )
    parser.add_argument(
        "builddir", type=pathlib.Path, help="Path to the build directory"
    )
    parser.add_argument(
        "destdir", type=pathlib.Path, help="Destination directory of the kernel headers"
    )
    parser.add_argument(
        "--kernel-hdr-dir",
        type=pathlib.Path,
        default=pathlib.Path("/usr/include"),
        help="Path to the kernel UAPI headers directory (default: /usr/include)",
    )

    args = parser.parse_args()

    # Check if compile_commands.json exists
    ccmds_path = args.builddir / "compile_commands.json"
    if not ccmds_path.exists():
        print(f"Error: {ccmds_path} does not exist.")
        sys.exit(1)

    # Load compile_commands.json
    try:
        with open(ccmds_path, "r") as f:
            ccmds = json.load(f)
    except json.JSONDecodeError as e:
        print(f"Error parsing {ccmds_path}: {e}")
        sys.exit(1)
    except Exception as e:
        print(f"Error reading {ccmds_path}: {e}")
        sys.exit(1)

    all_headers = set()

    srcdir = args.srcdir.absolute()

    # Process each compilation command
    for entry in ccmds:
        if "file" not in entry or "command" not in entry:
            continue

        src_path = pathlib.Path(entry["file"])
        cmd = entry["command"]

        # Ignore tests files
        if not src_path.is_relative_to(srcdir / "src"):
            continue

        print(f"Processing: {src_path}")

        # Extract include paths
        include_args = [
            f"-I{args.kernel_hdr_dir}",
            f"-I{args.srcdir}/src",
            f"-I{args.builddir}/include",
            f"-I{args.builddir}/src/bfcli/generated/include",
        ]

        # Use GCC to get dependencies
        gcc_cmd = ["gcc", "-M"] + include_args + [str(src_path)]
        deps_output = run_gcc_command(gcc_cmd)

        if deps_output:
            # Extract kernel headers from dependencies
            headers = get_kernel_headers_from_deps(args.kernel_hdr_dir, deps_output)
            all_headers.update(headers)

    # Copy all headers and their dependencies
    for header in all_headers:
        copy_header(header, args.kernel_hdr_dir, args.destdir)

    print("Kernel header copy completed!")


if __name__ == "__main__":
    main()
