diff --git a/rr_frontend/translation/src/body/translator.rs b/rr_frontend/translation/src/body/translator.rs index 0adc5e3123d4d0abdc5077b8b3907bc929160326..bfb76258e500b8f8834dd9d4d29b2f5c85eda33a 100644 --- a/rr_frontend/translation/src/body/translator.rs +++ b/rr_frontend/translation/src/body/translator.rs @@ -95,22 +95,6 @@ pub fn get_arg_syntypes_for_procedure_call<'tcx, 'def>( Ok(syntypes) } -// solve the constraints for the new_regions -// we first identify what kinds of constraints these new regions are subject to -#[derive(Debug)] -enum CallRegionKind { - // this is just an intersection of local regions. - Intersection(HashSet<Region>), - // this is equal to a specific region - EqR(Region), -} - -struct CallRegions { - pub early_regions: Vec<Region>, - pub late_regions: Vec<Region>, - pub classification: HashMap<Region, CallRegionKind>, -} - /// Struct that keeps track of all information necessary to translate a MIR Body to a `radium::Function`. /// `'a` is the lifetime of the translator and ends after translation has finished. /// `'def` is the lifetime of the generated code (the code may refer to struct defs). @@ -819,145 +803,6 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { //Self::dump_body(body); } - fn compute_call_regions( - &self, - func: &Constant<'tcx>, - loc: Location, - ) -> Result<CallRegions, TranslationError<'tcx>> { - let midpoint = self.info.interner.get_point_index(&facts::Point { - location: loc, - typ: facts::PointType::Mid, - }); - - // first identify substitutions for the early-bound regions - let (target_did, sig, substs, _) = self.call_expr_op_split_inst(func)?; - info!("calling function {:?}", target_did); - let mut early_regions = Vec::new(); - info!("call substs: {:?} = {:?}, {:?}", func, sig, substs); - for a in substs { - if let ty::GenericArgKind::Lifetime(r) = a.unpack() { - if let ty::RegionKind::ReVar(r) = r.kind() { - early_regions.push(r); - } - } - } - info!("call region instantiations (early): {:?}", early_regions); - - // this is a hack to identify the inference variables introduced for the - // call's late-bound universals. - // TODO: Can we get this information in a less hacky way? - // One approach: compute the early + late bound regions for a given DefId, similarly to how - // we do it when starting to translate a function - // Problem: this doesn't give a straightforward way to compute their instantiation - - // now find all the regions that appear in type parameters we instantiate. - // These are regions that the callee doesn't know about. - let mut generic_regions = HashSet::new(); - let mut clos = |r: ty::Region<'tcx>, _| match r.kind() { - ty::RegionKind::ReVar(rv) => { - generic_regions.insert(rv); - r - }, - _ => r, - }; - - for a in substs { - if let ty::GenericArgKind::Type(c) = a.unpack() { - let mut folder = ty::fold::RegionFolder::new(self.env.tcx(), &mut clos); - folder.fold_ty(c); - } - } - info!("Regions of generic args: {:?}", generic_regions); - - // go over all region constraints initiated at this location - let new_constraints = self.info.get_new_subset_constraints_at_point(midpoint); - let mut new_regions = HashSet::new(); - let mut relevant_constraints = Vec::new(); - for (r1, r2) in &new_constraints { - if matches!(self.info.get_region_kind(*r1), polonius_info::RegionKind::Unknown) { - // this is probably a inference variable for the call - new_regions.insert(*r1); - relevant_constraints.push((*r1, *r2)); - } - if matches!(self.info.get_region_kind(*r2), polonius_info::RegionKind::Unknown) { - new_regions.insert(*r2); - relevant_constraints.push((*r1, *r2)); - } - } - // first sort this to enable cycle resolution - let mut new_regions_sorted: Vec<Region> = new_regions.iter().copied().collect(); - new_regions_sorted.sort(); - - // identify the late-bound regions - let mut late_regions = Vec::new(); - for r in &new_regions_sorted { - // only take the ones which are not early bound and - // which are not due to a generic (the callee doesn't care about generic regions) - if !early_regions.contains(r) && !generic_regions.contains(r) { - late_regions.push(*r); - } - } - info!("call region instantiations (late): {:?}", late_regions); - - // Notes: - // - if two of the call regions need to be equal due to constraints on the function, we define the one - // with the larger id in terms of the other one - // - we ignore unidirectional subset constraints between call regions (these do not help in finding a - // solution if we take the transitive closure beforehand) - // - if a call region needs to be equal to a local region, we directly define it in terms of the local - // region - // - otherwise, it will be an intersection of local regions - let mut new_regions_classification = HashMap::new(); - // compute transitive closure of constraints - let relevant_constraints = polonius_info::compute_transitive_closure(relevant_constraints); - for r in &new_regions_sorted { - for (r1, r2) in &relevant_constraints { - if *r2 != *r { - continue; - } - - // i.e. (flipping it around when we are talking about lifetimes), - // r needs to be a sublft of r1 - if relevant_constraints.contains(&(*r2, *r1)) { - // if r1 is also a new region and r2 is ordered before it, we will - // just define r1 in terms of r2 - if new_regions.contains(r1) && r2.as_u32() < r1.as_u32() { - continue; - } - // need an equality constraint - new_regions_classification.insert(*r, CallRegionKind::EqR(*r1)); - // do not consider the rest of the constraints as r is already - // fully specified - break; - } - - // the intersection also needs to contain r1 - if new_regions.contains(r1) { - // we do not need this constraint, since we already computed the - // transitive closure. - continue; - } - - let kind = new_regions_classification - .entry(*r) - .or_insert(CallRegionKind::Intersection(HashSet::new())); - - let CallRegionKind::Intersection(s) = kind else { - unreachable!(); - }; - - s.insert(*r1); - } - } - info!("call arg classification: {:?}", new_regions_classification); - - Ok(CallRegions { - early_regions, - late_regions, - classification: new_regions_classification, - }) - } - fn translate_function_call( &mut self, func: &Operand<'tcx>, @@ -981,6 +826,11 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { }); }; + // Get the type of the return value from the function + let (target_did, sig, generic_args, inst_sig) = self.call_expr_op_split_inst(func_constant)?; + info!("calling function {:?}", target_did); + info!("call substs: {:?} = {:?}, {:?}", func, sig, generic_args); + // for lifetime annotations: // 1. get the regions involved here. for that, get the instantiation of the function. // + if it's a FnDef type, that should be easy. @@ -1001,20 +851,21 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { // substituted regions) should be. // 6. annotate the return value on assignment and establish constraints. - let classification = self.compute_call_regions(func_constant, loc)?; + let classification = + regions::calls::compute_call_regions(self.env, &self.inclusion_tracker, generic_args, loc); // update the inclusion tracker with the new regions we have introduced // We just add the inclusions and ignore that we resolve it in a "tight" way. // the cases where we need the reverse inclusion should be really rare. for (r, c) in &classification.classification { match c { - CallRegionKind::EqR(r2) => { + regions::calls::CallRegionKind::EqR(r2) => { // put it at the start point, because the inclusions come into effect // at the point right before. self.inclusion_tracker.add_static_inclusion(*r, *r2, startpoint); self.inclusion_tracker.add_static_inclusion(*r2, *r, startpoint); }, - CallRegionKind::Intersection(lfts) => { + regions::calls::CallRegionKind::Intersection(lfts) => { // all the regions represented by lfts need to be included in r for r2 in lfts { self.inclusion_tracker.add_static_inclusion(*r2, *r, startpoint); @@ -1050,8 +901,6 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { info!("Call lifetime instantiation (early): {:?}", classification.early_regions); info!("Call lifetime instantiation (late): {:?}", classification.late_regions); - // Get the type of the return value from the function - let (_, _, _, inst_sig) = self.call_expr_op_split_inst(func_constant)?; // TODO: do we need to do something with late bounds? let output_ty = inst_sig.output().skip_binder(); info!("call has instantiated type {:?}", inst_sig); @@ -1136,12 +985,12 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { for (r, class) in &classification.classification { let lft = self.format_region(*r); match class { - CallRegionKind::EqR(r2) => { + regions::calls::CallRegionKind::EqR(r2) => { let lft2 = self.format_region(*r2); stmt_annots.push(radium::Annotation::CopyLftName(lft2, lft)); }, - CallRegionKind::Intersection(rs) => { + regions::calls::CallRegionKind::Intersection(rs) => { match rs.len() { 0 => { return Err(TranslationError::UnsupportedFeature { @@ -1408,67 +1257,6 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { .and_then(|m| m.is_ignore().then_some(*did))) } - /// Get the regions appearing in a type. - fn get_regions_of_ty(&self, ty: Ty<'tcx>) -> HashSet<ty::RegionVid> { - let mut regions = HashSet::new(); - let mut clos = |r: ty::Region<'tcx>, _| match r.kind() { - ty::RegionKind::ReVar(rv) => { - regions.insert(rv); - r - }, - _ => r, - }; - let mut folder = ty::fold::RegionFolder::new(self.env.tcx(), &mut clos); - folder.fold_ty(ty); - regions - } - - /// On creating a composite value (e.g. a struct or enum), the composite value gets its own - /// Polonius regions. We need to map these regions properly to the respective lifetimes. - fn get_composite_rvalue_creation_annots( - &mut self, - loc: Location, - rhs_ty: ty::Ty<'tcx>, - ) -> Vec<radium::Annotation> { - let info = &self.info; - let input_facts = &info.borrowck_in_facts; - let subset_base = &input_facts.subset_base; - - let regions_of_ty = self.get_regions_of_ty(rhs_ty); - - let mut annots = Vec::new(); - - // Polonius subset constraint are spawned for the midpoint - let midpoint = self.info.interner.get_point_index(&facts::Point { - location: loc, - typ: facts::PointType::Mid, - }); - - for (s1, s2, point) in subset_base { - if *point == midpoint { - let lft1 = self.info.mk_atomic_region(*s1); - let lft2 = self.info.mk_atomic_region(*s2); - - // a place lifetime is included in a value lifetime - if lft2.is_value() && lft1.is_place() { - // make sure it's not due to an assignment constraint - if regions_of_ty.contains(s2) && !subset_base.contains(&(*s2, *s1, midpoint)) { - // we enforce this inclusion by setting the lifetimes to be equal - self.inclusion_tracker.add_static_inclusion(*s1, *s2, midpoint); - self.inclusion_tracker.add_static_inclusion(*s2, *s1, midpoint); - - let annot = radium::Annotation::CopyLftName( - self.ty_translator.format_atomic_region(&lft1), - self.ty_translator.format_atomic_region(&lft2), - ); - annots.push(annot); - } - } - } - } - annots - } - /** * Translate a single basic block. */ @@ -1560,7 +1348,8 @@ impl<'a, 'def: 'a, 'tcx: 'def> TX<'a, 'def, 'tcx> { loc, plc_strongly_writeable, plc_ty, rhs_ty); // TODO; maybe move this to rvalue - let composite_annots = self.get_composite_rvalue_creation_annots(loc, rhs_ty); + let composite_annots = regions::composite::get_composite_rvalue_creation_annots( + self.env, &mut self.inclusion_tracker, &self.ty_translator, loc, rhs_ty); cont_stmt = radium::Stmt::with_annotations( cont_stmt, diff --git a/rr_frontend/translation/src/regions/calls.rs b/rr_frontend/translation/src/regions/calls.rs new file mode 100644 index 0000000000000000000000000000000000000000..c0fb2db5321c0a8738d7174ab3d30b5da8ce2267 --- /dev/null +++ b/rr_frontend/translation/src/regions/calls.rs @@ -0,0 +1,182 @@ +// © 2024, The RefinedRust Developers and Contributors +// +// This Source Code Form is subject to the terms of the BSD-3-clause License. +// If a copy of the BSD-3-clause license was not distributed with this +// file, You can obtain one at https://opensource.org/license/bsd-3-clause/. + +//! Provides functionality for generating lifetime annotations for calls. + +use std::collections::{BTreeMap, HashMap, HashSet}; + +use derive_more::{Constructor, Debug}; +use log::{info, warn}; +use rr_rustc_interface::hir::def_id::DefId; +use rr_rustc_interface::middle::mir::tcx::PlaceTy; +use rr_rustc_interface::middle::mir::{BasicBlock, BorrowKind, Location, Rvalue}; +use rr_rustc_interface::middle::ty; +use rr_rustc_interface::middle::ty::{Ty, TyCtxt, TyKind, TypeFoldable, TypeFolder}; +use ty::TypeSuperFoldable; + +use crate::base::{self, Region}; +use crate::environment::borrowck::facts; +use crate::environment::polonius_info::PoloniusInfo; +use crate::environment::{dump_borrowck_info, polonius_info, Environment}; +use crate::regions::arg_folder::ty_instantiate; +use crate::regions::inclusion_tracker::{self, InclusionTracker}; +use crate::regions::EarlyLateRegionMap; +use crate::{regions, types}; + +// solve the constraints for the new_regions +// we first identify what kinds of constraints these new regions are subject to +#[derive(Debug)] +pub enum CallRegionKind { + // this is just an intersection of local regions. + Intersection(HashSet<Region>), + // this is equal to a specific region + EqR(Region), +} + +pub struct CallRegions { + pub early_regions: Vec<Region>, + pub late_regions: Vec<Region>, + pub classification: HashMap<Region, CallRegionKind>, +} + +// `substs` are the substitutions for the early-bound regions +pub fn compute_call_regions<'tcx>( + env: &Environment<'tcx>, + incl_tracker: &InclusionTracker<'_, '_>, + substs: &[ty::GenericArg<'tcx>], + loc: Location, +) -> CallRegions { + let info = incl_tracker.info(); + + let midpoint = info.interner.get_point_index(&facts::Point { + location: loc, + typ: facts::PointType::Mid, + }); + + let mut early_regions = Vec::new(); + for a in substs { + if let ty::GenericArgKind::Lifetime(r) = a.unpack() { + if let ty::RegionKind::ReVar(r) = r.kind() { + early_regions.push(r); + } + } + } + info!("call region instantiations (early): {:?}", early_regions); + + // this is a hack to identify the inference variables introduced for the + // call's late-bound universals. + // TODO: Can we get this information in a less hacky way? + // One approach: compute the early + late bound regions for a given DefId, similarly to how + // we do it when starting to translate a function + // Problem: this doesn't give a straightforward way to compute their instantiation + + // now find all the regions that appear in type parameters we instantiate. + // These are regions that the callee doesn't know about. + let mut generic_regions = HashSet::new(); + let mut clos = |r: ty::Region<'tcx>, _| match r.kind() { + ty::RegionKind::ReVar(rv) => { + generic_regions.insert(rv); + r + }, + _ => r, + }; + + for a in substs { + if let ty::GenericArgKind::Type(c) = a.unpack() { + let mut folder = ty::fold::RegionFolder::new(env.tcx(), &mut clos); + folder.fold_ty(c); + } + } + info!("Regions of generic args: {:?}", generic_regions); + + // go over all region constraints initiated at this location + let new_constraints = info.get_new_subset_constraints_at_point(midpoint); + let mut new_regions = HashSet::new(); + let mut relevant_constraints = Vec::new(); + for (r1, r2) in &new_constraints { + if matches!(info.get_region_kind(*r1), polonius_info::RegionKind::Unknown) { + // this is probably a inference variable for the call + new_regions.insert(*r1); + relevant_constraints.push((*r1, *r2)); + } + if matches!(info.get_region_kind(*r2), polonius_info::RegionKind::Unknown) { + new_regions.insert(*r2); + relevant_constraints.push((*r1, *r2)); + } + } + // first sort this to enable cycle resolution + let mut new_regions_sorted: Vec<Region> = new_regions.iter().copied().collect(); + new_regions_sorted.sort(); + + // identify the late-bound regions + let mut late_regions = Vec::new(); + for r in &new_regions_sorted { + // only take the ones which are not early bound and + // which are not due to a generic (the callee doesn't care about generic regions) + if !early_regions.contains(r) && !generic_regions.contains(r) { + late_regions.push(*r); + } + } + info!("call region instantiations (late): {:?}", late_regions); + + // Notes: + // - if two of the call regions need to be equal due to constraints on the function, we define the one + // with the larger id in terms of the other one + // - we ignore unidirectional subset constraints between call regions (these do not help in finding a + // solution if we take the transitive closure beforehand) + // - if a call region needs to be equal to a local region, we directly define it in terms of the local + // region + // - otherwise, it will be an intersection of local regions + let mut new_regions_classification = HashMap::new(); + // compute transitive closure of constraints + let relevant_constraints = polonius_info::compute_transitive_closure(relevant_constraints); + for r in &new_regions_sorted { + for (r1, r2) in &relevant_constraints { + if *r2 != *r { + continue; + } + + // i.e. (flipping it around when we are talking about lifetimes), + // r needs to be a sublft of r1 + if relevant_constraints.contains(&(*r2, *r1)) { + // if r1 is also a new region and r2 is ordered before it, we will + // just define r1 in terms of r2 + if new_regions.contains(r1) && r2.as_u32() < r1.as_u32() { + continue; + } + // need an equality constraint + new_regions_classification.insert(*r, CallRegionKind::EqR(*r1)); + // do not consider the rest of the constraints as r is already + // fully specified + break; + } + + // the intersection also needs to contain r1 + if new_regions.contains(r1) { + // we do not need this constraint, since we already computed the + // transitive closure. + continue; + } + + let kind = new_regions_classification + .entry(*r) + .or_insert(CallRegionKind::Intersection(HashSet::new())); + + let CallRegionKind::Intersection(s) = kind else { + unreachable!(); + }; + + s.insert(*r1); + } + } + info!("call arg classification: {:?}", new_regions_classification); + + CallRegions { + early_regions, + late_regions, + classification: new_regions_classification, + } +} diff --git a/rr_frontend/translation/src/regions/composite.rs b/rr_frontend/translation/src/regions/composite.rs new file mode 100644 index 0000000000000000000000000000000000000000..f473b8ab42b449393fabef293832f720bb2084c7 --- /dev/null +++ b/rr_frontend/translation/src/regions/composite.rs @@ -0,0 +1,90 @@ +// © 2024, The RefinedRust Developers and Contributors +// +// This Source Code Form is subject to the terms of the BSD-3-clause License. +// If a copy of the BSD-3-clause license was not distributed with this +// file, You can obtain one at https://opensource.org/license/bsd-3-clause/. + +//! Provides functionality for generating lifetime annotations for composite expressions. + +use std::collections::{BTreeMap, HashMap, HashSet}; + +use derive_more::{Constructor, Debug}; +use log::{info, warn}; +use rr_rustc_interface::hir::def_id::DefId; +use rr_rustc_interface::middle::mir::tcx::PlaceTy; +use rr_rustc_interface::middle::mir::{BasicBlock, BorrowKind, Location, Rvalue}; +use rr_rustc_interface::middle::ty; +use rr_rustc_interface::middle::ty::{Ty, TyCtxt, TyKind, TypeFoldable, TypeFolder}; +use ty::TypeSuperFoldable; + +use crate::base::{self, Region}; +use crate::environment::borrowck::facts; +use crate::environment::polonius_info::PoloniusInfo; +use crate::environment::{dump_borrowck_info, polonius_info, Environment}; +use crate::regions::arg_folder::ty_instantiate; +use crate::regions::inclusion_tracker::{self, InclusionTracker}; +use crate::regions::EarlyLateRegionMap; +use crate::{regions, types}; + +/// On creating a composite value (e.g. a struct or enum), the composite value gets its own +/// Polonius regions We need to map these regions properly to the respective lifetimes. +pub fn get_composite_rvalue_creation_annots<'tcx>( + env: &Environment<'tcx>, + inclusion_tracker: &mut InclusionTracker<'_, 'tcx>, + ty_translator: &types::LocalTX<'_, 'tcx>, + loc: Location, + rhs_ty: ty::Ty<'tcx>, +) -> Vec<radium::Annotation> { + let info = inclusion_tracker.info(); + let input_facts = &info.borrowck_in_facts; + let subset_base = &input_facts.subset_base; + + let regions_of_ty = get_regions_of_ty(env, rhs_ty); + + let mut annots = Vec::new(); + + // Polonius subset constraint are spawned for the midpoint + let midpoint = info.interner.get_point_index(&facts::Point { + location: loc, + typ: facts::PointType::Mid, + }); + + for (s1, s2, point) in subset_base { + if *point == midpoint { + let lft1 = info.mk_atomic_region(*s1); + let lft2 = info.mk_atomic_region(*s2); + + // a place lifetime is included in a value lifetime + if lft2.is_value() && lft1.is_place() { + // make sure it's not due to an assignment constraint + if regions_of_ty.contains(s2) && !subset_base.contains(&(*s2, *s1, midpoint)) { + // we enforce this inclusion by setting the lifetimes to be equal + inclusion_tracker.add_static_inclusion(*s1, *s2, midpoint); + inclusion_tracker.add_static_inclusion(*s2, *s1, midpoint); + + let annot = radium::Annotation::CopyLftName( + ty_translator.format_atomic_region(&lft1), + ty_translator.format_atomic_region(&lft2), + ); + annots.push(annot); + } + } + } + } + annots +} + +/// Get the regions appearing in a type. +fn get_regions_of_ty<'tcx>(env: &Environment<'tcx>, ty: Ty<'tcx>) -> HashSet<ty::RegionVid> { + let mut regions = HashSet::new(); + let mut clos = |r: ty::Region<'tcx>, _| match r.kind() { + ty::RegionKind::ReVar(rv) => { + regions.insert(rv); + r + }, + _ => r, + }; + let mut folder = ty::fold::RegionFolder::new(env.tcx(), &mut clos); + folder.fold_ty(ty); + regions +} diff --git a/rr_frontend/translation/src/regions/mod.rs b/rr_frontend/translation/src/regions/mod.rs index 2d0c4dec580986069673943870f2751c31e33c25..efbc908d42b6039b9560668d7ae02d73a20c294c 100644 --- a/rr_frontend/translation/src/regions/mod.rs +++ b/rr_frontend/translation/src/regions/mod.rs @@ -8,6 +8,8 @@ mod arg_folder; pub mod assignment; +pub mod calls; +pub mod composite; pub mod inclusion_tracker; pub mod init;