Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 180 additions & 6 deletions src/analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use rustc_index::IndexVec;
use rustc_middle::mir::{self, BasicBlock, Local};
use rustc_middle::ty::{self as mir_ty, TyCtxt};
use rustc_span::def_id::{DefId, LocalDefId};
use rustc_span::Symbol;

use crate::analyze;
use crate::annot::{AnnotFormula, AnnotParser, Resolver};
Expand All @@ -24,6 +25,7 @@ use crate::refine::{self, BasicBlockType, TypeBuilder};
use crate::rty;

mod annot;
mod annot_fn;
mod basic_block;
mod crate_;
mod did_cache;
Expand All @@ -32,6 +34,32 @@ mod local_def;
// TODO: organize structure and remove cross dependency between refine
pub use did_cache::DefIdCache;

pub fn mir_borrowck_skip_formula_fn(
tcx: rustc_middle::ty::TyCtxt<'_>,
local_def_id: rustc_span::def_id::LocalDefId,
) -> rustc_middle::query::queries::mir_borrowck::ProvidedValue {
// TODO: unify impl with local_def::Analyzer
let is_annotated_as_formula_fn = tcx
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::formula_fn_path())
.next()
.is_some();

if is_annotated_as_formula_fn {
tracing::debug!(?local_def_id, "skipping borrow check for formula fn");
let dummy_result = rustc_middle::mir::BorrowCheckResult {
concrete_opaque_types: Default::default(),
closure_requirements: None,
used_mut_upvars: Default::default(),
tainted_by_errors: None,
};
return tcx.arena.alloc(dummy_result);
}

(rustc_interface::DEFAULT_QUERY_PROVIDERS
.queries
.mir_borrowck)(tcx, local_def_id)
}

pub fn local_of_function_param(idx: rty::FunctionParamIdx) -> Local {
Local::from(idx.index() + 1)
}
Expand Down Expand Up @@ -162,6 +190,11 @@ impl refine::EnumDefProvider for Rc<RefCell<EnumDefs>> {

pub type Env = refine::Env<Rc<RefCell<EnumDefs>>>;

#[derive(Debug, Clone)]
struct DeferredFormulaFnDef<'tcx> {
cache: Rc<RefCell<HashMap<mir_ty::GenericArgsRef<'tcx>, annot_fn::FormulaFn<'tcx>>>>,
}

#[derive(Clone)]
pub struct Analyzer<'tcx> {
tcx: TyCtxt<'tcx>,
Expand All @@ -173,6 +206,9 @@ pub struct Analyzer<'tcx> {
/// (at least for every defs referenced by local def bodies)
defs: HashMap<DefId, DefTy<'tcx>>,

/// Collection of functions with `#[thrust::formula_fn]` attribute.
formula_fns: HashMap<LocalDefId, DeferredFormulaFnDef<'tcx>>,

/// Resulting CHC system.
system: Rc<RefCell<chc::System>>,

Expand All @@ -199,12 +235,14 @@ impl<'tcx> Analyzer<'tcx> {
impl<'tcx> Analyzer<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
let defs = Default::default();
let formula_fns = Default::default();
let system = Default::default();
let basic_blocks = Default::default();
let enum_defs = Default::default();
Self {
tcx,
defs,
formula_fns,
system,
basic_blocks,
def_ids: did_cache::DefIdCache::new(tcx),
Expand Down Expand Up @@ -358,6 +396,30 @@ impl<'tcx> Analyzer<'tcx> {
})
}

pub fn formula_fn_with_args(
&self,
local_def_id: LocalDefId,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<annot_fn::FormulaFn<'tcx>> {
let deferred_formula_fn = self.formula_fns.get(&local_def_id)?;

let deferred_formula_fn_cache = Rc::clone(&deferred_formula_fn.cache);
if let Some(formula_fn) = deferred_formula_fn_cache.borrow().get(&generic_args) {
return Some(formula_fn.clone());
}

let translator = annot_fn::AnnotFnTranslator::new(self.tcx, local_def_id)
.with_generic_args(generic_args)
.with_def_id_cache(self.def_ids());
let formula_fn = translator.to_formula_fn();
deferred_formula_fn_cache
.borrow_mut()
.insert(generic_args, formula_fn.clone());

tracing::info!(?local_def_id, formula_fn = %formula_fn.display(), ?generic_args, "formula_fn_with_args");
Some(formula_fn)
}

pub fn def_ty_with_args(
&mut self,
def_id: DefId,
Expand Down Expand Up @@ -410,6 +472,16 @@ impl<'tcx> Analyzer<'tcx> {
Some(expected)
}

pub fn register_formula_fn(&mut self, local_def_id: LocalDefId) {
tracing::info!(?local_def_id, "register_formula_fn");
self.formula_fns.insert(
local_def_id,
DeferredFormulaFnDef {
cache: Rc::new(RefCell::new(HashMap::new())),
},
);
}

pub fn register_basic_block_ty(
&mut self,
def_id: LocalDefId,
Expand Down Expand Up @@ -498,20 +570,75 @@ impl<'tcx> Analyzer<'tcx> {
self.fn_sig_with_body(def_id, body)
}

fn extract_path_with_attr(
&self,
local_def_id: LocalDefId,
attr_path: &[Symbol],
) -> Option<DefId> {
let map = self.tcx.hir();
let body_id = map.maybe_body_owned_by(local_def_id)?;
let body = map.body(body_id);

let rustc_hir::ExprKind::Block(block, _) = body.value.kind else {
return None;
};
for stmt in block.stmts {
if map
.attrs(stmt.hir_id)
.iter()
.all(|attr| !attr.path_matches(attr_path))
{
continue;
}
let rustc_hir::StmtKind::Semi(expr) = stmt.kind else {
self.tcx.dcx().span_err(
stmt.span,
"annotated path is expected to be a semi statement",
);
continue;
};
let rustc_hir::ExprKind::Path(qpath) = expr.kind else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to be a path expression",
);
continue;
};
let rustc_hir::QPath::Resolved(_, path) = qpath else {
self.tcx.dcx().span_err(
expr.span,
"annotated path is expected to be a resolved path",
);
continue;
};
let rustc_hir::def::Res::Def(_, def_id) = path.res else {
self.tcx.dcx().span_err(
path.span,
"annotated path is expected to refer to a definition",
);
continue;
};
return Some(def_id);
}
None
}

// TODO: reduce number of args
fn extract_require_annot<T>(
&self,
def_id: DefId,
local_def_id: LocalDefId,
resolver: T,
self_type_name: Option<String>,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<AnnotFormula<T::Output>>
where
T: Resolver,
T: Resolver<Output = rty::FunctionParamIdx>,
{
let mut require_annot = None;
let parser = AnnotParser::new(&resolver, self_type_name);
for attrs in self
.tcx
.get_attrs_by_path(def_id, &analyze::annot::requires_path())
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::requires_path())
{
if require_annot.is_some() {
unimplemented!();
Expand All @@ -520,23 +647,48 @@ impl<'tcx> Analyzer<'tcx> {
let require = parser.parse_formula(ts).unwrap();
require_annot = Some(require);
}

if let Some(formula_def_id) =
self.extract_path_with_attr(local_def_id, &analyze::annot::requires_path_path())
{
let Some(formula_def_id) = formula_def_id.as_local() else {
panic!(
"require annotation with path is expected to refer to a local def, but found: {:?}",
formula_def_id
);
};
if require_annot.is_some() {
unimplemented!();
}
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
panic!(
"require annotation {:?} is not a formula function",
formula_def_id
);
};
require_annot = Some(formula_fn.to_require_annot());
}

require_annot
}

// TODO: reduce number of args
fn extract_ensure_annot<T>(
&self,
def_id: DefId,
local_def_id: LocalDefId,
resolver: T,
self_type_name: Option<String>,
generic_args: mir_ty::GenericArgsRef<'tcx>,
) -> Option<AnnotFormula<T::Output>>
where
T: Resolver,
T: Resolver<Output = rty::RefinedTypeVar<rty::FunctionParamIdx>>,
{
let mut ensure_annot = None;

let parser = AnnotParser::new(&resolver, self_type_name);
for attrs in self
.tcx
.get_attrs_by_path(def_id, &analyze::annot::ensures_path())
.get_attrs_by_path(local_def_id.to_def_id(), &analyze::annot::ensures_path())
{
if ensure_annot.is_some() {
unimplemented!();
Expand All @@ -545,6 +697,28 @@ impl<'tcx> Analyzer<'tcx> {
let ensure = parser.parse_formula(ts).unwrap();
ensure_annot = Some(ensure);
}

if let Some(formula_def_id) =
self.extract_path_with_attr(local_def_id, &analyze::annot::ensures_path_path())
{
let Some(formula_def_id) = formula_def_id.as_local() else {
panic!(
"ensure annotation with path is expected to refer to a local def, but found: {:?}",
formula_def_id
);
};
if ensure_annot.is_some() {
unimplemented!();
}
let Some(formula_fn) = self.formula_fn_with_args(formula_def_id, generic_args) else {
panic!(
"ensure annotation {:?} is not a formula function",
formula_def_id
);
};
ensure_annot = Some(formula_fn.to_ensure_annot());
}

ensure_annot
}

Expand Down
28 changes: 28 additions & 0 deletions src/analyze/annot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ pub fn ignored_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("ignored")]
}

pub fn formula_fn_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("formula_fn")]
}

pub fn requires_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("requires_path")]
}

pub fn ensures_path_path() -> [Symbol; 2] {
[Symbol::intern("thrust"), Symbol::intern("ensures_path")]
}

pub fn model_ty_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Expand Down Expand Up @@ -97,6 +109,22 @@ pub fn closure_model_path() -> [Symbol; 3] {
]
}

pub fn mut_model_new_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("mut_new"),
]
}

pub fn box_model_new_path() -> [Symbol; 3] {
[
Symbol::intern("thrust"),
Symbol::intern("def"),
Symbol::intern("box_new"),
]
}

/// A [`annot::Resolver`] implementation for resolving function parameters.
///
/// The parameter names and their sorts needs to be configured via
Expand Down
Loading