"""Submit Flux-2-Klein workflow JSON to ComfyUI via API.

Converts frontend workflow JSON (with virtual nodes like SetNode/GetNode,
Any Switch, Reroute, bypassed nodes) to API format and submits.

Usage:
    python submit_workflow.py [workflow.json] [--prompt "custom prompt"]
        [--steps 30] [--cfg 3.5] [--negative "blurry, plastic skin"]
"""
import json, urllib.request, uuid, sys, re

WORKFLOW_PATH = (
    sys.argv[1] if len(sys.argv) > 1 and not sys.argv[1].startswith("--")
    else "G:/워크플로우/Flux-2-klein-mzyeoja-v2.json"
)

# Optional CLI overrides
def get_arg(name):
    if name in sys.argv:
        idx = sys.argv.index(name)
        if idx + 1 < len(sys.argv):
            return sys.argv[idx + 1]
    return None

custom_prompt = get_arg("--prompt")
custom_steps = get_arg("--steps")
custom_cfg = get_arg("--cfg")
custom_negative = get_arg("--negative")

with open(WORKFLOW_PATH, "r", encoding="utf-8") as f:
    wf = json.load(f)

nodes = {n["id"]: n for n in wf["nodes"]}
links = {
    l[0]: {"from_node": l[1], "from_slot": l[2], "to_node": l[3], "to_slot": l[4]}
    for l in wf["links"]
}

# nid -> {slot: (from_nid, from_slot)}
node_inputs = {}
for lid, link in links.items():
    tid = link["to_node"]
    node_inputs.setdefault(tid, {})[link["to_slot"]] = (
        link["from_node"],
        link["from_slot"],
    )

# SetNode lookup: variable_name -> node_id
set_nodes = {}
for n in wf["nodes"]:
    if n["type"] == "SetNode":
        varname = n.get("widgets_values", [""])[0]
        if varname:
            set_nodes[varname] = n["id"]

VIRTUAL_TYPES = {"SetNode", "GetNode", "Reroute", "Note", "Fast Groups Bypasser (rgthree)"}


def resolve_source(nid, slot, depth=0):
    """Resolve through virtual/bypassed nodes to find real source."""
    if depth > 20:
        return None
    n = nodes.get(nid)
    if not n:
        return None
    ntype = n.get("type", "")
    mode = n.get("mode", 0)

    if ntype == "GetNode":
        varname = n.get("widgets_values", [""])[0]
        set_nid = set_nodes.get(varname)
        if set_nid:
            inp = node_inputs.get(set_nid, {})
            if 0 in inp:
                return resolve_source(inp[0][0], inp[0][1], depth + 1)
        return None

    if ntype in ("SetNode", "Reroute"):
        inp = node_inputs.get(nid, {})
        if 0 in inp:
            return resolve_source(inp[0][0], inp[0][1], depth + 1)
        return None

    if ntype == "Any Switch (rgthree)":
        inp = node_inputs.get(nid, {})
        for s in sorted(inp.keys()):
            result = resolve_source(inp[s][0], inp[s][1], depth + 1)
            if result:
                return result
        return None

    # Bypassed node (mode=4): pass-through first input
    if mode == 4:
        inp = node_inputs.get(nid, {})
        if 0 in inp:
            return resolve_source(inp[0][0], inp[0][1], depth + 1)
        return None

    return (str(nid), slot)


# Widget specs: maps class_type -> ordered widget parameter names
NODE_SPECS = {
    "KSampler": ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"],
    "ModelSamplingFlux": ["max_shift", "base_shift", "width", "height"],
    "EmptyFlux2LatentImage": ["width", "height", "batch_size"],
    "LoraLoaderModelOnly": ["lora_name", "strength_model"],
    "UNETLoader": ["unet_name", "weight_dtype"],
    "CLIPLoader": ["clip_name", "type"],
    "VAELoader": ["vae_name"],
    "UpscaleModelLoader": ["model_name"],
    "ImageScaleBy": ["upscale_method", "scale_by"],
    "SaveImage": ["filename_prefix"],
    "CLIPTextEncode": ["text"],
    "easy int": ["value"],
    "easy string": ["value"],
    "TextBox1": ["text1"],
    "PerturbedAttention": ["scale", "adaptive_scale", "unet_block", "unet_block_id", "sigma_start", "sigma_end", "rescale", "rescale_mode"],
    "EmptyLatentImagePresets": ["dimensions", "invert", "batch_size"],
    "EmptySD3LatentImage": ["width", "height", "batch_size"],
    "easy promptLine": ["prompt", "start_index", "max_rows", "remove_empty_lines"],
    "easy promptReplace": ["find1", "replace1", "find2", "replace2", "find3", "replace3"],
}

api_prompt = {}

for n in wf["nodes"]:
    nid = n["id"]
    ntype = n.get("type", "")
    mode = n.get("mode", 0)
    if ntype in VIRTUAL_TYPES or mode == 4:
        continue

    inputs = {}
    spec = NODE_SPECS.get(ntype, [])
    wvals = n.get("widgets_values", [])

    # Resolve linked inputs
    inp_links = node_inputs.get(nid, {})
    node_type_inputs = n.get("inputs", [])
    linked_input_names = set()

    for i, inp_def in enumerate(node_type_inputs):
        if inp_def.get("link") is not None and i in inp_links:
            src = resolve_source(inp_links[i][0], inp_links[i][1])
            if src:
                inputs[inp_def["name"]] = [src[0], src[1]]
                linked_input_names.add(inp_def["name"])

    # Map widget values (skip params that are linked)
    wi = 0
    for param in spec:
        if param in linked_input_names:
            continue
        if wi < len(wvals):
            val = wvals[wi]
            # Skip control_after_generate values
            if isinstance(val, str) and val in ("fixed", "increment", "decrement", "randomize"):
                wi += 1
                if wi < len(wvals):
                    val = wvals[wi]
            inputs[param] = val
            wi += 1

    # Fix batch_size overflow
    if ntype == "EmptyFlux2LatentImage":
        if "batch_size" not in inputs or inputs.get("batch_size", 0) > 16:
            inputs["batch_size"] = 1

    # Clean Windows-invalid chars from filename
    if ntype == "SaveImage" and "filename_prefix" in inputs:
        prefix = re.sub(r"%[^%]*%", "", inputs["filename_prefix"])
        prefix = re.sub(r'[:<>"|?*]', "", prefix).strip("_- ")
        inputs["filename_prefix"] = prefix or "mzyeoja"

    # Apply custom prompt override (TextBox1 for v1, easy promptLine for v2)
    if ntype == "TextBox1" and custom_prompt:
        inputs["text1"] = custom_prompt
    if ntype == "easy promptLine" and custom_prompt:
        inputs["text"] = custom_prompt

    # Apply KSampler overrides (any active KSampler)
    if ntype == "KSampler":
        if custom_steps:
            inputs["steps"] = int(custom_steps)
        if custom_cfg:
            inputs["cfg"] = float(custom_cfg)

    # Apply negative prompt override (node 100 for v1, node 60 for v2)
    if ntype == "CLIPTextEncode" and nid in (60, 100) and custom_negative:
        inputs["text"] = custom_negative

    api_prompt[str(nid)] = {"class_type": ntype, "inputs": inputs}

# Submit
payload = {"prompt": api_prompt, "client_id": str(uuid.uuid4())}
data = json.dumps(payload).encode("utf-8")
req = urllib.request.Request(
    "http://127.0.0.1:8188/prompt",
    data=data,
    headers={"Content-Type": "application/json"},
)

try:
    with urllib.request.urlopen(req) as resp:
        result = json.loads(resp.read())
        print(f"[OK] Workflow queued!")
        print(f"  prompt_id: {result.get('prompt_id')}")
        print(f"  number:    {result.get('number')}")
except urllib.error.HTTPError as e:
    err = e.read().decode("utf-8")
    print(f"[ERROR] HTTP {e.code}: {err[:1000]}")
except Exception as e:
    print(f"[ERROR] {e}")
