"""
LoRA Retraining Auto-Run Script
================================
Runs training in joseonnamja -> mzyeoja order, then uses
arcface_analysis.py to identify the best checkpoint.

Usage:
  cd G:\StabilityMatrix\Packages\AI-Toolkit
  venv\Scripts\python.exe "G:\작업\조선왕자\배우 로라 데이터셋\run_training_all.py"
"""

import subprocess
import sys
import os
from datetime import datetime

BASE_DIR = r"G:\작업\조선왕자\배우 로라 데이터셋"
TOOLKIT_DIR = r"G:\StabilityMatrix\Packages\AI-Toolkit"
PYTHON = os.path.join(TOOLKIT_DIR, "venv", "Scripts", "python.exe")
RUN_PY = os.path.join(TOOLKIT_DIR, "run.py")

JOBS = [
    {
        "name": "joseonnamja",
        "config": os.path.join(BASE_DIR, "train_joseonnamja.yaml"),
        "dataset": os.path.join(TOOLKIT_DIR, "datasets", "joseonnamja"),
        "output": os.path.join(TOOLKIT_DIR, "output", "joseonnamja"),
    },
    {
        "name": "mzyeoja",
        "config": os.path.join(BASE_DIR, "train_mzyeoja.yaml"),
        "dataset": os.path.join(TOOLKIT_DIR, "datasets", "mzyeoja"),
        "output": os.path.join(TOOLKIT_DIR, "output", "mzyeoja"),
    },
]

LOG_FILE = os.path.join(BASE_DIR, "training_run_log.txt")

def log(msg):
    ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    line = f"[{ts}] {msg}"
    print(line)
    with open(LOG_FILE, "a", encoding="utf-8") as f:
        f.write(line + "\n")


def run_training(config_path, job_name):
    """Run training via AI-Toolkit run.py."""
    log(f"=== {job_name} training started ===")
    log(f"Config: {config_path}")

    # run.py must be executed from the AI-Toolkit folder
    result = subprocess.run(
        [PYTHON, RUN_PY, config_path],
        cwd=TOOLKIT_DIR,
        capture_output=False,  # stream output live
    )

    if result.returncode == 0:
        log(f"=== {job_name} training complete (success) ===")
        return True
    else:
        log(f"=== {job_name} training failed (exit code: {result.returncode}) ===")
        return False


def run_arcface_analysis(job_name, dataset_dir, output_dir):
    """Run arcface_analysis.py to pick the best checkpoint."""
    analysis_script = os.path.join(BASE_DIR, "arcface_analysis.py")
    if not os.path.exists(analysis_script):
        log(f"[WARN] arcface_analysis.py not found, skipping analysis: {job_name}")
        return

    log(f"=== {job_name} ArcFace analysis started ===")
    log(f"Dataset: {dataset_dir}")
    log(f"Output: {output_dir}")

    # Use the python from the facecheck conda env
    facecheck_python = r"C:\Users\USER\AppData\Local\Programs\Miniconda3\envs\facecheck\python.exe"

    result = subprocess.run(
        [facecheck_python, analysis_script,
         "--dataset", dataset_dir,
         "--output", output_dir,
         "--name", job_name],
        cwd=BASE_DIR,
        capture_output=False,
        env={**os.environ, "PYTHONIOENCODING": "utf-8"},
    )

    if result.returncode == 0:
        log(f"=== {job_name} ArcFace analysis complete ===")
    else:
        log(f"[WARN] {job_name} ArcFace analysis failed (exit code: {result.returncode})")


def main():
    log("=" * 70)
    log("LoRA retraining auto-run started")
    log("=" * 70)

    for job in JOBS:
        log(f"\n{'='*70}")
        log(f"JOB: {job['name']}")
        log(f"{'='*70}")

        success = run_training(job["config"], job["name"])

        if success:
            run_arcface_analysis(job["name"], job["dataset"], job["output"])
        else:
            log(f"[ERROR] {job['name']} training failed, moving on to next job")

    log("\n" + "=" * 70)
    log("All done!")
    log(f"Log: {LOG_FILE}")
    log("=" * 70)


if __name__ == "__main__":
    main()
