"""
LoRA Dataset Bible Validation Tool
====================================
Comprehensively validates a dataset against lora_dataset_bible.md and emits a report.
Independent of dataset_pipeline.py (used as final validation after the pipeline).

Usage:
  python dataset_validator.py --dataset <dataset_folder> [--name <name>]
  python dataset_validator.py --dataset G:\StabilityMatrix\Packages\AI-Toolkit\datasets\joseonnamja --name joseonnamja

Required packages (facecheck env):
  pip install insightface onnxruntime opencv-python "numpy>=1.26,<2.0"
"""

import argparse
import os
import sys
import json
import cv2
import numpy as np
from pathlib import Path
from datetime import datetime

IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}

# ============================================================
# Bible criteria constants
# ============================================================

# Hard NO criteria
MIN_RESOLUTION = 512  # Minimum px on shortest side
MIN_SHARPNESS = 30    # Laplacian variance (beauty filter: 12~25, normal: hundreds~thousands)
MAX_FACES = 1         # Single subject only

# Diversity 5-axis minimum count
MIN_ANGLES = 2        # Angle types (front/side/back-side)
MIN_DISTANCES = 2     # Distance types (close-up/upper body/full body)
MIN_CLOSEUPS = 5      # Minimum number of close-ups

# Image count range
MIN_IMAGES = 10
MAX_IMAGES = 50
OPTIMAL_MAX = 30

# num_repeats threshold
MAX_SAFE_REPEATS = 3

# Forbidden caption words
FORBIDDEN_CAPTION_WORDS = [
    # Abstract quality words
    "high quality", "best quality", "masterpiece", "8k", "4k", "uhd",
    "high resolution", "high-resolution", "hq",
    # Abstract lighting words
    "soft lighting", "cinematic lighting", "dramatic lighting",
    "perfect lighting", "beautiful lighting",
    # Subjective evaluation
    "beautiful", "attractive", "pretty", "gorgeous", "stunning",
    "handsome", "cute", "sexy",
    # Medium quality
    "realistic", "photorealistic", "hyperrealistic", "ultra realistic",
    "photo-realistic", "lifelike",
    "detailed", "highly detailed", "ultra detailed", "sharp focus",
    # Other
    "masterwork", "professional photo", "award winning",
    "raw photo",  # debated, but unnecessary by djdante standard
]

# Fixed identity words (warn if present in caption)
IDENTITY_CAPTION_WARNINGS = [
    "brown eyes", "blue eyes", "green eyes", "dark eyes", "black eyes",
    "round face", "oval face", "square jaw", "sharp jaw",
    "small nose", "big nose", "high cheekbones",
    "slim build", "athletic build", "muscular build",
    # Innate traits in their natural state
    "natural black hair", "natural brown hair",  # natural hair color, not dyed
]


def imread_unicode(path):
    """Image load with unicode/non-ASCII path support."""
    buf = np.fromfile(str(path), dtype=np.uint8)
    img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
    return img


# ============================================================
# Image analysis
# ============================================================

def analyze_image(img_path, app):
    """Analyze a single image against the bible criteria."""
    result = {
        'path': str(img_path),
        'fname': os.path.basename(img_path),
        'issues': [],       # Hard NO violations
        'warnings': [],     # Cautions
        'metrics': {},
    }

    img = imread_unicode(img_path)
    if img is None:
        result['issues'].append("Failed to load image")
        return result

    h, w = img.shape[:2]
    result['metrics']['width'] = w
    result['metrics']['height'] = h
    result['metrics']['resolution'] = f"{w}x{h}"

    # 1. Resolution check
    min_side = min(h, w)
    if min_side < MIN_RESOLUTION:
        result['issues'].append(f"Low resolution: {min_side}px < {MIN_RESOLUTION}px (hard NO)")
    result['metrics']['min_side'] = min_side

    # 2. Sharpness (Laplacian variance)
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    sharpness = cv2.Laplacian(gray, cv2.CV_64F).var()
    result['metrics']['sharpness'] = round(sharpness, 1)

    if sharpness < MIN_SHARPNESS:
        result['issues'].append(
            f"Extremely low sharpness: {sharpness:.0f} (threshold: {MIN_SHARPNESS}+). "
            "Suspect beauty filter / skin smoothing"
        )
    elif sharpness < 100:
        result['warnings'].append(f"Low sharpness: {sharpness:.0f} (typical photos: hundreds~thousands)")

    # 3. Face detection
    faces = app.get(img)
    result['metrics']['face_count'] = len(faces)

    if len(faces) == 0:
        result['issues'].append("No face detected")
    elif len(faces) > MAX_FACES:
        result['issues'].append(f"Multiple faces detected: {len(faces)} people (hard NO)")
    else:
        face = max(faces, key=lambda f: (f.bbox[2]-f.bbox[0])*(f.bbox[3]-f.bbox[1]))
        det_score = float(face.det_score)
        result['metrics']['det_score'] = round(det_score, 3)

        # Face ratio
        face_w = face.bbox[2] - face.bbox[0]
        face_h = face.bbox[3] - face.bbox[1]
        face_area = face_w * face_h
        img_area = h * w
        face_ratio = face_area / img_area
        result['metrics']['face_ratio'] = round(face_ratio, 3)
        result['metrics']['face_px'] = f"{int(face_w)}x{int(face_h)}"

        # Face size classification (distance)
        if face_ratio > 0.25:
            result['metrics']['distance'] = 'closeup'
        elif face_ratio > 0.08:
            result['metrics']['distance'] = 'upper_body'
        else:
            result['metrics']['distance'] = 'full_body'

        if face_w < 80 or face_h < 80:
            result['warnings'].append(
                f"Face is very small: {int(face_w)}x{int(face_h)}px. "
                "Hard to learn facial detail"
            )

        # Gaze/angle estimation (yaw-based)
        if hasattr(face, 'pose') and face.pose is not None:
            yaw = abs(face.pose[1])  # Left/right rotation
            pitch = abs(face.pose[0])  # Up/down
            result['metrics']['yaw'] = round(float(yaw), 1)
            result['metrics']['pitch'] = round(float(pitch), 1)

            if yaw < 15:
                result['metrics']['angle'] = 'front'
            elif yaw < 35:
                result['metrics']['angle'] = 'slight_side'
            elif yaw < 55:
                result['metrics']['angle'] = 'side'
            elif yaw < 75:
                result['metrics']['angle'] = 'back_side'
            else:
                result['metrics']['angle'] = 'back'
                result['warnings'].append(
                    f"Extreme side/back-of-head: yaw={yaw:.0f}deg. "
                    "70deg+ pure back-of-head wastes ROI"
                )

    # 4. JPEG compression artifact estimation
    ext = os.path.splitext(img_path)[1].lower()
    if ext in ('.jpg', '.jpeg'):
        # Rough DCT-based block-noise estimation
        result['metrics']['format'] = 'JPEG'
        if os.path.getsize(img_path) < 50_000 and min_side >= 512:
            result['warnings'].append("Very small JPEG file size — suspect heavy compression")
    else:
        result['metrics']['format'] = ext.upper().lstrip('.')

    return result


def analyze_caption(txt_path):
    """Validate a caption file against the bible criteria."""
    result = {
        'path': str(txt_path),
        'fname': os.path.basename(txt_path),
        'issues': [],
        'warnings': [],
        'caption': '',
    }

    if not os.path.exists(txt_path):
        result['issues'].append("Caption file missing")
        return result

    with open(txt_path, 'r', encoding='utf-8') as f:
        caption = f.read().strip()

    result['caption'] = caption

    if not caption:
        result['issues'].append("Caption is empty")
        return result

    caption_lower = caption.lower()

    # Forbidden word check
    for word in FORBIDDEN_CAPTION_WORDS:
        if word in caption_lower:
            result['issues'].append(f"Forbidden caption word: \"{word}\"")

    # Fixed identity description warning
    for word in IDENTITY_CAPTION_WARNINGS:
        if word in caption_lower:
            result['warnings'].append(
                f"Suspected fixed identity description: \"{word}\" — "
                "Consensus is to omit innate traits from captions "
                "(OK if variable, e.g. dyed hair / contacts)"
            )

    # Caption structure check (djdante format)
    if ',' not in caption:
        result['warnings'].append("No commas in caption — verify djdante format (trigger, gaze, view, clothing, bg)")

    if len(caption) < 20:
        result['warnings'].append(f"Caption too short ({len(caption)} chars) — likely lacking concrete description")

    return result


# ============================================================
# Diversity analysis
# ============================================================

def analyze_diversity(image_results):
    """Analyze whole-dataset diversity across the 5 axes."""
    report = {
        'total': len(image_results),
        'axes': {},
        'issues': [],
        'warnings': [],
    }

    # Angle distribution
    angles = {}
    for r in image_results:
        angle = r['metrics'].get('angle', 'unknown')
        angles[angle] = angles.get(angle, 0) + 1
    report['axes']['angles'] = angles

    if len([a for a in angles if a != 'unknown']) < MIN_ANGLES:
        report['issues'].append(
            f"Insufficient angle diversity: {len(angles)} types (need >= {MIN_ANGLES}). "
            "Mix of front/side/back-side required"
        )

    # Distance distribution
    distances = {}
    for r in image_results:
        dist = r['metrics'].get('distance', 'unknown')
        distances[dist] = distances.get(dist, 0) + 1
    report['axes']['distances'] = distances

    closeup_count = distances.get('closeup', 0)
    if closeup_count < MIN_CLOSEUPS:
        report['issues'].append(
            f"Not enough close-ups: {closeup_count} (need >= {MIN_CLOSEUPS}). "
            "Required for facial detail learning"
        )

    if len([d for d in distances if d != 'unknown']) < MIN_DISTANCES:
        report['issues'].append(
            f"Insufficient distance diversity: {len(distances)} types (need >= {MIN_DISTANCES})"
        )

    # Sharpness distribution
    sharpness_values = [r['metrics'].get('sharpness', 0) for r in image_results]
    if sharpness_values:
        report['axes']['sharpness'] = {
            'min': round(min(sharpness_values), 1),
            'max': round(max(sharpness_values), 1),
            'mean': round(np.mean(sharpness_values), 1),
            'std': round(np.std(sharpness_values), 1),
        }

    # Image count check
    n = len(image_results)
    if n < MIN_IMAGES:
        report['issues'].append(f"Not enough images: {n} (need >= {MIN_IMAGES})")
    elif n > MAX_IMAGES:
        report['warnings'].append(
            f"Too many images: {n} (50+ is reported to hurt performance). "
            "Re-check whether all are truly high quality"
        )
    elif n > OPTIMAL_MAX:
        report['warnings'].append(f"{n} images — under 30 is optimal. Verify quality wasn't compromised")

    return report


# ============================================================
# Exposure budget calculation
# ============================================================

def check_exposure(n_images, num_repeats, steps):
    """Validate total exposure against djdante criterion."""
    exposure = n_images * num_repeats * steps
    target = 100_000

    result = {
        'images': n_images,
        'repeats': num_repeats,
        'steps': steps,
        'total_exposure': exposure,
        'target': target,
        'ratio': round(exposure / target, 2),
        'issues': [],
        'warnings': [],
    }

    if num_repeats > MAX_SAFE_REPEATS:
        result['warnings'].append(
            f"num_repeats={num_repeats} (risky). "
            f"Recommend <= {MAX_SAFE_REPEATS}. Repeating the same image leans toward overfitting"
        )

    if exposure < target * 0.5:
        result['issues'].append(
            f"Insufficient total exposure: {exposure:,} ({exposure/target:.0%} of target). Risk of undertraining"
        )
    elif exposure > target * 1.5:
        result['warnings'].append(
            f"Excessive total exposure: {exposure:,} ({exposure/target:.0%} of target). Watch for overfitting"
        )

    return result


# ============================================================
# Report generation
# ============================================================

def generate_report(name, image_results, caption_results, diversity, exposure=None):
    """Generate the full validation report."""
    lines = []
    critical = 0
    warning = 0

    lines.append("=" * 70)
    lines.append(f"  LoRA Dataset Bible Validation Report: {name}")
    lines.append(f"  Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    lines.append("=" * 70)

    # --- Summary ---
    img_issues = sum(len(r['issues']) for r in image_results)
    img_warnings = sum(len(r['warnings']) for r in image_results)
    cap_issues = sum(len(r['issues']) for r in caption_results)
    cap_warnings = sum(len(r['warnings']) for r in caption_results)
    div_issues = len(diversity.get('issues', []))
    div_warnings = len(diversity.get('warnings', []))

    total_issues = img_issues + cap_issues + div_issues
    total_warnings = img_warnings + cap_warnings + div_warnings

    if exposure:
        total_issues += len(exposure.get('issues', []))
        total_warnings += len(exposure.get('warnings', []))

    lines.append("")
    lines.append(f"  Total images: {len(image_results)}")

    status = "PASS" if total_issues == 0 else "FAIL"
    status_icon = "[O]" if total_issues == 0 else "[X]"
    lines.append(f"  Verdict: {status_icon} {status}")
    lines.append(f"  Violations: {total_issues}  /  Warnings: {total_warnings}")

    # --- Per-image detail ---
    lines.append("")
    lines.append("-" * 70)
    lines.append("  [Image validation]")
    lines.append("-" * 70)

    for r in image_results:
        m = r['metrics']
        flag = "[X]" if r['issues'] else "[!]" if r['warnings'] else "[O]"
        dist = m.get('distance', '?')
        angle = m.get('angle', '?')
        sharp = m.get('sharpness', 0)

        lines.append(f"  {flag} {r['fname']}")
        lines.append(f"      res={m.get('resolution','?')}  sharp={sharp:.0f}  "
                      f"face={m.get('face_px','?')}({m.get('face_ratio',0):.1%})  "
                      f"distance={dist}  angle={angle}")

        for issue in r['issues']:
            lines.append(f"      [X violation] {issue}")
        for warn in r['warnings']:
            lines.append(f"      [! warning] {warn}")

    # --- Caption validation ---
    lines.append("")
    lines.append("-" * 70)
    lines.append("  [Caption validation]")
    lines.append("-" * 70)

    for r in caption_results:
        flag = "[X]" if r['issues'] else "[!]" if r['warnings'] else "[O]"
        lines.append(f"  {flag} {r['fname']}")
        if r['caption']:
            display = r['caption'][:80] + "..." if len(r['caption']) > 80 else r['caption']
            lines.append(f"      \"{display}\"")
        for issue in r['issues']:
            lines.append(f"      [X violation] {issue}")
        for warn in r['warnings']:
            lines.append(f"      [! warning] {warn}")

    # --- Diversity ---
    lines.append("")
    lines.append("-" * 70)
    lines.append("  [Diversity analysis (5 axes)]")
    lines.append("-" * 70)

    angles = diversity['axes'].get('angles', {})
    lines.append(f"  Angle distribution:    {dict(angles)}")

    distances = diversity['axes'].get('distances', {})
    lines.append(f"  Distance distribution: {dict(distances)}")

    sharp_stats = diversity['axes'].get('sharpness', {})
    if sharp_stats:
        lines.append(f"  Sharpness:             min={sharp_stats['min']}  max={sharp_stats['max']}  "
                      f"mean={sharp_stats['mean']}  std={sharp_stats['std']}")

    for issue in diversity.get('issues', []):
        lines.append(f"  [X violation] {issue}")
    for warn in diversity.get('warnings', []):
        lines.append(f"  [! warning] {warn}")

    # --- Exposure budget ---
    if exposure:
        lines.append("")
        lines.append("-" * 70)
        lines.append("  [Exposure budget validation]")
        lines.append("-" * 70)
        lines.append(f"  images={exposure['images']}  x  repeats={exposure['repeats']}  "
                      f"x  steps={exposure['steps']}  =  {exposure['total_exposure']:,}")
        lines.append(f"  vs djdante target: {exposure['ratio']:.0%}")
        for issue in exposure.get('issues', []):
            lines.append(f"  [X violation] {issue}")
        for warn in exposure.get('warnings', []):
            lines.append(f"  [! warning] {warn}")

    # --- Bible checklist ---
    lines.append("")
    lines.append("-" * 70)
    lines.append("  [Bible checklist]")
    lines.append("-" * 70)

    n = len(image_results)
    checks = [
        (MIN_IMAGES <= n <= OPTIMAL_MAX, f"Image count {n} (10~30 optimal)"),
        (img_issues == 0, f"All 5 hard-NOs eliminated (violations: {img_issues})"),
        (div_issues == 0, f"5 diversity axes met (violations: {div_issues})"),
        (cap_issues == 0, f"No forbidden caption words (violations: {cap_issues})"),
    ]

    for ok, desc in checks:
        mark = "[v]" if ok else "[ ]"
        lines.append(f"  {mark} {desc}")

    lines.append("")
    lines.append("=" * 70)

    return "\n".join(lines)


# ============================================================
# Main
# ============================================================

def main():
    parser = argparse.ArgumentParser(description="LoRA Dataset Bible validation")
    parser.add_argument("--dataset", required=True, help="Dataset folder path")
    parser.add_argument("--name", default="dataset", help="Dataset name")
    parser.add_argument("--repeats", type=int, default=1, help="num_repeats value")
    parser.add_argument("--steps", type=int, default=3000, help="Training step count")
    parser.add_argument("--output", help="Report save path (prints to console if omitted)")
    args = parser.parse_args()

    dataset_dir = args.dataset
    if not os.path.isdir(dataset_dir):
        print(f"[ERROR] Folder not found: {dataset_dir}")
        sys.exit(1)

    # Image list
    image_files = sorted([
        os.path.join(dataset_dir, f)
        for f in os.listdir(dataset_dir)
        if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS
    ])

    if not image_files:
        print(f"[ERROR] No images found: {dataset_dir}")
        sys.exit(1)

    print(f"Dataset: {dataset_dir}")
    print(f"Images: {len(image_files)}")
    print("Loading ArcFace model...")

    # Initialize InsightFace
    import insightface
    app = insightface.app.FaceAnalysis(
        name='buffalo_l',
        providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
    )
    app.prepare(ctx_id=0, det_size=(640, 640))

    print("Analyzing images...")
    image_results = []
    for img_path in image_files:
        r = analyze_image(img_path, app)
        image_results.append(r)
        # Progress indicator
        flag = "X" if r['issues'] else "!" if r['warnings'] else "O"
        print(f"  [{flag}] {r['fname']}")

    print("Validating captions...")
    caption_results = []
    for img_path in image_files:
        txt_path = os.path.splitext(img_path)[0] + '.txt'
        r = analyze_caption(txt_path)
        caption_results.append(r)

    print("Analyzing diversity...")
    diversity = analyze_diversity(image_results)

    exposure = check_exposure(len(image_files), args.repeats, args.steps)

    report = generate_report(args.name, image_results, caption_results, diversity, exposure)

    if args.output:
        with open(args.output, 'w', encoding='utf-8') as f:
            f.write(report)
        print(f"\nReport saved: {args.output}")
    else:
        print("\n" + report)

    # Save JSON (for programmatic use)
    json_path = os.path.join(
        dataset_dir,
        f"validation_{args.name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    )
    json_data = {
        'name': args.name,
        'timestamp': datetime.now().isoformat(),
        'total_images': len(image_results),
        'total_issues': sum(len(r['issues']) for r in image_results)
                        + sum(len(r['issues']) for r in caption_results)
                        + len(diversity.get('issues', [])),
        'total_warnings': sum(len(r['warnings']) for r in image_results)
                          + sum(len(r['warnings']) for r in caption_results)
                          + len(diversity.get('warnings', [])),
        'images': image_results,
        'captions': [{k: v for k, v in r.items()} for r in caption_results],
        'diversity': diversity,
        'exposure': exposure,
    }
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, ensure_ascii=False, indent=2)
    print(f"JSON saved: {json_path}")


if __name__ == "__main__":
    main()
