#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import time
import ctypes
from ctypes import *

# =================================================
# DLL Loading
# =================================================
_here = os.path.dirname(os.path.abspath(__file__))
_dll  = os.path.join(_here, "CH347DLLA64.dll")
ch347 = WinDLL(_dll)

# =================================================
# SPI Configuration Structure (based on CH347 SPI Init common definition)
# =================================================
class SPI_CONFIG(Structure):
    _fields_ = [
        ("iMode", c_ubyte),                # 0..3
        ("iClock", c_ubyte),               # 0..n, different values correspond to different frequency settings
        ("iByteOrder", c_ubyte),           # 0=LSB, 1=MSB
        ("iSpiWriteReadInterval", c_ushort),
        ("iSpiOutDefaultData", c_ubyte),
        ("iChipSelect", c_ulong),          # 0x80=CS1, 0x40=CS2 (common)
        ("CS1Polarity", c_ubyte),
        ("CS2Polarity", c_ubyte),
        ("iIsAutoDeativeCS", c_ushort),    # 1=auto deassert CS
        ("iActiveDelay", c_ushort),
        ("iDelayDeactive", c_ulong),
    ]

# =================================================
# Function Prototypes (ctypes)
# =================================================
ch347.CH347OpenDevice.argtypes  = [c_ulong]
ch347.CH347OpenDevice.restype   = c_long

ch347.CH347CloseDevice.argtypes = [c_ulong]
ch347.CH347CloseDevice.restype  = None

ch347.CH347SPI_Init.argtypes    = [c_ulong, POINTER(SPI_CONFIG)]
ch347.CH347SPI_Init.restype     = c_int

# int CH347SPI_WriteRead(ULONG iIndex, ULONG iChipSelect, ULONG iLength, PVOID ioBuffer)
# Some examples use (handle, ?, cs, len, buf); your Linux version uses 5 parameters.
# Windows DLL commonly uses 4 parameters: iIndex/cs/len/buf. Using 4 here.
# If your DLL version uses 5 parameters, adjust as commented.
ch347.CH347SPI_WriteRead.argtypes = [c_ulong, c_ulong, c_ulong, c_void_p]
ch347.CH347SPI_WriteRead.restype  = c_int

# =================================================
# Basic Parameters
# =================================================
USB_ID = 0          # Change to 0/1/2... for multiple devices
CS1    = 0x80       # CS1
# If you're actually using CS2, change to: CS1 = 0x40

# =================================================
# CH347 SPI Wrapper
# =================================================
class CH347SPI:
    def __init__(self, usb_id=USB_ID):
        self.id = c_ulong(usb_id)
        h = ch347.CH347OpenDevice(self.id)
        if h == -1:
            raise RuntimeError("CH347OpenDevice failed (driver not installed / no device)")
        self.handle = self.id  # Windows DLL mostly uses iIndex, not actual handle

    def init(self):
        cfg = SPI_CONFIG()
        cfg.iMode = 0x00              # Mode 0
        cfg.iClock = 0x02             # Lower speed for stability (you can increase)
        cfg.iByteOrder = 0x01         # MSB first
        cfg.iSpiWriteReadInterval = 0
        cfg.iSpiOutDefaultData = 0xFF
        cfg.iChipSelect = CS1         # CS1/CS2
        cfg.CS1Polarity = 0
        cfg.CS2Polarity = 0
        cfg.iIsAutoDeativeCS = 1      # Auto deassert CS (critical for W25Q)
        cfg.iActiveDelay = 0
        cfg.iDelayDeactive = 0

        r = ch347.CH347SPI_Init(self.handle, byref(cfg))
        if r != 1:
            raise RuntimeError("CH347SPI_Init failed")
        print("SPI init OK")

    def xfer(self, tx: bytes) -> bytes:
        buf = (c_ubyte * len(tx))(*tx)
        r = ch347.CH347SPI_WriteRead(self.handle, CS1, len(tx), byref(buf))
        if r != 1:
            raise RuntimeError("CH347SPI_WriteRead failed")
        return bytes(buf)

    def close(self):
        ch347.CH347CloseDevice(self.handle)

# =================================================
# W25Q Commands
# =================================================
def w25q_read_id(dev: CH347SPI) -> bytes:
    r = dev.xfer(b"\x9F\x00\x00\x00")
    return r[1:4]   # EF 40 18 for W25Q128JV

def w25q_sr1(dev: CH347SPI) -> int:
    return dev.xfer(b"\x05\x00")[1]

def w25q_wait_ready(dev: CH347SPI, timeout_s=5.0):
    t0 = time.time()
    while True:
        if (w25q_sr1(dev) & 0x01) == 0:
            return
        if time.time() - t0 > timeout_s:
            raise TimeoutError("W25Q busy timeout (check power=3.3V, wiring, CS)")
        time.sleep(0.01)

def w25q_we(dev: CH347SPI):
    dev.xfer(b"\x06")

def w25q_sector_erase_4k(dev: CH347SPI, addr: int):
    w25q_we(dev)
    cmd = bytes([0x20, (addr >> 16) & 0xFF, (addr >> 8) & 0xFF, addr & 0xFF])
    dev.xfer(cmd)
    w25q_wait_ready(dev, timeout_s=20.0)

def w25q_page_program(dev: CH347SPI, addr: int, data: bytes):
    if len(data) > 256:
        raise ValueError("page program max 256 bytes")
    # Don't cross page boundaries (safer)
    if (addr & 0xFF) + len(data) > 256:
        raise ValueError("cross-page write not allowed in this demo")
    w25q_we(dev)
    cmd = bytes([0x02, (addr >> 16) & 0xFF, (addr >> 8) & 0xFF, addr & 0xFF]) + data
    dev.xfer(cmd)
    w25q_wait_ready(dev, timeout_s=5.0)

def w25q_read(dev: CH347SPI, addr: int, n: int) -> bytes:
    cmd = bytes([0x03, (addr >> 16) & 0xFF, (addr >> 8) & 0xFF, addr & 0xFF]) + (b"\x00" * n)
    r = dev.xfer(cmd)
    return r[4:]

# =================================================
# Main Program: Erase -> Write "hello" -> Read Back and Verify
# =================================================
def main():
    dev = CH347SPI(USB_ID)
    try:
        dev.init()

        jedec = w25q_read_id(dev)
        print("JEDEC ID:", jedec)
        if jedec != b"\xEF\x40\x18":
            print("Warning: not W25Q128JV (expected EF 40 18). Still continue...")

        addr = 0x000000
        print("Before erase:", w25q_read(dev, addr, 8))

        print("Erasing 4KB sector...")
        w25q_sector_erase_4k(dev, addr)
        print("After erase :", w25q_read(dev, addr, 3))  # Should be all FF

        msg = b"When I was a child, I dug a hole in the wall to steal the light from someone's home"
        print("Write:", msg)
        w25q_page_program(dev, addr, msg)

        rb = w25q_read(dev, addr, len(msg))
        print("Read :", rb)
        print("PASS" if rb == msg else "FAIL")

    finally:
        dev.close()

if __name__ == "__main__":
    main()