Skip to content

Commit

Permalink
changed loop generated funcion key from exprid to stable_ptr for enab… (
Browse files Browse the repository at this point in the history
  • Loading branch information
TomerStarkware authored Jan 30, 2025
1 parent 655edd4 commit b90e1e6
Show file tree
Hide file tree
Showing 17 changed files with 160 additions and 145 deletions.
48 changes: 25 additions & 23 deletions crates/bin/get-lowering/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use anyhow::Context;
use cairo_lang_compiler::db::RootDatabase;
use cairo_lang_compiler::project::{check_compiler_path, setup_project};
use cairo_lang_debug::debug::DebugWithDb;
use cairo_lang_defs::ids::TopLevelLanguageElementId;
use cairo_lang_defs::ids::{NamedLanguageElementId, TopLevelLanguageElementId};
use cairo_lang_filesystem::ids::CrateId;
use cairo_lang_lowering::FlatLowered;
use cairo_lang_lowering::add_withdraw_gas::add_withdraw_gas;
Expand All @@ -27,7 +27,7 @@ use cairo_lang_semantic::items::functions::{
};
use cairo_lang_starknet::starknet_plugin_suite;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{Intern, LookupIntern};
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
use clap::Parser;
use convert_case::Casing;
use itertools::Itertools;
Expand Down Expand Up @@ -73,9 +73,9 @@ struct Args {
#[arg(short, long)]
all: bool,

/// The id the expr id of the generated function to output.
/// The index of the generated function to output.
#[arg(long)]
expr_id: Option<usize>,
generated_function_index: Option<usize>,

/// The output file name (default: stdout).
output: Option<String>,
Expand Down Expand Up @@ -233,33 +233,35 @@ fn main() -> anyhow::Result<()> {

let res = if let Some(function_path) = args.function_path {
let mut function_id = get_func_id_by_name(db, &main_crate_ids, function_path)?;
if let Some(expr_id) = args.expr_id {
if let Some(generated_function_index) = args.generated_function_index {
let multi = db
.priv_function_with_body_multi_lowering(
function_id.function_with_body_id(db).base_semantic_function(db),
)
.unwrap();
let key = *multi
let keys = multi
.generated_lowerings
.keys()
.find(|key| match key {
GeneratedFunctionKey::Loop(id) => id.index() == expr_id,
// TODO(ilya): Support other types of generated functions.
_ => false,
.sorted_by_key(|key| match key {
GeneratedFunctionKey::Loop(id) => {
(id.0.lookup(db).span_without_trivia(db.upcast()), "".into())
}
GeneratedFunctionKey::TraitFunc(trat_function, id) => (
id.syntax_node(db).span_without_trivia(db.upcast()),
trat_function.name(db),
),
})
.with_context(|| {
format!(
"expr_id not found - available expr_ids: {:?}",
multi
.generated_lowerings
.keys()
.filter_map(|key| match key {
GeneratedFunctionKey::Loop(id) => Some(id.index()),
_ => None,
})
.collect_vec()
)
})?;
.take(generated_function_index + 1)
.collect_vec();

let key = **keys.get(generated_function_index).with_context(|| {
format!(
"Invalid generated function index. There are {} generated functions in the \
function",
keys.len()
)
})?;

function_id = db.intern_lowering_concrete_function_with_body(
ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction {
parent: function_id.base_semantic_function(db),
Expand Down
37 changes: 25 additions & 12 deletions crates/cairo-lang-lowering/src/ids.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use cairo_lang_semantic::corelib::panic_destruct_trait_fn;
use cairo_lang_semantic::items::functions::ImplGenericFunctionId;
use cairo_lang_semantic::items::imp::ImplLongId;
use cairo_lang_semantic::{GenericArgumentId, TypeLongId};
use cairo_lang_syntax::node::ast::ExprPtr;
use cairo_lang_syntax::node::kind::SyntaxKind;
use cairo_lang_syntax::node::{TypedStablePtr, ast};
use cairo_lang_utils::{Intern, LookupIntern, define_short_id, try_extract_matches};
use defs::diagnostic_utils::StableLocation;
Expand Down Expand Up @@ -215,15 +217,10 @@ impl ConcreteFunctionWithBodyId {
let semantic_db = db.upcast();
Ok(match self.lookup_intern(db) {
ConcreteFunctionWithBodyLongId::Semantic(id) => id.stable_location(semantic_db),
ConcreteFunctionWithBodyLongId::Generated(generated) => {
let parent_id = generated.parent.function_with_body_id(semantic_db);
match generated.key {
GeneratedFunctionKey::Loop(expr_id) => StableLocation::new(
db.function_body(parent_id)?.arenas.exprs[expr_id].stable_ptr().untyped(),
),
GeneratedFunctionKey::TraitFunc(_, stable_location) => stable_location,
}
}
ConcreteFunctionWithBodyLongId::Generated(generated) => match generated.key {
GeneratedFunctionKey::Loop(stable_ptr) => StableLocation::new(stable_ptr.untyped()),
GeneratedFunctionKey::TraitFunc(_, stable_location) => stable_location,
},
})
}
}
Expand Down Expand Up @@ -379,7 +376,7 @@ impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for FunctionLongId {
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub enum GeneratedFunctionKey {
/// Generated loop functions are identified by the loop expr_id.
Loop(semantic::ExprId),
Loop(ExprPtr),
TraitFunc(TraitFunctionId, StableLocation),
}

Expand All @@ -405,8 +402,24 @@ impl<'a> DebugWithDb<dyn LoweringGroup + 'a> for GeneratedFunction {
db: &(dyn LoweringGroup + 'a),
) -> std::fmt::Result {
match self.key {
GeneratedFunctionKey::Loop(expr_id) => {
write!(f, "{:?}[expr{}]", self.parent.debug(db), expr_id.index())
GeneratedFunctionKey::Loop(expr_ptr) => {
let mut func_ptr = expr_ptr.untyped();
while !matches!(
func_ptr.kind(db.upcast()),
SyntaxKind::FunctionWithBody | SyntaxKind::TraitItemFunction
) {
func_ptr = func_ptr.parent(db.upcast())
}

let span = expr_ptr.0.lookup(db.upcast()).span(db.upcast());
let function_start = func_ptr.lookup(db.upcast()).span(db.upcast()).start.as_u32();
write!(
f,
"{:?}[{}-{}]",
self.parent.debug(db),
span.start.as_u32() - function_start,
span.end.as_u32() - function_start
)
}
GeneratedFunctionKey::TraitFunc(trait_func, loc) => {
let trait_id = trait_func.trait_id(db.upcast());
Expand Down
8 changes: 4 additions & 4 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,7 @@ fn lower_expr_loop(
// Get the function id.
let function = FunctionWithBodyLongId::Generated {
parent: ctx.semantic_function_id,
key: GeneratedFunctionKey::Loop(loop_expr_id),
key: GeneratedFunctionKey::Loop(stable_ptr),
}
.intern(ctx.db);

Expand All @@ -1449,7 +1449,7 @@ fn lower_expr_loop(
)
.map_err(LoweringFlowError::Failed)?;
// TODO(spapini): Recursive call.
encapsulating_ctx.lowerings.insert(GeneratedFunctionKey::Loop(loop_expr_id), lowered);
encapsulating_ctx.lowerings.insert(GeneratedFunctionKey::Loop(stable_ptr), lowered);
ctx.encapsulating_ctx = Some(encapsulating_ctx);
let old_loop_expr_id = std::mem::replace(&mut ctx.current_loop_expr_id, Some(loop_expr_id));
for snapshot_param in snap_usage.values() {
Expand Down Expand Up @@ -1480,11 +1480,11 @@ fn call_loop_func(
stable_ptr: SyntaxStablePtrId,
) -> LoweringResult<LoweredExpr> {
let location = ctx.get_location(stable_ptr);

let loop_stable_ptr = ctx.function_body.arenas.exprs[loop_expr_id].stable_ptr();
// Call it.
let function = FunctionLongId::Generated(GeneratedFunction {
parent: ctx.concrete_function_id.base_semantic_function(ctx.db),
key: GeneratedFunctionKey::Loop(loop_expr_id),
key: GeneratedFunctionKey::Loop(loop_stable_ptr),
})
.intern(ctx.db);
let inputs = loop_signature
Expand Down
8 changes: 4 additions & 4 deletions crates/cairo-lang-lowering/src/lower/test_data/for
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Statements:
(v16: core::array::Array::<core::felt252>, v17: @core::array::Array::<core::felt252>) <- snapshot(v15)
(v18: core::array::Span::<core::felt252>) <- core::array::ArrayToSpan::<core::felt252>::span(v17)
(v19: core::array::SpanIter::<core::felt252>) <- core::array::SpanIntoIterator::<core::felt252>::into_iter(v18)
(v21: core::array::SpanIter::<core::felt252>, v22: core::felt252, v20: ()) <- test::foo[expr30](v19, v0, v2)
(v21: core::array::SpanIter::<core::felt252>, v22: core::felt252, v20: ()) <- test::foo[118-164](v19, v0, v2)
End:
Return(v22)

Expand All @@ -69,7 +69,7 @@ Statements:
(v14: core::array::Array::<core::felt252>, v15: @core::array::Array::<core::felt252>) <- snapshot(v10)
(v16: core::array::Span::<core::felt252>) <- struct_construct(v15)
(v17: core::array::SpanIter::<core::felt252>) <- struct_construct(v16)
(v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr30](v0, v1, v17, v11, v13)
(v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[118-164](v0, v1, v17, v11, v13)
End:
Match(match_enum(v20) {
PanicResult::Ok(v21) => blk1,
Expand Down Expand Up @@ -111,7 +111,7 @@ End:
blk1:
Statements:
(v6: core::felt252) <- core::Felt252Add::add(v1, v2)
(v8: core::array::SpanIter::<core::felt252>, v9: core::felt252, v7: ()) <- test::foo[expr30](v4, v6, v2)
(v8: core::array::SpanIter::<core::felt252>, v9: core::felt252, v7: ()) <- test::foo[118-164](v4, v6, v2)
End:
Goto(blk3, {v9 -> v12, v8 -> v13, v7 -> v11})

Expand Down Expand Up @@ -195,7 +195,7 @@ End:
blk8:
Statements:
(v30: core::felt252) <- core::felt252_add(v3, v4)
(v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr30](v5, v6, v27, v30, v4)
(v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[118-164](v5, v6, v27, v30, v4)
End:
Return(v31, v32, v33)

Expand Down
Loading

0 comments on commit b90e1e6

Please sign in to comment.