Skip to content
This repository was archived by the owner on Jan 17, 2022. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Use linear time algorithm to inject stack height metering
  • Loading branch information
athei committed Sep 7, 2021
commit 25e5b89f97d652dc1bb497105988ba4adfecdd18
113 changes: 54 additions & 59 deletions src/stack_height/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
//! between the frames.
//! - upon entry into the function entire stack frame is allocated.

use crate::std::{string::String, vec::Vec};
use crate::std::{mem, string::String, vec::Vec};

use parity_wasm::{
builder,
elements::{self, Type},
elements::{self, Instruction, Instructions, Type},
};

/// Macro to generate preamble and postamble.
Expand Down Expand Up @@ -145,7 +145,7 @@ fn generate_stack_height_global(module: &mut elements::Module) -> u32 {
.value_type()
.i32()
.mutable()
.init_expr(elements::Instruction::I32Const(0))
.init_expr(Instruction::I32Const(0))
.build();

// Try to find an existing global section.
Expand Down Expand Up @@ -253,75 +253,70 @@ fn instrument_functions(ctx: &mut Context, module: &mut elements::Module) -> Res
///
/// drop
/// ```
fn instrument_function(
ctx: &mut Context,
instructions: &mut elements::Instructions,
) -> Result<(), Error> {
use parity_wasm::elements::Instruction::*;

let mut cursor = 0;
loop {
if cursor >= instructions.elements().len() {
break
}
fn instrument_function(ctx: &mut Context, func: &mut Instructions) -> Result<(), Error> {
use Instruction::*;

enum Action {
InstrumentCall { callee_idx: u32, callee_stack_cost: u32 },
Nop,
}
struct InstrumentCall {
offset: usize,
callee: u32,
cost: u32,
}

let action: Action = {
let instruction = &instructions.elements()[cursor];
match instruction {
Call(callee_idx) => {
let callee_stack_cost = ctx.stack_cost(*callee_idx).ok_or_else(|| {
Error(format!("Call to function that out-of-bounds: {}", callee_idx))
})?;

// Instrument only calls to a functions which stack_cost is
// non-zero.
if callee_stack_cost > 0 {
Action::InstrumentCall { callee_idx: *callee_idx, callee_stack_cost }
let calls: Vec<_> = func
.elements()
.iter()
.enumerate()
.filter_map(|(offset, instruction)| {
if let Call(callee) = instruction {
ctx.stack_cost(*callee).and_then(|cost| {
if cost > 0 {
Some(InstrumentCall { callee: *callee, offset, cost })
} else {
Action::Nop
None
}
},
_ => Action::Nop,
})
} else {
None
}
};

match action {
// We need to wrap a `call idx` instruction
// with a code that adjusts stack height counter
// and then restores it.
Action::InstrumentCall { callee_idx, callee_stack_cost } => {
})
.collect();

// The `instrumented_call!` contains the call itself. This is why we need to subtract one.
let len = func.elements().len() + calls.len() * (instrument_call!(0, 0, 0, 0).len() - 1);
let original_instrs = mem::replace(func.elements_mut(), Vec::with_capacity(len));
let new_instrs = func.elements_mut();

let mut calls = calls.into_iter().peekable();
for (original_pos, instr) in original_instrs.into_iter().enumerate() {
// whether there is some call instruction at this position that needs to be instrumented
let did_instrument = if let Some(call) = calls.peek() {
if call.offset == original_pos {
let new_seq = instrument_call!(
callee_idx,
callee_stack_cost as i32,
call.callee,
call.cost as i32,
ctx.stack_height_global_idx(),
ctx.stack_limit()
);
new_instrs.extend(new_seq);
true
} else {
false
}
} else {
false
};

// Replace the original `call idx` instruction with
// a wrapped call sequence.
//
// To splice actually take a place, we need to consume iterator
// splice returns. So we just `count()` it.
let _ = instructions
.elements_mut()
.splice(cursor..(cursor + 1), new_seq.iter().cloned())
.count();

// Advance cursor to be after the inserted sequence.
cursor += new_seq.len();
},
// Do nothing for other instructions.
_ => {
cursor += 1;
},
if did_instrument {
calls.next();
} else {
new_instrs.push(instr);
}
}

if calls.next().is_some() {
return Err(Error("Not all calls were used".into()))
}

Ok(())
}

Expand Down