# lora_merge_ui.py
# LoRA Merger (Primary + Secondary) for .safetensors
#
# - Outputs ONLY into a subfolder:  .\merge-files\
# - Auto-numbered outputs: merged_lora001.safetensors, merged_lora002...
# - Counter stored inside merge-files\merge_counter.txt
# - If counter is missing, it scans existing merged_lora### files in merge-files and continues from highest
# - Prevents Windows file-lock overwrite errors (Invoke holding old merged files open)
# - Offline-friendly: local URL only
#
# Works best when both LoRAs are from the SAME model family (Z-Image+Z-Image, Flux+Flux, etc.)

import os
import re
import time
from typing import Dict, Tuple

import torch
import gradio as gr
from safetensors.torch import load_file, save_file

ROOT_DIR = os.path.dirname(os.path.abspath(__file__))

OUTPUT_DIR_NAME = "merge-files"  # your folder
OUTPUT_DIR = os.path.join(ROOT_DIR, OUTPUT_DIR_NAME)
os.makedirs(OUTPUT_DIR, exist_ok=True)

PREFIX = "merged_lora"
DIGITS = 3  # 001..999
COUNTER_FILE = os.path.join(OUTPUT_DIR, "merge_counter.txt")
PATTERN = re.compile(rf"^{re.escape(PREFIX)}(\d{{{DIGITS}}})\.safetensors$", re.IGNORECASE)


def _path_from_gradio_file(f) -> str:
    if f is None:
        return ""
    if isinstance(f, str):
        return f
    p = getattr(f, "name", "")
    if p:
        return p
    if isinstance(f, dict) and "name" in f:
        return f["name"]
    return ""


def _scan_highest_existing_number() -> int:
    highest = 0
    try:
        for fn in os.listdir(OUTPUT_DIR):
            m = PATTERN.match(fn)
            if m:
                n = int(m.group(1))
                highest = max(highest, n)
    except Exception:
        pass
    return highest


def _read_counter() -> int:
    try:
        with open(COUNTER_FILE, "r", encoding="utf-8") as fp:
            n = int(fp.read().strip())
            return max(n, 0)
    except Exception:
        return _scan_highest_existing_number()


def _write_counter(n: int) -> None:
    with open(COUNTER_FILE, "w", encoding="utf-8") as fp:
        fp.write(str(int(n)))


def _next_output() -> Tuple[str, str, int]:
    n = _read_counter() + 1
    out_name = f"{PREFIX}{n:0{DIGITS}d}.safetensors"
    out_path = os.path.join(OUTPUT_DIR, out_name)
    _write_counter(n)
    return out_path, out_name, n


def merge_primary_secondary(
    primary_path: str,
    secondary_path: str,
    out_path: str,
    secondary_strength: float,
    out_dtype: str = "fp16",
) -> Tuple[int, int, float]:
    """Returns: (merged_count, total_primary_tensors, seconds)"""
    t0 = time.time()

    if secondary_strength < 0:
        raise ValueError("Secondary strength must be >= 0")

    if out_dtype.lower() in ("fp16", "float16"):
        dtype = torch.float16
    elif out_dtype.lower() in ("bf16", "bfloat16"):
        dtype = torch.bfloat16
    else:
        dtype = torch.float32

    p: Dict[str, torch.Tensor] = load_file(primary_path)
    s: Dict[str, torch.Tensor] = load_file(secondary_path)

    out: Dict[str, torch.Tensor] = {}
    merged = 0

    for k, pv in p.items():
        sv = s.get(k, None)
        if sv is not None and tuple(sv.shape) == tuple(pv.shape) and pv.dtype.is_floating_point:
            pv32 = pv.to("cpu", dtype=torch.float32)
            sv32 = sv.to("cpu", dtype=torch.float32)
            ov = pv32 + (sv32 * float(secondary_strength))
            out[k] = ov.to("cpu", dtype=dtype)
            merged += 1
        else:
            out[k] = pv.to("cpu", dtype=(dtype if pv.dtype.is_floating_point else pv.dtype))

    meta = {
        "merged_from": f"{os.path.basename(primary_path)} + {os.path.basename(secondary_path)}",
        "secondary_strength": str(secondary_strength),
        "method": "universal_key_match",
    }
    save_file(out, out_path, metadata=meta)

    secs = time.time() - t0
    return merged, len(p), secs


def swap_files(primary, secondary):
    return secondary, primary


def ui_merge(primary_file, secondary_file, secondary_strength, out_dtype):
    p = _path_from_gradio_file(primary_file)
    s = _path_from_gradio_file(secondary_file)

    if not p or not s:
        return "Pick BOTH files first.", "", ""

    out_path, out_name, _ = _next_output()

    merged, total, secs = merge_primary_secondary(
        primary_path=p,
        secondary_path=s,
        out_path=out_path,
        secondary_strength=float(secondary_strength),
        out_dtype=out_dtype,
    )

    if merged == 0:
        return (
            f"DONE, but merged 0/{total} tensors.\n"
            "These two LoRAs likely do NOT match (different model family or different internal keys).\n"
            f"Saved in {OUTPUT_DIR_NAME}: {out_name}",
            out_path,
            out_name,
        )

    return (
        f"DONE. Merged {merged}/{total} tensors in {secs:.2f}s.\n"
        f"Saved in {OUTPUT_DIR_NAME}: {out_name}",
        out_path,
        out_name,
    )


def main():
    with gr.Blocks(title="LoRA Merger (Primary + Secondary)") as demo:
        gr.Markdown(
            "## LoRA Merger (Primary + Secondary)\n"
            "- Output is saved ONLY in the **merge-files** folder.\n"
            "- Output auto-names: merged_lora001.safetensors, merged_lora002, etc.\n"
            "- If you want to reverse order, click **Swap**."
        )

        with gr.Row():
            primary = gr.File(label="Primary LoRA (.safetensors)", file_types=[".safetensors"])
            secondary = gr.File(label="Secondary LoRA (.safetensors)", file_types=[".safetensors"])

        with gr.Row():
            swap_btn = gr.Button("Swap Primary / Secondary")

        secondary_strength = gr.Slider(
            label="Secondary strength (applied onto Primary)",
            minimum=0.0,
            maximum=2.0,
            value=0.35,
            step=0.01,
        )

        out_dtype = gr.Dropdown(["fp16", "bf16", "fp32"], value="fp16", label="Output dtype")

        merge_btn = gr.Button("Merge (creates new file each time)")

        status = gr.Textbox(label="Status", interactive=False)
        saved_path = gr.Textbox(label="Full output path", interactive=False)
        saved_name = gr.Textbox(label="Output filename", interactive=False)

        swap_btn.click(fn=swap_files, inputs=[primary, secondary], outputs=[primary, secondary])
        merge_btn.click(
            fn=ui_merge,
            inputs=[primary, secondary, secondary_strength, out_dtype],
            outputs=[status, saved_path, saved_name],
        )

    demo.launch(inbrowser=True)


if __name__ == "__main__":
    main()
