"""ArcFace-based LoRA overfit post-analysis — joseonnamja + mzyeoja"""
import os, sys, cv2, numpy as np
from pathlib import Path

print('[INFO] Loading ArcFace model...')
from insightface.app import FaceAnalysis
app = FaceAnalysis(name='buffalo_l', allowed_modules=['detection', 'recognition'])
app.prepare(ctx_id=-1, det_size=(640, 640))

def imread_unicode(path):
    buf = np.fromfile(str(path), dtype=np.uint8)
    return cv2.imdecode(buf, cv2.IMREAD_COLOR)

def get_embedding(img):
    faces = app.get(img)
    if not faces:
        return None
    face = max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
    return face.normed_embedding

for name in ['joseonnamja', 'mzyeoja']:
    dataset_dir = Path(f'G:/StabilityMatrix/Packages/AI-Toolkit/datasets/{name}')
    samples_dir = Path(f'G:/StabilityMatrix/Packages/AI-Toolkit/output/{name}/samples')

    print('\n' + '='*80)
    print(f'  {name.upper()} - ArcFace Cosine Similarity')
    print('='*80)

    embs = []
    for f in sorted(dataset_dir.glob('*.jpg')):
        img = imread_unicode(f)
        if img is None: continue
        emb = get_embedding(img)
        if emb is not None:
            embs.append(emb)
    ref = np.mean(embs, axis=0)
    ref = ref / np.linalg.norm(ref)
    print(f'[REF] {len(embs)} faces from dataset')

    steps = sorted(set(int(f.name.split('__')[1].split('_')[0]) for f in samples_dir.glob('*.jpg')))

    header = f"{'Step':>6} | {'A-1.0':>6} {'A-0.8':>6} | {'B-1.0':>6} {'B-0.8':>6} | {'C-1.0':>6} {'C-0.8':>6} | {'AVG':>6} {'avg1.0':>6} {'avg0.8':>6} | Flag"
    print(header)
    print('-'*95)

    peak_avg = 0
    decline = 0
    results = []

    for step in steps:
        step_files = list(samples_dir.glob(f'*__{step:09d}_*.jpg'))
        scores = {}
        for f in step_files:
            idx = int(f.name.split('_')[-1].replace('.jpg',''))
            img = imread_unicode(f)
            if img is None: continue
            emb = get_embedding(img)
            if emb is not None:
                scores[idx] = float(np.dot(ref, emb))
            else:
                scores[idx] = -1.0

        vals = [scores.get(i, -1) for i in range(6)]
        v10 = [scores.get(i,-1) for i in [0,2,4] if scores.get(i,-1) > 0]
        v08 = [scores.get(i,-1) for i in [1,3,5] if scores.get(i,-1) > 0]
        va = [v for v in scores.values() if v > 0]

        a_all = np.mean(va) if va else 0
        a_10 = np.mean(v10) if v10 else 0
        a_08 = np.mean(v08) if v08 else 0

        flag = ''
        if a_all > peak_avg:
            peak_avg = a_all
            decline = 0
            flag = 'PEAK'
        elif a_all > 0:
            decline += 1
            flag = f'decline {decline}'

        if a_08 > a_10 and a_10 > 0.3:
            flag += ' *** 0.8>1.0! ***'

        parts = ['%6.3f' % v if v > 0 else '   N/A' for v in vals]
        print(f'{step:>6} | {parts[0]} {parts[1]} | {parts[2]} {parts[3]} | {parts[4]} {parts[5]} | {a_all:>6.3f} {a_10:>6.3f} {a_08:>6.3f} | {flag}')
        results.append((step, a_all, a_10, a_08))
        sys.stdout.flush()

    if results:
        pk = max(results, key=lambda x: x[1])
        print(f'\n>>> PEAK: Step {pk[0]} (AVG={pk[1]:.3f}, 1.0={pk[2]:.3f}, 0.8={pk[3]:.3f})')

        candidates = [(s, a, a1, a8) for s, a, a1, a8 in results if a1 > 0.3 and a8 > 0.3]
        if candidates:
            best = min(candidates, key=lambda x: abs(x[2]-x[3]))
            print(f'>>> BEST BALANCE: Step {best[0]} (1.0={best[2]:.3f}, 0.8={best[3]:.3f}, gap={abs(best[2]-best[3]):.3f})')
