#!/usr/bin/env python3
"""
cortex_eeg.py -- Single-file Python SDK for the Cortex EEG Analysis Tool.

Usage:
    pip install requests
    python3 cortex_eeg.py  # runs built-in demo

    # In your analysis script:
    from cortex_eeg import analyze, analyze_batch, save_csv

    result = analyze("sub-01.edf")
    print(result["exponent"], result["fit_quality"])

    results = analyze_batch(["sub-01.edf", "sub-02.edf", "sub-03.edf"])
    save_csv(results, "results.csv")

Download: https://cortex-ten-rose.vercel.app/cortex_eeg.py
Backend:  https://rayo-lab-cortex-backend.hf.space
Docs:     https://cortex-ten-rose.vercel.app/methods
"""

import csv
import json
import os
from pathlib import Path
from typing import Optional

try:
    import requests
except ImportError:
    raise ImportError("Install requests first: pip install requests")


DEFAULT_BACKEND = "https://rayo-lab-cortex-backend.hf.space"
DEFAULT_FREQ_MIN = 1.0
DEFAULT_FREQ_MAX = 20.0
MAX_BATCH_SIZE = 20


def analyze(
    filepath: str,
    backend_url: str = DEFAULT_BACKEND,
    freq_min: float = DEFAULT_FREQ_MIN,
    freq_max: float = DEFAULT_FREQ_MAX,
    timeout: int = 120,
) -> dict:
    """
    Analyze a single EEG file.

    Parameters
    ----------
    filepath : str
        Path to the EEG file (.edf, .bdf, .vhdr, .mff, .set, .cnt).
    backend_url : str
        Cortex backend URL. Defaults to the public HuggingFace Space.
    freq_min, freq_max : float
        Aperiodic fitting frequency range in Hz. Default: 1-20 Hz.
    timeout : int
        Request timeout in seconds. Default: 120 (backend may be cold-starting).

    Returns
    -------
    dict with keys: filename, exponent, r2, fit_quality, nearest_population,
                    distance_sd, umap_x, umap_y, error (None on success)
    """
    results = analyze_batch([filepath], backend_url, freq_min, freq_max, timeout)
    return results[0]


def analyze_batch(
    filepaths: list,
    backend_url: str = DEFAULT_BACKEND,
    freq_min: float = DEFAULT_FREQ_MIN,
    freq_max: float = DEFAULT_FREQ_MAX,
    timeout: int = 180,
) -> list:
    """
    Analyze multiple EEG files in batches of up to 20.

    Parameters
    ----------
    filepaths : list of str
        Paths to EEG files (.edf, .bdf, .vhdr, .mff, .set, .cnt).
    backend_url : str
        Cortex backend URL.
    freq_min, freq_max : float
        Fitting frequency range in Hz.
    timeout : int
        Request timeout per batch call in seconds.

    Returns
    -------
    list of dicts -- one per file, same schema as analyze().
    """
    backend = backend_url.rstrip("/")
    url = f"{backend}/analyze-batch"
    all_results = []

    # Split into chunks of MAX_BATCH_SIZE
    chunks = [filepaths[i:i + MAX_BATCH_SIZE] for i in range(0, len(filepaths), MAX_BATCH_SIZE)]

    for chunk_idx, chunk in enumerate(chunks):
        if len(chunks) > 1:
            print(f"Batch {chunk_idx + 1}/{len(chunks)}: sending {len(chunk)} files...")

        files_payload = []
        opened = []
        try:
            for fp in chunk:
                fh = open(fp, "rb")
                opened.append(fh)
                files_payload.append(("files", (Path(fp).name, fh, "application/octet-stream")))

            data = {
                "freq_min": str(freq_min),
                "freq_max": str(freq_max),
            }

            resp = requests.post(url, files=files_payload, data=data, timeout=timeout)
            resp.raise_for_status()
            all_results.extend(resp.json())
        except requests.exceptions.ConnectionError:
            # Backend may be cold-starting -- surface clear message
            for fp in chunk:
                all_results.append({
                    "filename": Path(fp).name,
                    "error": "Connection failed -- backend may be starting up. Retry in 60 seconds.",
                    "exponent": None, "r2": None, "fit_quality": None,
                    "nearest_population": None, "distance_sd": None,
                    "umap_x": None, "umap_y": None,
                })
        except requests.exceptions.HTTPError as e:
            detail = ""
            try:
                detail = e.response.json().get("detail", str(e))
            except Exception:
                detail = str(e)
            for fp in chunk:
                all_results.append({
                    "filename": Path(fp).name,
                    "error": f"HTTP error: {detail}",
                    "exponent": None, "r2": None, "fit_quality": None,
                    "nearest_population": None, "distance_sd": None,
                    "umap_x": None, "umap_y": None,
                })
        finally:
            for fh in opened:
                fh.close()

    return all_results


def save_csv(results: list, output_path: str) -> None:
    """
    Save batch results to a CSV file.

    Parameters
    ----------
    results : list of dicts from analyze() or analyze_batch()
    output_path : str -- path for the output CSV file
    """
    fieldnames = [
        "filename", "exponent", "r2", "fit_quality",
        "nearest_population", "distance_sd", "umap_x", "umap_y", "error",
    ]
    with open(output_path, "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction="ignore")
        writer.writeheader()
        writer.writerows(results)
    print(f"Saved {len(results)} rows -> {output_path}")


def print_result(result: dict) -> None:
    """Pretty-print a single analysis result."""
    if result.get("error"):
        print(f"  ERROR: {result['filename']}: {result['error']}")
        return
    print(f"  {result['filename']}")
    print(f"    exponent:    {result['exponent']:.4f}")
    print(f"    r2:          {result['r2']:.4f}  ({result['fit_quality']})")
    print(f"    nearest pop: {result['nearest_population']}  ({result['distance_sd']:+.2f} SD)")
    print(f"    umap:        ({result['umap_x']:.3f}, {result['umap_y']:.3f})")


if __name__ == "__main__":
    import sys
    import argparse

    parser = argparse.ArgumentParser(
        description="Cortex EEG Analysis Tool -- Python SDK",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Analyze a single file
  python3 cortex_eeg.py sub-01.edf

  # Analyze multiple files and save CSV
  python3 cortex_eeg.py sub-01.edf sub-02.edf sub-03.edf --csv results.csv

  # Custom frequency range
  python3 cortex_eeg.py sub-01.edf --freq-min 2 --freq-max 30

  # Local backend
  python3 cortex_eeg.py sub-01.edf --backend http://localhost:7860
        """,
    )
    parser.add_argument("files", nargs="+", help="EEG files to analyze (.edf, .bdf, .vhdr, .mff, .set, .cnt)")
    parser.add_argument("--backend", default=DEFAULT_BACKEND, help=f"Backend URL (default: {DEFAULT_BACKEND})")
    parser.add_argument("--freq-min", type=float, default=DEFAULT_FREQ_MIN, help="Min fitting frequency Hz")
    parser.add_argument("--freq-max", type=float, default=DEFAULT_FREQ_MAX, help="Max fitting frequency Hz")
    parser.add_argument("--csv", default=None, help="Save results to CSV file")

    args = parser.parse_args()

    print(f"Cortex EEG Analysis -- {len(args.files)} file(s)")
    print(f"Backend: {args.backend}")
    print(f"Fitting range: {args.freq_min}-{args.freq_max} Hz\n")

    results = analyze_batch(args.files, args.backend, args.freq_min, args.freq_max)

    for r in results:
        print_result(r)

    if args.csv:
        save_csv(results, args.csv)
    else:
        # Print JSON to stdout if no CSV requested
        print("\nJSON output:")
        print(json.dumps(results, indent=2))
