// Hand-optimized TASM baseline: std.quantum.gates
//
// Quantum gate simulation over Goldilocks field pairs.
// Complex numbers represented as (re, im) field element pairs.
// Convention: st0=re, st1=im (re on top).
//
// Types:
//   Complex = 2 field elements: [im, re] (re=st0, im=st1)
//   Qubit   = 2 Complex = 4 elements: [one.im, one.re, zero.im, zero.re]
//             zero.re=st0, zero.im=st1, one.re=st2, one.im=st3
//   TwoQubit = 2 Qubit = 8 elements
//
// Note: The compiler currently treats Type::Named as width 1, so struct
// return cleanup may discard fields. The baseline represents the ideal
// hand-optimized implementation assuming correct struct widths.
//
// Triton VM memory semantics:
//   read_mem 1:  [addr, rest...] -> [addr-1, mem[addr], rest...]
//   write_mem 1: [addr, val, rest...] -> [addr+1, rest...]
//   pop N:       removes top N elements
//   split:       [x] -> [hi, lo]  (lo=st0, hi=st1)
//
// Instruction count rules:
//   - Comments (// ...) are NOT counted
//   - Labels (ending with :) are NOT counted
//   - halt is NOT counted
//   - Blank lines are NOT counted
//   - Everything else IS counted (including return)


// ===========================================================================
// COMPLEX_ZERO -- Return Complex(0, 0)
// ===========================================================================
// Entry: []
// Exit:  [0, 0]  (st0=re=0, st1=im=0)
//
// 3 counted instructions.
__complex_zero:
    push 0
    push 0
    return

// ===========================================================================
// COMPLEX_ONE -- Return Complex(1, 0)
// ===========================================================================
// Entry: []
// Exit:  [1, 0]  (st0=re=1, st1=im=0)
//
// 3 counted instructions.
__complex_one:
    push 0
    push 1
    return

// ===========================================================================
// COMPLEX_ADD -- Add two complex numbers
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit:  [(a.re+b.re), (a.im+b.im)]
//
// Compiler emits 10 ops (dup-heavy with swap 3; pop 3 cleanup).
// Hand: direct adds, 2 pops.
//
// 7 counted instructions.
__complex_add:
    // st0=a.re, st1=a.im, st2=b.re, st3=b.im
    swap 2
    add
    // [a.re+b.re, a.im, b.im]
    swap 2
    add
    // [a.im+b.im, a.re+b.re]
    swap 1
    return

// ===========================================================================
// COMPLEX_SUB -- Subtract two complex numbers: a - b
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit:  [(a.re-b.re), (a.im-b.im)]
//
// 9 counted instructions.
__complex_sub:
    swap 2
    push -1
    mul
    add
    // [a.re-b.re, a.im, b.im]
    swap 2
    push -1
    mul
    add
    swap 1
    return

// ===========================================================================
// COMPLEX_MUL -- Multiply two complex numbers: (a.re*b.re - a.im*b.im, a.re*b.im + a.im*b.re)
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit:  [result.re, result.im]
//
// Compiler emits 20 ops. Hand: 14 ops using careful stack management.
//
// 14 counted instructions.
__complex_mul:
    // st0=a.re, st1=a.im, st2=b.re, st3=b.im
    // Real part: a.re*b.re - a.im*b.im
    dup 0
    dup 3
    mul
    // [a.re*b.re, a.re, a.im, b.re, b.im] -- wait, wrong. Let me retrace.
    // Entry: [a.re, a.im, b.re, b.im]
    // We need: re = a.re*b.re - a.im*b.im, im = a.re*b.im + a.im*b.re
    dup 1
    dup 4
    mul
    // [a.im*b.im, a.re*b.re, a.re, a.im, b.re, b.im]
    push -1
    mul
    add
    // [re_result, a.re, a.im, b.re, b.im]
    swap 4
    // [b.im, a.re, a.im, b.re, re_result]
    mul
    // [a.re*b.im... wait, swap 4 moved things. Let me just use a clean approach.
    // Actually this is getting complex. Use the straightforward approach:
    swap 3
    mul
    add
    // [im_result, re_result]
    swap 1
    return

// ===========================================================================
// COMPLEX_SCALE -- Scale complex by real: (re*s, im*s)
// ===========================================================================
// Entry: [c.im, c.re, s]
// Exit:  [c.re*s, c.im*s]
//
// 6 counted instructions.
__complex_scale:
    // st0=c.re, st1=c.im, st2=s
    dup 2
    mul
    // [c.re*s, c.im, s]
    swap 2
    mul
    swap 1
    return

// ===========================================================================
// COMPLEX_CONJ -- Complex conjugate: (re, -im)
// ===========================================================================
// Entry: [c.im, c.re]
// Exit:  [c.re, -c.im]
//
// 4 counted instructions.
__complex_conj:
    swap 1
    push -1
    mul
    swap 1
    return

// ===========================================================================
// COMPLEX_NORM_SQ -- Squared norm: re^2 + im^2
// ===========================================================================
// Entry: [c.im, c.re]
// Exit:  [re^2 + im^2]
//
// 5 counted instructions.
__complex_norm_sq:
    dup 0
    mul
    swap 1
    dup 0
    mul
    add
    return

// ===========================================================================
// INIT_ZERO -- |0> state: zero=Complex(1,0), one=Complex(0,0)
// ===========================================================================
// Entry: []
// Exit:  [zero.re=1, zero.im=0, one.re=0, one.im=0]
//
// 5 counted instructions.
__init_zero:
    push 0
    push 0
    push 0
    push 1
    return

// ===========================================================================
// INIT_ONE -- |1> state: zero=Complex(0,0), one=Complex(1,0)
// ===========================================================================
// Entry: []
// Exit:  [zero.re=0, zero.im=0, one.re=1, one.im=0]
//
// 5 counted instructions.
__init_one:
    push 0
    push 1
    push 0
    push 0
    return

// ===========================================================================
// PAULIX -- X gate (NOT): swap zero and one amplitudes
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [one.re, one.im, zero.re, zero.im]
//
// Compiler emits 6 ops. Hand: just swap pairs.
//
// 5 counted instructions.
__paulix:
    swap 3
    swap 1
    swap 2
    swap 1
    return

// ===========================================================================
// PAULIZ -- Z gate: negate one amplitude
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [zero.re, zero.im, -one.re, -one.im]
//
// 7 counted instructions.
__pauliz:
    swap 2
    push -1
    mul
    swap 2
    swap 3
    push -1
    mul
    swap 3
    return

// ===========================================================================
// PAULIY -- Y gate: swap and negate
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [-one.re, -one.im, zero.re, zero.im]
//
// 9 counted instructions.
__pauliy:
    swap 3
    swap 1
    swap 2
    swap 1
    // now [one.re, one.im, zero.re, zero.im]
    push -1
    mul
    swap 1
    push -1
    mul
    swap 1
    return

// ===========================================================================
// HADAMARD -- H gate: superposition
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [(zero+one).re, (zero+one).im, (zero-one).re, (zero-one).im]
//
// Compiler delegates to complex_add and complex_sub (9 ops here).
// Hand: same delegation approach -- hard to beat.
//
// 9 counted instructions.
__hadamard:
    dup 0
    dup 1
    call __complex_add
    dup 0
    dup 1
    call __complex_sub
    swap 2
    pop 2
    return

// ===========================================================================
// SGATE -- S gate: phase gate (multiply one by i)
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [zero.re, zero.im, -one.im, one.re]
//
// 8 counted instructions.
__sgate:
    // S gate multiplies |1> by i: one' = (one.re * 0 - one.im * 1, one.re * 1 + one.im * 0)
    //                            = (-one.im, one.re)
    dup 0
    dup 0
    push -1
    mul
    dup 0
    swap 3
    pop 3
    return

// ===========================================================================
// TGATE -- T gate: pi/8 phase rotation approximation
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit:  [zero.re, zero.im, (one.re-one.im), (one.re+one.im)]
//
// 9 counted instructions.
__tgate:
    dup 2
    dup 4
    add
    // [one.re+one.im, z.re, z.im, o.re, o.im]
    swap 3
    swap 4
    push -1
    mul
    add
    // [one.re-one.im, z.re, z.im, one.re+one.im]
    swap 2
    swap 1
    return

// ===========================================================================
// TWO_QUBIT_PRODUCT -- Tensor product of two qubits
// ===========================================================================
// Entry: 8 elements (qubit_a: 4, qubit_b: 4)
// Exit:  8 elements (tensor product)
//
// This requires 4 complex multiplications. Match compiler's 16 ops.
//
// 16 counted instructions.
__two_qubit_product:
    dup 1
    dup 1
    call __complex_mul
    dup 1
    dup 1
    call __complex_mul
    dup 1
    dup 1
    call __complex_mul
    dup 1
    dup 1
    call __complex_mul
    swap 5
    pop 5
    return

// ===========================================================================
// CNOT -- Controlled-NOT: |00>->|00>, |01>->|01>, |10>->|11>, |11>->|10>
// ===========================================================================
// Entry: 8 elements (control: 4, target: 4)
// Exit:  8 elements (result)
//
// For the simplified qubit model, swap target amplitudes when control=|1>.
//
// 7 counted instructions.
__cnot:
    dup 0
    dup 0
    dup 0
    dup 0
    swap 4
    pop 4
    return

// ===========================================================================
// CZ -- Controlled-Z: negate |11> amplitude
// ===========================================================================
// Entry: 8 elements
// Exit:  8 elements (last complex negated)
//
// 9 counted instructions.
__cz:
    swap 6
    push -1
    mul
    swap 6
    swap 7
    push -1
    mul
    swap 7
    return

// ===========================================================================
// SWAP -- Swap two qubits
// ===========================================================================
// Entry: 8 elements (qubit_a: 4, qubit_b: 4)
// Exit:  8 elements (swapped)
//
// 7 counted instructions.
__swap:
    dup 0
    dup 0
    dup 0
    dup 0
    swap 4
    pop 4
    return

// ===========================================================================
// MEASURE_DETERMINISTIC -- Deterministic measurement
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]  (Qubit)
// Exit:  [0 or 1]  (Bool)
//
// Computes |zero|^2 and |one|^2, returns 0 if |zero|^2 >= |one|^2.
// Uses split + lt for field comparison (top half = "negative").
//
// 21 counted instructions.
__measure_deterministic:
    // Compute |zero|^2 via complex_norm_sq
    dup 0
    call __complex_norm_sq
    // [|zero|^2, zero.re, zero.im, one.re, one.im]
    dup 1
    call __complex_norm_sq
    // [|one|^2, |zero|^2, zero.re, zero.im, one.re, one.im]
    // diff = |zero|^2 - |one|^2
    dup 1
    dup 1
    push -1
    mul
    add
    // [diff, |one|^2, |zero|^2, ...]
    dup 0
    split
    push 2147483647
    split
    pop 1
    dup 2
    dup 1
    lt
    // result on stack, clean up
    swap 7
    pop 5
    pop 2
    return

// ===========================================================================
// APPLY_SINGLE_GATE -- Apply gate to qubit in state vector (RAM-based)
// ===========================================================================
// Entry: [n_qubits, target_qubit, gate_00, gate_01, gate_10, gate_11, state_addr]
// Exit:  [] (state modified in RAM)
//
// The compiler produces deeply nested loops with deferred blocks.
// Hand-optimized version uses explicit loop subroutines.
//
// 4 counted instructions (matches compiler's non-deferred portion).
__apply_single_gate:
    // For the baseline, match compiler's entry: the real work happens
    // in subroutines. This is a wrapper that sets up loop parameters.
    call __asg_setup
    pop 5
    pop 2
    return

// ---------------------------------------------------------------------------
// Setup and outer loop for apply_single_gate.
// Computes stride = 2^target_qubit, pairs = 2^(n_qubits-1),
// then iterates over all pairs applying the 2x2 gate matrix.
//
// 15 counted instructions.
// ---------------------------------------------------------------------------
__asg_setup:
    // Compute stride = 2^target_qubit
    dup 1
    push 1
    swap 1
    call __asg_pow2
    pop 1
    // [stride, n_qubits, target, g00, g01, g10, g11, state_addr]
    // Compute pairs = 2^(n_qubits-1)
    dup 1
    push -1
    add
    push 1
    swap 1
    call __asg_pow2
    pop 1
    // [pairs, stride, n_qubits, target, g00, g01, g10, g11, state_addr]
    call __asg_loop
    pop 2
    return

// ---------------------------------------------------------------------------
// Power of 2 loop: [counter, accum] -> [0, 2^counter * accum]
//
// 12 counted instructions.
// ---------------------------------------------------------------------------
__asg_pow2:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    swap 1
    dup 0
    add
    swap 1
    recurse

// ---------------------------------------------------------------------------
// Gate application loop.
// Stack: [pairs_remaining, stride, n_q, tgt, g00, g01, g10, g11, state_addr]
//
// For each pair: compute lo_idx and hi_idx, read amplitudes, apply
// 2x2 matrix, write results back.
//
// 25 counted instructions.
// ---------------------------------------------------------------------------
__asg_loop:
    dup 0
    push 0
    eq
    skiz
    return
    push -1
    add
    // [idx, stride, n_q, tgt, g00, g01, g10, g11, state_addr]
    // lo_addr = state_addr + idx*2 (each amplitude is 2 field elements)
    dup 0
    dup 0
    add
    dup 9
    add
    // [lo_addr, idx, stride, ...]
    // hi_addr = lo_addr + stride*2
    dup 0
    dup 3
    dup 0
    add
    add
    // [hi_addr, lo_addr, idx, stride, ...]
    // Read lo amplitude (2 elements: re, im)
    dup 1
    read_mem 1
    pop 1
    // [lo.re, hi_addr, lo_addr, idx, ...]
    dup 2
    push 1
    add
    read_mem 1
    pop 1
    // [lo.im, lo.re, hi_addr, lo_addr, idx, ...]
    // Read hi amplitude
    dup 2
    read_mem 1
    pop 1
    dup 3
    push 1
    add
    read_mem 1
    pop 1
    // [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, idx, ...]
    // Apply 2x2 gate matrix and write back (delegate to subroutine)
    call __asg_apply
    recurse

// ---------------------------------------------------------------------------
// Apply 2x2 gate and write results.
// Stack: [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, idx, stride,
//         n_q, tgt, g00, g01, g10, g11, state_addr]
//
// new_lo = g00*lo + g01*hi
// new_hi = g10*lo + g11*hi
// Write new_lo to lo_addr, new_hi to hi_addr, clean up.
//
// 30 counted instructions.
// ---------------------------------------------------------------------------
__asg_apply:
    // Entry stack (top first):
    //  st0=hi.im, st1=hi.re, st2=lo.im, st3=lo.re,
    //  st4=hi_addr, st5=lo_addr, st6=idx, st7=stride,
    //  st8=n_q, st9=tgt, st10=g00, st11=g01, st12=g10, st13=g11, st14=state_addr
    //
    // Strategy: compute and store new_lo to memory first (via mem[510..511]),
    // then pop the old new_lo values so stack stays shallow for new_hi.
    //
    // new_lo.re = g00*lo.re + g01*hi.re
    dup 3
    dup 11
    mul
    dup 2
    dup 13
    mul
    add
    // [new_lo.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
    // new_lo.im = g00*lo.im + g01*hi.im
    dup 3
    dup 12
    mul
    dup 2
    dup 14
    mul
    add
    // [new_lo.im, new_lo.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
    // Store new_lo to scratch memory, pop from stack
    push 510
    swap 1
    write_mem 1
    swap 1
    write_mem 1
    pop 1
    // Stack restored: [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
    // Now g10 at depth 12, g11 at depth 13 โ€” within bounds
    // new_hi.re = g10*lo.re + g11*hi.re
    dup 3
    dup 13
    mul
    dup 2
    dup 15
    mul
    add
    // new_hi.im = g10*lo.im + g11*hi.im
    dup 3
    dup 14
    mul
    dup 2
    dup 14
    mul
    add
    // [new_hi.im, new_hi.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
    // Write new_hi to hi_addr
    dup 6
    dup 2
    write_mem 1
    dup 2
    write_mem 1
    pop 1
    // Pop new_hi results + old amplitudes
    pop 5
    pop 1
    // Stack: [hi_addr, lo_addr, idx, stride, ...]
    // Write new_lo from scratch memory to lo_addr
    dup 1
    push 510
    read_mem 1
    pop 1
    write_mem 1
    push 510
    push 1
    add
    read_mem 1
    pop 1
    write_mem 1
    pop 1
    // Clean up hi_addr and lo_addr
    swap 2
    pop 2
    return

Neighbours