import struct
from wasmtime import Store, Module, Instance, Linker, FuncType, ValType
from PIL import Image
from io import BytesIO


class DRMWasm:
    def __init__(self, wasm_path):
        self.store = Store()
        self.objects = {0: None}
        self.next_obj_id = 1
        self.object_count = 0
        self.collecting_bytes = False
        self.real_img_bytes = bytearray()

        self.linker = Linker(self.store.engine)

        self.linker.define_func("./drmwasm_bg.js", "__wbindgen_object_drop_ref",
                          FuncType([ValType.i32()], []), self._object_drop_ref)
        self.linker.define_func("./drmwasm_bg.js", "__wbindgen_number_new",
                          FuncType([ValType.f64()], [ValType.i32()]), self._number_new)
        self.linker.define_func("./drmwasm_bg.js", "__wbindgen_object_clone_ref",
                          FuncType([ValType.i32()], [ValType.i32()]), self._object_clone_ref)
        self.linker.define_func("./drmwasm_bg.js", "__wbindgen_string_new",
                          FuncType([ValType.i32(), ValType.i32()], [ValType.i32()]), self._string_new)
        self.linker.define_func("./drmwasm_bg.js", "__wbg_set_f975102236d3c502",
                          FuncType([ValType.i32(), ValType.i32(), ValType.i32()], []), self._obj_set)
        self.linker.define_func("./drmwasm_bg.js", "__wbg_new_16b304a2cfa7ff4a",
                          FuncType([], [ValType.i32()]), self._new_object)
        self.linker.define_func("./drmwasm_bg.js", "__wbg_new_72fb9a18b5ae2624",
                          FuncType([], [ValType.i32()]), self._new_uint8array)
        self.linker.define_func("./drmwasm_bg.js", "__wbg_set_d4638f722068f043",
                          FuncType([ValType.i32(), ValType.i32(), ValType.i32()], []), self._uint8array_set)
        self.linker.define_func("./drmwasm_bg.js", "__wbindgen_throw",
                          FuncType([ValType.i32(), ValType.i32()], []), self._throw)

        self.module = Module.from_file(self.store.engine, wasm_path)
        self.instance = self.linker.instantiate(self.store, self.module)

        self.memory = self.instance.exports(self.store)["memory"]
        self.malloc = self.instance.exports(self.store)["__wbindgen_malloc"]
        self.add_to_stack_pointer = self.instance.exports(self.store)["__wbindgen_add_to_stack_pointer"]
        self.get_replacement_jpeg_fn = self.instance.exports(self.store)["get_replacement_jpeg"]
        self.get_replacement_webp_fn = self.instance.exports(self.store)["get_replacement_webp"]

    def _add_object(self, obj):
        idx = self.next_obj_id
        self.objects[idx] = obj
        self.next_obj_id += 1
        return idx

    def _object_drop_ref(self, idx):
        pass

    def _number_new(self, val):
        if self.collecting_bytes:
            self.real_img_bytes.append(int(val) & 0xFF)
        return self._add_object(val)

    def _object_clone_ref(self, idx):
        return self._add_object(self.objects.get(idx))

    def _string_new(self, ptr, length):
        data = self.memory.data_ptr(self.store)[ptr:ptr+length]
        s = bytes(data).decode('utf-8')
        if s == 'real_img' and self.collecting_bytes:
            self.collecting_bytes = False
        return self._add_object(s)

    def _obj_set(self, obj_idx, key_idx, val_idx):
        obj = self.objects.get(obj_idx)
        key = self.objects.get(key_idx)
        val = self.objects.get(val_idx)
        if isinstance(obj, dict) and isinstance(key, str):
            obj[key] = val

    def _new_object(self):
        self.object_count += 1
        if self.object_count == 2:
            self.collecting_bytes = True
        return self._add_object({})

    def _new_uint8array(self):
        self.object_count += 1
        if self.object_count == 2:
            self.collecting_bytes = True
        return self._add_object({})

    def _uint8array_set(self, arr_idx, index_idx, value_idx):
        pass

    def _throw(self, ptr, length):
        data = self.memory.data_ptr(self.store)[ptr:ptr+length]
        raise Exception(f"WASM error: {bytes(data).decode('utf-8')}")

    def _write_bytes(self, data):
        ptr = self.malloc(self.store, len(data), 1)
        mem = self.memory.data_ptr(self.store)
        for i, b in enumerate(data):
            mem[ptr + i] = b
        return ptr, len(data)

    def _read_i32(self, ptr):
        mem = self.memory.data_ptr(self.store)
        return struct.unpack('<i', bytes(mem[ptr:ptr+4]))[0]

    def _reset_state(self):
        """Reset state for a new image."""
        self.objects = {0: None}
        self.next_obj_id = 1
        self.object_count = 0
        self.collecting_bytes = False
        self.real_img_bytes = bytearray()

        # Re-instantiate WASM module
        self.instance = self.linker.instantiate(self.store, self.module)
        self.memory = self.instance.exports(self.store)["memory"]
        self.malloc = self.instance.exports(self.store)["__wbindgen_malloc"]
        self.add_to_stack_pointer = self.instance.exports(self.store)["__wbindgen_add_to_stack_pointer"]
        self.get_replacement_jpeg_fn = self.instance.exports(self.store)["get_replacement_jpeg"]
        self.get_replacement_webp_fn = self.instance.exports(self.store)["get_replacement_webp"]

    def get_replacement_image(self, image_bytes, chmkeys):
        self._reset_state()

        is_webp = image_bytes[:4] == b'RIFF' and image_bytes[8:12] == b'WEBP'
        is_jpeg = image_bytes[:2] == b'\xff\xd8'

        if is_webp:
            decoder_fn = self.get_replacement_webp_fn
            img_type = 'image/webp'
        elif is_jpeg:
            decoder_fn = self.get_replacement_jpeg_fn
            img_type = 'image/jpeg'
        else:
            raise ValueError(f"Unknown image format")

        # Reset state
        self.objects = {0: None}
        self.next_obj_id = 1
        self.object_count = 0
        self.collecting_bytes = False
        self.real_img_bytes = bytearray()

        result_ptr = self.add_to_stack_pointer(self.store, -16)
        img_ptr, img_len = self._write_bytes(image_bytes)
        key_ptr, key_len = self._write_bytes(chmkeys)

        decoder_fn(self.store, result_ptr, img_ptr, img_len, key_ptr, key_len)

        r0 = self._read_i32(result_ptr)
        r1 = self._read_i32(result_ptr + 4)
        r2 = self._read_i32(result_ptr + 8)

        self.add_to_stack_pointer(self.store, 16)

        if r2:
            raise RuntimeError(f"Decoder error (r1={r1})")

        result_obj = self.objects.get(r0)

        if not isinstance(result_obj, dict):
            raise RuntimeError(f"Result is not a dict: {type(result_obj)}")

        if len(self.real_img_bytes) == 0:
            raise RuntimeError("Failed to extract real_img data")

        return {
            'x': int(result_obj.get('x', 0)),
            'y': int(result_obj.get('y', 0)),
            'width': int(result_obj.get('width', 0)),
            'height': int(result_obj.get('height', 0)),
            'type': img_type,
            'real_img': bytes(self.real_img_bytes),
        }


wasm = DRMWasm('drmwasm_bg.wasm')

def process_image(image_bytes, chmkeys):
    """Process image bytes and return unwatermarked image bytes."""
    patch = wasm.get_replacement_image(image_bytes, chmkeys)

    original = Image.open(BytesIO(image_bytes))
    patch_img = Image.open(BytesIO(patch['real_img']))
    original.paste(patch_img, (patch['x'], patch['y']))

    output = BytesIO()
    original.save(output, format=original.format, lossless=True)

    return output.getvalue()