@@ -190,6 +190,11 @@ impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {
190190
191191pub type Env = refine:: Env < Rc < RefCell < EnumDefs > > > ;
192192
193+ #[ derive( Debug , Clone ) ]
194+ struct DeferredFormulaFnDef < ' tcx > {
195+ cache : Rc < RefCell < HashMap < mir_ty:: GenericArgsRef < ' tcx > , annot_fn:: FormulaFn < ' tcx > > > > ,
196+ }
197+
193198#[ derive( Clone ) ]
194199pub struct Analyzer < ' tcx > {
195200 tcx : TyCtxt < ' tcx > ,
@@ -202,7 +207,7 @@ pub struct Analyzer<'tcx> {
202207 defs : HashMap < DefId , DefTy < ' tcx > > ,
203208
204209 /// Collection of functions with `#[thrust::formula_fn]` attribute.
205- formula_fns : HashMap < DefId , annot_fn :: FormulaFn < ' tcx > > ,
210+ formula_fns : HashMap < LocalDefId , DeferredFormulaFnDef < ' tcx > > ,
206211
207212 /// Resulting CHC system.
208213 system : Rc < RefCell < chc:: System > > ,
@@ -391,6 +396,29 @@ impl<'tcx> Analyzer<'tcx> {
391396 } )
392397 }
393398
399+ pub fn formula_fn_with_args (
400+ & self ,
401+ local_def_id : LocalDefId ,
402+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
403+ ) -> Option < annot_fn:: FormulaFn < ' tcx > > {
404+ let deferred_formula_fn = self . formula_fns . get ( & local_def_id) ?;
405+
406+ let deferred_formula_fn_cache = Rc :: clone ( & deferred_formula_fn. cache ) ;
407+ if let Some ( formula_fn) = deferred_formula_fn_cache. borrow ( ) . get ( & generic_args) {
408+ return Some ( formula_fn. clone ( ) ) ;
409+ }
410+
411+ let translator = annot_fn:: AnnotFnTranslator :: new ( self . tcx , self . def_ids ( ) , local_def_id)
412+ . with_generic_args ( generic_args) ;
413+ let formula_fn = translator. to_formula_fn ( ) ;
414+ deferred_formula_fn_cache
415+ . borrow_mut ( )
416+ . insert ( generic_args, formula_fn. clone ( ) ) ;
417+
418+ tracing:: info!( ?local_def_id, formula_fn = %formula_fn. display( ) , ?generic_args, "formula_fn_with_args" ) ;
419+ Some ( formula_fn)
420+ }
421+
394422 pub fn def_ty_with_args (
395423 & mut self ,
396424 def_id : DefId ,
@@ -443,9 +471,14 @@ impl<'tcx> Analyzer<'tcx> {
443471 Some ( expected)
444472 }
445473
446- pub fn register_formula_fn ( & mut self , def_id : DefId , formula_fn : annot_fn:: FormulaFn < ' tcx > ) {
447- tracing:: info!( def_id = ?def_id, formula_fn = %formula_fn. display( ) , "register_formula_fn" ) ;
448- self . formula_fns . insert ( def_id, formula_fn) ;
474+ pub fn register_formula_fn ( & mut self , local_def_id : LocalDefId ) {
475+ tracing:: info!( ?local_def_id, "register_formula_fn" ) ;
476+ self . formula_fns . insert (
477+ local_def_id,
478+ DeferredFormulaFnDef {
479+ cache : Rc :: new ( RefCell :: new ( HashMap :: new ( ) ) ) ,
480+ } ,
481+ ) ;
449482 }
450483
451484 pub fn register_basic_block_ty (
@@ -585,11 +618,13 @@ impl<'tcx> Analyzer<'tcx> {
585618 None
586619 }
587620
621+ // TODO: reduce number of args
588622 fn extract_require_annot < T > (
589623 & self ,
590624 local_def_id : LocalDefId ,
591625 resolver : T ,
592626 self_type_name : Option < String > ,
627+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
593628 ) -> Option < AnnotFormula < T :: Output > >
594629 where
595630 T : Resolver < Output = rty:: FunctionParamIdx > ,
@@ -611,10 +646,16 @@ impl<'tcx> Analyzer<'tcx> {
611646 if let Some ( formula_def_id) =
612647 self . extract_path_with_attr ( local_def_id, & analyze:: annot:: requires_path_path ( ) )
613648 {
649+ let Some ( formula_def_id) = formula_def_id. as_local ( ) else {
650+ panic ! (
651+ "require annotation with path is expected to refer to a local def, but found: {:?}" ,
652+ formula_def_id
653+ ) ;
654+ } ;
614655 if require_annot. is_some ( ) {
615656 unimplemented ! ( ) ;
616657 }
617- let Some ( formula_fn) = self . formula_fns . get ( & formula_def_id) else {
658+ let Some ( formula_fn) = self . formula_fn_with_args ( formula_def_id, generic_args ) else {
618659 panic ! (
619660 "require annotation {:?} is not a formula function" ,
620661 formula_def_id
@@ -626,11 +667,13 @@ impl<'tcx> Analyzer<'tcx> {
626667 require_annot
627668 }
628669
670+ // TODO: reduce number of args
629671 fn extract_ensure_annot < T > (
630672 & self ,
631673 local_def_id : LocalDefId ,
632674 resolver : T ,
633675 self_type_name : Option < String > ,
676+ generic_args : mir_ty:: GenericArgsRef < ' tcx > ,
634677 ) -> Option < AnnotFormula < T :: Output > >
635678 where
636679 T : Resolver < Output = rty:: RefinedTypeVar < rty:: FunctionParamIdx > > ,
@@ -653,10 +696,16 @@ impl<'tcx> Analyzer<'tcx> {
653696 if let Some ( formula_def_id) =
654697 self . extract_path_with_attr ( local_def_id, & analyze:: annot:: ensures_path_path ( ) )
655698 {
699+ let Some ( formula_def_id) = formula_def_id. as_local ( ) else {
700+ panic ! (
701+ "require annotation with path is expected to refer to a local def, but found: {:?}" ,
702+ formula_def_id
703+ ) ;
704+ } ;
656705 if ensure_annot. is_some ( ) {
657706 unimplemented ! ( ) ;
658707 }
659- let Some ( formula_fn) = self . formula_fns . get ( & formula_def_id) else {
708+ let Some ( formula_fn) = self . formula_fn_with_args ( formula_def_id, generic_args ) else {
660709 panic ! (
661710 "ensure annotation {:?} is not a formula function" ,
662711 formula_def_id
0 commit comments