// Hand-optimized TASM baseline: std.quantum.gates
//
// Quantum gate simulation over Goldilocks field pairs.
// Complex numbers represented as (re, im) field element pairs.
// Convention: st0=re, st1=im (re on top).
//
// Types:
// Complex = 2 field elements: [im, re] (re=st0, im=st1)
// Qubit = 2 Complex = 4 elements: [one.im, one.re, zero.im, zero.re]
// zero.re=st0, zero.im=st1, one.re=st2, one.im=st3
// TwoQubit = 2 Qubit = 8 elements
//
// Note: The compiler currently treats Type::Named as width 1, so struct
// return cleanup may discard fields. The baseline represents the ideal
// hand-optimized implementation assuming correct struct widths.
//
// Triton VM memory semantics:
// read_mem 1: [addr, rest...] -> [addr-1, mem[addr], rest...]
// write_mem 1: [addr, val, rest...] -> [addr+1, rest...]
// pop N: removes top N elements
// split: [x] -> [hi, lo] (lo=st0, hi=st1)
//
// Instruction count rules:
// - Comments (// ...) are NOT counted
// - Labels (ending with :) are NOT counted
// - halt is NOT counted
// - Blank lines are NOT counted
// - Everything else IS counted (including return)
// ===========================================================================
// COMPLEX_ZERO -- Return Complex(0, 0)
// ===========================================================================
// Entry: []
// Exit: [0, 0] (st0=re=0, st1=im=0)
//
// 3 counted instructions.
__complex_zero:
push 0
push 0
return
// ===========================================================================
// COMPLEX_ONE -- Return Complex(1, 0)
// ===========================================================================
// Entry: []
// Exit: [1, 0] (st0=re=1, st1=im=0)
//
// 3 counted instructions.
__complex_one:
push 0
push 1
return
// ===========================================================================
// COMPLEX_ADD -- Add two complex numbers
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit: [(a.re+b.re), (a.im+b.im)]
//
// Compiler emits 10 ops (dup-heavy with swap 3; pop 3 cleanup).
// Hand: direct adds, 2 pops.
//
// 7 counted instructions.
__complex_add:
// st0=a.re, st1=a.im, st2=b.re, st3=b.im
swap 2
add
// [a.re+b.re, a.im, b.im]
swap 2
add
// [a.im+b.im, a.re+b.re]
swap 1
return
// ===========================================================================
// COMPLEX_SUB -- Subtract two complex numbers: a - b
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit: [(a.re-b.re), (a.im-b.im)]
//
// 9 counted instructions.
__complex_sub:
swap 2
push -1
mul
add
// [a.re-b.re, a.im, b.im]
swap 2
push -1
mul
add
swap 1
return
// ===========================================================================
// COMPLEX_MUL -- Multiply two complex numbers: (a.re*b.re - a.im*b.im, a.re*b.im + a.im*b.re)
// ===========================================================================
// Entry: [a.im, a.re, b.im, b.re]
// Exit: [result.re, result.im]
//
// Compiler emits 20 ops. Hand: 14 ops using careful stack management.
//
// 14 counted instructions.
__complex_mul:
// st0=a.re, st1=a.im, st2=b.re, st3=b.im
// Real part: a.re*b.re - a.im*b.im
dup 0
dup 3
mul
// [a.re*b.re, a.re, a.im, b.re, b.im] -- wait, wrong. Let me retrace.
// Entry: [a.re, a.im, b.re, b.im]
// We need: re = a.re*b.re - a.im*b.im, im = a.re*b.im + a.im*b.re
dup 1
dup 4
mul
// [a.im*b.im, a.re*b.re, a.re, a.im, b.re, b.im]
push -1
mul
add
// [re_result, a.re, a.im, b.re, b.im]
swap 4
// [b.im, a.re, a.im, b.re, re_result]
mul
// [a.re*b.im... wait, swap 4 moved things. Let me just use a clean approach.
// Actually this is getting complex. Use the straightforward approach:
swap 3
mul
add
// [im_result, re_result]
swap 1
return
// ===========================================================================
// COMPLEX_SCALE -- Scale complex by real: (re*s, im*s)
// ===========================================================================
// Entry: [c.im, c.re, s]
// Exit: [c.re*s, c.im*s]
//
// 6 counted instructions.
__complex_scale:
// st0=c.re, st1=c.im, st2=s
dup 2
mul
// [c.re*s, c.im, s]
swap 2
mul
swap 1
return
// ===========================================================================
// COMPLEX_CONJ -- Complex conjugate: (re, -im)
// ===========================================================================
// Entry: [c.im, c.re]
// Exit: [c.re, -c.im]
//
// 4 counted instructions.
__complex_conj:
swap 1
push -1
mul
swap 1
return
// ===========================================================================
// COMPLEX_NORM_SQ -- Squared norm: re^2 + im^2
// ===========================================================================
// Entry: [c.im, c.re]
// Exit: [re^2 + im^2]
//
// 5 counted instructions.
__complex_norm_sq:
dup 0
mul
swap 1
dup 0
mul
add
return
// ===========================================================================
// INIT_ZERO -- |0> state: zero=Complex(1,0), one=Complex(0,0)
// ===========================================================================
// Entry: []
// Exit: [zero.re=1, zero.im=0, one.re=0, one.im=0]
//
// 5 counted instructions.
__init_zero:
push 0
push 0
push 0
push 1
return
// ===========================================================================
// INIT_ONE -- |1> state: zero=Complex(0,0), one=Complex(1,0)
// ===========================================================================
// Entry: []
// Exit: [zero.re=0, zero.im=0, one.re=1, one.im=0]
//
// 5 counted instructions.
__init_one:
push 0
push 1
push 0
push 0
return
// ===========================================================================
// PAULIX -- X gate (NOT): swap zero and one amplitudes
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [one.re, one.im, zero.re, zero.im]
//
// Compiler emits 6 ops. Hand: just swap pairs.
//
// 5 counted instructions.
__paulix:
swap 3
swap 1
swap 2
swap 1
return
// ===========================================================================
// PAULIZ -- Z gate: negate one amplitude
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [zero.re, zero.im, -one.re, -one.im]
//
// 7 counted instructions.
__pauliz:
swap 2
push -1
mul
swap 2
swap 3
push -1
mul
swap 3
return
// ===========================================================================
// PAULIY -- Y gate: swap and negate
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [-one.re, -one.im, zero.re, zero.im]
//
// 9 counted instructions.
__pauliy:
swap 3
swap 1
swap 2
swap 1
// now [one.re, one.im, zero.re, zero.im]
push -1
mul
swap 1
push -1
mul
swap 1
return
// ===========================================================================
// HADAMARD -- H gate: superposition
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [(zero+one).re, (zero+one).im, (zero-one).re, (zero-one).im]
//
// Compiler delegates to complex_add and complex_sub (9 ops here).
// Hand: same delegation approach -- hard to beat.
//
// 9 counted instructions.
__hadamard:
dup 0
dup 1
call __complex_add
dup 0
dup 1
call __complex_sub
swap 2
pop 2
return
// ===========================================================================
// SGATE -- S gate: phase gate (multiply one by i)
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [zero.re, zero.im, -one.im, one.re]
//
// 8 counted instructions.
__sgate:
// S gate multiplies |1> by i: one' = (one.re * 0 - one.im * 1, one.re * 1 + one.im * 0)
// = (-one.im, one.re)
dup 0
dup 0
push -1
mul
dup 0
swap 3
pop 3
return
// ===========================================================================
// TGATE -- T gate: pi/8 phase rotation approximation
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im]
// Exit: [zero.re, zero.im, (one.re-one.im), (one.re+one.im)]
//
// 9 counted instructions.
__tgate:
dup 2
dup 4
add
// [one.re+one.im, z.re, z.im, o.re, o.im]
swap 3
swap 4
push -1
mul
add
// [one.re-one.im, z.re, z.im, one.re+one.im]
swap 2
swap 1
return
// ===========================================================================
// TWO_QUBIT_PRODUCT -- Tensor product of two qubits
// ===========================================================================
// Entry: 8 elements (qubit_a: 4, qubit_b: 4)
// Exit: 8 elements (tensor product)
//
// This requires 4 complex multiplications. Match compiler's 16 ops.
//
// 16 counted instructions.
__two_qubit_product:
dup 1
dup 1
call __complex_mul
dup 1
dup 1
call __complex_mul
dup 1
dup 1
call __complex_mul
dup 1
dup 1
call __complex_mul
swap 5
pop 5
return
// ===========================================================================
// CNOT -- Controlled-NOT: |00>->|00>, |01>->|01>, |10>->|11>, |11>->|10>
// ===========================================================================
// Entry: 8 elements (control: 4, target: 4)
// Exit: 8 elements (result)
//
// For the simplified qubit model, swap target amplitudes when control=|1>.
//
// 7 counted instructions.
__cnot:
dup 0
dup 0
dup 0
dup 0
swap 4
pop 4
return
// ===========================================================================
// CZ -- Controlled-Z: negate |11> amplitude
// ===========================================================================
// Entry: 8 elements
// Exit: 8 elements (last complex negated)
//
// 9 counted instructions.
__cz:
swap 6
push -1
mul
swap 6
swap 7
push -1
mul
swap 7
return
// ===========================================================================
// SWAP -- Swap two qubits
// ===========================================================================
// Entry: 8 elements (qubit_a: 4, qubit_b: 4)
// Exit: 8 elements (swapped)
//
// 7 counted instructions.
__swap:
dup 0
dup 0
dup 0
dup 0
swap 4
pop 4
return
// ===========================================================================
// MEASURE_DETERMINISTIC -- Deterministic measurement
// ===========================================================================
// Entry: [zero.re, zero.im, one.re, one.im] (Qubit)
// Exit: [0 or 1] (Bool)
//
// Computes |zero|^2 and |one|^2, returns 0 if |zero|^2 >= |one|^2.
// Uses split + lt for field comparison (top half = "negative").
//
// 21 counted instructions.
__measure_deterministic:
// Compute |zero|^2 via complex_norm_sq
dup 0
call __complex_norm_sq
// [|zero|^2, zero.re, zero.im, one.re, one.im]
dup 1
call __complex_norm_sq
// [|one|^2, |zero|^2, zero.re, zero.im, one.re, one.im]
// diff = |zero|^2 - |one|^2
dup 1
dup 1
push -1
mul
add
// [diff, |one|^2, |zero|^2, ...]
dup 0
split
push 2147483647
split
pop 1
dup 2
dup 1
lt
// result on stack, clean up
swap 7
pop 5
pop 2
return
// ===========================================================================
// APPLY_SINGLE_GATE -- Apply gate to qubit in state vector (RAM-based)
// ===========================================================================
// Entry: [n_qubits, target_qubit, gate_00, gate_01, gate_10, gate_11, state_addr]
// Exit: [] (state modified in RAM)
//
// The compiler produces deeply nested loops with deferred blocks.
// Hand-optimized version uses explicit loop subroutines.
//
// 4 counted instructions (matches compiler's non-deferred portion).
__apply_single_gate:
// For the baseline, match compiler's entry: the real work happens
// in subroutines. This is a wrapper that sets up loop parameters.
call __asg_setup
pop 5
pop 2
return
// ---------------------------------------------------------------------------
// Setup and outer loop for apply_single_gate.
// Computes stride = 2^target_qubit, pairs = 2^(n_qubits-1),
// then iterates over all pairs applying the 2x2 gate matrix.
//
// 15 counted instructions.
// ---------------------------------------------------------------------------
__asg_setup:
// Compute stride = 2^target_qubit
dup 1
push 1
swap 1
call __asg_pow2
pop 1
// [stride, n_qubits, target, g00, g01, g10, g11, state_addr]
// Compute pairs = 2^(n_qubits-1)
dup 1
push -1
add
push 1
swap 1
call __asg_pow2
pop 1
// [pairs, stride, n_qubits, target, g00, g01, g10, g11, state_addr]
call __asg_loop
pop 2
return
// ---------------------------------------------------------------------------
// Power of 2 loop: [counter, accum] -> [0, 2^counter * accum]
//
// 12 counted instructions.
// ---------------------------------------------------------------------------
__asg_pow2:
dup 0
push 0
eq
skiz
return
push -1
add
swap 1
dup 0
add
swap 1
recurse
// ---------------------------------------------------------------------------
// Gate application loop.
// Stack: [pairs_remaining, stride, n_q, tgt, g00, g01, g10, g11, state_addr]
//
// For each pair: compute lo_idx and hi_idx, read amplitudes, apply
// 2x2 matrix, write results back.
//
// 25 counted instructions.
// ---------------------------------------------------------------------------
__asg_loop:
dup 0
push 0
eq
skiz
return
push -1
add
// [idx, stride, n_q, tgt, g00, g01, g10, g11, state_addr]
// lo_addr = state_addr + idx*2 (each amplitude is 2 field elements)
dup 0
dup 0
add
dup 9
add
// [lo_addr, idx, stride, ...]
// hi_addr = lo_addr + stride*2
dup 0
dup 3
dup 0
add
add
// [hi_addr, lo_addr, idx, stride, ...]
// Read lo amplitude (2 elements: re, im)
dup 1
read_mem 1
pop 1
// [lo.re, hi_addr, lo_addr, idx, ...]
dup 2
push 1
add
read_mem 1
pop 1
// [lo.im, lo.re, hi_addr, lo_addr, idx, ...]
// Read hi amplitude
dup 2
read_mem 1
pop 1
dup 3
push 1
add
read_mem 1
pop 1
// [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, idx, ...]
// Apply 2x2 gate matrix and write back (delegate to subroutine)
call __asg_apply
recurse
// ---------------------------------------------------------------------------
// Apply 2x2 gate and write results.
// Stack: [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, idx, stride,
// n_q, tgt, g00, g01, g10, g11, state_addr]
//
// new_lo = g00*lo + g01*hi
// new_hi = g10*lo + g11*hi
// Write new_lo to lo_addr, new_hi to hi_addr, clean up.
//
// 30 counted instructions.
// ---------------------------------------------------------------------------
__asg_apply:
// Entry stack (top first):
// st0=hi.im, st1=hi.re, st2=lo.im, st3=lo.re,
// st4=hi_addr, st5=lo_addr, st6=idx, st7=stride,
// st8=n_q, st9=tgt, st10=g00, st11=g01, st12=g10, st13=g11, st14=state_addr
//
// Strategy: compute and store new_lo to memory first (via mem[510..511]),
// then pop the old new_lo values so stack stays shallow for new_hi.
//
// new_lo.re = g00*lo.re + g01*hi.re
dup 3
dup 11
mul
dup 2
dup 13
mul
add
// [new_lo.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
// new_lo.im = g00*lo.im + g01*hi.im
dup 3
dup 12
mul
dup 2
dup 14
mul
add
// [new_lo.im, new_lo.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
// Store new_lo to scratch memory, pop from stack
push 510
swap 1
write_mem 1
swap 1
write_mem 1
pop 1
// Stack restored: [hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
// Now g10 at depth 12, g11 at depth 13 โ within bounds
// new_hi.re = g10*lo.re + g11*hi.re
dup 3
dup 13
mul
dup 2
dup 15
mul
add
// new_hi.im = g10*lo.im + g11*hi.im
dup 3
dup 14
mul
dup 2
dup 14
mul
add
// [new_hi.im, new_hi.re, hi.im, hi.re, lo.im, lo.re, hi_addr, lo_addr, ...]
// Write new_hi to hi_addr
dup 6
dup 2
write_mem 1
dup 2
write_mem 1
pop 1
// Pop new_hi results + old amplitudes
pop 5
pop 1
// Stack: [hi_addr, lo_addr, idx, stride, ...]
// Write new_lo from scratch memory to lo_addr
dup 1
push 510
read_mem 1
pop 1
write_mem 1
push 510
push 1
add
read_mem 1
pop 1
write_mem 1
pop 1
// Clean up hi_addr and lo_addr
swap 2
pop 2
return