Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ pub fn box_model_new_path() -> [Symbol; 3] {
]
}

pub fn exists_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("exists"),
]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
112 changes: 98 additions & 14 deletions src/analyze/annot_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use rustc_middle::ty::{self as mir_ty, TyCtxt};
use crate::analyze::did_cache::DefIdCache;
use crate::annot::AnnotFormula;
use crate::chc;
use crate::refine::TypeBuilder;
use crate::rty;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -118,6 +119,7 @@ impl<T> FormulaOrTerm<T> {
}
}

#[derive(Clone)]
pub struct AnnotFnTranslator<'tcx> {
tcx: TyCtxt<'tcx>,
local_def_id: LocalDefId,
Expand All @@ -127,6 +129,7 @@ pub struct AnnotFnTranslator<'tcx> {
generic_args: mir_ty::GenericArgsRef<'tcx>,

def_ids: DefIdCache<'tcx>,
type_builder: TypeBuilder<'tcx>,
env: HashMap<HirId, chc::Term<rty::FunctionParamIdx>>,
}

Expand All @@ -138,13 +141,15 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
let generic_args = tcx.mk_args(&[]);
let typeck = tcx.typeck(local_def_id);
let def_ids = DefIdCache::new(tcx);
let type_builder = TypeBuilder::new(tcx, def_ids.clone(), local_def_id.to_def_id());
let mut translator = Self {
tcx,
local_def_id,
typeck,
body,
generic_args,
def_ids,
type_builder,
env: HashMap::default(),
};
translator.build_env_from_params();
Expand All @@ -158,6 +163,11 @@ impl<'tcx> AnnotFnTranslator<'tcx> {

pub fn with_def_id_cache(mut self, def_ids: DefIdCache<'tcx>) -> Self {
self.def_ids = def_ids;
self.type_builder = TypeBuilder::new(
self.tcx,
self.def_ids.clone(),
self.local_def_id.to_def_id(),
);
self
}

Expand Down Expand Up @@ -197,6 +207,13 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
self.tcx.normalize_erasing_regions(param_env, instantiated)
}

fn pat_ty(&self, pat: &'tcx rustc_hir::Pat<'tcx>) -> mir_ty::Ty<'tcx> {
let ty = self.typeck.pat_ty(pat);
let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args);
let param_env = mir_ty::ParamEnv::reveal_all();
self.tcx.normalize_erasing_regions(param_env, instantiated)
}

pub fn to_formula_fn(&self) -> FormulaFn<'tcx> {
let formula = self.to_formula(self.body.value);
let params = self
Expand Down Expand Up @@ -224,6 +241,28 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
.expect("expected a term")
}

fn variant_ctor_term(
&self,
ctor_did: rustc_span::def_id::DefId,
result_ty: mir_ty::Ty<'tcx>,
field_terms: Vec<chc::Term<rty::FunctionParamIdx>>,
) -> chc::Term<rty::FunctionParamIdx> {
let variant_did = self.tcx.parent(ctor_did);
let adt_did = self.tcx.parent(variant_did);
let d_sym = crate::refine::datatype_symbol(self.tcx, adt_did);
let variant_name = self.tcx.item_name(variant_did);
let v_sym = chc::DatatypeSymbol::new(format!("{}.{}", d_sym, variant_name));
let sort_args = if let mir_ty::TyKind::Adt(_, generic_args) = result_ty.kind() {
generic_args
.types()
.map(|ty| self.type_builder.build(ty).to_sort())
.collect()
} else {
panic!("expected an ADT type for variant constructor")
};
chc::Term::datatype_ctor(d_sym, sort_args, v_sym, field_terms)
}

fn to_formula_or_term(
&self,
hir: &'tcx rustc_hir::Expr<'tcx>,
Expand Down Expand Up @@ -319,20 +358,21 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
rustc_ast::LitKind::Bool(b) => FormulaOrTerm::Literal(b),
_ => unimplemented!("unsupported literal in formula: {:?}", lit),
},
ExprKind::Path(qpath) => {
if let rustc_hir::def::Res::Local(hir_id) =
self.typeck.qpath_res(&qpath, hir.hir_id)
{
FormulaOrTerm::Term(
self.env
.get(&hir_id)
.expect("unbound variable in formula")
.clone(),
)
} else {
unimplemented!("unsupported path in formula: {:?}", qpath);
ExprKind::Path(qpath) => match self.typeck.qpath_res(&qpath, hir.hir_id) {
rustc_hir::def::Res::Local(hir_id) => FormulaOrTerm::Term(
self.env
.get(&hir_id)
.expect("unbound variable in formula")
.clone(),
),
rustc_hir::def::Res::Def(
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _),
ctor_did,
) => {
FormulaOrTerm::Term(self.variant_ctor_term(ctor_did, self.expr_ty(hir), vec![]))
}
}
_ => unimplemented!("unsupported path in formula: {:?}", qpath),
},
ExprKind::Tup(exprs) => {
let terms = exprs.iter().map(|e| self.to_term(e)).collect();
FormulaOrTerm::Term(chc::Term::tuple(terms))
Expand All @@ -349,7 +389,40 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
ExprKind::Call(func_expr, args) => {
if let ExprKind::Path(qpath) = &func_expr.kind {
let res = self.typeck.qpath_res(qpath, func_expr.hir_id);
if let rustc_hir::def::Res::Def(_, def_id) = res {
if let rustc_hir::def::Res::Def(def_kind, def_id) = res {
if Some(def_id) == self.def_ids.exists() {
assert_eq!(args.len(), 1, "exists takes exactly 1 argument");
let ExprKind::Closure(closure) = args[0].kind else {
panic!("exists argument must be a closure");
};
let closure_body = self.tcx.hir().body(closure.body);

let mut inner_translator = self.clone();
let mut vars = Vec::new();
for param in closure_body.params {
let rustc_hir::PatKind::Binding(_, hir_id, ident, None) =
param.pat.kind
else {
panic!(
"exists closure parameter must be a simple binding: {:?}",
param.pat
);
};
let param_ty = self.pat_ty(param.pat);
let sort = self.type_builder.build(param_ty).to_sort();
let var_term = chc::Term::FormulaExistentialVar(
sort.clone(),
ident.name.to_string(),
);
inner_translator.env.insert(hir_id, var_term);
vars.push((ident.name.to_string(), sort));
}
let body_formula = inner_translator.to_formula(closure_body.value);
return FormulaOrTerm::Formula(chc::Formula::exists(
vars,
body_formula,
));
}
if Some(def_id) == self.def_ids.mut_model_new() {
assert_eq!(args.len(), 2, "Mut::new takes exactly 2 arguments");
let t1 = self.to_term(&args[0]);
Expand All @@ -361,6 +434,17 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
let t = self.to_term(&args[0]);
return FormulaOrTerm::Term(chc::Term::box_(t));
}
if matches!(
def_kind,
rustc_hir::def::DefKind::Ctor(rustc_hir::def::CtorOf::Variant, _)
) {
let field_terms = args.iter().map(|arg| self.to_term(arg)).collect();
return FormulaOrTerm::Term(self.variant_ctor_term(
def_id,
self.expr_ty(hir),
field_terms,
));
}
}
}
unimplemented!("unsupported call in formula: {:?}", func_expr)
Expand Down
9 changes: 9 additions & 0 deletions src/analyze/did_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ struct DefIds {

mut_model_new: OnceCell<Option<DefId>>,
box_model_new: OnceCell<Option<DefId>>,

exists: OnceCell<Option<DefId>>,
}

/// Retrieves and caches well-known [`DefId`]s.
Expand Down Expand Up @@ -160,4 +162,11 @@ impl<'tcx> DefIdCache<'tcx> {
.box_model_new
.get_or_init(|| self.annotated_def(&crate::analyze::annot::box_model_new_path()))
}

pub fn exists(&self) -> Option<DefId> {
*self
.def_ids
.exists
.get_or_init(|| self.annotated_def(&crate::analyze::annot::exists_path()))
}
}
7 changes: 7 additions & 0 deletions std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ mod thrust_models {
impl<T, E> Model for Result<T, E> where T: Model, E: Model {
type Ty = Result<<T as Model>::Ty, <E as Model>::Ty>;
}

#[allow(dead_code)]
#[thrust::def::exists]
#[thrust::ignored]
pub fn exists<T>(_x: T) -> bool {
unimplemented!()
}
}

#[thrust::extern_spec_fn]
Expand Down
40 changes: 40 additions & 0 deletions tests/ui/fail/annot_enum_simple_formula_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//@error-in-other-file: Unsat

#[derive(PartialEq)]
pub enum X {
A(i64),
B(bool),
}

impl thrust_models::Model for X {
type Ty = X;
}

#[thrust::formula_fn]
fn _thrust_requires_test(x: X) -> bool {
x == X::A(1)
}

#[thrust::formula_fn]
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
true
}

#[allow(path_statements)]
fn test(x: X) {
#[thrust::requires_path]
_thrust_requires_test;

#[thrust::ensures_path]
_thrust_ensures_test;

if let X::A(i) = x {
assert!(i == 2);
} else {
loop {}
}
}

fn main() {
test(X::A(1));
}
31 changes: 31 additions & 0 deletions tests/ui/fail/annot_exists_formula_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//@error-in-other-file: Unsat
//@compile-flags: -C debug-assertions=off
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper

#[thrust::trusted]
#[thrust::callable]
fn rand() -> i32 { unimplemented!() }

#[thrust::formula_fn]
fn _thrust_requires_f() -> bool {
true
}

#[thrust::formula_fn]
fn _thrust_ensures_f(result: i32) -> bool {
thrust_models::exists(|x: i32| result == 2 * x)
}

#[allow(path_statements)]
fn f() -> i32 {
#[thrust::requires_path]
_thrust_requires_f;

#[thrust::ensures_path]
_thrust_ensures_f;

let x = rand();
x + x + x
}

fn main() {}
40 changes: 40 additions & 0 deletions tests/ui/pass/annot_enum_simple_formula_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//@check-pass

#[derive(PartialEq)]
pub enum X {
A(i64),
B(bool),
}

impl thrust_models::Model for X {
type Ty = X;
}

#[thrust::formula_fn]
fn _thrust_requires_test(x: X) -> bool {
x == X::A(1)
}

#[thrust::formula_fn]
fn _thrust_ensures_test(_result: (), _x: X) -> bool {
true
}

#[allow(path_statements)]
fn test(x: X) {
#[thrust::requires_path]
_thrust_requires_test;

#[thrust::ensures_path]
_thrust_ensures_test;

if let X::A(i) = x {
assert!(i == 1);
} else {
loop {}
}
}

fn main() {
test(X::A(1));
}
31 changes: 31 additions & 0 deletions tests/ui/pass/annot_exists_formula_fn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//@check-pass
//@compile-flags: -C debug-assertions=off
//@rustc-env: THRUST_SOLVER=tests/thrust-pcsat-wrapper

#[thrust::trusted]
#[thrust::callable]
fn rand() -> i32 { unimplemented!() }

#[thrust::formula_fn]
fn _thrust_requires_f() -> bool {
true
}

#[thrust::formula_fn]
fn _thrust_ensures_f(result: i32) -> bool {
thrust_models::exists(|x: i32| result == 2 * x)
}

#[allow(path_statements)]
fn f() -> i32 {
#[thrust::requires_path]
_thrust_requires_f;

#[thrust::ensures_path]
_thrust_ensures_f;

let x = rand();
x + x
}

fn main() {}