"""
LoRA Overfit Watchdog
======================
Measures ArcFace similarity per checkpoint to monitor + log training.
Runs all 3000 steps without auto-stop, then arcface_analysis.py picks
the optimal checkpoint.

Warning conditions:
  1. AVG declines for N consecutive checks (relative to peak)
  2. 0.8 multiplier AVG > 1.0 multiplier AVG (djdante criterion)

Auto chaining:
  When --next-job-id / --next-dataset are given, the next job auto-starts
  after the current one finishes and is monitored under the same conditions.

Usage:
  python overfit_watchdog.py --job-id <JOB_ID> --dataset <dataset_folder> [options]
  python overfit_watchdog.py --job-id <ID1> --dataset <folder1> --next-job-id <ID2> --next-dataset <folder2>
"""

import os
import sys
import time
import json
import urllib.request
import cv2
import numpy as np
from datetime import datetime

API_BASE = "http://localhost:8675"
CHECK_INTERVAL = 60  # seconds between checks


def imread_unicode(path):
    buf = np.fromfile(path, dtype=np.uint8)
    return cv2.imdecode(buf, cv2.IMREAD_COLOR)


def init_arcface():
    from insightface.app import FaceAnalysis
    app = FaceAnalysis(name="buffalo_l", allowed_modules=["detection", "recognition"])
    app.prepare(ctx_id=-1, det_size=(640, 640))
    return app


def get_embedding(app, img):
    faces = app.get(img)
    if not faces:
        return None
    face = max(faces, key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]))
    return face.normed_embedding


def build_reference_embedding(app, dataset_dir):
    """Compute the mean embedding across the entire dataset."""
    exts = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
    embs = []
    for f in sorted(os.listdir(dataset_dir)):
        if os.path.splitext(f)[1].lower() not in exts:
            continue
        img = imread_unicode(os.path.join(dataset_dir, f))
        if img is None:
            continue
        emb = get_embedding(app, img)
        if emb is not None:
            embs.append(emb)
    if not embs:
        print("[FATAL] No faces found in dataset.")
        sys.exit(1)
    avg = np.mean(embs, axis=0)
    avg = avg / np.linalg.norm(avg)
    print(f"[INFO] Reference embedding built ({len(embs)} images)")
    return avg


def api_get(path):
    try:
        resp = urllib.request.urlopen(f"{API_BASE}{path}", timeout=10)
        return json.loads(resp.read())
    except Exception as e:
        print(f"[WARN] API call failed: {path} — {e}")
        return None


def api_stop_job(job_id):
    """Stop training."""
    result = api_get(f"/api/jobs/{job_id}/stop")
    return result is not None


def api_start_job(job_id):
    """Start training (queue + run queue)."""
    r1 = api_get(f"/api/jobs/{job_id}/start")
    r2 = api_get("/api/queue/0/start")
    return r1 is not None and r2 is not None


def get_sample_paths(job_id):
    """List of sample image paths generated so far."""
    data = api_get(f"/api/jobs/{job_id}/samples")
    if data is None:
        return []
    return data.get("samples", [])


def get_job_status(job_id):
    data = api_get("/api/jobs")
    if data is None:
        return None
    for job in data.get("jobs", []):
        if job["id"] == job_id:
            return job.get("status")
    return None


def parse_sample_filename(path):
    """Extract step and index from a sample filename. e.g. ....__000001100_2.jpg -> (1100, 2)"""
    fname = os.path.basename(path).replace('.jpg', '')
    parts = fname.split('_')
    try:
        step = int(parts[-2])
        idx = int(parts[-1])
        return step, idx
    except (ValueError, IndexError):
        return None, None


def measure_checkpoint(app, ref_emb, sample_paths, step):
    """Measure similarity across the 6 samples for a given step."""
    scores = {}
    for path in sample_paths:
        s, idx = parse_sample_filename(path)
        if s != step:
            continue
        img = imread_unicode(path)
        if img is None:
            continue
        emb = get_embedding(app, img)
        if emb is not None:
            scores[idx] = float(np.dot(ref_emb, emb))
        else:
            scores[idx] = -1.0
    return scores


def analyze_scores(scores_dict):
    """
    scores_dict: {0: 0.75, 1: 0.70, 2: 0.55, 3: 0.50, 4: 0.72, 5: 0.68}
    idx 0,2,4 = multiplier 1.0 / idx 1,3,5 = multiplier 0.8
    Returns: (avg_all, avg_1_0, avg_0_8)
    """
    all_valid = [v for v in scores_dict.values() if v > 0]
    scores_1_0 = [scores_dict.get(i, -1) for i in [0, 2, 4]]
    scores_0_8 = [scores_dict.get(i, -1) for i in [1, 3, 5]]

    avg_all = np.mean(all_valid) if all_valid else 0
    avg_1_0 = np.mean([v for v in scores_1_0 if v > 0]) if any(v > 0 for v in scores_1_0) else 0
    avg_0_8 = np.mean([v for v in scores_0_8 if v > 0]) if any(v > 0 for v in scores_0_8) else 0

    return avg_all, avg_1_0, avg_0_8


def main():
    import argparse
    parser = argparse.ArgumentParser(description="LoRA overfit auto-monitoring")
    parser.add_argument("--job-id", required=True, help="AI-Toolkit job ID")
    parser.add_argument("--dataset", required=True, help="Dataset folder path")
    parser.add_argument("--decline-count", type=int, default=3, help="Stop after N consecutive AVG declines (default: 3)")
    parser.add_argument("--check-interval", type=int, default=60, help="Check interval in seconds (default: 60)")
    parser.add_argument("--next-job-id", default=None, help="Next training job ID (auto-chain)")
    parser.add_argument("--next-dataset", default=None, help="Next training dataset folder")
    args = parser.parse_args()

    # Initialize ArcFace
    print("\n[INFO] Loading ArcFace model...")
    app = init_arcface()

    # Build job queue
    jobs = [(args.job_id, args.dataset)]
    if args.next_job_id and args.next_dataset:
        jobs.append((args.next_job_id, args.next_dataset))

    for job_idx, (job_id, dataset_dir) in enumerate(jobs):
        job_name = "JOB %d/%d" % (job_idx + 1, len(jobs))

        print("\n" + "=" * 70)
        print("  LoRA Overfit Watchdog — %s" % job_name)
        print("=" * 70)
        print("  Job ID:    %s" % job_id)
        print("  Dataset:   %s" % dataset_dir)
        print("  Warn cond: AVG declines %d in a row OR 0.8 > 1.0 inversion (no auto-stop)" % args.decline_count)
        print("  Interval:  %ds" % args.check_interval)
        print("=" * 70)

        # If this is a follow-up job, start it
        if job_idx > 0:
            print("\n[ACTION] Starting next training: %s" % job_id)
            if api_start_job(job_id):
                print("[OK] Training started. Waiting 30s (init)...")
                time.sleep(30)
            else:
                print("[FATAL] Failed to start training. Start it manually.")
                break

        ref_emb = build_reference_embedding(app, dataset_dir)

        measured_steps = {}
        peak_avg = 0.0
        decline_streak = 0
        stopped_by_watchdog = False

        print("\n[INFO] Monitoring started — %s" % datetime.now().strftime("%H:%M:%S"))
        header = "%6s %6s | %6s %6s | %6s %6s | %6s %6s | %6s %6s %6s | %s" % (
            "", "Step", "A-1.0", "A-0.8", "B-1.0", "B-0.8", "C-1.0", "C-0.8", "AVG", "1.0", "0.8", "Status")
        print(header)
        print("-" * 95)

        while True:
            status = get_job_status(job_id)
            if status != "running":
                print("\n[INFO] Training status: %s" % status)
                if status == "completed":
                    print("[INFO] Reached 4000 steps — finished cleanly with no overfit detected")
                break

            sample_paths = get_sample_paths(job_id)
            if not sample_paths:
                time.sleep(args.check_interval)
                continue

            steps_available = set()
            for p in sample_paths:
                s, _ = parse_sample_filename(p)
                if s is not None:
                    steps_available.add(s)

            new_steps = sorted(steps_available - set(measured_steps.keys()))
            did_stop = False

            for step in new_steps:
                step_samples = [p for p in sample_paths if parse_sample_filename(p)[0] == step]
                if len(step_samples) < 6:
                    continue

                scores = measure_checkpoint(app, ref_emb, sample_paths, step)
                if len(scores) < 4:
                    continue

                avg_all, avg_1_0, avg_0_8 = analyze_scores(scores)
                measured_steps[step] = (avg_all, avg_1_0, avg_0_8)

                flag = ""
                if avg_all > peak_avg:
                    peak_avg = avg_all
                    decline_streak = 0
                    flag = "NEW PEAK"
                else:
                    decline_streak += 1
                    flag = "decline %d/%d" % (decline_streak, args.decline_count)

                reversed_flag = ""
                if avg_0_8 > avg_1_0 and avg_1_0 > 0 and avg_0_8 > 0:
                    reversed_flag = " *** 0.8 > 1.0 inversion! ***"

                vals = [scores.get(i, -1) for i in range(6)]
                parts = []
                for v in vals:
                    parts.append("%6.3f" % v if v > 0 else "   N/A")
                now = datetime.now().strftime("%H:%M")
                print("%6s %6d | %s %s | %s %s | %s %s | %6.3f %6.3f %6.3f | %s%s" % (
                    now, step, parts[0], parts[1], parts[2], parts[3], parts[4], parts[5],
                    avg_all, avg_1_0, avg_0_8, flag, reversed_flag
                ))
                sys.stdout.flush()

                # Warn-only logging (no auto-stop — runs all 3000 steps)
                if decline_streak >= args.decline_count:
                    print("  [WARN] AVG declined %d in a row (peak %.3f -> current %.3f)" % (args.decline_count, peak_avg, avg_all))

                if avg_0_8 > avg_1_0 and avg_1_0 > 0.3 and avg_0_8 > 0.3:
                    print("  [WARN] 0.8(%.3f) > 1.0(%.3f) inversion detected" % (avg_0_8, avg_1_0))

            if did_stop:
                break

            time.sleep(args.check_interval)

        # Final summary for this job
        if measured_steps:
            print("\n" + "=" * 70)
            print("  %s final summary" % job_name)
            print("=" * 70)
            peak_step = max(measured_steps, key=lambda s: measured_steps[s][0])
            peak_val = measured_steps[peak_step][0]
            print("  Checkpoints measured: %d" % len(measured_steps))
            print("  Peak AVG: %.3f (step %d)" % (peak_val, peak_step))
            # Pull job name
            data = api_get("/api/jobs")
            jname = job_id[:8]
            if data:
                for j in data.get("jobs", []):
                    if j["id"] == job_id:
                        jname = j.get("name", jname)
                        break
            print("  Recommended checkpoint: %s_%09d.safetensors" % (jname, peak_step))
            print("=" * 70)

    print("\n[DONE] All monitoring complete — %s" % datetime.now().strftime("%Y-%m-%d %H:%M:%S"))


if __name__ == "__main__":
    main()
