#!/usr/bin/env python3
"""
LoRA Converter for z-image-turbo / Lumina2 (GUI + Batch Mode)
Converts separate to_q/to_k/to_v LoRA weights to merged qkv format
"""

import os
import sys
import traceback
from collections import defaultdict

# Check if we're in batch mode first (before importing heavy GUI dependencies)
BATCH_MODE = False
INPUT_DIR = None

if len(sys.argv) > 1:
    arg_path = sys.argv[1]
    if os.path.isdir(arg_path):
        BATCH_MODE = True
        INPUT_DIR = arg_path

# Import core dependencies needed for both modes
try:
    import torch
    from safetensors.torch import load_file, save_file
    from safetensors import safe_open
except ImportError as e:
    if BATCH_MODE:
        print(f"Dependency Missing: {e}")
        print("Please run: pip install torch safetensors")
        sys.exit(1)
    else:
        # Try to show GUI error if possible
        try:
            import tkinter as tk
            from tkinter import messagebox
            root = tk.Tk()
            root.withdraw()
            messagebox.showerror("Dependency Missing", f"Required libraries missing: {e}\n\nPlease run in terminal:\npip install torch safetensors")
        except:
            print(f"Dependency Missing: {e}")
            print("Please run: pip install torch safetensors")
        sys.exit(1)

def convert_lora(input_path, output_path, progress_callback=None):
    """Core conversion function used by both GUI and batch modes"""
    if progress_callback:
        progress_callback(0, "Loading LoRA file...")

    if input_path.endswith('.safetensors'):
        lora_dict = load_file(input_path)
    elif input_path.endswith(('.pt', '.pth')):
        lora_dict = torch.load(input_path, map_location='cpu')
    else:
        raise ValueError("Only .safetensors / .pt / .pth files supported")

    total_keys = len(lora_dict)
    if progress_callback:
        progress_callback(10, f"Loading complete, found {total_keys} keys")

    layer_groups = defaultdict(lambda: defaultdict(dict))
    output_dict = {}
    converted_count = 0

    # Collect alpha values (to be scaled by 3x later)
    alpha_values = {}
    for k, v in lora_dict.items():
        if '.alpha' in k and not k.endswith('.weight'):
            alpha_values[k] = v

    processed = 0
    for key, value in lora_dict.items():
        processed += 1
        if progress_callback and processed % 50 == 0:
            progress_callback(10 + 40 * processed // total_keys, f"Parsing keys... ({processed}/{total_keys})")

        # 1. Rename to_out.0 → out (Lumina2 compatibility)
        if '.attention.to_out.0.' in key:
            new_key = key.replace('.to_out.0.', '.out.')
            output_dict[new_key] = value
            if 'lora_A' in key:
                base = key.rsplit('.lora_A', 1)[0]
                alpha_key = f"{base}.alpha"
                if alpha_key in lora_dict:
                    new_alpha = alpha_key.replace('.to_out.0.', '.out.')
                    output_dict[new_alpha] = lora_dict[alpha_key]
            continue

        # 2. Skip individual alphas (handled later)
        if '.attention.to_' in key and '.alpha' in key:
            continue

        # 3. Collect q/k/v attention layers for merging
        if '.attention.to_' in key and any(x in key for x in ('.to_q.', '.to_k.', '.to_v.')):
            parts = key.split('.')
            layer_idx = None
            attn_type = None
            lora_type = None

            for i, p in enumerate(parts):
                if p == 'layers' and i + 1 < len(parts):
                    layer_idx = parts[i + 1]
                elif p in ('to_q', 'to_k', 'to_v'):
                    attn_type = p[3:]           # q / k / v
                elif p in ('lora_A', 'lora_B'):
                    lora_type = p

            if layer_idx and attn_type and lora_type:
                # Construct base_key without to_q/to_k/to_v
                base_parts = []
                skip = False
                for p in parts:
                    if p in ('to_q', 'to_k', 'to_v'):
                        skip = True
                        continue
                    if skip:
                        base_parts.append(p)
                        skip = False
                    else:
                        base_parts.append(p)
                base_key = '.'.join(base_parts[:-2])  # Remove lora_X.weight
                layer_groups[base_key][attn_type][lora_type] = value
                continue

        # Copy other keys directly
        output_dict[key] = value

    if progress_callback:
        progress_callback(60, f"Found {len(layer_groups)} attention layers to merge")

    # ==================== Merge qkv ====================
    step = 30.0 / max(len(layer_groups), 1)
    current = 0

    for base_key, qkv_dict in layer_groups.items():
        current += 1
        if progress_callback:
            progress_callback(60 + step * current, f"Merging layer {current}/{len(layer_groups)}")

        if not all(x in qkv_dict for x in ('q', 'k', 'v')):
            continue

        qB = qkv_dict['q'].get('lora_B')
        kB = qkv_dict['k'].get('lora_B')
        vB = qkv_dict['v'].get('lora_B')
        qA = qkv_dict['q'].get('lora_A')
        kA = qkv_dict['k'].get('lora_A')
        vA = qkv_dict['v'].get('lora_A')

        if None in (qB, kB, vB, qA, kA, vA):
            continue

        try:
            assert qB.shape == kB.shape == vB.shape
            assert qA.shape == kA.shape == vA.shape
            hidden_dim, rank = qB.shape

            # Construct block-diagonal lora_B
            qkv_B = torch.zeros(3 * hidden_dim, 3 * rank, dtype=qB.dtype)
            qkv_B[:hidden_dim, :rank]                 = qB
            qkv_B[hidden_dim:2*hidden_dim, rank:2*rank] = kB
            qkv_B[2*hidden_dim:, 2*rank:]             = vB

            # Vertically stack lora_A
            qkv_A = torch.cat([qA, kA, vA], dim=0)

            output_dict[f"{base_key}.qkv.lora_B.weight"] = qkv_B
            output_dict[f"{base_key}.qkv.lora_A.weight"] = qkv_A
            converted_count += 1

            # Alpha handling: original alpha/rank → new alpha/(rank*3) → scale alpha by 3x
            alpha_key_q = f"{base_key}.to_q.alpha"
            orig_alpha = lora_dict.get(alpha_key_q) or lora_dict.get(f"{base_key}.to_q.lora_A.alpha")
            if orig_alpha is not None:
                output_dict[f"{base_key}.qkv.alpha"] = orig_alpha * 3.0

        except Exception as e:
            error_msg = f"Merge failed for {base_key}: {e}"
            if progress_callback:
                progress_callback(-1, error_msg)
            else:
                print(error_msg)
            traceback.print_exc()

    if progress_callback:
        progress_callback(95, "Saving file...")

    # Metadata
    metadata = {}
    if input_path.endswith('.safetensors'):
        try:
            with safe_open(input_path, framework="pt", device="cpu") as f:
                metadata = f.metadata() or {}
        except:
            pass
    metadata['converted_for'] = 'z-image-turbo / Lumina2 (GUI converter)'
    metadata['conversion_script'] = 'convert_lora_to_zimage_gui.py'

    save_file(output_dict, output_path, metadata=metadata)

    if progress_callback:
        progress_callback(100, f"Done! Merged {converted_count} attention layers")

    return converted_count

def batch_convert(input_dir):
    """Process all compatible LoRA files in a directory"""
    supported_ext = ['.safetensors', '.pt', '.pth']
    files = [
        f for f in os.listdir(input_dir)
        if os.path.isfile(os.path.join(input_dir, f)) 
        and os.path.splitext(f)[1].lower() in supported_ext
    ]
    
    if not files:
        print(f"No compatible files found in {input_dir}")
        print(f"Supported formats: {', '.join(supported_ext)}")
        return
    
    print(f"Found {len(files)} files to convert in: {input_dir}")
    print("-" * 60)
    
    success_count = 0
    fail_count = 0
    
    for i, filename in enumerate(files, 1):
        input_path = os.path.join(input_dir, filename)
        base, ext = os.path.splitext(filename)
        output_path = os.path.join(input_dir, f"{base}_zimage{ext}")
        
        # Skip if output already exists
        if os.path.exists(output_path):
            print(f"[{i}/{len(files)}] SKIPPED: {filename} (output already exists)")
            continue
        
        print(f"[{i}/{len(files)}] Processing: {filename}")
        try:
            converted = convert_lora(input_path, output_path)
            print(f"  => SUCCESS: Merged {converted} attention layers")
            print(f"  => Output: {os.path.basename(output_path)}")
            success_count += 1
        except Exception as e:
            print(f"  => FAILED: {str(e)}")
            traceback.print_exc()
            fail_count += 1
        print("-" * 60)
    
    print("\nBatch conversion completed!")
    print(f"Successful conversions: {success_count}")
    print(f"Failed conversions:     {fail_count}")
    print(f"Skipped files:          {len(files) - success_count - fail_count}")
    
    # Pause before exiting for Windows users
    if os.name == 'nt':
        input("\nPress Enter to exit...")

# ====================== GUI MODE ======================
def run_gui():
    """Run the graphical interface"""
    import tkinter as tk
    from tkinter import filedialog, messagebox, ttk
    
    class LoRAConverterGUI(tk.Tk):
        def __init__(self):
            super().__init__()
            self.title("LoRA → z-image-turbo / Lumina2 Converter")
            self.geometry("660x360")
            self.resizable(False, False)

            # Styling
            style = ttk.Style(self)
            style.theme_use('clam')

            # Main frame
            main_frame = ttk.Frame(self, padding=20)
            main_frame.pack(fill=tk.BOTH, expand=True)

            ttk.Label(main_frame, text="Zimage(ToKit2Comy)Fok&PAseer", font=("Segoe UI", 16, "bold")).pack(pady=(0, 20))

            # Input file
            input_frame = ttk.Frame(main_frame)
            input_frame.pack(fill=tk.X, pady=8)
            ttk.Label(input_frame, text="Input LoRA:", width=12).pack(side=tk.LEFT)
            self.input_path = tk.StringVar()
            ttk.Entry(input_frame, textvariable=self.input_path, width=50).pack(side=tk.LEFT, padx=5, expand=True, fill=tk.X)
            ttk.Button(input_frame, text="Browse...", command=self.browse_input).pack(side=tk.RIGHT)

            # Output file
            output_frame = ttk.Frame(main_frame)
            output_frame.pack(fill=tk.X, pady=8)
            ttk.Label(output_frame, text="Output path:", width=12).pack(side=tk.LEFT)
            self.output_path = tk.StringVar()
            ttk.Entry(output_frame, textvariable=self.output_path, width=50).pack(side=tk.LEFT, padx=5, expand=True, fill=tk.X)
            ttk.Button(output_frame, text="Browse...", command=self.browse_output).pack(side=tk.RIGHT)

            # Progress bar
            self.progress = ttk.Progressbar(main_frame, mode='determinate')
            self.progress.pack(fill=tk.X, pady=20)

            self.status_label = ttk.Label(main_frame, text="Ready", foreground="gray")
            self.status_label.pack(pady=5)

            # Convert button
            self.convert_btn = ttk.Button(main_frame, text="Start Conversion", command=self.start_conversion)
            self.convert_btn.pack(pady=10)

            # Log area
            log_frame = ttk.LabelFrame(main_frame, text="Log")
            log_frame.pack(fill=tk.BOTH, expand=True, pady=10)
            self.log_text = tk.Text(log_frame, height=6, state='disabled', font=("Consolas", 9))
            scrollbar = ttk.Scrollbar(log_frame, orient="vertical", command=self.log_text.yview)
            self.log_text.configure(yscrollcommand=scrollbar.set)
            self.log_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
            scrollbar.pack(side=tk.RIGHT, fill=tk.Y)

        def log(self, msg):
            self.log_text.config(state='normal')
            self.log_text.insert(tk.END, msg + "\n")
            self.log_text.see(tk.END)
            self.log_text.config(state='disabled')
            self.update_idletasks()

        def browse_input(self):
            path = filedialog.askopenfilename(
                title="Select Input LoRA File",
                filetypes=[("LoRA files", "*.safetensors *.pt *.pth"), ("All files", "*.*")]
            )
            if path:
                self.input_path.set(path)
                # Auto-suggest output path
                dir_name = os.path.dirname(path)
                base = os.path.basename(path)
                name, ext = os.path.splitext(base)
                suggested = os.path.join(dir_name, f"{name}_zimage{ext}")
                self.output_path.set(suggested)

        def browse_output(self):
            path = filedialog.asksaveasfilename(
                title="Save Converted LoRA",
                defaultextension=".safetensors",
                filetypes=[("SafeTensors", "*.safetensors"), ("PyTorch", "*.pt")]
            )
            if path:
                self.output_path.set(path)

        def update_progress(self, value, status):
            if value < 0:  # Error status
                self.status_label.config(text=status, foreground="red")
                self.log(f"[ERROR] {status}")
            else:
                self.progress['value'] = value
                self.status_label.config(text=status, foreground="black" if value < 100 else "green")
                self.log(f"[{value:.1f}%] {status}")
            self.update_idletasks()

        def start_conversion(self):
            in_path = self.input_path.get().strip()
            out_path = self.output_path.get().strip()

            if not in_path or not os.path.isfile(in_path):
                messagebox.showerror("Error", "Please select a valid input LoRA file")
                return
            if not out_path:
                messagebox.showerror("Error", "Please specify output file path")
                return

            if os.path.exists(out_path):
                if not messagebox.askyesno("File Exists", f"File already exists, overwrite?\n{out_path}"):
                    return

            self.convert_btn.config(state='disabled')
            self.log_text.config(state='normal')
            self.log_text.delete(1.0, tk.END)
            self.log_text.config(state='disabled')
            self.progress['value'] = 0
            self.update_progress(0, "Starting conversion...")

            def run():
                try:
                    converted = convert_lora(
                        in_path,
                        out_path,
                        progress_callback=self.update_progress
                    )
                    self.update_progress(100, f"Conversion complete! Merged {converted} attention layers")
                    messagebox.showinfo("Success", f"Conversion completed!\nSaved to:\n{out_path}")
                except Exception as e:
                    error_msg = ''.join(traceback.format_exc())
                    self.log(error_msg)
                    self.update_progress(-1, f"Conversion failed: {str(e)}")
                    messagebox.showerror("Conversion Failed", f"Error during conversion:\n{e}")
                finally:
                    self.convert_btn.config(state='normal')

            # Run in separate thread to prevent UI freeze
            import threading
            threading.Thread(target=run, daemon=True).start()

    app = LoRAConverterGUI()
    app.mainloop()

# ====================== MAIN ENTRY POINT ======================
if __name__ == "__main__":
    if BATCH_MODE:
        print("Running in batch mode...")
        print(f"Input directory: {INPUT_DIR}")
        print("\nIMPORTANT: This will process ALL compatible files in the directory")
        print("Output files will have '_zimage' appended to their names")
        print("-" * 60)
        
        # Safety confirmation for batch mode
        if os.name == 'nt':
            input("Press Enter to start batch conversion (Ctrl+C to cancel)...")
        else:
            try:
                input("Press Enter to start batch conversion (Ctrl+C to cancel)...")
            except KeyboardInterrupt:
                print("\nBatch conversion cancelled by user")
                sys.exit(0)
        
        batch_convert(INPUT_DIR)
    else:
        print("Running in GUI mode...")
        run_gui()