Skip to content

Commit

Permalink
feat: SltuChip (#150)
Browse files Browse the repository at this point in the history
* feat: SltuChip

* rd == x0 case in SLTU chip

* refactor: streamline trace assertion in SLTU chip tests

* rebase

* improve panic messages

* drop todo

---------

Co-authored-by: Kristian Sosnin <[email protected]>
  • Loading branch information
2 people authored and sjudson committed Feb 5, 2025
1 parent 53b02b4 commit 2043c7a
Show file tree
Hide file tree
Showing 6 changed files with 267 additions and 4 deletions.
2 changes: 1 addition & 1 deletion prover/src/machine2/chips/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ impl AddChip {
fn execute(program_step: &ProgramStep) -> ExecutionResult {
let result = program_step
.get_result()
.expect("Instruction does not have result");
.expect("Add instruction must have a result");
let rd_is_x0 = program_step.is_value_a_x0();

// Recompute 32-bit result from 8-bit limbs.
Expand Down
3 changes: 3 additions & 0 deletions prover/src/machine2/chips/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ impl MachineChip for CpuChip {
Some(BuiltinOpcode::SUB) => {
trace_column_mut!(traces, row_idx, IsSub)
}
Some(BuiltinOpcode::SLTU) => {
trace_column_mut!(traces, row_idx, IsSltu)
}
_ => panic!(
"Unsupported opcode: {:?}",
step.instruction.opcode.builtin()
Expand Down
3 changes: 2 additions & 1 deletion prover/src/machine2/chips/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ mod add;
mod sub;

mod cpu;
mod sltu;

pub use self::{add::AddChip, cpu::CpuChip, sub::SubChip};
pub use self::{add::AddChip, cpu::CpuChip, sltu::SltuChip, sub::SubChip};
251 changes: 251 additions & 0 deletions prover/src/machine2/chips/sltu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
use num_traits::Zero;
use stwo_prover::{constraint_framework::EvalAtRow, core::fields::m31::BaseField};

use nexus_vm::{riscv::BuiltinOpcode, WORD_SIZE};

use crate::machine2::{
column::Column::{self, *},
trace::{
eval::{trace_eval, TraceEval},
trace_column_mut, ProgramStep, Traces,
},
traits::MachineChip,
};

// Support SLTU opcode.
pub struct SltuChip;

struct ExecutionResult {
borrow_bits: [u32; WORD_SIZE],
helper_bytes: [u32; WORD_SIZE],
rd_is_x0: bool,
}

impl SltuChip {
fn execute(program_step: &ProgramStep) -> ExecutionResult {
let result = program_step
.get_result()
.expect("Slt instruction must have a result");
let rd_is_x0 = program_step.is_value_a_x0();

// Recompute 32-bit result from 8-bit limbs.
// 1. Break the computation to 8-bit limbs.
// 2. Compute the diff and borrow of each limb.
// 3. Check that the final result matches the expected result.

// Step 1. Break the computation to 8-bit limbs.
let value_b = program_step.get_value_b();
let (value_c, _) = program_step.get_value_c();

let mut helper1_bytes = [0u8; WORD_SIZE];
let mut borrow = [false; WORD_SIZE];

// Step 2. Compute the diff and borrow of each limb.
let (diff, c0) = value_b[0].overflowing_sub(value_c[0]);
borrow[0] = c0;
helper1_bytes[0] = diff;

// Process the remaining bytes
for i in 1..WORD_SIZE {
// Add the bytes and the previous borrow
let (diff, b1) = value_b[i].overflowing_sub(borrow[i - 1] as u8);
let (diff, b2) = diff.overflowing_sub(value_c[i]);

// There can't be 2 borrow in: a - b - cary, at most either c1 or c2 is true.
borrow[i] = b1 || b2;
helper1_bytes[i] = diff;
}
let mut rd_bytes = [0u8; WORD_SIZE];
rd_bytes[0] = borrow[3] as u8;

// Step 3. Check that the final result matches the expected result.
assert_eq!(rd_bytes, result);

// Map borrow bits to 0/1 values, and expand to 32-bit words.
let borrow_bits: [u32; WORD_SIZE] = borrow.map(|c| c as u32);
let helper_bytes = helper1_bytes.map(|b| b as u32);

ExecutionResult {
borrow_bits,
helper_bytes,
rd_is_x0,
}
}
}

impl MachineChip for SltuChip {
fn fill_main_trace(traces: &mut Traces, row_idx: usize, vm_step: &ProgramStep) {
if !matches!(
vm_step.step.instruction.opcode.builtin(),
Some(BuiltinOpcode::SLTU)
) {
return;
}

let ExecutionResult {
borrow_bits,
helper_bytes,
rd_is_x0,
} = Self::execute(vm_step);

let helper_col = trace_column_mut!(traces, row_idx, Helper1);
for (i, b) in helper_bytes.iter().enumerate() {
*helper_col[i] = BaseField::from(*b);
}
let borrow_col = trace_column_mut!(traces, row_idx, CarryFlag);
for (i, b) in borrow_bits.iter().enumerate() {
*borrow_col[i] = BaseField::from(*b);
}
let rd_col = trace_column_mut!(traces, row_idx, ValueA);
*rd_col[0] = BaseField::from(borrow_bits[3]);
let rd_effective_col = trace_column_mut!(traces, row_idx, ValueAEffective);
*rd_effective_col[0] = if rd_is_x0 {
BaseField::zero()
} else {
BaseField::from(borrow_bits[3])
};
}

fn add_constraints<E: EvalAtRow>(eval: &mut E, trace_eval: &TraceEval<E>) {
let (_, is_sltu) = trace_eval!(trace_eval, IsSltu);
let is_sltu = is_sltu[0].clone();
// modulus for 8-bit limbs
let modulus = E::F::from(256u32.into());

// Reusing the CarryFlag as borrow flag.
let (_, borrow_flag) = trace_eval!(trace_eval, CarryFlag);
let (_, rs1_val) = trace_eval!(trace_eval, ValueB);
let (_, rs2_val) = trace_eval!(trace_eval, ValueC);
let (_, rd_val) = trace_eval!(trace_eval, ValueA);
let (_, helper1_val) = trace_eval!(trace_eval, Helper1);

for i in 0..WORD_SIZE {
let borrow = i
.checked_sub(1)
.map(|j| borrow_flag[j].clone())
.unwrap_or(E::F::zero());

// SLTU a, b, c
// h_1[i] - h1[i] * 2^8 = rs1val[i] - rs2val[i] - borrow[i - 1]
eval.add_constraint(
is_sltu.clone()
* (helper1_val[i].clone()
- borrow_flag[i].clone() * modulus.clone()
- (rs1_val[i].clone() - rs2_val[i].clone() - borrow)),
);

match i {
0 =>
// SLTU rd[0] = borrow[3]
{
eval.add_constraint(
is_sltu.clone() * (rd_val[i].clone() - borrow_flag[3].clone()),
)
}
1..=3 =>
// SLTU rd[1,2,3] = 0
{
eval.add_constraint(is_sltu.clone() * rd_val[i].clone())
}
_ => panic!("never reached"),
}
}
// TODO: range check CarryFlag to be in {0, 1}.
// TODO: range check r{s1,s2}_val[i] to be in [0, 255].
// TODO: range check helper1_val[i] to be in [0, 255].
// TODO: range check rd_val[i] to be in {0, 1}.
// TODO: constrain ValueAEffective in CpuChip.
}
}

#[cfg(test)]
mod test {
use crate::machine2::chips::{AddChip, CpuChip}; // needed for ADDI to put a non-zero value in a register

use super::*;
use nexus_vm::{
riscv::{BasicBlock, BuiltinOpcode, Instruction, InstructionType, Opcode},
trace::k_trace_direct,
};

const LOG_SIZE: u32 = 8;

#[rustfmt::skip]
fn setup_basic_block_ir() -> Vec<BasicBlock>
{
let basic_block = BasicBlock::new(vec![
// Set x0 = 0 (default constant), x1 = 1
Instruction::new(Opcode::from(BuiltinOpcode::ADDI), 1, 0, 1, InstructionType::IType),
// x2 = x1 < x0
// x3 = x2 < x1 ... and so on
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 2, 1, 0, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 3, 2, 1, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 4, 3, 2, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 5, 4, 3, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 6, 5, 4, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 7, 6, 5, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 8, 7, 6, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 9, 8, 7, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 10, 9, 8, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 11, 10, 9, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 12, 11, 10, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 13, 12, 11, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 14, 13, 12, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 15, 14, 13, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 16, 15, 14, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 17, 16, 15, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 18, 17, 16, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 19, 18, 17, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 20, 19, 18, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 21, 20, 19, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 22, 21, 20, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 23, 22, 21, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 24, 23, 22, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 25, 24, 23, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 26, 25, 24, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 27, 26, 25, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 28, 27, 26, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 29, 28, 27, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 30, 29, 28, InstructionType::RType),
Instruction::new(Opcode::from(BuiltinOpcode::SLTU), 31, 30, 29, InstructionType::RType),
]);
vec![basic_block]
}

#[test]
fn test_k_trace_constrained_add_instructions() {
let basic_block = setup_basic_block_ir();
let k = 1;

// Get traces from VM K-Trace interface
let vm_traces = k_trace_direct(&basic_block, k).expect("Failed to create trace");

// Trace circuit
let mut traces = Traces::new(LOG_SIZE);
let mut row_idx = 0;

// We iterate each block in the trace for each instruction
for trace in vm_traces.blocks.iter() {
let regs = trace.regs;
for step in trace.steps.iter() {
let program_step = ProgramStep {
regs,
step: step.clone(),
};

// Now fill in the traces with ValueA and CarryFlags
CpuChip::fill_main_trace(&mut traces, row_idx, &program_step);
// AddChip::fill_main_trace() needs to be called because the first step is ADDI.
AddChip::fill_main_trace(&mut traces, row_idx, &program_step);
SltuChip::fill_main_trace(&mut traces, row_idx, &program_step);

row_idx += 1;
}
}

traces.assert_as_original_trace(|eval, trace_eval| {
AddChip::add_constraints(eval, trace_eval);
SltuChip::add_constraints(eval, trace_eval);
});
}
}
8 changes: 8 additions & 0 deletions prover/src/machine2/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,20 @@ pub enum Column {
/// The previous timestamp for c.
#[size = 4]
PrevTsC,

/// Boolean flag on whether the row is an addition.
#[size = 1]
IsAdd,
/// Boolean flag on whether the row is a subtraction.
#[size = 1]
IsSub,
/// Boolean flag on whether the row is a SLTU.
#[size = 1]
IsSltu,

/// Helper variable 1. Called h_1 in document.
#[size = 4]
Helper1,
/// Multiplicity column for byte range check. Multipllicity256[row_idx] counts how many times the number Range256[row_idx] is used in the entire trace.
#[size = 1]
Multiplicity256,
Expand Down
4 changes: 2 additions & 2 deletions prover/src/machine2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ pub mod traits;

pub use crate::utils::WORD_SIZE;

use chips::{AddChip, CpuChip, SubChip};
use chips::{AddChip, CpuChip, SltuChip, SubChip};
use components::{MachineComponent, MachineEval, LOG_CONSTRAINT_DEGREE};
use traits::MachineChip;

pub type Components = (CpuChip, AddChip, SubChip);
pub type Components = (CpuChip, AddChip, SubChip, SltuChip);
pub type Proof = StarkProof<Blake2sMerkleHasher>;

pub struct Machine<C = Components> {
Expand Down

0 comments on commit 2043c7a

Please sign in to comment.