"""Analyze all sweep output images with ArcFace — compare to Higgsfield benchmark"""
import sys, re, cv2, numpy as np
from pathlib import Path
sys.stdout.reconfigure(encoding='utf-8')

from insightface.app import FaceAnalysis
app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition'])
app.prepare(ctx_id=0, det_size=(640, 640))

OUTPUT_DIR = Path("G:/StabilityMatrix/Packages/ComfyUI/output")
HF_DIR = Path("G:/작업/조선왕자/배우 로라 데이터셋/Higgsfield Soul 2.0 용 데이터셋/Higgsfield Soul 2.0 결과물")
DATASET_DIR = Path("G:/StabilityMatrix/Packages/AI-Toolkit/datasets/mzyeoja")

def imread_unicode(path):
    buf = np.fromfile(str(path), dtype=np.uint8)
    return cv2.imdecode(buf, cv2.IMREAD_COLOR)

def get_embedding(img):
    faces = app.get(img)
    if not faces: return None
    face = max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
    return face.normed_embedding

# Dataset centroid
embs = []
for f in sorted(list(DATASET_DIR.glob('*.jpg')) + list(DATASET_DIR.glob('*.webp')) + list(DATASET_DIR.glob('*.png'))):
    img = imread_unicode(f)
    if img is None: continue
    emb = get_embedding(img)
    if emb is not None: embs.append(emb)
ref = np.mean(embs, axis=0)
ref = ref / np.linalg.norm(ref)
print(f"REF: {len(embs)} faces from LoRA dataset")

# HF embeddings
hf_embs = []
for f in sorted(HF_DIR.glob('*.png')):
    img = imread_unicode(f)
    if img is None: continue
    emb = get_embedding(img)
    if emb is not None: hf_embs.append(emb)
hf_vs_dataset = [float(np.dot(ref, e)) for e in hf_embs]
hf_pw = []
for i in range(len(hf_embs)):
    for j in range(i+1, len(hf_embs)):
        hf_pw.append(float(np.dot(hf_embs[i], hf_embs[j])))
print(f"HF: {len(hf_embs)} faces, vs Dataset AVG={np.mean(hf_vs_dataset):.4f}, pairwise AVG={np.mean(hf_pw):.4f}")

# Analyze sweep files
print(f"\n{'='*90}")
print(f"  {'File':<50s}  {'vs Dataset':>10s}  {'vs HF avg':>10s}  {'vs HF max':>10s}")
print(f"  {'-'*50}  {'-'*10}  {'-'*10}  {'-'*10}")

results = []
for f in sorted(OUTPUT_DIR.glob("sweep_*.png")):
    img = imread_unicode(f)
    if img is None: continue
    emb = get_embedding(img)
    if emb is not None:
        vs_d = float(np.dot(ref, emb))
        vs_hf = [float(np.dot(emb, h)) for h in hf_embs]
        vs_hf_avg = np.mean(vs_hf)
        vs_hf_max = max(vs_hf)
        results.append((f.name, vs_d, vs_hf_avg, vs_hf_max))
    else:
        results.append((f.name, 0, 0, 0))

# Sort by vs_hf_avg
results.sort(key=lambda x: x[2], reverse=True)
for name, vs_d, vs_hf_avg, vs_hf_max in results:
    marker = " <<< BEST" if results[0][0] == name else ""
    print(f"  {name:<50s}  {vs_d:>10.4f}  {vs_hf_avg:>10.4f}  {vs_hf_max:>10.4f}{marker}")

# Group analysis
print(f"\n{'='*90}")
print("GROUPED ANALYSIS")
print(f"{'='*90}")

# Group by parameter
groups = {}
for name, vs_d, vs_hf_avg, vs_hf_max in results:
    m = re.match(r'sweep_c([\d.]+)_s(\d+)_ms([\d.]+)_bs([\d.]+)', name)
    if m:
        cfg, steps, ms, bs = m.group(1), m.group(2), m.group(3), m.group(4)
        groups.setdefault('cfg', {}).setdefault(cfg, []).append(vs_hf_avg)
        groups.setdefault('steps', {}).setdefault(steps, []).append(vs_hf_avg)
        groups.setdefault('max_shift', {}).setdefault(ms, []).append(vs_hf_avg)
        groups.setdefault('base_shift', {}).setdefault(bs, []).append(vs_hf_avg)

for param in ['cfg', 'steps', 'max_shift', 'base_shift']:
    if param in groups:
        print(f"\n  --- {param} ---")
        sorted_g = sorted(groups[param].items(), key=lambda x: np.mean(x[1]), reverse=True)
        for val, scores in sorted_g:
            avg = np.mean(scores)
            print(f"    {param}={val:<6s}  vs HF avg={avg:.4f}  (n={len(scores)})")

# Best overall
if results:
    best = results[0]
    print(f"\n{'='*90}")
    print(f"BEST RESULT: {best[0]}")
    print(f"  vs Dataset = {best[1]:.4f}")
    print(f"  vs HF avg  = {best[2]:.4f}")
    print(f"  vs HF max  = {best[3]:.4f}")
    print(f"  HF benchmark: vs Dataset AVG={np.mean(hf_vs_dataset):.4f}, pairwise AVG={np.mean(hf_pw):.4f}")
    delta = best[2] - np.mean(hf_pw)
    if delta >= 0:
        print(f"  >>> +{delta:.4f} vs HF pairwise — HF level reached!")
    else:
        print(f"  >>> {delta:.4f} vs HF pairwise — needs further optimization")
