Skip to content

Commit 668226f

Browse files
committed
Translate formula fns after type params are instantiated
1 parent f71f919 commit 668226f

6 files changed

Lines changed: 146 additions & 14 deletions

File tree

src/analyze.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
190190

191191
pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;
192192

193+
#[derive(Debug, Clone)]
194+
struct DeferredFormulaFnDef<'tcx> {
195+
cache: Rc<RefCell<HashMap<mir_ty::GenericArgsRef<'tcx>, annot_fn::FormulaFn<'tcx>>>>,
196+
}
197+
193198
#[derive(Clone)]
194199
pub struct Analyzer<'tcx> {
195200
tcx: TyCtxt<'tcx>,
@@ -202,7 +207,7 @@ pub struct Analyzer<'tcx> {
202207
defs: HashMap<DefId, DefTy<'tcx>>,
203208

204209
/// Collection of functions with `#[thrust::formula_fn]` attribute.
205-
formula_fns: HashMap<DefId, annot_fn::FormulaFn<'tcx>>,
210+
formula_fns: HashMap<LocalDefId, DeferredFormulaFnDef<'tcx>>,
206211

207212
/// Resulting CHC system.
208213
system: Rc<RefCell<chc::System>>,
@@ -391,6 +396,29 @@ impl<'tcx> Analyzer<'tcx> {
391396
})
392397
}
393398

399+
pub fn formula_fn_with_args(
400+
&self,
401+
local_def_id: LocalDefId,
402+
generic_args: mir_ty::GenericArgsRef<'tcx>,
403+
) -> Option<annot_fn::FormulaFn<'tcx>> {
404+
let deferred_formula_fn = self.formula_fns.get(&local_def_id)?;
405+
406+
let deferred_formula_fn_cache = Rc::clone(&deferred_formula_fn.cache);
407+
if let Some(formula_fn) = deferred_formula_fn_cache.borrow().get(&generic_args) {
408+
return Some(formula_fn.clone());
409+
}
410+
411+
let translator = annot_fn::AnnotFnTranslator::new(self.tcx, self.def_ids(), local_def_id)
412+
.with_generic_args(generic_args);
413+
let formula_fn = translator.to_formula_fn();
414+
deferred_formula_fn_cache
415+
.borrow_mut()
416+
.insert(generic_args, formula_fn.clone());
417+
418+
tracing::info!(?local_def_id, formula_fn = %formula_fn.display(), ?generic_args, "formula_fn_with_args");
419+
Some(formula_fn)
420+
}
421+
394422
pub fn def_ty_with_args(
395423
&mut self,
396424
def_id: DefId,
@@ -443,9 +471,14 @@ impl<'tcx> Analyzer<'tcx> {
443471
Some(expected)
444472
}
445473

446-
pub fn register_formula_fn(&mut self, def_id: DefId, formula_fn: annot_fn::FormulaFn<'tcx>) {
447-
tracing::info!(def_id = ?def_id, formula_fn = %formula_fn.display(), "register_formula_fn");
448-
self.formula_fns.insert(def_id, formula_fn);
474+
pub fn register_formula_fn(&mut self, local_def_id: LocalDefId) {
475+
tracing::info!(?local_def_id, "register_formula_fn");
476+
self.formula_fns.insert(
477+
local_def_id,
478+
DeferredFormulaFnDef {
479+
cache: Rc::new(RefCell::new(HashMap::new())),
480+
},
481+
);
449482
}
450483

451484
pub fn register_basic_block_ty(
@@ -585,11 +618,13 @@ impl<'tcx> Analyzer<'tcx> {
585618
None
586619
}
587620

621+
// TODO: reduce number of args
588622
fn extract_require_annot<T>(
589623
&self,
590624
local_def_id: LocalDefId,
591625
resolver: T,
592626
self_type_name: Option<String>,
627+
generic_args: mir_ty::GenericArgsRef<'tcx>,
593628
) -> Option<AnnotFormula<T::Output>>
594629
where
595630
T: Resolver<Output = rty::FunctionParamIdx>,
@@ -611,10 +646,16 @@ impl<'tcx> Analyzer<'tcx> {
611646
if let Some(formula_def_id) =
612647
self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path_path())
613648
{
649+
let Some(formula_def_id) = formula_def_id.as_local() else {
650+
panic!(
651+
"require annotation with path is expected to refer to a local def, but found: {:?}",
652+
formula_def_id
653+
);
654+
};
614655
if require_annot.is_some() {
615656
unimplemented!();
616657
}
617-
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
658+
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
618659
panic!(
619660
"require annotation {:?} is not a formula function",
620661
formula_def_id
@@ -626,11 +667,13 @@ impl<'tcx> Analyzer<'tcx> {
626667
require_annot
627668
}
628669

670+
// TODO: reduce number of args
629671
fn extract_ensure_annot<T>(
630672
&self,
631673
local_def_id: LocalDefId,
632674
resolver: T,
633675
self_type_name: Option<String>,
676+
generic_args: mir_ty::GenericArgsRef<'tcx>,
634677
) -> Option<AnnotFormula<T::Output>>
635678
where
636679
T: Resolver<Output = rty::RefinedTypeVar<rty::FunctionParamIdx>>,
@@ -653,10 +696,16 @@ impl<'tcx> Analyzer<'tcx> {
653696
if let Some(formula_def_id) =
654697
self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path_path())
655698
{
699+
let Some(formula_def_id) = formula_def_id.as_local() else {
700+
panic!(
701+
"require annotation with path is expected to refer to a local def, but found: {:?}",
702+
formula_def_id
703+
);
704+
};
656705
if ensure_annot.is_some() {
657706
unimplemented!();
658707
}
659-
let Some(formula_fn) = self.formula_fns.get(&formula_def_id) else {
708+
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
660709
panic!(
661710
"ensure annotation {:?} is not a formula function",
662711
formula_def_id

src/analyze/annot_fn.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ pub struct AnnotFnTranslator<'tcx> {
124124

125125
typeck: &'tcx mir_ty::TypeckResults<'tcx>,
126126
body: &'tcx rustc_hir::Body<'tcx>,
127+
generic_args: mir_ty::GenericArgsRef<'tcx>,
127128

128129
def_ids: DefIdCache<'tcx>,
129130
env: HashMap<HirId, chc::Term<rty::FunctionParamIdx>>,
@@ -134,20 +135,27 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
134135
let map = tcx.hir();
135136
let body_id = map.body_owned_by(local_def_id);
136137
let body = map.body(body_id);
138+
let generic_args = tcx.mk_args(&[]);
137139

138140
let typeck = tcx.typeck(local_def_id);
139141
let mut translator = Self {
140142
tcx,
141143
local_def_id,
142144
typeck,
143145
body,
146+
generic_args,
144147
def_ids,
145148
env: HashMap::default(),
146149
};
147150
translator.build_env_from_params();
148151
translator
149152
}
150153

154+
pub fn with_generic_args(mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> Self {
155+
self.generic_args = generic_args;
156+
self
157+
}
158+
151159
fn build_env_from_params(&mut self) {
152160
for (idx, param) in self.body.params.iter().enumerate() {
153161
let param_idx = rty::FunctionParamIdx::from(idx);
@@ -177,12 +185,19 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
177185
}
178186
}
179187

188+
fn expr_ty(&self, expr: &'tcx rustc_hir::Expr<'tcx>) -> mir_ty::Ty<'tcx> {
189+
let ty = self.typeck.expr_ty(expr);
190+
let instantiated = mir_ty::EarlyBinder::bind(ty).instantiate(self.tcx, self.generic_args);
191+
let param_env = mir_ty::ParamEnv::reveal_all();
192+
self.tcx.normalize_erasing_regions(param_env, instantiated)
193+
}
194+
180195
pub fn to_formula_fn(&self) -> FormulaFn<'tcx> {
181196
let formula = self.to_formula(self.body.value);
182197
let params = self
183198
.tcx
184199
.fn_sig(self.local_def_id.to_def_id())
185-
.instantiate_identity()
200+
.instantiate(self.tcx, self.generic_args)
186201
.skip_binder()
187202
.inputs()
188203
.to_vec();
@@ -260,7 +275,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
260275
FormulaOrTerm::Term(operand.neg())
261276
}
262277
rustc_hir::UnOp::Not => {
263-
let operand_ty = self.typeck.expr_ty(operand);
278+
let operand_ty = self.expr_ty(operand);
264279
match operand_ty.ty_adt_def() {
265280
Some(adt) if Some(adt.did()) == self.def_ids.mut_model() => {
266281
let operand = self.to_term(operand);
@@ -273,7 +288,7 @@ impl<'tcx> AnnotFnTranslator<'tcx> {
273288
}
274289
}
275290
rustc_hir::UnOp::Deref => {
276-
let operand_ty = self.typeck.expr_ty(operand);
291+
let operand_ty = self.expr_ty(operand);
277292
let adt = operand_ty
278293
.ty_adt_def()
279294
.expect("deref operand must be a model type");

src/analyze/crate_.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_hir::def_id::CRATE_DEF_ID;
66
use rustc_middle::ty::{self as mir_ty, TyCtxt};
77
use rustc_span::def_id::LocalDefId;
88

9-
use crate::analyze::{self, annot_fn::AnnotFnTranslator};
9+
use crate::analyze;
1010
use crate::chc;
1111
use crate::rty::{self, ClauseBuilderExt as _};
1212

@@ -95,10 +95,7 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
9595
}
9696

9797
if analyzer.is_annotated_as_formula_fn() {
98-
let formula_fn =
99-
AnnotFnTranslator::new(self.tcx, self.ctx.def_ids(), local_def_id).to_formula_fn();
100-
self.ctx
101-
.register_formula_fn(local_def_id.to_def_id(), formula_fn);
98+
self.ctx.register_formula_fn(local_def_id);
10299
self.skip_analysis.insert(local_def_id);
103100
return;
104101
}

src/analyze/local_def.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pub struct Analyzer<'tcx, 'ctx> {
4848
local_def_id: LocalDefId,
4949

5050
body: Body<'tcx>,
51+
/// to substitute HIR types during translation in [`crate::analyze::annot_fn`]
52+
generic_args: mir_ty::GenericArgsRef<'tcx>,
5153
drop_points: HashMap<BasicBlock, analyze::basic_block::DropPoints>,
5254
type_builder: TypeBuilder<'tcx>,
5355
}
@@ -310,12 +312,14 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
310312
self.local_def_id,
311313
&param_resolver,
312314
self_type_name.clone(),
315+
self.generic_args,
313316
);
314317

315318
let mut ensure_annot = self.ctx.extract_ensure_annot(
316319
self.local_def_id,
317320
&result_param_resolver,
318321
self_type_name.clone(),
322+
self.generic_args,
319323
);
320324

321325
if let Some(trait_item_id) = self.trait_item_id() {
@@ -324,11 +328,13 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
324328
trait_item_id,
325329
&param_resolver,
326330
self_type_name.clone(),
331+
self.generic_args,
327332
);
328333
let trait_ensure_annot = self.ctx.extract_ensure_annot(
329334
trait_item_id,
330335
&result_param_resolver,
331336
self_type_name.clone(),
337+
self.generic_args,
332338
);
333339

334340
assert!(require_annot.is_none() || trait_require_annot.is_none());
@@ -851,17 +857,20 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> {
851857
let body = tcx.optimized_mir(local_def_id.to_def_id()).clone();
852858
let drop_points = Default::default();
853859
let type_builder = TypeBuilder::new(tcx, ctx.def_ids(), local_def_id.to_def_id());
860+
let generic_args = tcx.mk_args(&[]);
854861
Self {
855862
ctx,
856863
tcx,
857864
local_def_id,
858865
body,
866+
generic_args,
859867
drop_points,
860868
type_builder,
861869
}
862870
}
863871

864872
pub fn generic_args(&mut self, generic_args: mir_ty::GenericArgsRef<'tcx>) -> &mut Self {
873+
self.generic_args = generic_args;
865874
self.body =
866875
mir_ty::EarlyBinder::bind(self.body.clone()).instantiate(self.tcx, generic_args);
867876
self
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@error-in-other-file: Unsat
2+
3+
#[allow(unused_variables)]
4+
#[thrust::formula_fn]
5+
fn _thrust_requires_swap(x: thrust_models::model::Mut<i64>, y: i64) -> bool {
6+
true
7+
}
8+
9+
#[allow(unused_variables)]
10+
#[thrust::formula_fn]
11+
fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut<i64>, y: thrust_models::model::Mut<i64>) -> bool {
12+
*x == *y && *y == *x
13+
}
14+
15+
#[allow(path_statements)]
16+
fn swap<T>(x: &mut T, y: &mut T) {
17+
#[thrust::requires_path]
18+
_thrust_requires_swap;
19+
20+
#[thrust::ensures_path]
21+
_thrust_ensures_swap;
22+
23+
std::mem::swap(x, y)
24+
}
25+
26+
fn main() {
27+
let mut a = 1;
28+
let mut b = 2;
29+
swap(&mut a, &mut b);
30+
assert!(a == 2 && b == 1);
31+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//@check-pass
2+
3+
#[allow(unused_variables)]
4+
#[thrust::formula_fn]
5+
fn _thrust_requires_swap(x: thrust_models::model::Mut<i64>, y: i64) -> bool {
6+
true
7+
}
8+
9+
#[allow(unused_variables)]
10+
#[thrust::formula_fn]
11+
fn _thrust_ensures_swap(result: (), x: thrust_models::model::Mut<i64>, y: thrust_models::model::Mut<i64>) -> bool {
12+
!x == *y && !y == *x
13+
}
14+
15+
#[allow(path_statements)]
16+
fn swap<T>(x: &mut T, y: &mut T) {
17+
#[thrust::requires_path]
18+
_thrust_requires_swap;
19+
20+
#[thrust::ensures_path]
21+
_thrust_ensures_swap;
22+
23+
std::mem::swap(x, y)
24+
}
25+
26+
fn main() {
27+
let mut a = 1;
28+
let mut b = 2;
29+
swap(&mut a, &mut b);
30+
assert!(a == 2 && b == 1);
31+
}

0 commit comments

Comments
 (0)