"""
Parameter Sweep: CFG -> Steps -> ModelSamplingFlux
3 sequential rounds with automatic ArcFace analysis after each.
"""
import json, urllib.request, uuid, time, re, copy, sys, os
import numpy as np, cv2
from pathlib import Path

sys.stdout.reconfigure(encoding='utf-8')

WORKFLOW_PATH = "G:/StabilityMatrix/Images/Text2Img/ComfyUI_2026-03-19_00017_.png.json"
COMFYUI_API = "http://127.0.0.1:8188"
OUTPUT_DIR = Path("G:/StabilityMatrix/Packages/ComfyUI/output")
HF_DIR = Path("G:/작업/조선왕자/배우 로라 데이터셋/Higgsfield Soul 2.0 용 데이터셋/Higgsfield Soul 2.0 결과물")
DATASET_DIR = Path("G:/StabilityMatrix/Packages/AI-Toolkit/datasets/mzyeoja")

# Node IDs from workflow analysis
NODE_KSAMPLER = 77
NODE_MODELSAMPLINGFLUX = 102
NODE_SAVEIMAGE = 9

# ============================================================
# Workflow conversion (from submit_workflow.py)
# ============================================================

VIRTUAL_TYPES = {"SetNode", "GetNode", "Reroute", "Note", "Fast Groups Bypasser (rgthree)"}

NODE_SPECS = {
    "KSampler": ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"],
    "ModelSamplingFlux": ["max_shift", "base_shift", "width", "height"],
    "EmptyFlux2LatentImage": ["width", "height", "batch_size"],
    "LoraLoaderModelOnly": ["lora_name", "strength_model"],
    "UNETLoader": ["unet_name", "weight_dtype"],
    "CLIPLoader": ["clip_name", "type"],
    "VAELoader": ["vae_name"],
    "UpscaleModelLoader": ["model_name"],
    "ImageScaleBy": ["upscale_method", "scale_by"],
    "SaveImage": ["filename_prefix"],
    "CLIPTextEncode": ["text"],
    "easy int": ["value"],
    "easy string": ["value"],
    "TextBox1": ["text1"],
    "PerturbedAttention": ["scale", "adaptive_scale", "unet_block", "unet_block_id",
                           "sigma_start", "sigma_end", "rescale", "rescale_mode"],
    "EmptyLatentImagePresets": ["dimensions", "invert", "batch_size"],
    "EmptySD3LatentImage": ["width", "height", "batch_size"],
    "easy promptLine": ["prompt", "start_index", "max_rows", "remove_empty_lines"],
    "easy promptReplace": ["find1", "replace1", "find2", "replace2", "find3", "replace3"],
}


def build_api_prompt(wf):
    """Convert frontend workflow JSON to API prompt format."""
    nodes = {n["id"]: n for n in wf["nodes"]}
    links = {
        l[0]: {"from_node": l[1], "from_slot": l[2], "to_node": l[3], "to_slot": l[4]}
        for l in wf["links"]
    }
    node_inputs = {}
    for lid, link in links.items():
        tid = link["to_node"]
        node_inputs.setdefault(tid, {})[link["to_slot"]] = (link["from_node"], link["from_slot"])

    set_nodes = {}
    for n in wf["nodes"]:
        if n["type"] == "SetNode":
            varname = n.get("widgets_values", [""])[0]
            if varname:
                set_nodes[varname] = n["id"]

    def resolve_source(nid, slot, depth=0):
        if depth > 20: return None
        n = nodes.get(nid)
        if not n: return None
        ntype = n.get("type", "")
        mode = n.get("mode", 0)
        if ntype == "GetNode":
            varname = n.get("widgets_values", [""])[0]
            set_nid = set_nodes.get(varname)
            if set_nid:
                inp = node_inputs.get(set_nid, {})
                if 0 in inp: return resolve_source(inp[0][0], inp[0][1], depth+1)
            return None
        if ntype in ("SetNode", "Reroute"):
            inp = node_inputs.get(nid, {})
            if 0 in inp: return resolve_source(inp[0][0], inp[0][1], depth+1)
            return None
        if ntype == "Any Switch (rgthree)":
            inp = node_inputs.get(nid, {})
            for s in sorted(inp.keys()):
                result = resolve_source(inp[s][0], inp[s][1], depth+1)
                if result: return result
            return None
        if mode == 4:
            inp = node_inputs.get(nid, {})
            if 0 in inp: return resolve_source(inp[0][0], inp[0][1], depth+1)
            return None
        return (str(nid), slot)

    api_prompt = {}
    for n in wf["nodes"]:
        nid = n["id"]
        ntype = n.get("type", "")
        mode = n.get("mode", 0)
        if ntype in VIRTUAL_TYPES or mode == 4:
            continue
        inputs = {}
        spec = NODE_SPECS.get(ntype, [])
        wvals = n.get("widgets_values", [])
        inp_links = node_inputs.get(nid, {})
        node_type_inputs = n.get("inputs", [])
        linked_input_names = set()
        for i, inp_def in enumerate(node_type_inputs):
            if inp_def.get("link") is not None and i in inp_links:
                src = resolve_source(inp_links[i][0], inp_links[i][1])
                if src:
                    inputs[inp_def["name"]] = [src[0], src[1]]
                    linked_input_names.add(inp_def["name"])
        wi = 0
        for param in spec:
            if param in linked_input_names: continue
            if wi < len(wvals):
                val = wvals[wi]
                if isinstance(val, str) and val in ("fixed", "increment", "decrement", "randomize"):
                    wi += 1
                    if wi < len(wvals): val = wvals[wi]
                inputs[param] = val
                wi += 1
        if ntype == "EmptyFlux2LatentImage":
            if "batch_size" not in inputs or inputs.get("batch_size", 0) > 16:
                inputs["batch_size"] = 1
        if ntype == "SaveImage" and "filename_prefix" in inputs:
            prefix = re.sub(r"%[^%]*%", "", inputs["filename_prefix"])
            prefix = re.sub(r'[:<>"|?*]', "", prefix).strip("_- ")
            inputs["filename_prefix"] = prefix or "sweep"
        api_prompt[str(nid)] = {"class_type": ntype, "inputs": inputs}
    return api_prompt


def modify_workflow(wf, cfg=None, steps=None, max_shift=None, base_shift=None, prefix=None):
    """Modify specific node parameters in the workflow."""
    wf = copy.deepcopy(wf)
    for n in wf["nodes"]:
        nid = n["id"]
        if nid == NODE_KSAMPLER:
            wvals = n.get("widgets_values", [])
            # widgets_values: [seed, "fixed", steps, cfg, sampler_name, scheduler, denoise]
            if len(wvals) >= 7:
                if steps is not None: wvals[2] = steps
                if cfg is not None: wvals[3] = cfg
        elif nid == NODE_MODELSAMPLINGFLUX:
            wvals = n.get("widgets_values", [])
            # widgets_values: [max_shift, base_shift, width, height]
            if len(wvals) >= 2:
                if max_shift is not None: wvals[0] = max_shift
                if base_shift is not None: wvals[1] = base_shift
        elif nid == NODE_SAVEIMAGE:
            if prefix is not None:
                n["widgets_values"] = [prefix]
    return wf


def submit_workflow(api_prompt):
    """Submit API prompt to ComfyUI."""
    payload = {"prompt": api_prompt, "client_id": str(uuid.uuid4())}
    data = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(
        f"{COMFYUI_API}/prompt", data=data,
        headers={"Content-Type": "application/json"},
    )
    with urllib.request.urlopen(req) as resp:
        result = json.loads(resp.read())
    return result.get("prompt_id")


def get_queue_status():
    """Check ComfyUI queue status."""
    with urllib.request.urlopen(f"{COMFYUI_API}/queue") as resp:
        data = json.loads(resp.read())
    running = len(data.get("queue_running", []))
    pending = len(data.get("queue_pending", []))
    return running, pending


def wait_for_queue(timeout=600):
    """Wait until ComfyUI queue is empty."""
    start = time.time()
    while time.time() - start < timeout:
        running, pending = get_queue_status()
        if running == 0 and pending == 0:
            return True
        elapsed = time.time() - start
        print(f"\r  Queue: {running} running, {pending} pending ({elapsed:.0f}s)", end="", flush=True)
        time.sleep(2)
    print("\n  Timeout!")
    return False


# ============================================================
# ArcFace Analysis
# ============================================================

def init_arcface():
    """Initialize ArcFace model."""
    from insightface.app import FaceAnalysis
    app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition'])
    app.prepare(ctx_id=0, det_size=(640, 640))
    return app

def imread_unicode(path):
    buf = np.fromfile(str(path), dtype=np.uint8)
    return cv2.imdecode(buf, cv2.IMREAD_COLOR)

def get_embedding(app, img):
    faces = app.get(img)
    if not faces: return None
    face = max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
    return face.normed_embedding

def compute_dataset_centroid(app):
    """Compute LoRA dataset face centroid."""
    embs = []
    for f in sorted(list(DATASET_DIR.glob('*.jpg')) + list(DATASET_DIR.glob('*.webp')) + list(DATASET_DIR.glob('*.png'))):
        img = imread_unicode(f)
        if img is None: continue
        emb = get_embedding(app, img)
        if emb is not None: embs.append(emb)
    ref = np.mean(embs, axis=0)
    return ref / np.linalg.norm(ref), len(embs)

def compute_hf_embeddings(app):
    """Compute Higgsfield reference embeddings."""
    embs = []
    for f in sorted(HF_DIR.glob('*.png')):
        img = imread_unicode(f)
        if img is None: continue
        emb = get_embedding(app, img)
        if emb is not None: embs.append(emb)
    return embs

def analyze_sweep_results(app, ref_centroid, hf_embs, prefix_pattern):
    """Analyze sweep output images."""
    results = []
    for f in sorted(OUTPUT_DIR.glob(f"{prefix_pattern}*.png")):
        img = imread_unicode(f)
        if img is None: continue
        emb = get_embedding(app, img)
        if emb is not None:
            vs_dataset = float(np.dot(ref_centroid, emb))
            vs_hf_scores = [float(np.dot(emb, hf)) for hf in hf_embs]
            vs_hf_avg = np.mean(vs_hf_scores) if vs_hf_scores else 0
            vs_hf_max = max(vs_hf_scores) if vs_hf_scores else 0
            results.append({
                'file': f.name,
                'vs_dataset': vs_dataset,
                'vs_hf_avg': vs_hf_avg,
                'vs_hf_max': vs_hf_max,
            })
        else:
            results.append({
                'file': f.name,
                'vs_dataset': 0,
                'vs_hf_avg': 0,
                'vs_hf_max': 0,
            })
    return results


# ============================================================
# Main Sweep
# ============================================================

def run_sweep(wf_original, sweep_configs, round_name):
    """Submit a batch of sweep configs, wait, return file mapping."""
    print(f"\n{'='*60}")
    print(f"  {round_name}")
    print(f"{'='*60}")

    prompt_ids = {}
    for i, config in enumerate(sweep_configs):
        cfg = config.get('cfg', 5)
        steps = config.get('steps', 50)
        ms = config.get('max_shift', 1.15)
        bs = config.get('base_shift', 0.5)
        prefix = f"sweep_c{cfg}_s{steps}_ms{ms}_bs{bs}"
        config['prefix'] = prefix

        wf = modify_workflow(wf_original, cfg=cfg, steps=steps,
                           max_shift=ms, base_shift=bs, prefix=prefix)
        api_prompt = build_api_prompt(wf)
        prompt_id = submit_workflow(api_prompt)
        prompt_ids[prefix] = prompt_id
        print(f"  [{i+1}/{len(sweep_configs)}] {prefix} -> {prompt_id}")

    print(f"\n  Submitted {len(sweep_configs)} jobs. Waiting...")
    wait_for_queue(timeout=1200)
    print(f"\n  {round_name} complete!")
    # Small delay for file writes
    time.sleep(3)
    return sweep_configs


def print_results_table(results, round_name):
    """Pretty-print results table."""
    print(f"\n{'='*80}")
    print(f"  {round_name} — ArcFace Results")
    print(f"{'='*80}")
    print(f"  {'File':<45s}  {'vs Dataset':>10s}  {'vs HF avg':>10s}  {'vs HF max':>10s}")
    print(f"  {'-'*45}  {'-'*10}  {'-'*10}  {'-'*10}")

    # Sort by vs_hf_avg descending
    sorted_results = sorted(results, key=lambda x: x['vs_hf_avg'], reverse=True)
    best = sorted_results[0] if sorted_results else None

    for r in sorted_results:
        marker = " *" if r == best else ""
        print(f"  {r['file']:<45s}  {r['vs_dataset']:>10.4f}  {r['vs_hf_avg']:>10.4f}  {r['vs_hf_max']:>10.4f}{marker}")

    if best:
        print(f"\n  Best: {best['file']} (vs HF avg={best['vs_hf_avg']:.4f})")
    return best


def extract_params_from_prefix(prefix):
    """Extract cfg, steps, max_shift, base_shift from filename prefix."""
    m = re.match(r'sweep_c([\d.]+)_s(\d+)_ms([\d.]+)_bs([\d.]+)', prefix)
    if m:
        return {
            'cfg': float(m.group(1)),
            'steps': int(m.group(2)),
            'max_shift': float(m.group(3)),
            'base_shift': float(m.group(4)),
        }
    return {}


def main():
    print("=" * 60)
    print("  ComfyUI Parameter Sweep — mzyeoja LoRA")
    print("  Benchmark: Higgsfield Soul 2.0")
    print("=" * 60)

    # Load workflow
    with open(WORKFLOW_PATH, "r", encoding="utf-8") as f:
        wf_original = json.load(f)
    print(f"  Workflow loaded: {WORKFLOW_PATH}")

    # Init ArcFace
    print("  Initializing ArcFace...")
    app = init_arcface()
    ref_centroid, n_ref = compute_dataset_centroid(app)
    print(f"  Dataset centroid: {n_ref} faces")
    hf_embs = compute_hf_embeddings(app)
    print(f"  Higgsfield embeddings: {len(hf_embs)} faces")

    # HF benchmark
    hf_vs_dataset = [float(np.dot(ref_centroid, e)) for e in hf_embs]
    print(f"  HF vs Dataset benchmark: AVG={np.mean(hf_vs_dataset):.4f}")
    hf_pairwise = []
    for i in range(len(hf_embs)):
        for j in range(i+1, len(hf_embs)):
            hf_pairwise.append(float(np.dot(hf_embs[i], hf_embs[j])))
    print(f"  HF pairwise benchmark:   AVG={np.mean(hf_pairwise):.4f}")

    # ============================================================
    # ROUND 1: CFG Sweep
    # ============================================================
    round1_configs = [
        {"cfg": c, "steps": 50, "max_shift": 1.15, "base_shift": 0.5}
        for c in [2, 3, 3.5, 4, 5, 6, 7]
    ]
    run_sweep(wf_original, round1_configs, "ROUND 1: CFG Sweep")

    # Analyze Round 1
    r1_results = analyze_sweep_results(app, ref_centroid, hf_embs, "sweep_c")
    # Filter only round 1 results (steps=50, ms=1.15, bs=0.5)
    r1_filtered = [r for r in r1_results if "_s50_ms1.15_bs0.5" in r['file']]
    best_r1 = print_results_table(r1_filtered, "ROUND 1: CFG Sweep")

    # Extract best CFG
    if best_r1:
        params = extract_params_from_prefix(best_r1['file'])
        best_cfg = params.get('cfg', 5)
    else:
        best_cfg = 5
    print(f"\n  -> Best CFG = {best_cfg}")

    # ============================================================
    # ROUND 2: Steps Sweep (with best CFG)
    # ============================================================
    round2_configs = [
        {"cfg": best_cfg, "steps": s, "max_shift": 1.15, "base_shift": 0.5}
        for s in [20, 30, 40, 60, 70]
        # skip 50 since it was already done in round 1
    ]
    run_sweep(wf_original, round2_configs, f"ROUND 2: Steps Sweep (CFG={best_cfg})")

    # Analyze Round 2 (include round 1's best for comparison)
    r2_results = analyze_sweep_results(app, ref_centroid, hf_embs,
                                       f"sweep_c{best_cfg}_s")
    r2_filtered = [r for r in r2_results if f"_ms1.15_bs0.5" in r['file']]
    best_r2 = print_results_table(r2_filtered, f"ROUND 2: Steps Sweep (CFG={best_cfg})")

    if best_r2:
        params = extract_params_from_prefix(best_r2['file'])
        best_steps = params.get('steps', 50)
    else:
        best_steps = 50
    print(f"\n  -> Best Steps = {best_steps}")

    # ============================================================
    # ROUND 3a: max_shift Sweep
    # ============================================================
    round3a_configs = [
        {"cfg": best_cfg, "steps": best_steps, "max_shift": ms, "base_shift": 0.5}
        for ms in [0.5, 0.8, 1.0, 1.3, 1.5]
        # skip 1.15 since already done
    ]
    run_sweep(wf_original, round3a_configs,
              f"ROUND 3a: max_shift Sweep (CFG={best_cfg}, Steps={best_steps})")

    r3a_results = analyze_sweep_results(app, ref_centroid, hf_embs,
                                        f"sweep_c{best_cfg}_s{best_steps}_ms")
    r3a_filtered = [r for r in r3a_results if "_bs0.5" in r['file']]
    best_r3a = print_results_table(r3a_filtered,
                                   f"ROUND 3a: max_shift Sweep (CFG={best_cfg}, Steps={best_steps})")

    if best_r3a:
        params = extract_params_from_prefix(best_r3a['file'])
        best_ms = params.get('max_shift', 1.15)
    else:
        best_ms = 1.15
    print(f"\n  -> Best max_shift = {best_ms}")

    # ============================================================
    # ROUND 3b: base_shift Sweep
    # ============================================================
    round3b_configs = [
        {"cfg": best_cfg, "steps": best_steps, "max_shift": best_ms, "base_shift": bs}
        for bs in [0.2, 0.3, 0.4, 0.6, 0.7, 0.8]
        # skip 0.5 since already done
    ]
    run_sweep(wf_original, round3b_configs,
              f"ROUND 3b: base_shift Sweep (CFG={best_cfg}, Steps={best_steps}, max_shift={best_ms})")

    r3b_results = analyze_sweep_results(app, ref_centroid, hf_embs,
                                        f"sweep_c{best_cfg}_s{best_steps}_ms{best_ms}_bs")
    best_r3b = print_results_table(r3b_results,
                                   f"ROUND 3b: base_shift Sweep (CFG={best_cfg}, Steps={best_steps}, max_shift={best_ms})")

    if best_r3b:
        params = extract_params_from_prefix(best_r3b['file'])
        best_bs = params.get('base_shift', 0.5)
    else:
        best_bs = 0.5
    print(f"\n  -> Best base_shift = {best_bs}")

    # ============================================================
    # FINAL SUMMARY
    # ============================================================
    print("\n" + "=" * 80)
    print("  FINAL OPTIMAL PARAMETERS")
    print("=" * 80)
    print(f"  CFG        = {best_cfg}")
    print(f"  Steps      = {best_steps}")
    print(f"  max_shift  = {best_ms}")
    print(f"  base_shift = {best_bs}")
    print(f"  Sampler    = dpmpp_2m / beta57")
    print(f"  LoRA       = mzyeoja_v3_000001900 @ 1.0")
    print(f"  Resolution = 1536x2048")
    print()

    # Collect ALL results
    all_results = analyze_sweep_results(app, ref_centroid, hf_embs, "sweep_")
    print(f"\n  Total sweep images: {len(all_results)}")
    print_results_table(all_results, "ALL SWEEP RESULTS — GLOBAL RANKING")

    # Final comparison
    print(f"\n  Higgsfield benchmark:  vs Dataset AVG={np.mean(hf_vs_dataset):.4f}")
    if all_results:
        best_overall = max(all_results, key=lambda x: x['vs_hf_avg'])
        print(f"  Best LoRA output:      vs Dataset={best_overall['vs_dataset']:.4f}  vs HF avg={best_overall['vs_hf_avg']:.4f}")
        delta = best_overall['vs_hf_avg'] - np.mean(hf_pairwise)
        if delta >= 0:
            print(f"  -> +{delta:.4f} vs HF consistency (HF level reached!)")
        else:
            print(f"  -> {delta:.4f} vs HF consistency (needs further optimization)")


if __name__ == "__main__":
    main()
