From 00723737488d913a2e510b1a1698dbe975ad118e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lennard=20G=C3=A4her?= <gaeher@mpi-sws.org>
Date: Tue, 16 Jul 2024 18:27:00 +0000
Subject: [PATCH] Fixes + frontend support for casts
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Co-authored-by: Lennard Gäher <l.gaeher@posteo.de>
---
 rr_frontend/radium/src/code.rs               | 11 +++
 rr_frontend/translation/src/function_body.rs | 73 ++++++++++++++++----
 rr_frontend/translation/src/lib.rs           | 14 ++--
 theories/rust_typing/ltypes.v                | 16 ++---
 4 files changed, 89 insertions(+), 25 deletions(-)

diff --git a/rr_frontend/radium/src/code.rs b/rr_frontend/radium/src/code.rs
index 8999c1b9..f60213eb 100644
--- a/rr_frontend/radium/src/code.rs
+++ b/rr_frontend/radium/src/code.rs
@@ -491,6 +491,9 @@ pub enum Unop {
 
     #[display("NotIntOp")]
     NotInt,
+
+    #[display("CastOp {}", _0)]
+    Cast(OpType),
 }
 
 #[derive(Clone, Eq, PartialEq, Debug)]
@@ -1081,6 +1084,14 @@ pub struct StaticMeta<'def> {
     pub ty: Type<'def>,
 }
 
+/// Information on a used const place
+#[derive(Clone, Debug)]
+pub struct ConstPlaceMeta<'def> {
+    pub ident: String,
+    pub loc_name: String,
+    pub ty: Type<'def>,
+}
+
 /// A `CaesiumFunctionBuilder` allows to incrementally construct the functions's code and the spec
 /// at the same time. It ensures that both definitions line up in the right way (for instance, by
 /// ensuring that other functions are linked up in a consistent way).
diff --git a/rr_frontend/translation/src/function_body.rs b/rr_frontend/translation/src/function_body.rs
index d91376c1..d8917956 100644
--- a/rr_frontend/translation/src/function_body.rs
+++ b/rr_frontend/translation/src/function_body.rs
@@ -183,7 +183,31 @@ impl<'def> ProcedureScope<'def> {
 
 /// Scope of consts that are available
 pub struct ConstScope<'def> {
-    pub statics: HashMap<DefId, radium::StaticMeta<'def>>,
+    // statics are explicitly declared
+    statics: HashMap<DefId, radium::StaticMeta<'def>>,
+    // const places are constants that lie in a static memory segment because they are referred to
+    // by-ref
+    const_places: HashMap<DefId, radium::ConstPlaceMeta<'def>>,
+}
+
+impl<'def> ConstScope<'def> {
+    /// Create a new const scope.
+    pub fn empty() -> Self {
+        Self {
+            statics: HashMap::new(),
+            const_places: HashMap::new(),
+        }
+    }
+
+    /// Register a static
+    pub fn register_static(&mut self, did: DefId, meta: radium::StaticMeta<'def>) {
+        self.statics.insert(did, meta);
+    }
+
+    /// Register a const place
+    pub fn register_const_place(&mut self, did: DefId, meta: radium::ConstPlaceMeta<'def>) {
+        self.const_places.insert(did, meta);
+    }
 }
 
 // solve the constraints for the new_regions
@@ -3260,10 +3284,16 @@ impl<'a, 'def: 'a, 'tcx: 'def> BodyTranslator<'a, 'def, 'tcx> {
                 }
             },
 
-            Rvalue::Cast(kind, op, ty) => {
+            Rvalue::Cast(kind, op, to_ty) => {
                 let op_ty = self.get_type_of_operand(op);
+                let op_st = self.ty_translator.translate_type_to_syn_type(op_ty)?;
+                let op_ot = self.ty_translator.translate_syn_type_to_op_type(&op_st);
+
                 let translated_op = self.translate_operand(op, true)?;
 
+                let target_st = self.ty_translator.translate_type_to_syn_type(*to_ty)?;
+                let target_ot = self.ty_translator.translate_syn_type_to_op_type(&target_st);
+
                 match kind {
                     mir::CastKind::PointerCoercion(x) => {
                         match x {
@@ -3292,9 +3322,11 @@ impl<'a, 'def: 'a, 'tcx: 'def> BodyTranslator<'a, 'def, 'tcx> {
                     }),
 
                     mir::CastKind::IntToInt => {
-                        // TODO
-                        Err(TranslationError::Unimplemented {
-                            description: "RefinedRust does currently not support int-to-int cast".to_owned(),
+                        // Cast integer to integer
+                        Ok(radium::Expr::UnOp {
+                            o: radium::Unop::Cast(target_ot),
+                            ot: op_ot,
+                            e: Box::new(translated_op),
                         })
                     },
 
@@ -3311,7 +3343,7 @@ impl<'a, 'def: 'a, 'tcx: 'def> BodyTranslator<'a, 'def, 'tcx> {
                     }),
 
                     mir::CastKind::PtrToPtr => {
-                        match (op_ty.kind(), ty.kind()) {
+                        match (op_ty.kind(), to_ty.kind()) {
                             (TyKind::RawPtr(_), TyKind::RawPtr(_)) => {
                                 // Casts between raw pointers are NOPs for us
                                 Ok(translated_op)
@@ -3343,14 +3375,21 @@ impl<'a, 'def: 'a, 'tcx: 'def> BodyTranslator<'a, 'def, 'tcx> {
                         ),
                     }),
 
-                    mir::CastKind::PointerExposeAddress | mir::CastKind::PointerFromExposedAddress => {
-                        Err(TranslationError::UnsupportedFeature {
-                            description: format!(
-                                "RefinedRust does currently not support this kind of cast (got: {:?})",
-                                rval
-                            ),
+                    mir::CastKind::PointerExposeAddress => {
+                        // Cast pointer to integer
+                        Ok(radium::Expr::UnOp {
+                            o: radium::Unop::Cast(target_ot),
+                            ot: radium::OpType::Ptr,
+                            e: Box::new(translated_op),
                         })
                     },
+
+                    mir::CastKind::PointerFromExposedAddress => Err(TranslationError::UnsupportedFeature {
+                        description: format!(
+                            "RefinedRust does currently not support this kind of cast (got: {:?})",
+                            rval
+                        ),
+                    }),
                 }
             },
 
@@ -3549,7 +3588,15 @@ impl<'a, 'def: 'a, 'tcx: 'def> BodyTranslator<'a, 'def, 'tcx> {
                             self.collected_statics.insert(did);
                             Ok(radium::Expr::Literal(radium::Literal::Loc(s.loc_name.clone())))
                         },
-
+                        middle::mir::interpret::GlobalAlloc::Memory(alloc) => {
+                            // TODO: this is needed
+                            Err(TranslationError::UnsupportedFeature {
+                                description: format!(
+                                    "RefinedRust does currently not support GlobalAlloc {:?} for scalar {:?} at type {:?}",
+                                    glob_alloc, sc, ty
+                                ),
+                            })
+                        },
                         _ => Err(TranslationError::UnsupportedFeature {
                             description: format!(
                                 "RefinedRust does currently not support GlobalAlloc {:?} for scalar {:?} at type {:?}",
diff --git a/rr_frontend/translation/src/lib.rs b/rr_frontend/translation/src/lib.rs
index bb8f4d4d..15851881 100644
--- a/rr_frontend/translation/src/lib.rs
+++ b/rr_frontend/translation/src/lib.rs
@@ -983,6 +983,14 @@ fn get_attributes_of_function<'a>(env: &'a Environment, did: DefId) -> Vec<&'a a
         let filtered_impl_attrs = utils::filter_tool_attrs(impl_attrs);
         filtered_attrs.extend(filtered_impl_attrs.into_iter().filter(|x| propagate_attr_from_impl(x)));
     }
+
+    // for closures, propagate from the surrounding function
+    if env.tcx().is_closure(did) {
+        let parent_did = env.tcx().parent(did);
+        let parent_attrs = get_attributes_of_function(env, parent_did);
+        filtered_attrs.extend(parent_attrs.into_iter().filter(|x| propagate_attr_from_impl(x)));
+    }
+
     filtered_attrs
 }
 
@@ -1152,7 +1160,7 @@ pub fn register_consts<'rcx, 'tcx>(vcx: &mut VerificationCtxt<'tcx, 'rcx>) -> Re
                     loc_name,
                     ty: translated_ty,
                 };
-                vcx.const_registry.statics.insert(s.to_def_id(), meta);
+                vcx.const_registry.register_static(s.to_def_id(), meta);
             },
             Err(e) => {
                 println!("Warning: static {:?} has unsupported type, skipping: {:?}", s, e);
@@ -1305,9 +1313,7 @@ where
         coq_path_prefix: path_prefix,
         shim_registry,
         dune_package: package,
-        const_registry: ConstScope {
-            statics: HashMap::new(),
-        },
+        const_registry: ConstScope::empty(),
     };
 
     register_functions(&mut vcx)?;
diff --git a/theories/rust_typing/ltypes.v b/theories/rust_typing/ltypes.v
index 4937377e..88b01333 100644
--- a/theories/rust_typing/ltypes.v
+++ b/theories/rust_typing/ltypes.v
@@ -4272,18 +4272,18 @@ Ltac simp_ltype_st Heq :=
 
 Ltac simp_ltype :=
   match goal with
-  | |- context[ltype_core ?lt] =>
+  | |- context[@ltype_core _ _ ?rt ?lt] =>
       assert_fails (is_var lt);
       let ltc := fresh "ltc" in
       let Heq := fresh "Heq_lt" in
-      remember (ltype_core lt) as ltc eqn:Heq;
+      remember (ltype_core (rt:=rt) lt) as ltc eqn:Heq;
       simp_ltype_core Heq;
       subst ltc
-  | |- context[ltype_st ?lt] =>
+  | |- context[@ltype_st _ _ ?rt ?lt] =>
       assert_fails (is_var lt);
       let ltc := fresh "ltc" in
       let Heq := fresh "Heq_lt" in
-      remember (ltype_st lt) as ltc eqn:Heq;
+      remember (ltype_st (rt:=rt) lt) as ltc eqn:Heq;
       simp_ltype_st Heq;
       subst ltc
   end.
@@ -4291,18 +4291,18 @@ Ltac simp_ltypes := repeat simp_ltype.
 
 Tactic Notation "simp_ltype" "in" hyp(H) :=
   match type of H with
-  | context[ltype_core ?lt] =>
+  | context[@ltype_core _ _ ?rt ?lt] =>
       assert_fails (is_var lt);
       let ltc := fresh "ltc" in
       let Heq := fresh "Heq_lt" in
-      remember (ltype_core lt) as ltc eqn:Heq in H;
+      remember (ltype_core (rt:=rt) lt) as ltc eqn:Heq in H;
       simp_ltype_core Heq;
       subst ltc
-  | context[ltype_st ?lt] =>
+  | context[@ltype_st _ _ ?rt ?lt] =>
       assert_fails (is_var lt);
       let ltc := fresh "ltc" in
       let Heq := fresh "Heq_lt" in
-      remember (ltype_st lt) as ltc eqn:Heq in H;
+      remember (ltype_st (rt:=rt) lt) as ltc eqn:Heq in H;
       simp_ltype_st Heq;
       subst ltc
   end.
-- 
GitLab