diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 9696c39..3163803 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -125,6 +125,14 @@ pub fn box_model_new_path() -> [Symbol; 3] { ] } +pub fn exists_path() -> [Symbol; 3] { + [ + Symbol::intern("thrust"), + Symbol::intern("def"), + Symbol::intern("exists"), + ] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/annot_fn.rs b/src/analyze/annot_fn.rs index 3b870f0..4e7c396 100644 --- a/src/analyze/annot_fn.rs +++ b/src/analyze/annot_fn.rs @@ -8,6 +8,7 @@ use rustc_middle::ty::{self as mir_ty, TyCtxt}; use crate::analyze::did_cache::DefIdCache; use crate::annot::AnnotFormula; use crate::chc; +use crate::refine::TypeBuilder; use crate::rty; #[derive(Debug, Clone)] @@ -118,6 +119,7 @@ impl FormulaOrTerm { } } +#[derive(Clone)] pub struct AnnotFnTranslator<'tcx> { tcx: TyCtxt<'tcx>, local_def_id: LocalDefId, @@ -127,6 +129,7 @@ pub struct AnnotFnTranslator<'tcx> { generic_args: mir_ty::GenericArgsRef<'tcx>, def_ids: DefIdCache<'tcx>, + type_builder: TypeBuilder<'tcx>, env: HashMap>, } @@ -138,6 +141,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { let generic_args = tcx.mk_args(&[]); let typeck = tcx.typeck(local_def_id); let def_ids = DefIdCache::new(tcx); + let type_builder = TypeBuilder::new(tcx, def_ids.clone(), local_def_id.to_def_id()); let mut translator = Self { tcx, local_def_id, @@ -145,6 +149,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> { body, generic_args, def_ids, + type_builder, env: HashMap::default(), }; translator.build_env_from_params(); @@ -158,6 +163,11 @@ impl<'tcx> AnnotFnTranslator<'tcx> { pub fn with_def_id_cache(mut self, def_ids: DefIdCache<'tcx>) -> Self { self.def_ids = def_ids; + self.type_builder = TypeBuilder::new( + self.tcx, + self.def_ids.clone(), + self.local_def_id.to_def_id(), + ); self } @@ -197,6 +207,13 @@ impl<'tcx> AnnotFnTranslator<'tcx> { self.tcx.normalize_erasing_regions(param_env, instantiated) } + fn pat_ty(&self, pat: &'tcx rustc_hir::Pat<'tcx>) -> mir_ty::Ty<'tcx> { + let ty = self.typeck.pat_ty(pat); + let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args); + let param_env = mir_ty::ParamEnv::reveal_all(); + self.tcx.normalize_erasing_regions(param_env, instantiated) + } + pub fn to_formula_fn(&self) -> FormulaFn<'tcx> { let formula = self.to_formula(self.body.value); let params = self @@ -224,6 +241,28 @@ impl<'tcx> AnnotFnTranslator<'tcx> { .expect("expected a term") } + fn variant_ctor_term( + &self, + ctor_did: rustc_span::def_id::DefId, + result_ty: mir_ty::Ty<'tcx>, + field_terms: Vec>, + ) -> chc::Term { + let variant_did = self.tcx.parent(ctor_did); + let adt_did = self.tcx.parent(variant_did); + let d_sym = crate::refine::datatype_symbol(self.tcx, adt_did); + let variant_name = self.tcx.item_name(variant_did); + let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, variant_name)); + let sort_args = if let mir_ty::TyKind::Adt(_, generic_args) = result_ty.kind() { + generic_args + .types() + .map(|ty| self.type_builder.build(ty).to_sort()) + .collect() + } else { + panic!("expected an ADT type for variant constructor") + }; + chc::Term::datatype_ctor(d_sym, sort_args, v_sym, field_terms) + } + fn to_formula_or_term( &self, hir: &'tcx rustc_hir::Expr<'tcx>, @@ -319,20 +358,21 @@ impl<'tcx> AnnotFnTranslator<'tcx> { rustc_ast::LitKind::Bool(b) => FormulaOrTerm::Literal(b), _ => unimplemented!("unsupported literal in formula: {:?}", lit), }, - ExprKind::Path(qpath) => { - if let rustc_hir::def::Res::Local(hir_id) = - self.typeck.qpath_res(&qpath, hir.hir_id) - { - FormulaOrTerm::Term( - self.env - .get(&hir_id) - .expect("unbound variable in formula") - .clone(), - ) - } else { - unimplemented!("unsupported path in formula: {:?}", qpath); + ExprKind::Path(qpath) => match self.typeck.qpath_res(&qpath, hir.hir_id) { + rustc_hir::def::Res::Local(hir_id) => FormulaOrTerm::Term( + self.env + .get(&hir_id) + .expect("unbound variable in formula") + .clone(), + ), + rustc_hir::def::Res::Def( + rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _), + ctor_did, + ) => { + FormulaOrTerm::Term(self.variant_ctor_term(ctor_did, self.expr_ty(hir), vec![])) } - } + _ => unimplemented!("unsupported path in formula: {:?}", qpath), + }, ExprKind::Tup(exprs) => { let terms = exprs.iter().map(|e| self.to_term(e)).collect(); FormulaOrTerm::Term(chc::Term::tuple(terms)) @@ -349,7 +389,40 @@ impl<'tcx> AnnotFnTranslator<'tcx> { ExprKind::Call(func_expr, args) => { if let ExprKind::Path(qpath) = &func_expr.kind { let res = self.typeck.qpath_res(qpath, func_expr.hir_id); - if let rustc_hir::def::Res::Def(_, def_id) = res { + if let rustc_hir::def::Res::Def(def_kind, def_id) = res { + if Some(def_id) == self.def_ids.exists() { + assert_eq!(args.len(), 1, "exists takes exactly 1 argument"); + let ExprKind::Closure(closure) = args[0].kind else { + panic!("exists argument must be a closure"); + }; + let closure_body = self.tcx.hir().body(closure.body); + + let mut inner_translator = self.clone(); + let mut vars = Vec::new(); + for param in closure_body.params { + let rustc_hir::PatKind::Binding(_, hir_id, ident, None) = + param.pat.kind + else { + panic!( + "exists closure parameter must be a simple binding: {:?}", + param.pat + ); + }; + let param_ty = self.pat_ty(param.pat); + let sort = self.type_builder.build(param_ty).to_sort(); + let var_term = chc::Term::FormulaExistentialVar( + sort.clone(), + ident.name.to_string(), + ); + inner_translator.env.insert(hir_id, var_term); + vars.push((ident.name.to_string(), sort)); + } + let body_formula = inner_translator.to_formula(closure_body.value); + return FormulaOrTerm::Formula(chc::Formula::exists( + vars, + body_formula, + )); + } if Some(def_id) == self.def_ids.mut_model_new() { assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments"); let t1 = self.to_term(&args[0]); @@ -361,6 +434,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> { let t = self.to_term(&args[0]); return FormulaOrTerm::Term(chc::Term::box_(t)); } + if matches!( + def_kind, + rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _) + ) { + let field_terms = args.iter().map(|arg| self.to_term(arg)).collect(); + return FormulaOrTerm::Term(self.variant_ctor_term( + def_id, + self.expr_ty(hir), + field_terms, + )); + } } } unimplemented!("unsupported call in formula: {:?}", func_expr) diff --git a/src/analyze/did_cache.rs b/src/analyze/did_cache.rs index 3c2ff10..99b6e4c 100644 --- a/src/analyze/did_cache.rs +++ b/src/analyze/did_cache.rs @@ -22,6 +22,8 @@ struct DefIds { mut_model_new: OnceCell>, box_model_new: OnceCell>, + + exists: OnceCell>, } /// Retrieves and caches well-known [`DefId`]s. @@ -160,4 +162,11 @@ impl<'tcx> DefIdCache<'tcx> { .box_model_new .get_or_init(|| self.annotated_def(&crate::analyze::annot::box_model_new_path())) } + + pub fn exists(&self) -> Option { + *self + .def_ids + .exists + .get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path())) + } } diff --git a/std.rs b/std.rs index 7afcaf6..3318656 100644 --- a/std.rs +++ b/std.rs @@ -234,6 +234,13 @@ mod thrust_models { impl Model for Result where T: Model, E: Model { type Ty = Result<::Ty, ::Ty>; } + + #[allow(dead_code)] + #[thrust::def::exists] + #[thrust::ignored] + pub fn exists(_x: T) -> bool { + unimplemented!() + } } #[thrust::extern_spec_fn] diff --git a/tests/ui/fail/annot_enum_simple_formula_fn.rs b/tests/ui/fail/annot_enum_simple_formula_fn.rs new file mode 100644 index 0000000..30d19f5 --- /dev/null +++ b/tests/ui/fail/annot_enum_simple_formula_fn.rs @@ -0,0 +1,40 @@ +//@error-in-other-file: Unsat + +#[derive(PartialEq)] +pub enum X { + A(i64), + B(bool), +} + +impl thrust_models::Model for X { + type Ty = X; +} + +#[thrust::formula_fn] +fn _thrust_requires_test(x: X) -> bool { + x == X::A(1) +} + +#[thrust::formula_fn] +fn _thrust_ensures_test(_result: (), _x: X) -> bool { + true +} + +#[allow(path_statements)] +fn test(x: X) { + #[thrust::requires_path] + _thrust_requires_test; + + #[thrust::ensures_path] + _thrust_ensures_test; + + if let X::A(i) = x { + assert!(i == 2); + } else { + loop {} + } +} + +fn main() { + test(X::A(1)); +} diff --git a/tests/ui/fail/annot_exists_formula_fn.rs b/tests/ui/fail/annot_exists_formula_fn.rs new file mode 100644 index 0000000..767aea7 --- /dev/null +++ b/tests/ui/fail/annot_exists_formula_fn.rs @@ -0,0 +1,31 @@ +//@error-in-other-file: Unsat +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i32 { unimplemented!() } + +#[thrust::formula_fn] +fn _thrust_requires_f() -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_f(result: i32) -> bool { + thrust_models::exists(|x: i32| result == 2 * x) +} + +#[allow(path_statements)] +fn f() -> i32 { + #[thrust::requires_path] + _thrust_requires_f; + + #[thrust::ensures_path] + _thrust_ensures_f; + + let x = rand(); + x + x + x +} + +fn main() {} diff --git a/tests/ui/pass/annot_enum_simple_formula_fn.rs b/tests/ui/pass/annot_enum_simple_formula_fn.rs new file mode 100644 index 0000000..0a31951 --- /dev/null +++ b/tests/ui/pass/annot_enum_simple_formula_fn.rs @@ -0,0 +1,40 @@ +//@check-pass + +#[derive(PartialEq)] +pub enum X { + A(i64), + B(bool), +} + +impl thrust_models::Model for X { + type Ty = X; +} + +#[thrust::formula_fn] +fn _thrust_requires_test(x: X) -> bool { + x == X::A(1) +} + +#[thrust::formula_fn] +fn _thrust_ensures_test(_result: (), _x: X) -> bool { + true +} + +#[allow(path_statements)] +fn test(x: X) { + #[thrust::requires_path] + _thrust_requires_test; + + #[thrust::ensures_path] + _thrust_ensures_test; + + if let X::A(i) = x { + assert!(i == 1); + } else { + loop {} + } +} + +fn main() { + test(X::A(1)); +} diff --git a/tests/ui/pass/annot_exists_formula_fn.rs b/tests/ui/pass/annot_exists_formula_fn.rs new file mode 100644 index 0000000..3f376d6 --- /dev/null +++ b/tests/ui/pass/annot_exists_formula_fn.rs @@ -0,0 +1,31 @@ +//@check-pass +//@compile-flags: -C debug-assertions=off +//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper + +#[thrust::trusted] +#[thrust::callable] +fn rand() -> i32 { unimplemented!() } + +#[thrust::formula_fn] +fn _thrust_requires_f() -> bool { + true +} + +#[thrust::formula_fn] +fn _thrust_ensures_f(result: i32) -> bool { + thrust_models::exists(|x: i32| result == 2 * x) +} + +#[allow(path_statements)] +fn f() -> i32 { + #[thrust::requires_path] + _thrust_requires_f; + + #[thrust::ensures_path] + _thrust_ensures_f; + + let x = rand(); + x + x +} + +fn main() {}