"""
LoRA Dataset Curation Pipeline
==============================
Auto-selects an optimal training dataset (~29 images) from a large pool of portrait photos.

Usage:
  python dataset_pipeline.py --input <source_folder> --ref <reference_image> --trigger <trigger_word> --gender <man/woman> [options]

Pipeline:
  [1] ArcFace identity verification -> reject other people
  [2] Auto chest-up crop
  [3] Quality scoring -> drop bottom 30%
  [4] Tone consistency filter -> drop LAB z-score outliers
  [5] Deduplication -> remove near-duplicates above 90% similarity
  [6] Diversity report
  [7] Auto captioning + trigger word

Required packages:
  pip install insightface onnxruntime opencv-python "numpy>=1.26,<2.0" imagededup

Examples:
  python dataset_pipeline.py --input ./raw_photos --ref ./ref.jpg --trigger mytrigger --gender man
  python dataset_pipeline.py --input ./raw --ref ./ref.jpg --trigger myperson --gender woman --skip-crop
"""

import argparse
import os
import sys
import shutil
import cv2
import numpy as np
from pathlib import Path
from datetime import datetime

IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}


# ============================================================
# Utilities
# ============================================================

def imread_unicode(path):
    """Image load with unicode/non-ASCII path support (replaces cv2.imread on Windows)."""
    buf = np.fromfile(path, dtype=np.uint8)
    img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
    return img


def imwrite_unicode(path, img):
    """Image save with unicode/non-ASCII path support."""
    ext = Path(path).suffix
    success, buf = cv2.imencode(ext, img)
    if success:
        buf.tofile(path)
        return True
    return False


def find_images(folder):
    """Return list of image files in a folder."""
    images = []
    for f in sorted(os.listdir(folder)):
        if Path(f).suffix.lower() in IMAGE_EXTENSIONS:
            images.append(os.path.join(folder, f))
    return images


def safe_copy(src, dst_dir, filename=None):
    """Copy an image to the target folder."""
    os.makedirs(dst_dir, exist_ok=True)
    fname = filename or os.path.basename(src)
    dst = os.path.join(dst_dir, fname)
    shutil.copy2(src, dst)
    return dst


def log(msg, file=None):
    """Log to console and file simultaneously."""
    print(msg)
    if file:
        file.write(msg + "\n")


def get_largest_face(faces):
    """Return the largest face."""
    return sorted(
        faces,
        key=lambda f: (f.bbox[2] - f.bbox[0]) * (f.bbox[3] - f.bbox[1]),
        reverse=True
    )[0]


# ============================================================
# [1] Identity verification (ArcFace)
# ============================================================

def step_identity_check(app, images, ref_path, threshold, output_dir, reject_dir, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[1/7] Identity verification (ArcFace, threshold={threshold})", log_f)
    log(f"{'='*60}", log_f)

    ref_img = imread_unicode(ref_path)
    if ref_img is None:
        log(f"  [FATAL] Failed to load reference image: {ref_path}", log_f)
        return []

    ref_faces = app.get(ref_img)
    if len(ref_faces) == 0:
        log(f"  [FATAL] No face detected in reference image", log_f)
        return []

    ref_emb = get_largest_face(ref_faces).normed_embedding

    passed = []
    for img_path in images:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            log(f"  [SKIP] Load failed: {fname}", log_f)
            safe_copy(img_path, os.path.join(reject_dir, "01_identity"), fname)
            continue

        faces = app.get(img)
        if len(faces) == 0:
            log(f"  [REJECT] No face detected: {fname}", log_f)
            safe_copy(img_path, os.path.join(reject_dir, "01_identity"), fname)
            continue

        face = get_largest_face(faces)
        score = float(np.dot(ref_emb, face.normed_embedding))

        if score >= threshold:
            dst = safe_copy(img_path, output_dir, fname)
            passed.append((dst, score, face))
            log(f"  [PASS]   {score:.4f}  {fname}", log_f)
        else:
            safe_copy(img_path, os.path.join(reject_dir, "01_identity"), fname)
            log(f"  [REJECT] {score:.4f}  {fname}", log_f)

    log(f"  => {len(passed)}/{len(images)} passed", log_f)
    return passed


# ============================================================
# [2] Auto chest-up crop
# ============================================================

def step_chest_crop(app, images_with_scores, output_dir, skip, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[2/7] Auto chest-up crop", log_f)
    log(f"{'='*60}", log_f)

    if skip:
        log("  [SKIP] Skipped via --skip-crop", log_f)
        result = []
        for img_path, score, face in images_with_scores:
            dst = safe_copy(img_path, output_dir)
            result.append((dst, score))
        return result

    result = []
    for img_path, id_score, _ in images_with_scores:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            continue

        h, w = img.shape[:2]
        faces = app.get(img)
        if len(faces) == 0:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, id_score))
            log(f"  [KEEP]  original kept (re-detect failed): {fname}", log_f)
            continue

        face = get_largest_face(faces)
        x1, y1, x2, y2 = face.bbox
        face_w = x2 - x1
        face_h = y2 - y1
        face_cx = (x1 + x2) / 2
        face_cy = (y1 + y2) / 2

        # If the face takes up 8%+ of the image, it's already chest-up
        face_ratio = (face_w * face_h) / (w * h)
        if face_ratio > 0.08:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, id_score))
            log(f"  [KEEP]  already chest-up (face={face_ratio:.1%}): {fname}", log_f)
            continue

        # Compute chest-up crop region
        crop_top = face_cy - face_h * 1.2       # margin above the head
        crop_bottom = face_cy + face_h * 2.0     # down to the chest
        crop_height = crop_bottom - crop_top
        crop_width = crop_height * 0.75           # 3:4 ratio

        crop_left = face_cx - crop_width / 2
        crop_right = face_cx + crop_width / 2

        # Clip to image bounds
        crop_top = max(0, int(crop_top))
        crop_bottom = min(h, int(crop_bottom))
        crop_left = max(0, int(crop_left))
        crop_right = min(w, int(crop_right))

        cropped = img[crop_top:crop_bottom, crop_left:crop_right]
        if cropped.shape[0] < 100 or cropped.shape[1] < 100:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, id_score))
            log(f"  [KEEP]  crop too small, keeping original: {fname}", log_f)
            continue

        os.makedirs(output_dir, exist_ok=True)
        dst = os.path.join(output_dir, fname)
        imwrite_unicode(dst, cropped)
        result.append((dst, id_score))
        log(f"  [CROP]  {w}x{h} -> {cropped.shape[1]}x{cropped.shape[0]}: {fname}", log_f)

    log(f"  => {len(result)} processed", log_f)
    return result


# ============================================================
# [3] Quality filter
# ============================================================

def step_quality_filter(app, images_with_scores, cutoff_ratio, output_dir, reject_dir, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[3/7] Quality filter (drop bottom {cutoff_ratio:.0%})", log_f)
    log(f"{'='*60}", log_f)

    scored = []
    for img_path, id_score in images_with_scores:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            continue

        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()

        faces = app.get(img)
        if len(faces) == 0:
            det_score = 0.0
            face_pixel_ratio = 0.0
        else:
            face = get_largest_face(faces)
            det_score = float(face.det_score)
            face_area = (face.bbox[2] - face.bbox[0]) * (face.bbox[3] - face.bbox[1])
            face_pixel_ratio = face_area / (img.shape[0] * img.shape[1])

        resolution = img.shape[0] * img.shape[1]

        scored.append({
            'path': img_path,
            'fname': fname,
            'id_score': id_score,
            'sharpness': sharpness,
            'det_score': det_score,
            'face_ratio': face_pixel_ratio,
            'resolution': resolution,
        })

    if len(scored) == 0:
        return []

    # Normalize each metric to 0~1
    for key in ['sharpness', 'det_score', 'face_ratio', 'resolution']:
        values = [s[key] for s in scored]
        vmin, vmax = min(values), max(values)
        for s in scored:
            s[f'{key}_n'] = (s[key] - vmin) / (vmax - vmin) if vmax > vmin else 1.0

    # Composite quality score
    for s in scored:
        s['quality'] = (
            s['sharpness_n'] * 0.35 +
            s['det_score_n'] * 0.30 +
            s['face_ratio_n'] * 0.20 +
            s['resolution_n'] * 0.15
        )

    scored.sort(key=lambda x: x['quality'], reverse=True)

    n_keep = max(1, int(len(scored) * (1 - cutoff_ratio)))
    passed = scored[:n_keep]
    rejected = scored[n_keep:]

    result = []
    for s in passed:
        dst = safe_copy(s['path'], output_dir, s['fname'])
        result.append((dst, s['id_score']))
        log(f"  [PASS]   Q={s['quality']:.3f}  sharp={s['sharpness']:>8.0f}  det={s['det_score']:.3f}  {s['fname']}", log_f)

    for s in rejected:
        safe_copy(s['path'], os.path.join(reject_dir, "03_quality"), s['fname'])
        log(f"  [REJECT] Q={s['quality']:.3f}  sharp={s['sharpness']:>8.0f}  det={s['det_score']:.3f}  {s['fname']}", log_f)

    log(f"  => {len(result)}/{len(scored)} passed", log_f)
    return result


# ============================================================
# [3.5] Face occlusion filter (detects hair etc. covering the face)
# ============================================================

def step_occlusion_filter(app, images_with_scores, threshold, output_dir, reject_dir, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[3.5/7] Face occlusion filter (drop VDark > {threshold})", log_f)
    log(f"{'='*60}", log_f)

    result = []
    rejected_count = 0
    for img_path, score in images_with_scores:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            continue

        faces = app.get(img)
        if not faces:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, score))
            log(f"  [PASS]   vdark=N/A (re-detect failed): {fname}", log_f)
            continue

        face = get_largest_face(faces)
        x1, y1, x2, y2 = [int(v) for v in face.bbox]
        h, w = img.shape[:2]
        x1, y1 = max(0, x1), max(0, y1)
        x2, y2 = min(w, x2), min(h, y2)

        kps = face.kps
        if kps is None:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, score))
            log(f"  [PASS]   vdark=N/A (no landmarks): {fname}", log_f)
            continue

        # Inner face region: eyes~mouth with 30% lateral padding
        eye_top = int(min(kps[0][1], kps[1][1])) - y1
        mouth_bottom = int(max(kps[3][1], kps[4][1])) - y1
        left_x = int(min(kps[0][0], kps[3][0])) - x1
        right_x = int(max(kps[1][0], kps[4][0])) - x1

        pad_x = int((right_x - left_x) * 0.3)
        pad_y = int((mouth_bottom - eye_top) * 0.15)
        ey = max(0, eye_top - pad_y)
        my = min(y2 - y1, mouth_bottom + pad_y)
        lx = max(0, left_x - pad_x)
        rx = min(x2 - x1, right_x + pad_x)

        face_crop = img[y1:y2, x1:x2]
        inner = face_crop[ey:my, lx:rx]
        if inner.size == 0:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, score))
            continue

        gray_inner = cv2.cvtColor(inner, cv2.COLOR_BGR2GRAY)
        vdark_ratio = float(np.count_nonzero(gray_inner < 50) / max(1, gray_inner.size))

        if vdark_ratio > threshold:
            safe_copy(img_path, os.path.join(reject_dir, "03_occlusion"), fname)
            log(f"  [REJECT] vdark={vdark_ratio:.3f}: {fname}", log_f)
            rejected_count += 1
        else:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, score))
            log(f"  [PASS]   vdark={vdark_ratio:.3f}: {fname}", log_f)

    log(f"  => {len(result)}/{len(result)+rejected_count} passed ({rejected_count} occluded removed)", log_f)
    return result


# ============================================================
# [4] Tone consistency filter (LAB color space)
# ============================================================

def step_tone_filter(images_with_scores, threshold, output_dir, reject_dir, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[4/7] Tone consistency filter (drop z-score > {threshold})", log_f)
    log(f"{'='*60}", log_f)

    if len(images_with_scores) <= 2:
        log("  [SKIP] 2 or fewer images — skipping filter", log_f)
        result = []
        for img_path, score in images_with_scores:
            dst = safe_copy(img_path, output_dir)
            result.append((dst, score))
        return result

    # Compute mean L, A, B per image
    stats = []
    for img_path, score in images_with_scores:
        img = imread_unicode(img_path)
        if img is None:
            continue
        lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB).astype(np.float32)
        mean_l = lab[:, :, 0].mean()
        mean_a = lab[:, :, 1].mean()
        mean_b = lab[:, :, 2].mean()
        stats.append({
            'path': img_path,
            'score': score,
            'fname': os.path.basename(img_path),
            'L': mean_l, 'A': mean_a, 'B': mean_b,
        })

    if len(stats) <= 2:
        result = []
        for s in stats:
            dst = safe_copy(s['path'], output_dir)
            result.append((dst, s['score']))
        return result

    # Per-channel z-score
    for ch in ['L', 'A', 'B']:
        values = np.array([s[ch] for s in stats])
        mean = values.mean()
        std = values.std()
        for s in stats:
            s[f'{ch}_z'] = abs(s[ch] - mean) / std if std > 0 else 0.0

    # If any channel z-score exceeds threshold, treat as outlier
    result = []
    rejected_count = 0
    for s in stats:
        max_z = max(s['L_z'], s['A_z'], s['B_z'])
        if max_z > threshold:
            safe_copy(s['path'], os.path.join(reject_dir, "tone_outlier"), s['fname'])
            log(f"  [REJECT] z={max_z:.2f} (L={s['L']:.1f} A={s['A']:.1f} B={s['B']:.1f}): {s['fname']}", log_f)
            rejected_count += 1
        else:
            dst = safe_copy(s['path'], output_dir, s['fname'])
            result.append((dst, s['score']))
            log(f"  [PASS]   z={max_z:.2f} (L={s['L']:.1f} A={s['A']:.1f} B={s['B']:.1f}): {s['fname']}", log_f)

    log(f"  => {len(result)}/{len(stats)} passed ({rejected_count} tone outliers removed)", log_f)
    return result


# ============================================================
# [5] Deduplication
# ============================================================

def step_dedup(images_with_scores, threshold, output_dir, reject_dir, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[5/7] Deduplication (threshold={threshold})", log_f)
    log(f"{'='*60}", log_f)

    if len(images_with_scores) <= 1:
        for img_path, score in images_with_scores:
            safe_copy(img_path, output_dir)
        return [(os.path.join(output_dir, os.path.basename(p)), s) for p, s in images_with_scores]

    try:
        from imagededup.methods import CNN
    except ImportError:
        log("  [WARN] imagededup not installed. Skipping deduplication.", log_f)
        log("         install: pip install imagededup", log_f)
        result = []
        for img_path, score in images_with_scores:
            dst = safe_copy(img_path, output_dir)
            result.append((dst, score))
        return result

    # imagededup operates on a folder -> use the folder from the previous step directly
    src_dir = os.path.dirname(images_with_scores[0][0])
    score_map = {os.path.basename(p): s for p, s in images_with_scores}

    cnn = CNN()
    duplicates = cnn.find_duplicates(image_dir=src_dir, min_similarity_threshold=threshold)

    # In each duplicate group, keep only the one with the highest id_score
    to_remove = set()
    processed_groups = set()

    for fname, dups in duplicates.items():
        if not dups or fname in processed_groups:
            continue
        group = [fname] + [d for d in dups if d in score_map]
        group = list(set(group))
        if len(group) <= 1:
            continue

        best = max(group, key=lambda g: score_map.get(g, 0))
        for g in group:
            processed_groups.add(g)
            if g != best:
                to_remove.add(g)

    result = []
    for img_path, score in images_with_scores:
        fname = os.path.basename(img_path)
        if fname in to_remove:
            safe_copy(img_path, os.path.join(reject_dir, "05_duplicate"), fname)
            log(f"  [REJECT] duplicate: {fname}", log_f)
        else:
            dst = safe_copy(img_path, output_dir, fname)
            result.append((dst, score))
            log(f"  [PASS]   {fname}", log_f)

    log(f"  => {len(result)}/{len(images_with_scores)} passed ({len(to_remove)} duplicates removed)", log_f)
    return result


# ============================================================
# [6] Diversity report
# ============================================================

def estimate_yaw(kps):
    """Estimate left/right rotation from 5-point landmarks. Front=0, left=-1, right=+1."""
    left_eye, right_eye, nose = kps[0], kps[1], kps[2]
    eye_center = (left_eye + right_eye) / 2
    eye_dist = np.linalg.norm(right_eye - left_eye)
    if eye_dist == 0:
        return 0.0
    return (nose[0] - eye_center[0]) / eye_dist


def step_diversity_report(app, images_with_scores, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[6/7] Diversity report", log_f)
    log(f"{'='*60}", log_f)

    poses = {'front': [], 'three_quarter': [], 'side': []}

    for img_path, _ in images_with_scores:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            continue

        faces = app.get(img)
        if len(faces) == 0:
            poses['front'].append(fname)
            continue

        face = get_largest_face(faces)
        if hasattr(face, 'kps') and face.kps is not None:
            yaw = abs(estimate_yaw(face.kps))
            if yaw < 0.15:
                poses['front'].append(fname)
            elif yaw < 0.35:
                poses['three_quarter'].append(fname)
            else:
                poses['side'].append(fname)
        else:
            poses['front'].append(fname)

    total = len(images_with_scores)
    for pose, files in poses.items():
        pct = len(files) / total * 100 if total > 0 else 0
        log(f"  {pose}: {len(files)} ({pct:.0f}%)", log_f)
        for f in files:
            log(f"    - {f}", log_f)

    if total > 0:
        front_pct = len(poses['front']) / total
        if front_pct > 0.7:
            log(f"\n  [!] Warning: front shots are {front_pct:.0%} of dataset. Add more 3/4 view and side shots.", log_f)
        if len(poses['side']) == 0:
            log(f"  [!] Warning: no side shots. Adding 1-2 is recommended.", log_f)


# ============================================================
# [7] Captioning
# ============================================================

def step_captioning(app, images_with_scores, trigger, gender, output_dir, skip, log_f):
    log(f"\n{'='*60}", log_f)
    log(f"[7/7] Captioning (trigger={trigger}, gender={gender})", log_f)
    log(f"{'='*60}", log_f)

    if skip:
        log("  [SKIP] Skipped via --skip-caption", log_f)
        return

    gender_word = "man" if gender == "man" else "woman"

    for img_path, _ in images_with_scores:
        fname = os.path.basename(img_path)
        img = imread_unicode(img_path)
        if img is None:
            continue

        # Pose detection
        orientation = "front view"
        gaze = "looking at viewer"

        faces = app.get(img)
        if len(faces) > 0:
            face = get_largest_face(faces)
            if hasattr(face, 'kps') and face.kps is not None:
                yaw_raw = estimate_yaw(face.kps)
                yaw = abs(yaw_raw)
                direction = "right" if yaw_raw > 0 else "left"

                if yaw < 0.15:
                    orientation = "front view"
                    gaze = "looking at viewer"
                elif yaw < 0.35:
                    orientation = "three-quarter view"
                    gaze = f"looking slightly {direction}"
                else:
                    orientation = "side view"
                    gaze = f"looking away, from {direction} side"

        caption = f"{trigger}, a young {gender_word}, {gaze}, {orientation}, upper body"

        txt_path = os.path.join(output_dir, Path(fname).stem + ".txt")
        with open(txt_path, 'w', encoding='utf-8') as f:
            f.write(caption)

        log(f"  {fname}", log_f)
        log(f"    -> {caption}", log_f)

    log(f"\n  [INFO] Captions are auto-generated templates.", log_f)
    log(f"  [INFO] Open each .txt and add [hair color], [clothing], [background], etc.", log_f)
    log(f"  [INFO] djdante style: short and objective. No subjective words (beautiful, etc.).", log_f)


# ============================================================
# Main
# ============================================================

def main():
    parser = argparse.ArgumentParser(
        description="LoRA dataset curation pipeline",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument('--input', required=True, help='Source photos folder')
    parser.add_argument('--ref', required=True, help='Reference image (front-facing close-up recommended)')
    parser.add_argument('--trigger', required=True, help='Trigger word (English)')
    parser.add_argument('--gender', required=True, choices=['man', 'woman'], help='Subject gender')
    parser.add_argument('--output', default='./pipeline_output', help='Output folder (default: ./pipeline_output)')
    parser.add_argument('--id-threshold', type=float, default=0.41, help='Identity match threshold (default: 0.41)')
    parser.add_argument('--quality-cut', type=float, default=0.0, help='Quality bottom cutoff (default: 0.0 = disabled)')
    parser.add_argument('--dup-threshold', type=float, default=1.0, help='Duplicate threshold (default: 1.0 = disabled)')
    parser.add_argument('--tone-threshold', type=float, default=3.0, help='Tone outlier z-score threshold (default: 3.0)')
    parser.add_argument('--occlusion-threshold', type=float, default=0.15, help='Face occlusion VDark threshold (default: 0.15)')
    parser.add_argument('--skip-crop', action='store_true', help='Skip crop (already chest-up)')
    parser.add_argument('--skip-caption', action='store_true', help='Skip captioning')

    args = parser.parse_args()

    # Path validation
    if not os.path.isdir(args.input):
        print(f"[ERROR] Input folder not found: {args.input}")
        sys.exit(1)
    if not os.path.isfile(args.ref):
        print(f"[ERROR] Reference image not found: {args.ref}")
        sys.exit(1)

    images = find_images(args.input)
    if not images:
        print(f"[ERROR] No images found: {args.input}")
        sys.exit(1)

    # Output folders
    dirs = {
        'verified':      os.path.join(args.output, '01_verified'),
        'cropped':       os.path.join(args.output, '02_cropped'),
        'quality':       os.path.join(args.output, '03_quality_filtered'),
        'occlusion':     os.path.join(args.output, '03b_occlusion_filtered'),
        'tone_filtered': os.path.join(args.output, '04_tone_filtered'),
        'deduped':       os.path.join(args.output, '05_deduplicated'),
        'final':         os.path.join(args.output, '06_final'),
        'rejected':      os.path.join(args.output, 'rejected'),
    }
    for d in dirs.values():
        os.makedirs(d, exist_ok=True)

    log_path = os.path.join(args.output, 'pipeline_report.txt')
    log_f = open(log_path, 'w', encoding='utf-8')

    log(f"LoRA Dataset Curation Pipeline", log_f)
    log(f"Run: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", log_f)
    log(f"Input: {args.input} ({len(images)} images)", log_f)
    log(f"Reference: {args.ref}", log_f)
    log(f"Trigger: {args.trigger} / Gender: {args.gender}", log_f)

    # Initialize InsightFace
    print("\nLoading InsightFace model...")
    from insightface.app import FaceAnalysis
    face_app = FaceAnalysis(name="buffalo_l", allowed_modules=["detection", "recognition"])
    face_app.prepare(ctx_id=-1, det_size=(640, 640))
    print("Loaded.\n")

    # Run pipeline
    verified = step_identity_check(
        face_app, images, args.ref, args.id_threshold,
        dirs['verified'], dirs['rejected'], log_f
    )
    if not verified:
        log("\n[FATAL] 0 images passed identity verification. Aborting.", log_f)
        log_f.close()
        sys.exit(1)

    cropped = step_chest_crop(
        face_app, verified, dirs['cropped'], args.skip_crop, log_f
    )

    quality_filtered = step_quality_filter(
        face_app, cropped, args.quality_cut,
        dirs['quality'], dirs['rejected'], log_f
    )

    occlusion_filtered = step_occlusion_filter(
        face_app, quality_filtered, args.occlusion_threshold,
        dirs['occlusion'], dirs['rejected'], log_f
    )

    tone_filtered = step_tone_filter(
        occlusion_filtered, args.tone_threshold,
        dirs['tone_filtered'], dirs['rejected'], log_f
    )

    deduped = step_dedup(
        tone_filtered, args.dup_threshold,
        dirs['deduped'], dirs['rejected'], log_f
    )

    # Copy to final folder
    final_images = []
    for img_path, score in deduped:
        dst = safe_copy(img_path, dirs['final'])
        final_images.append((dst, score))

    step_diversity_report(face_app, final_images, log_f)

    step_captioning(
        face_app, final_images, args.trigger, args.gender,
        dirs['final'], args.skip_caption, log_f
    )

    # ==================== Final summary ====================
    log(f"\n{'='*60}", log_f)
    log(f"Final results", log_f)
    log(f"{'='*60}", log_f)
    log(f"  Source:           {len(images)}", log_f)
    log(f"  Identity passed:  {len(verified)}", log_f)
    log(f"  Cropped:          {len(cropped)}", log_f)
    log(f"  Quality passed:   {len(quality_filtered)}", log_f)
    log(f"  Occlusion passed: {len(occlusion_filtered)}", log_f)
    log(f"  Tone passed:      {len(tone_filtered)}", log_f)
    log(f"  After dedup:      {len(deduped)}", log_f)
    log(f"", log_f)
    log(f"  Final dataset: {dirs['final']}", log_f)
    log(f"  Rejected:      {dirs['rejected']}", log_f)
    log(f"  Report:        {log_path}", log_f)

    n = len(deduped)
    if n > 0:
        if n <= 8:
            rr, rs = 4, 3500
        elif n <= 15:
            rr, rs = 2, 4000
        elif n <= 25:
            rr, rs = 1, 4000
        else:
            rr, rs = 1, 4000
        epochs = rs * 2 / (n * rr)
        log(f"\n  AI-Toolkit recommended settings:", log_f)
        log(f"    num_repeats: {rr}", log_f)
        log(f"    steps:       {rs}", log_f)
        log(f"    epochs:      ~{epochs:.0f}", log_f)

    if not args.skip_caption:
        log(f"\n  Next steps:", log_f)
        log(f"    1. Manually augment .txt captions in {dirs['final']} with [hair color], [clothing], [background]", log_f)
        log(f"    2. After augmentation, point AI-Toolkit datasets path to that folder", log_f)

    log_f.close()
    print(f"\nDone! Report: {log_path}")


if __name__ == "__main__":
    main()
