From 805cc644d991eac67aed513ff72aa5b8a90efc2e Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 27 Mar 2026 00:34:13 +0900 Subject: [PATCH 1/6] Translate and register formula_fn --- src/analyze.rs | 37 +++++ src/analyze/annot.rs | 4 + src/analyze/annot_fn.rs | 297 +++++++++++++++++++++++++++++++++++++++ src/analyze/crate_.rs | 10 +- src/analyze/local_def.rs | 10 ++ src/lib.rs | 3 + src/main.rs | 4 + 7 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 src/analyze/annot_fn.rs diff --git a/src/analyze.rs b/src/analyze.rs index b9203f2..438d4d6 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -24,6 +24,7 @@ use crate::refine::{self, BasicBlockType, TypeBuilder}; use crate::rty; mod annot; +mod annot_fn; mod basic_block; mod crate_; mod did_cache; @@ -32,6 +33,32 @@ mod local_def; // TODO: organize structure and remove cross dependency between refine pub use did_cache::DefIdCache; +pub fn mir_borrowck_skip_formula_fn( + tcx: rustc_middle::ty::TyCtxt<'_>, + local_def_id: rustc_span::def_id::LocalDefId, +) -> rustc_middle::query::queries::mir_borrowck::ProvidedValue { + // TODO: unify impl with local_def::Analyzer + let is_annotated_as_formula_fn = tcx + .get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::formula_fn_path()) + .next() + .is_some(); + + if is_annotated_as_formula_fn { + tracing::debug!(?local_def_id, "skipping borrow check for formula fn"); + let dummy_result = rustc_middle::mir::BorrowCheckResult { + concrete_opaque_types: Default::default(), + closure_requirements: None, + used_mut_upvars: Default::default(), + tainted_by_errors: None, + }; + return tcx.arena.alloc(dummy_result); + } + + (rustc_interface::DEFAULT_QUERY_PROVIDERS + .queries + .mir_borrowck)(tcx, local_def_id) +} + pub fn local_of_function_param(idx: rty::FunctionParamIdx) -> Local { Local::from(idx.index() + 1) } @@ -173,6 +200,9 @@ pub struct Analyzer<'tcx> { /// (at least for every defs referenced by local def bodies) defs: HashMap>, + /// Collection of functions with `#[thrust::formula_fn]` attribute. + formula_fns: HashMap>, + /// Resulting CHC system. system: Rc>, @@ -199,12 +229,14 @@ impl<'tcx> Analyzer<'tcx> { impl<'tcx> Analyzer<'tcx> { pub fn new(tcx: TyCtxt<'tcx>) -> Self { let defs = Default::default(); + let formula_fns = Default::default(); let system = Default::default(); let basic_blocks = Default::default(); let enum_defs = Default::default(); Self { tcx, defs, + formula_fns, system, basic_blocks, def_ids: did_cache::DefIdCache::new(tcx), @@ -410,6 +442,11 @@ impl<'tcx> Analyzer<'tcx> { Some(expected) } + pub fn register_formula_fn(&mut self, def_id: DefId, formula_fn: annot_fn::FormulaFn<'tcx>) { + tracing::info!(def_id = ?def_id, formula_fn = %formula_fn.display(), "register_formula_fn"); + self.formula_fns.insert(def_id, formula_fn); + } + pub fn register_basic_block_ty( &mut self, def_id: LocalDefId, diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index c9d80e3..8cb43d2 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -49,6 +49,10 @@ pub fn ignored_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("ignored")] } +pub fn formula_fn_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("formula_fn")] +} + pub fn model_ty_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs new file mode 100644 index 0000000..219e5cc --- /dev/null +++ b/src/analyze/annot_fn.rs @@ -0,0 +1,297 @@ +use std::collections::HashMap; + +use pretty::{termcolor, Pretty}; +use rustc_hir::{def_id::LocalDefId, HirId}; +use rustc_index::IndexVec; +use rustc_middle::ty::{self as mir_ty, TyCtxt}; + +use crate::annot::AnnotFormula; +use crate::chc; +use crate::rty; + +#[derive(Debug, Clone)] +pub struct FormulaFn<'tcx> { + params: IndexVec>, + formula: chc::Formula, +} + +impl<'a, 'b, D> Pretty<'a, D, termcolor::ColorSpec> for &'b FormulaFn<'_> +where + D: pretty::DocAllocator<'a, termcolor::ColorSpec>, + D::Doc: Clone, +{ + fn pretty(self, allocator: &'a D) -> pretty::DocBuilder<'a, D, termcolor::ColorSpec> { + allocator + .intersperse( + self.params.iter_enumerated().map(|(idx, ty)| { + idx.pretty(allocator) + .append(": ") + .append(allocator.as_string(ty)) + }), + ", ", + ) + .enclose("|", "|") + .group() + .append(self.formula.pretty(allocator)) + .group() + } +} + +impl<'tcx> FormulaFn<'tcx> { + pub fn to_require_annot( + &self, + sig: mir_ty::FnSig<'tcx>, + ) -> Option> { + if &self.params.raw != sig.inputs() { + return None; + } + Some(AnnotFormula::Formula(self.formula.clone())) + } + + pub fn to_ensure_annot( + &self, + sig: mir_ty::FnSig<'tcx>, + ) -> Option>> { + if &self.params.raw[1..] != sig.inputs() { + return None; + } + if self.params.raw[0] != sig.output() { + return None; + } + Some(AnnotFormula::Formula(self.formula.clone().map_var(|v| { + if v.as_usize() == 0 { + rty::RefinedTypeVar::Value + } else { + rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1)) + } + }))) + } +} + +#[derive(Debug, Clone, Copy)] +enum AmbiguousBinOp { + Eq, + Ne, + Ge, + Le, + Gt, + Lt, +} + +#[derive(Debug, Clone)] +enum FormulaOrTerm { + Formula(chc::Formula), + Term(chc::Term), + BinOp(chc::Term, AmbiguousBinOp, chc::Term), + Not(Box>), + Literal(bool), +} + +impl FormulaOrTerm { + fn into_formula(self) -> Option> { + let fo = match self { + FormulaOrTerm::Formula(fo) => fo, + FormulaOrTerm::Term { .. } => return None, + FormulaOrTerm::BinOp(lhs, binop, rhs) => { + let pred = match binop { + AmbiguousBinOp::Eq => chc::KnownPred::EQUAL, + AmbiguousBinOp::Ne => chc::KnownPred::NOT_EQUAL, + AmbiguousBinOp::Ge => chc::KnownPred::GREATER_THAN_OR_EQUAL, + AmbiguousBinOp::Le => chc::KnownPred::LESS_THAN_OR_EQUAL, + AmbiguousBinOp::Gt => chc::KnownPred::GREATER_THAN, + AmbiguousBinOp::Lt => chc::KnownPred::LESS_THAN, + }; + chc::Formula::Atom(chc::Atom::new(pred.into(), vec![lhs, rhs])) + } + FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_formula()?.not(), + FormulaOrTerm::Literal(b) => { + if b { + chc::Formula::top() + } else { + chc::Formula::bottom() + } + } + }; + Some(fo) + } + + fn into_term(self) -> Option> { + let t = match self { + FormulaOrTerm::Formula { .. } => return None, + FormulaOrTerm::Term(t) => t, + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Eq, rhs) => lhs.eq(rhs), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ne, rhs) => lhs.ne(rhs), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Ge, rhs) => lhs.ge(rhs), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Le, rhs) => lhs.le(rhs), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Gt, rhs) => lhs.gt(rhs), + FormulaOrTerm::BinOp(lhs, AmbiguousBinOp::Lt, rhs) => lhs.lt(rhs), + FormulaOrTerm::Not(formula_or_term) => formula_or_term.into_term()?.not(), + FormulaOrTerm::Literal(b) => chc::Term::bool(b), + }; + Some(t) + } +} + +pub struct AnnotFnTranslator<'tcx> { + tcx: TyCtxt<'tcx>, + local_def_id: LocalDefId, + + typeck: &'tcx mir_ty::TypeckResults<'tcx>, + body: &'tcx rustc_hir::Body<'tcx>, + + env: HashMap>, +} + +impl<'tcx> AnnotFnTranslator<'tcx> { + pub fn new(tcx: TyCtxt<'tcx>, local_def_id: LocalDefId) -> Self { + let map = tcx.hir(); + let body_id = map.body_owned_by(local_def_id); + let body = map.body(body_id); + + let typeck = tcx.typeck(local_def_id); + let mut translator = Self { + tcx, + local_def_id, + typeck, + body, + env: HashMap::default(), + }; + translator.build_env_from_params(); + translator + } + + fn build_env_from_params(&mut self) { + for (idx, param) in self.body.params.iter().enumerate() { + let param_idx = rty::FunctionParamIdx::from(idx); + let term = chc::Term::var(param_idx); + self.build_env_from_pat(term, param.pat); + } + } + + fn build_env_from_pat( + &mut self, + param: chc::Term, + pat: &'tcx rustc_hir::Pat<'tcx>, + ) { + use rustc_hir::PatKind; + + match pat.kind { + PatKind::Binding(_, hir_id, _, None) => { + self.env.insert(hir_id, param); + } + PatKind::TupleStruct(_, subpats, _) | PatKind::Tuple(subpats, _) => { + for (idx, subpat) in subpats.iter().enumerate() { + let field_term = param.clone().tuple_proj(idx.into()); + self.build_env_from_pat(field_term, subpat); + } + } + _ => unimplemented!("unsupported pattern in formula: {:?}", pat), + } + } + + pub fn to_formula_fn(&self) -> FormulaFn<'tcx> { + let formula = self.to_formula(self.body.value); + let params = self + .tcx + .fn_sig(self.local_def_id.to_def_id()) + .instantiate_identity() + .skip_binder() + .inputs() + .to_vec(); + FormulaFn { + params: IndexVec::from_raw(params), + formula, + } + } + + fn to_formula(&self, hir: &'tcx rustc_hir::Expr<'tcx>) -> chc::Formula { + self.to_formula_or_term(hir) + .into_formula() + .expect("expected a formula") + } + + fn to_term(&self, hir: &'tcx rustc_hir::Expr<'tcx>) -> chc::Term { + self.to_formula_or_term(hir) + .into_term() + .expect("expected a term") + } + + fn to_formula_or_term( + &self, + hir: &'tcx rustc_hir::Expr<'tcx>, + ) -> FormulaOrTerm { + use rustc_hir::ExprKind; + + match hir.kind { + ExprKind::Binary(op, lhs, rhs) => { + match op.node { + rustc_hir::BinOpKind::Or => { + let lhs = self.to_formula(lhs); + let rhs = self.to_formula(rhs); + return FormulaOrTerm::Formula(lhs.or(rhs)); + } + rustc_hir::BinOpKind::And => { + let lhs = self.to_formula(lhs); + let rhs = self.to_formula(rhs); + return FormulaOrTerm::Formula(lhs.and(rhs)); + } + _ => {} + } + + let binop = match op.node { + rustc_hir::BinOpKind::Eq => AmbiguousBinOp::Eq, + rustc_hir::BinOpKind::Ne => AmbiguousBinOp::Ne, + rustc_hir::BinOpKind::Ge => AmbiguousBinOp::Ge, + rustc_hir::BinOpKind::Le => AmbiguousBinOp::Le, + rustc_hir::BinOpKind::Gt => AmbiguousBinOp::Gt, + rustc_hir::BinOpKind::Lt => AmbiguousBinOp::Lt, + _ => unimplemented!("unsupported binary operator in formula: {:?}", op), + }; + let lhs = self.to_formula_or_term(lhs); + let rhs = self.to_formula_or_term(rhs); + FormulaOrTerm::BinOp(lhs.into_term().unwrap(), binop, rhs.into_term().unwrap()) + } + ExprKind::Unary(op, operand) => match op { + rustc_hir::UnOp::Neg => { + let operand = self.to_term(operand); + FormulaOrTerm::Term(operand.neg()) + } + rustc_hir::UnOp::Not => { + FormulaOrTerm::Not(Box::new(self.to_formula_or_term(operand))) + } + _ => unimplemented!("unsupported unary operator in formula: {:?}", op), + }, + ExprKind::Lit(lit) => match lit.node { + rustc_ast::LitKind::Int(i, _) => { + let n = i64::try_from(i.get()) + .expect("integer literal out of i64 range in formula"); + FormulaOrTerm::Term(chc::Term::int(n)) + } + rustc_ast::LitKind::Bool(b) => FormulaOrTerm::Literal(b), + _ => unimplemented!("unsupported literal in formula: {:?}", lit), + }, + ExprKind::Path(qpath) => { + if let rustc_hir::def::Res::Local(hir_id) = + self.typeck.qpath_res(&qpath, hir.hir_id) + { + FormulaOrTerm::Term( + self.env + .get(&hir_id) + .expect("unbound variable in formula") + .clone(), + ) + } else { + unimplemented!("unsupported path in formula: {:?}", qpath); + } + } + ExprKind::Block(block, _) => { + if block.stmts.is_empty() { + self.to_formula_or_term(block.expr.expect("expected an expression in block")) + } else { + unimplemented!("unsupported block in formula: {:?}", block); + } + } + _ => unimplemented!("unsupported expression in formula: {:?}", hir), + } + } +} diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 965eb13..ed7c547 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -6,7 +6,7 @@ use rustc_hir::def_id::CRATE_DEF_ID; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::LocalDefId; -use crate::analyze; +use crate::analyze::{self, annot_fn::AnnotFnTranslator}; use crate::chc; use crate::rty::{self, ClauseBuilderExt as _}; @@ -94,6 +94,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { return; } + if analyzer.is_annotated_as_formula_fn() { + let formula_fn = AnnotFnTranslator::new(self.tcx, local_def_id).to_formula_fn(); + self.ctx + .register_formula_fn(local_def_id.to_def_id(), formula_fn); + self.skip_analysis.insert(local_def_id); + return; + } + let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() { analyzer.extern_spec_fn_target_def_id() } else { diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index a76823d..9a6a678 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -212,6 +212,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_formula_fn(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::formula_fn_path(), + ) + .next() + .is_some() + } + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self diff --git a/src/lib.rs b/src/lib.rs index 55fcaca..14499d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ #![feature(rustc_private)] extern crate rustc_ast; +extern crate rustc_borrowck; extern crate rustc_data_structures; extern crate rustc_hir; extern crate rustc_index; +extern crate rustc_interface; extern crate rustc_middle; extern crate rustc_mir_dataflow; extern crate rustc_span; @@ -23,4 +25,5 @@ mod rty; // utility mod pretty; +pub use analyze::mir_borrowck_skip_formula_fn; pub use analyze::Analyzer; diff --git a/src/main.rs b/src/main.rs index d36b88c..391424e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,10 @@ impl Callbacks for CompilerCalls { let attrs = &mut config.opts.unstable_opts.crate_attr; attrs.push("feature(register_tool)".to_owned()); attrs.push("register_tool(thrust)".to_owned()); + + config.override_queries = Some(|_sess, providers| { + providers.mir_borrowck = thrust::mir_borrowck_skip_formula_fn; + }); } fn after_crate_root_parsing<'tcx>( From ff72e5ac2bd58647a0f2c79796258ba28b7c4a0b Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 30 Mar 2026 23:48:02 +0900 Subject: [PATCH 2/6] Pick extern_spec target function from last expression --- src/analyze/local_def.rs | 82 ++++++++++++++++++---------------------- std.rs | 2 +- 2 files changed, 37 insertions(+), 47 deletions(-) diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 9a6a678..8ce6a37 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -379,59 +379,49 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } /// Extract the target DefId from `#[thrust::extern_spec_fn]` function. + /// + /// The target is identified as the tail call expression (last expression without + /// semicolon) in the function body block. pub fn extern_spec_fn_target_def_id(&self) -> DefId { - struct ExtractDefId<'tcx> { - tcx: TyCtxt<'tcx>, - outer_def_id: LocalDefId, - inner_def_id: Option, - } + let node = self.tcx.hir_node_by_def_id(self.local_def_id); + let rustc_hir::Node::Item(item) = node else { + panic!("extern_spec_fn must be a function item"); + }; + let rustc_hir::ItemKind::Fn(_, _, body_id) = item.kind else { + panic!("extern_spec_fn must be a function"); + }; - impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> { - type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies; + let body = self.tcx.hir().body(body_id); - fn nested_visit_map(&mut self) -> Self::Map { - self.tcx.hir() - } + // The body is a block; the tail expression is the function call to the target. + let rustc_hir::ExprKind::Block(block, _) = &body.value.kind else { + panic!("extern_spec_fn body must be a block"); + }; + let tail_expr = block + .expr + .expect("extern_spec_fn block must end with a tail call expression"); - fn visit_qpath( - &mut self, - qpath: &rustc_hir::QPath<'tcx>, - hir_id: rustc_hir::HirId, - _span: rustc_span::Span, - ) { - let typeck_result = self.tcx.typeck(self.outer_def_id); - if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id) - { - if matches!( - self.tcx.def_kind(def_id), - rustc_hir::def::DefKind::Fn | rustc_hir::def::DefKind::AssocFn - ) { - assert!(self.inner_def_id.is_none(), "invalid extern_spec_fn"); - - let args = typeck_result.node_args(hir_id); - let param_env = self.tcx.param_env(self.outer_def_id); - let instance = - mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap(); - if let Some(instance) = instance { - self.inner_def_id = Some(instance.def_id()); - } else { - self.inner_def_id = Some(def_id); - } - } - } - } - } + let rustc_hir::ExprKind::Call(func_expr, _) = &tail_expr.kind else { + panic!("extern_spec_fn tail expression must be a function call"); + }; + let rustc_hir::ExprKind::Path(qpath) = &func_expr.kind else { + panic!("extern_spec_fn call must be a path expression"); + }; - use rustc_hir::intravisit::Visitor as _; - let mut visitor = ExtractDefId { - tcx: self.tcx, - outer_def_id: self.local_def_id, - inner_def_id: None, + let typeck_result = self.tcx.typeck(self.local_def_id); + let hir_id = func_expr.hir_id; + let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id) else { + panic!("extern_spec_fn call must resolve to a definition"); }; - if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) { - visitor.visit_item(item); + + let args = typeck_result.node_args(hir_id); + let param_env = self.tcx.param_env(self.local_def_id); + let instance = mir_ty::Instance::resolve(self.tcx, param_env, def_id, args).unwrap(); + if let Some(instance) = instance { + instance.def_id() + } else { + def_id } - visitor.inner_def_id.expect("invalid extern_spec_fn") } fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { diff --git a/std.rs b/std.rs index 6d60601..4564277 100644 --- a/std.rs +++ b/std.rs @@ -115,7 +115,7 @@ fn _extern_spec_box_new(x: T) -> Box where T: thrust_models::Model { #[thrust::requires(true)] #[thrust::ensures(*x == ^y && *y == ^x)] fn _extern_spec_std_mem_swap(x: &mut T, y: &mut T) where T: thrust_models::Model { - std::mem::swap(x, y); + std::mem::swap(x, y) } #[thrust::extern_spec_fn] From 3aaa735b8071a15fd7e67dda8249f72ac3eb764a Mon Sep 17 00:00:00 2001 From: coord_e Date: Fri, 27 Mar 2026 00:34:28 +0900 Subject: [PATCH 3/6] Enable to annotate functions with formula_fn --- src/analyze.rs | 99 +++++++++++++++++++++++++++++-- src/analyze/annot.rs | 8 +++ src/analyze/annot_fn.rs | 25 ++------ src/analyze/local_def.rs | 26 +++++--- tests/ui/fail/annot_formula_fn.rs | 54 +++++++++++++++++ tests/ui/pass/annot_formula_fn.rs | 54 +++++++++++++++++ 6 files changed, 231 insertions(+), 35 deletions(-) create mode 100644 tests/ui/fail/annot_formula_fn.rs create mode 100644 tests/ui/pass/annot_formula_fn.rs diff --git a/src/analyze.rs b/src/analyze.rs index 438d4d6..d8cf918 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -15,6 +15,7 @@ use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::{DefId, LocalDefId}; +use rustc_span::Symbol; use crate::analyze; use crate::annot::{AnnotFormula, AnnotParser, Resolver}; @@ -535,20 +536,73 @@ impl<'tcx> Analyzer<'tcx> { self.fn_sig_with_body(def_id, body) } + fn extract_path_with_attr( + &self, + local_def_id: LocalDefId, + attr_path: &[Symbol], + ) -> Option { + let map = self.tcx.hir(); + let body_id = map.maybe_body_owned_by(local_def_id)?; + let body = map.body(body_id); + + let rustc_hir::ExprKind::Block(block, _) = body.value.kind else { + return None; + }; + for stmt in block.stmts { + if map + .attrs(stmt.hir_id) + .iter() + .all(|attr| !attr.path_matches(attr_path)) + { + continue; + } + let rustc_hir::StmtKind::Semi(expr) = stmt.kind else { + self.tcx.dcx().span_err( + stmt.span, + "annotated path is expected to be a semi statement", + ); + continue; + }; + let rustc_hir::ExprKind::Path(qpath) = expr.kind else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to be a path expression", + ); + continue; + }; + let rustc_hir::QPath::Resolved(_, path) = qpath else { + self.tcx.dcx().span_err( + expr.span, + "annotated path is expected to be a resolved path", + ); + continue; + }; + let rustc_hir::def::Res::Def(_, def_id) = path.res else { + self.tcx.dcx().span_err( + path.span, + "annotated path is expected to refer to a definition", + ); + continue; + }; + return Some(def_id); + } + None + } + fn extract_require_annot( &self, - def_id: DefId, + local_def_id: LocalDefId, resolver: T, self_type_name: Option, ) -> Option> where - T: Resolver, + T: Resolver, { let mut require_annot = None; let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx - .get_attrs_by_path(def_id, &analyze::annot::requires_path()) + .get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::requires_path()) { if require_annot.is_some() { unimplemented!(); @@ -557,23 +611,40 @@ impl<'tcx> Analyzer<'tcx> { let require = parser.parse_formula(ts).unwrap(); require_annot = Some(require); } + + if let Some(formula_def_id) = + self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path_path()) + { + if require_annot.is_some() { + unimplemented!(); + } + let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else { + panic!( + "require annotation {:?} is not a formula function", + formula_def_id + ); + }; + require_annot = Some(formula_fn.to_require_annot()); + } + require_annot } fn extract_ensure_annot( &self, - def_id: DefId, + local_def_id: LocalDefId, resolver: T, self_type_name: Option, ) -> Option> where - T: Resolver, + T: Resolver>, { let mut ensure_annot = None; + let parser = AnnotParser::new(&resolver, self_type_name); for attrs in self .tcx - .get_attrs_by_path(def_id, &analyze::annot::ensures_path()) + .get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::ensures_path()) { if ensure_annot.is_some() { unimplemented!(); @@ -582,6 +653,22 @@ impl<'tcx> Analyzer<'tcx> { let ensure = parser.parse_formula(ts).unwrap(); ensure_annot = Some(ensure); } + + if let Some(formula_def_id) = + self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path_path()) + { + if ensure_annot.is_some() { + unimplemented!(); + } + let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else { + panic!( + "ensure annotation {:?} is not a formula function", + formula_def_id + ); + }; + ensure_annot = Some(formula_fn.to_ensure_annot()); + } + ensure_annot } diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 8cb43d2..18d2f8f 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -53,6 +53,14 @@ pub fn formula_fn_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("formula_fn")] } +pub fn requires_path_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("requires_path")] +} + +pub fn ensures_path_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("ensures_path")] +} + pub fn model_ty_path() -> [Symbol; 3] { [ Symbol::intern("thrust"), diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 219e5cc..f3f5db8 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -38,33 +38,18 @@ where } impl<'tcx> FormulaFn<'tcx> { - pub fn to_require_annot( - &self, - sig: mir_ty::FnSig<'tcx>, - ) -> Option> { - if &self.params.raw != sig.inputs() { - return None; - } - Some(AnnotFormula::Formula(self.formula.clone())) + pub fn to_require_annot(&self) -> AnnotFormula { + AnnotFormula::Formula(self.formula.clone()) } - pub fn to_ensure_annot( - &self, - sig: mir_ty::FnSig<'tcx>, - ) -> Option>> { - if &self.params.raw[1..] != sig.inputs() { - return None; - } - if self.params.raw[0] != sig.output() { - return None; - } - Some(AnnotFormula::Formula(self.formula.clone().map_var(|v| { + pub fn to_ensure_annot(&self) -> AnnotFormula> { + AnnotFormula::Formula(self.formula.clone().map_var(|v| { if v.as_usize() == 0 { rty::RefinedTypeVar::Value } else { rty::RefinedTypeVar::Free(rty::FunctionParamIdx::from(v.as_usize() - 1)) } - }))) + })) } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 8ce6a37..ca1f81d 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -224,22 +224,30 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { - let has_require = self + let has_requires = self .tcx .get_attrs_by_path( self.local_def_id.to_def_id(), &analyze::annot::requires_path(), ) .next() - .is_some(); - let has_ensure = self + .is_some() + || self + .ctx + .extract_path_with_attr(self.local_def_id, &analyze::annot::requires_path_path()) + .is_some(); + let has_ensures = self .tcx .get_attrs_by_path( self.local_def_id.to_def_id(), &analyze::annot::ensures_path(), ) .next() - .is_some(); + .is_some() + || self + .ctx + .extract_path_with_attr(self.local_def_id, &analyze::annot::ensures_path_path()) + .is_some(); let annotated_params: Vec<_> = self .tcx .get_attrs_by_path(self.local_def_id.to_def_id(), &analyze::annot::param_path()) @@ -260,7 +268,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .iter() .all(|ident| annotated_params.contains(ident)); self.is_annotated_as_callable() - || (has_require && has_ensure) + || (has_requires && has_ensures) || (all_params_annotated && has_ret) } @@ -299,13 +307,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let self_type_name = self.impl_type().map(|ty| ty.to_string()); let mut require_annot = self.ctx.extract_require_annot( - self.local_def_id.to_def_id(), + self.local_def_id, ¶m_resolver, self_type_name.clone(), ); let mut ensure_annot = self.ctx.extract_ensure_annot( - self.local_def_id.to_def_id(), + self.local_def_id, &result_param_resolver, self_type_name.clone(), ); @@ -313,12 +321,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { if let Some(trait_item_id) = self.trait_item_id() { tracing::info!("trait item fonud: {:?}", trait_item_id); let trait_require_annot = self.ctx.extract_require_annot( - trait_item_id.into(), + trait_item_id, ¶m_resolver, self_type_name.clone(), ); let trait_ensure_annot = self.ctx.extract_ensure_annot( - trait_item_id.into(), + trait_item_id, &result_param_resolver, self_type_name.clone(), ); diff --git a/tests/ui/fail/annot_formula_fn.rs b/tests/ui/fail/annot_formula_fn.rs new file mode 100644 index 0000000..e2ae600 --- /dev/null +++ b/tests/ui/fail/annot_formula_fn.rs @@ -0,0 +1,54 @@ +//@error-in-other-file: Unsat + +#[thrust::formula_fn] +fn _thrust_requires_rand_except(x: i64) -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_rand_except(result: i64, x: i64) -> bool { + result != x +} + +fn rand_except(x: i64) -> i64 { + #[thrust::requires_path] + _thrust_requires_rand_except; + #[thrust::ensures_path] + _thrust_ensures_rand_except; + + if x == 0 { + 1 + } else { + 0 + } +} + +#[thrust::formula_fn] +fn _thrust_requires_f(x: i64) -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_f(result: i64, x: i64) -> bool { + (result == 1) || (result == -1) && result == 0 +} + +fn f(x: i64) -> i64 { + #[thrust::requires_path] + _thrust_requires_f; + #[thrust::ensures_path] + _thrust_ensures_f; + + let y = rand_except(x); + if y > x { + 1 + } else if y < x { + -1 + } else { + 0 + } +} + +fn main() { + assert!(rand_except(1) == 0); +} diff --git a/tests/ui/pass/annot_formula_fn.rs b/tests/ui/pass/annot_formula_fn.rs new file mode 100644 index 0000000..035015b --- /dev/null +++ b/tests/ui/pass/annot_formula_fn.rs @@ -0,0 +1,54 @@ +//@check-pass + +#[thrust::formula_fn] +fn _thrust_requires_rand_except(x: i64) -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_rand_except(result: i64, x: i64) -> bool { + result != x +} + +fn rand_except(x: i64) -> i64 { + #[thrust::requires_path] + _thrust_requires_rand_except; + #[thrust::ensures_path] + _thrust_ensures_rand_except; + + if x == 0 { + 1 + } else { + 0 + } +} + +#[thrust::formula_fn] +fn _thrust_requires_f(x: i64) -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_f(result: i64, x: i64) -> bool { + (result == 1) || (result == -1) +} + +fn f(x: i64) -> i64 { + #[thrust::requires_path] + _thrust_requires_f; + #[thrust::ensures_path] + _thrust_ensures_f; + + let y = rand_except(x); + if y > x { + 1 + } else if y < x { + -1 + } else { + 0 + } +} + +fn main() { + assert!(rand_except(1) != 1); +} From f5b66dd41671931a554da94373b290ea29b14dda Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 30 Mar 2026 23:49:49 +0900 Subject: [PATCH 4/6] Support more annotation expressions --- src/analyze/annot.rs | 16 +++ src/analyze/annot_fn.rs | 72 ++++++++++- src/analyze/crate_.rs | 3 +- src/analyze/did_cache.rs | 25 ++++ src/analyze/local_def.rs | 2 +- std.rs | 132 +++++++++++++++++++++ tests/ui/fail/annot_mut_term_formula_fn.rs | 30 +++++ tests/ui/pass/annot_mut_term_formula_fn.rs | 30 +++++ 8 files changed, 305 insertions(+), 5 deletions(-) create mode 100644 tests/ui/fail/annot_mut_term_formula_fn.rs create mode 100644 tests/ui/pass/annot_mut_term_formula_fn.rs diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 18d2f8f..9696c39 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -109,6 +109,22 @@ pub fn closure_model_path() -> [Symbol; 3] { ] } +pub fn mut_model_new_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("mut_new"), + ] +} + +pub fn box_model_new_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("box_new"), + ] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index f3f5db8..7d1a4b6 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -5,6 +5,7 @@ use rustc_hir::{def_id::LocalDefId, HirId}; use rustc_index::IndexVec; use rustc_middle::ty::{self as mir_ty, TyCtxt}; +use crate::analyze::did_cache::DefIdCache; use crate::annot::AnnotFormula; use crate::chc; use crate::rty; @@ -124,11 +125,12 @@ pub struct AnnotFnTranslator<'tcx> { typeck: &'tcx mir_ty::TypeckResults<'tcx>, body: &'tcx rustc_hir::Body<'tcx>, + def_ids: DefIdCache<'tcx>, env: HashMap>, } impl<'tcx> AnnotFnTranslator<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>, local_def_id: LocalDefId) -> Self { + pub fn new(tcx: TyCtxt<'tcx>, def_ids: DefIdCache<'tcx>, local_def_id: LocalDefId) -> Self { let map = tcx.hir(); let body_id = map.body_owned_by(local_def_id); let body = map.body(body_id); @@ -139,6 +141,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { local_def_id, typeck, body, + def_ids, env: HashMap::default(), }; translator.build_env_from_params(); @@ -166,7 +169,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { } PatKind::TupleStruct(_, subpats, _) | PatKind::Tuple(subpats, _) => { for (idx, subpat) in subpats.iter().enumerate() { - let field_term = param.clone().tuple_proj(idx.into()); + let field_term = param.clone().tuple_proj(idx); self.build_env_from_pat(field_term, subpat); } } @@ -220,6 +223,21 @@ impl<'tcx> AnnotFnTranslator<'tcx> { let rhs = self.to_formula(rhs); return FormulaOrTerm::Formula(lhs.and(rhs)); } + rustc_hir::BinOpKind::Add => { + let lhs = self.to_term(lhs); + let rhs = self.to_term(rhs); + return FormulaOrTerm::Term(lhs.add(rhs)); + } + rustc_hir::BinOpKind::Sub => { + let lhs = self.to_term(lhs); + let rhs = self.to_term(rhs); + return FormulaOrTerm::Term(lhs.sub(rhs)); + } + rustc_hir::BinOpKind::Mul => { + let lhs = self.to_term(lhs); + let rhs = self.to_term(rhs); + return FormulaOrTerm::Term(lhs.mul(rhs)); + } _ => {} } @@ -244,7 +262,23 @@ impl<'tcx> AnnotFnTranslator<'tcx> { rustc_hir::UnOp::Not => { FormulaOrTerm::Not(Box::new(self.to_formula_or_term(operand))) } - _ => unimplemented!("unsupported unary operator in formula: {:?}", op), + rustc_hir::UnOp::Deref => { + let operand_ty = self.typeck.expr_ty(operand); + let adt = operand_ty + .ty_adt_def() + .expect("deref operand must be a model type"); + let term = self.to_term(operand); + if Some(adt.did()) == self.def_ids.mut_model() { + FormulaOrTerm::Term(term.mut_current()) + } else if Some(adt.did()) == self.def_ids.box_model() { + FormulaOrTerm::Term(term.box_current()) + } else { + unimplemented!( + "unsupported deref operand type in formula: {:?}", + operand_ty + ) + } + } }, ExprKind::Lit(lit) => match lit.node { rustc_ast::LitKind::Int(i, _) => { @@ -269,6 +303,38 @@ impl<'tcx> AnnotFnTranslator<'tcx> { unimplemented!("unsupported path in formula: {:?}", qpath); } } + ExprKind::Tup(exprs) => { + let terms = exprs.iter().map(|e| self.to_term(e)).collect(); + FormulaOrTerm::Term(chc::Term::tuple(terms)) + } + ExprKind::Field(expr, field) => { + let index = field + .name + .as_str() + .parse::() + .expect("tuple field index must be a non-negative integer"); + let term = self.to_term(expr); + FormulaOrTerm::Term(term.tuple_proj(index)) + } + ExprKind::Call(func_expr, args) => { + if let ExprKind::Path(qpath) = &func_expr.kind { + let res = self.typeck.qpath_res(qpath, func_expr.hir_id); + if let rustc_hir::def::Res::Def(_, def_id) = res { + if Some(def_id) == self.def_ids.mut_model_new() { + assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments"); + let t1 = self.to_term(&args[0]); + let t2 = self.to_term(&args[1]); + return FormulaOrTerm::Term(chc::Term::mut_(t1, t2)); + } + if Some(def_id) == self.def_ids.box_model_new() { + assert_eq!(args.len(), 1, "Box::new takes exactly 1 argument"); + let t = self.to_term(&args[0]); + return FormulaOrTerm::Term(chc::Term::box_(t)); + } + } + } + unimplemented!("unsupported call in formula: {:?}", func_expr) + } ExprKind::Block(block, _) => { if block.stmts.is_empty() { self.to_formula_or_term(block.expr.expect("expected an expression in block")) diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index ed7c547..1a81b64 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -95,7 +95,8 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } if analyzer.is_annotated_as_formula_fn() { - let formula_fn = AnnotFnTranslator::new(self.tcx, local_def_id).to_formula_fn(); + let formula_fn = + AnnotFnTranslator::new(self.tcx, self.ctx.def_ids(), local_def_id).to_formula_fn(); self.ctx .register_formula_fn(local_def_id.to_def_id(), formula_fn); self.skip_analysis.insert(local_def_id); diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index 9646be8..3c2ff10 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -19,6 +19,9 @@ struct DefIds { box_model: OnceCell>, array_model: OnceCell>, closure_model: OnceCell>, + + mut_model_new: OnceCell>, + box_model_new: OnceCell>, } /// Retrieves and caches well-known [`DefId`]s. @@ -90,6 +93,14 @@ impl<'tcx> DefIdCache<'tcx> { } } } + if let rustc_hir::ItemKind::Impl(impl_) = item.kind { + for impl_item_ref in impl_.items { + let def_id = impl_item_ref.id.owner_id.to_def_id(); + if self.tcx.get_attrs_by_path(def_id, path).next().is_some() { + return Some(def_id); + } + } + } } None } @@ -135,4 +146,18 @@ impl<'tcx> DefIdCache<'tcx> { .closure_model .get_or_init(|| self.annotated_def(&crate::analyze::annot::closure_model_path())) } + + pub fn mut_model_new(&self) -> Option { + *self + .def_ids + .mut_model_new + .get_or_init(|| self.annotated_def(&crate::analyze::annot::mut_model_new_path())) + } + + pub fn box_model_new(&self) -> Option { + *self + .def_ids + .box_model_new + .get_or_init(|| self.annotated_def(&crate::analyze::annot::box_model_new_path())) + } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index ca1f81d..4fbf0d4 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -319,7 +319,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { ); if let Some(trait_item_id) = self.trait_item_id() { - tracing::info!("trait item fonud: {:?}", trait_item_id); + tracing::info!("trait item found: {:?}", trait_item_id); let trait_require_annot = self.ctx.extract_require_annot( trait_item_id, ¶m_resolver, diff --git a/std.rs b/std.rs index 4564277..7afcaf6 100644 --- a/std.rs +++ b/std.rs @@ -12,21 +12,141 @@ mod thrust_models { #[thrust::def::int_model] pub struct Int; + impl PartialEq for Int where T: super::Model { + #[thrust::ignored] + fn eq(&self, _other: &T) -> bool { + unimplemented!() + } + } + + impl PartialOrd for Int where T: super::Model { + #[thrust::ignored] + fn partial_cmp(&self, _other: &T) -> Option { + unimplemented!() + } + } + + impl std::ops::Add for Int where T: super::Model { + type Output = Self; + + #[thrust::ignored] + fn add(self, _rhs: T) -> Self::Output { + unimplemented!() + } + } + + impl std::ops::Sub for Int where T: super::Model { + type Output = Self; + + #[thrust::ignored] + fn sub(self, _rhs: T) -> Self::Output { + unimplemented!() + } + } + + impl std::ops::Mul for Int where T: super::Model { + type Output = Self; + + #[thrust::ignored] + fn mul(self, _rhs: T) -> Self::Output { + unimplemented!() + } + } + + impl std::ops::Neg for Int { + type Output = Self; + + #[thrust::ignored] + fn neg(self) -> Self::Output { + unimplemented!() + } + } + #[thrust::def::mut_model] pub struct Mut(PhantomData); + impl Mut { + #[allow(dead_code)] + #[thrust::def::mut_new] + #[thrust::ignored] + pub fn new(_a: T, _b: T) -> Self { + unimplemented!() + } + } + + impl PartialEq for Mut where U: super::Model { + #[thrust::ignored] + fn eq(&self, _other: &U) -> bool { + unimplemented!() + } + } + + impl std::ops::Deref for Mut { + type Target = T; + + #[thrust::ignored] + fn deref(&self) -> &Self::Target { + unimplemented!() + } + } + + impl std::ops::Not for Mut { + type Output = T; + + #[thrust::ignored] + fn not(self) -> Self::Output { + unimplemented!() + } + } + #[thrust::def::box_model] pub struct Box(PhantomData); + impl Box { + #[allow(dead_code)] + #[thrust::def::box_new] + #[thrust::ignored] + pub fn new(_x: T) -> Self { + unimplemented!() + } + } + + impl PartialEq for Box where U: super::Model { + #[thrust::ignored] + fn eq(&self, _other: &U) -> bool { + unimplemented!() + } + } + + impl std::ops::Deref for Box { + type Target = T; + + #[thrust::ignored] + fn deref(&self) -> &Self::Target { + unimplemented!() + } + } + #[thrust::def::array_model] pub struct Array(PhantomData, PhantomData); + impl PartialEq for Array where U: super::Model { + #[thrust::ignored] + fn eq(&self, _other: &U) -> bool { + unimplemented!() + } + } + #[thrust::def::closure_model] pub struct Closure(PhantomData); pub struct Vec(pub Array, pub usize); } + impl Model for model::Int { + type Ty = model::Int; + } + impl Model for isize { type Ty = model::Int; } @@ -83,6 +203,10 @@ mod thrust_models { type Ty = model::Mut<::Ty>; } + impl Model for model::Mut { + type Ty = model::Mut; + } + impl<'a, T> Model for &'a T where T: Model { type Ty = &'a ::Ty; } @@ -91,10 +215,18 @@ mod thrust_models { type Ty = model::Box<::Ty>; } + impl Model for model::Box { + type Ty = model::Box; + } + impl Model for Vec where T: Model { type Ty = model::Vec<::Ty>; } + impl Model for model::Vec { + type Ty = model::Vec; + } + impl Model for Option where T: Model { type Ty = Option<::Ty>; } diff --git a/tests/ui/fail/annot_mut_term_formula_fn.rs b/tests/ui/fail/annot_mut_term_formula_fn.rs new file mode 100644 index 0000000..7b7a582 --- /dev/null +++ b/tests/ui/fail/annot_mut_term_formula_fn.rs @@ -0,0 +1,30 @@ +//@error-in-other-file: Unsat + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_f(x: thrust_models::model::Mut, y: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_f(result: (), x: thrust_models::model::Mut, y: i64) -> bool { + x == thrust_models::model::Mut::new(*x, y) +} + +#[allow(path_statements)] +fn f(x: &mut i64, y: i64) { + #[thrust::requires_path] + _thrust_requires_f; + + #[thrust::ensures_path] + _thrust_ensures_f; + + *x = y; +} + +fn main() { + let mut a = 1; + f(&mut a, 2); + assert!(a == 1); +} diff --git a/tests/ui/pass/annot_mut_term_formula_fn.rs b/tests/ui/pass/annot_mut_term_formula_fn.rs new file mode 100644 index 0000000..37d9e9a --- /dev/null +++ b/tests/ui/pass/annot_mut_term_formula_fn.rs @@ -0,0 +1,30 @@ +//@check-pass + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_f(x: thrust_models::model::Mut, y: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_f(result: (), x: thrust_models::model::Mut, y: i64) -> bool { + x == thrust_models::model::Mut::new(*x, y) +} + +#[allow(path_statements)] +fn f(x: &mut i64, y: i64) { + #[thrust::requires_path] + _thrust_requires_f; + + #[thrust::ensures_path] + _thrust_ensures_f; + + *x = y; +} + +fn main() { + let mut a = 1; + f(&mut a, 2); + assert!(a == 2); +} From 06737293edccc4c55753df6eb3805c0c1e5b3d03 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 31 Mar 2026 01:57:28 +0900 Subject: [PATCH 5/6] Refer mut final via ! operator --- src/analyze/annot_fn.rs | 12 ++++++++++- tests/ui/fail/annot_formula_fn_mut.rs | 31 +++++++++++++++++++++++++++ tests/ui/pass/annot_formula_fn_mut.rs | 31 +++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 tests/ui/fail/annot_formula_fn_mut.rs create mode 100644 tests/ui/pass/annot_formula_fn_mut.rs diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 7d1a4b6..26383d3 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -260,7 +260,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> { FormulaOrTerm::Term(operand.neg()) } rustc_hir::UnOp::Not => { - FormulaOrTerm::Not(Box::new(self.to_formula_or_term(operand))) + let operand_ty = self.typeck.expr_ty(operand); + match operand_ty.ty_adt_def() { + Some(adt) if Some(adt.did()) == self.def_ids.mut_model() => { + let operand = self.to_term(operand); + FormulaOrTerm::Term(operand.mut_final()) + } + _ => { + let operand = self.to_formula_or_term(operand); + FormulaOrTerm::Not(Box::new(operand)) + } + } } rustc_hir::UnOp::Deref => { let operand_ty = self.typeck.expr_ty(operand); diff --git a/tests/ui/fail/annot_formula_fn_mut.rs b/tests/ui/fail/annot_formula_fn_mut.rs new file mode 100644 index 0000000..c63cb9a --- /dev/null +++ b/tests/ui/fail/annot_formula_fn_mut.rs @@ -0,0 +1,31 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_incr(m: thrust_models::model::Mut, x: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_incr(result: (), m: thrust_models::model::Mut, x: i64) -> bool { + !m == *m + 1 +} + +#[allow(path_statements)] +fn incr(m: &mut i64, x: i64) { + #[thrust::requires_path] + _thrust_requires_incr; + #[thrust::ensures_path] + _thrust_ensures_incr; + + *m += x; +} + +fn main() { + let mut x = 0; + incr(&mut x, 1); + incr(&mut x, 1); + assert!(x == 2); +} diff --git a/tests/ui/pass/annot_formula_fn_mut.rs b/tests/ui/pass/annot_formula_fn_mut.rs new file mode 100644 index 0000000..1fb3a1f --- /dev/null +++ b/tests/ui/pass/annot_formula_fn_mut.rs @@ -0,0 +1,31 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_incr(m: thrust_models::model::Mut, x: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_incr(result: (), m: thrust_models::model::Mut, x: i64) -> bool { + !m == *m + x +} + +#[allow(path_statements)] +fn incr(m: &mut i64, x: i64) { + #[thrust::requires_path] + _thrust_requires_incr; + #[thrust::ensures_path] + _thrust_ensures_incr; + + *m += x; +} + +fn main() { + let mut x = 0; + incr(&mut x, 1); + incr(&mut x, 1); + assert!(x == 2); +} From abe5984f8fd56ea978ad90f5ae423c3928ad9cb3 Mon Sep 17 00:00:00 2001 From: coord_e Date: Tue, 31 Mar 2026 21:20:39 +0900 Subject: [PATCH 6/6] Translate formula fns after type params are instantiated --- src/analyze.rs | 62 +++++++++++++++++-- src/analyze/annot_fn.rs | 30 +++++++-- src/analyze/crate_.rs | 7 +-- src/analyze/local_def.rs | 9 +++ .../ui/fail/annot_mut_term_formula_fn_poly.rs | 31 ++++++++++ .../ui/pass/annot_mut_term_formula_fn_poly.rs | 31 ++++++++++ 6 files changed, 154 insertions(+), 16 deletions(-) create mode 100644 tests/ui/fail/annot_mut_term_formula_fn_poly.rs create mode 100644 tests/ui/pass/annot_mut_term_formula_fn_poly.rs diff --git a/src/analyze.rs b/src/analyze.rs index d8cf918..85ba3b2 100644 --- a/src/analyze.rs +++ b/src/analyze.rs @@ -190,6 +190,11 @@ impl refine::EnumDefProvider for Rc> { pub type Env = refine::Env>>; +#[derive(Debug, Clone)] +struct DeferredFormulaFnDef<'tcx> { + cache: Rc, annot_fn::FormulaFn<'tcx>>>>, +} + #[derive(Clone)] pub struct Analyzer<'tcx> { tcx: TyCtxt<'tcx>, @@ -202,7 +207,7 @@ pub struct Analyzer<'tcx> { defs: HashMap>, /// Collection of functions with `#[thrust::formula_fn]` attribute. - formula_fns: HashMap>, + formula_fns: HashMap>, /// Resulting CHC system. system: Rc>, @@ -391,6 +396,30 @@ impl<'tcx> Analyzer<'tcx> { }) } + pub fn formula_fn_with_args( + &self, + local_def_id: LocalDefId, + generic_args: mir_ty::GenericArgsRef<'tcx>, + ) -> Option> { + let deferred_formula_fn = self.formula_fns.get(&local_def_id)?; + + let deferred_formula_fn_cache = Rc::clone(&deferred_formula_fn.cache); + if let Some(formula_fn) = deferred_formula_fn_cache.borrow().get(&generic_args) { + return Some(formula_fn.clone()); + } + + let translator = annot_fn::AnnotFnTranslator::new(self.tcx, local_def_id) + .with_generic_args(generic_args) + .with_def_id_cache(self.def_ids()); + let formula_fn = translator.to_formula_fn(); + deferred_formula_fn_cache + .borrow_mut() + .insert(generic_args, formula_fn.clone()); + + tracing::info!(?local_def_id, formula_fn = %formula_fn.display(), ?generic_args, "formula_fn_with_args"); + Some(formula_fn) + } + pub fn def_ty_with_args( &mut self, def_id: DefId, @@ -443,9 +472,14 @@ impl<'tcx> Analyzer<'tcx> { Some(expected) } - pub fn register_formula_fn(&mut self, def_id: DefId, formula_fn: annot_fn::FormulaFn<'tcx>) { - tracing::info!(def_id = ?def_id, formula_fn = %formula_fn.display(), "register_formula_fn"); - self.formula_fns.insert(def_id, formula_fn); + pub fn register_formula_fn(&mut self, local_def_id: LocalDefId) { + tracing::info!(?local_def_id, "register_formula_fn"); + self.formula_fns.insert( + local_def_id, + DeferredFormulaFnDef { + cache: Rc::new(RefCell::new(HashMap::new())), + }, + ); } pub fn register_basic_block_ty( @@ -589,11 +623,13 @@ impl<'tcx> Analyzer<'tcx> { None } + // TODO: reduce number of args fn extract_require_annot( &self, local_def_id: LocalDefId, resolver: T, self_type_name: Option, + generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Option> where T: Resolver, @@ -615,10 +651,16 @@ impl<'tcx> Analyzer<'tcx> { if let Some(formula_def_id) = self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path_path()) { + let Some(formula_def_id) = formula_def_id.as_local() else { + panic!( + "require annotation with path is expected to refer to a local def, but found: {:?}", + formula_def_id + ); + }; if require_annot.is_some() { unimplemented!(); } - let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else { + let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { panic!( "require annotation {:?} is not a formula function", formula_def_id @@ -630,11 +672,13 @@ impl<'tcx> Analyzer<'tcx> { require_annot } + // TODO: reduce number of args fn extract_ensure_annot( &self, local_def_id: LocalDefId, resolver: T, self_type_name: Option, + generic_args: mir_ty::GenericArgsRef<'tcx>, ) -> Option> where T: Resolver>, @@ -657,10 +701,16 @@ impl<'tcx> Analyzer<'tcx> { if let Some(formula_def_id) = self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path_path()) { + let Some(formula_def_id) = formula_def_id.as_local() else { + panic!( + "ensure annotation with path is expected to refer to a local def, but found: {:?}", + formula_def_id + ); + }; if ensure_annot.is_some() { unimplemented!(); } - let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else { + let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else { panic!( "ensure annotation {:?} is not a formula function", formula_def_id diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 26383d3..3b870f0 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -124,23 +124,26 @@ pub struct AnnotFnTranslator<'tcx> { typeck: &'tcx mir_ty::TypeckResults<'tcx>, body: &'tcx rustc_hir::Body<'tcx>, + generic_args: mir_ty::GenericArgsRef<'tcx>, def_ids: DefIdCache<'tcx>, env: HashMap>, } impl<'tcx> AnnotFnTranslator<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>, def_ids: DefIdCache<'tcx>, local_def_id: LocalDefId) -> Self { + pub fn new(tcx: TyCtxt<'tcx>, local_def_id: LocalDefId) -> Self { let map = tcx.hir(); let body_id = map.body_owned_by(local_def_id); let body = map.body(body_id); - + let generic_args = tcx.mk_args(&[]); let typeck = tcx.typeck(local_def_id); + let def_ids = DefIdCache::new(tcx); let mut translator = Self { tcx, local_def_id, typeck, body, + generic_args, def_ids, env: HashMap::default(), }; @@ -148,6 +151,16 @@ impl<'tcx> AnnotFnTranslator<'tcx> { translator } + pub fn with_generic_args(mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> Self { + self.generic_args = generic_args; + self + } + + pub fn with_def_id_cache(mut self, def_ids: DefIdCache<'tcx>) -> Self { + self.def_ids = def_ids; + self + } + fn build_env_from_params(&mut self) { for (idx, param) in self.body.params.iter().enumerate() { let param_idx = rty::FunctionParamIdx::from(idx); @@ -177,12 +190,19 @@ impl<'tcx> AnnotFnTranslator<'tcx> { } } + fn expr_ty(&self, expr: &'tcx rustc_hir::Expr<'tcx>) -> mir_ty::Ty<'tcx> { + let ty = self.typeck.expr_ty(expr); + let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args); + let param_env = mir_ty::ParamEnv::reveal_all(); + self.tcx.normalize_erasing_regions(param_env, instantiated) + } + pub fn to_formula_fn(&self) -> FormulaFn<'tcx> { let formula = self.to_formula(self.body.value); let params = self .tcx .fn_sig(self.local_def_id.to_def_id()) - .instantiate_identity() + .instantiate(self.tcx, self.generic_args) .skip_binder() .inputs() .to_vec(); @@ -260,7 +280,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { FormulaOrTerm::Term(operand.neg()) } rustc_hir::UnOp::Not => { - let operand_ty = self.typeck.expr_ty(operand); + let operand_ty = self.expr_ty(operand); match operand_ty.ty_adt_def() { Some(adt) if Some(adt.did()) == self.def_ids.mut_model() => { let operand = self.to_term(operand); @@ -273,7 +293,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { } } rustc_hir::UnOp::Deref => { - let operand_ty = self.typeck.expr_ty(operand); + let operand_ty = self.expr_ty(operand); let adt = operand_ty .ty_adt_def() .expect("deref operand must be a model type"); diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index 1a81b64..51e21db 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -6,7 +6,7 @@ use rustc_hir::def_id::CRATE_DEF_ID; use rustc_middle::ty::{self as mir_ty, TyCtxt}; use rustc_span::def_id::LocalDefId; -use crate::analyze::{self, annot_fn::AnnotFnTranslator}; +use crate::analyze; use crate::chc; use crate::rty::{self, ClauseBuilderExt as _}; @@ -95,10 +95,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { } if analyzer.is_annotated_as_formula_fn() { - let formula_fn = - AnnotFnTranslator::new(self.tcx, self.ctx.def_ids(), local_def_id).to_formula_fn(); - self.ctx - .register_formula_fn(local_def_id.to_def_id(), formula_fn); + self.ctx.register_formula_fn(local_def_id); self.skip_analysis.insert(local_def_id); return; } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 4fbf0d4..e245dbe 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -48,6 +48,8 @@ pub struct Analyzer<'tcx, 'ctx> { local_def_id: LocalDefId, body: Body<'tcx>, + /// to substitute HIR types during translation in [`crate::analyze::annot_fn`] + generic_args: mir_ty::GenericArgsRef<'tcx>, drop_points: HashMap, type_builder: TypeBuilder<'tcx>, } @@ -310,12 +312,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.local_def_id, ¶m_resolver, self_type_name.clone(), + self.generic_args, ); let mut ensure_annot = self.ctx.extract_ensure_annot( self.local_def_id, &result_param_resolver, self_type_name.clone(), + self.generic_args, ); if let Some(trait_item_id) = self.trait_item_id() { @@ -324,11 +328,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { trait_item_id, ¶m_resolver, self_type_name.clone(), + self.generic_args, ); let trait_ensure_annot = self.ctx.extract_ensure_annot( trait_item_id, &result_param_resolver, self_type_name.clone(), + self.generic_args, ); assert!(require_annot.is_none() || trait_require_annot.is_none()); @@ -851,17 +857,20 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { let body = tcx.optimized_mir(local_def_id.to_def_id()).clone(); let drop_points = Default::default(); let type_builder = TypeBuilder::new(tcx, ctx.def_ids(), local_def_id.to_def_id()); + let generic_args = tcx.mk_args(&[]); Self { ctx, tcx, local_def_id, body, + generic_args, drop_points, type_builder, } } pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self { + self.generic_args = generic_args; self.body = mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args); self diff --git a/tests/ui/fail/annot_mut_term_formula_fn_poly.rs b/tests/ui/fail/annot_mut_term_formula_fn_poly.rs new file mode 100644 index 0000000..e87f5e3 --- /dev/null +++ b/tests/ui/fail/annot_mut_term_formula_fn_poly.rs @@ -0,0 +1,31 @@ +//@error-in-other-file: Unsat + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_swap(x: thrust_models::model::Mut, y: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut, y: thrust_models::model::Mut) -> bool { + *x == *y && *y == *x +} + +#[allow(path_statements)] +fn swap(x: &mut T, y: &mut T) { + #[thrust::requires_path] + _thrust_requires_swap; + + #[thrust::ensures_path] + _thrust_ensures_swap; + + std::mem::swap(x, y) +} + +fn main() { + let mut a = 1; + let mut b = 2; + swap(&mut a, &mut b); + assert!(a == 2 && b == 1); +} diff --git a/tests/ui/pass/annot_mut_term_formula_fn_poly.rs b/tests/ui/pass/annot_mut_term_formula_fn_poly.rs new file mode 100644 index 0000000..ffdda7e --- /dev/null +++ b/tests/ui/pass/annot_mut_term_formula_fn_poly.rs @@ -0,0 +1,31 @@ +//@check-pass + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_requires_swap(x: thrust_models::model::Mut, y: i64) -> bool { + true +} + +#[allow(unused_variables)] +#[thrust::formula_fn] +fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut, y: thrust_models::model::Mut) -> bool { + !x == *y && !y == *x +} + +#[allow(path_statements)] +fn swap(x: &mut T, y: &mut T) { + #[thrust::requires_path] + _thrust_requires_swap; + + #[thrust::ensures_path] + _thrust_ensures_swap; + + std::mem::swap(x, y) +} + +fn main() { + let mut a = 1; + let mut b = 2; + swap(&mut a, &mut b); + assert!(a == 2 && b == 1); +}