From 5750145a79732bc38f89f48415e3f8c11fde95af Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lennard=20G=C3=A4her?= <l.gaeher@posteo.de>
Date: Mon, 10 Mar 2025 17:24:37 +0100
Subject: [PATCH] generate inclusion proofs

---
 case_studies/tests/src/trait_deps.rs          | 122 ++++++-------
 rr_frontend/radium/src/coq/command.rs         |  20 +++
 rr_frontend/radium/src/specs.rs               | 161 ++++++++++++++++--
 rr_frontend/translation/src/lib.rs            |  93 ++++++++--
 rr_frontend/translation/src/shims/registry.rs |   5 +
 theories/rust_typing/automation.v             |  16 +-
 theories/rust_typing/automation/solvers.v     |  57 ++++---
 theories/rust_typing/functions.v              |   4 +-
 8 files changed, 360 insertions(+), 118 deletions(-)

diff --git a/case_studies/tests/src/trait_deps.rs b/case_studies/tests/src/trait_deps.rs
index ac52631..8eb268a 100644
--- a/case_studies/tests/src/trait_deps.rs
+++ b/case_studies/tests/src/trait_deps.rs
@@ -61,43 +61,6 @@ mod dep2 {
     }
 }
 
-mod dep3 {
-    trait Bar {
-
-    }
-
-    trait Foo<T: Bar> {
-
-        #[rr::verify]
-        fn foofoo(x: T);
-    }
-
-    impl Bar for i32 {
-
-    }
-
-    // the `T: Bar` can be directly dispatched with a concrete instance
-    // TODO this does not work currently.
-    // We should maybe make the spec still be parametric, but then instantiate that in the lemma
-    // statement with the statically known instance.
-    #[rr::skip]
-    impl Foo<i32> for i32 {
-
-        #[rr::default_spec]
-        fn foofoo(x: i32) {
-
-        }
-    }
-
-    // parametric dispatch
-    impl<T: Bar> Foo<T> for u32 {
-
-        #[rr::default_spec]
-        fn foofoo(x: T) {
-
-        }
-    }
-}
 
 /// Check that lifetime parameters are resolved correctly.
 mod dep6 {
@@ -116,27 +79,6 @@ mod dep6 {
     }
 }
 
-mod dep7 {
-    trait Foo {
-
-    }
-
-    #[rr::verify]
-    fn bla<T : Foo>(x: T) {
-
-    }
-
-    impl<'a> Foo for &'a i32 {
-
-    }
-
-    #[rr::verify]
-    fn call_bla() {
-        let x = 42;
-        bla(&x);
-    }
-}
-
 
 /// HRTB tests
 mod dep5 {
@@ -199,6 +141,28 @@ mod dep5 {
 }
 
 
+mod dep7 {
+    trait Foo {
+
+    }
+
+    #[rr::verify]
+    fn bla<T : Foo>(x: T) {
+
+    }
+
+    impl<'a> Foo for &'a i32 {
+
+    }
+
+    #[rr::verify]
+    fn call_bla() {
+        let x = 42;
+        bla(&x);
+    }
+}
+
+/*
 mod dep4 {
 
     trait Bar {
@@ -214,3 +178,45 @@ mod dep4 {
         fn foo() -> Self::BT;
     }
 }
+
+*/
+
+/*
+mod dep3 {
+    trait Bar {
+
+    }
+
+    trait Foo<T: Bar> {
+
+        #[rr::verify]
+        fn foofoo(x: T);
+    }
+
+    impl Bar for i32 {
+
+    }
+
+    // the `T: Bar` can be directly dispatched with a concrete instance
+    // TODO this does not work currently.
+    // We should maybe make the spec still be parametric, but then instantiate that in the lemma
+    // statement with the statically known instance.
+    //#[rr::skip]
+    impl Foo<i32> for i32 {
+
+        #[rr::default_spec]
+        fn foofoo(x: i32) {
+
+        }
+    }
+
+    // parametric dispatch
+    impl<T: Bar> Foo<T> for u32 {
+
+        #[rr::default_spec]
+        fn foofoo(x: T) {
+
+        }
+    }
+}
+*/
diff --git a/rr_frontend/radium/src/coq/command.rs b/rr_frontend/radium/src/coq/command.rs
index 3b01a2f..21cc7e4 100644
--- a/rr_frontend/radium/src/coq/command.rs
+++ b/rr_frontend/radium/src/coq/command.rs
@@ -155,6 +155,12 @@ pub enum Command {
     #[display("{}", _0)]
     Definition(Definition),
 
+    /// The [`Lemma`] command.
+    ///
+    /// [`Lemma`]: https://coq.inria.fr/doc/v8.20/refman/language/core/definitions.html#coq:cmd.Lemma
+    #[display("{}", _0)]
+    Lemma(Lemma),
+
     /// The [`Section`] command.
     ///
     /// [`Section`]: https://coq.inria.fr/doc/v8.20/refman/language/core/sections.html#using-sections
@@ -230,3 +236,17 @@ impl Display for Definition {
         }
     }
 }
+
+/// A Rocq lemma declaration.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct Lemma {
+    pub name: String,
+    pub params: binder::BinderList,
+    pub ty: term::Type,
+}
+
+impl Display for Lemma {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "Lemma {} {} : {}\n", self.name, self.params, self.ty)
+    }
+}
diff --git a/rr_frontend/radium/src/specs.rs b/rr_frontend/radium/src/specs.rs
index ac90a95..ee53e91 100644
--- a/rr_frontend/radium/src/specs.rs
+++ b/rr_frontend/radium/src/specs.rs
@@ -4833,6 +4833,8 @@ pub struct LiteralTraitImpl {
     pub spec_attrs_record: String,
     /// The name of the proof that the base spec is implied by the more specific spec
     pub spec_subsumption_proof: String,
+    /// The name of the definition for the lemma statement
+    pub spec_subsumption_statement: String,
 }
 pub type LiteralTraitImplRef<'def> = &'def LiteralTraitImpl;
 
@@ -4881,18 +4883,70 @@ impl<'def> TraitRefInst<'def> {
     /// Get the term for referring to the attr record of this impl
     /// The parameters are expected to be in scope.
     #[must_use]
-    fn get_attr_record_term(&self) -> String {
+    fn get_attr_record_term(&self) -> coq::term::Gallina {
         let attr_record = &self.impl_ref.spec_attrs_record;
 
-        let mut attr_term = String::with_capacity(100);
-        write!(attr_term, "{attr_record}").unwrap();
+        let binders = self.generics.get_all_ty_params_with_assocs().get_coq_ty_rt_params();
+        let args = binders.make_using_terms();
 
-        // add the type parameters of the impl
-        for ty in self.generics.get_all_ty_params_with_assocs().params {
-            write!(attr_term, " {}", ty.refinement_type).unwrap();
+        coq::term::Gallina::App(Box::new(coq::term::App::new(
+            coq::term::Gallina::Literal(attr_record.to_owned()),
+            args,
+        )))
+    }
+
+    /// Get the term for referring to the spec record of this impl
+    /// The parameters are expected to be in scope.
+    #[must_use]
+    fn get_spec_record_term(&self) -> coq::term::Gallina {
+        let spec_record = &self.impl_ref.spec_record;
+
+        let tys = self.generics.get_all_ty_params_with_assocs();
+        let mut binders = tys.get_coq_ty_params();
+        binders.append(self.generics.get_all_attr_trait_parameters(false).0);
+        let args = binders.make_using_terms();
+
+        let mut specialized_spec = coq::term::App::new(spec_record.to_owned(), args).to_string();
+
+        // specialize to semtys
+        push_str_list!(specialized_spec, &tys.params, " ", |x| { format!("<TY> {}", x.type_term) });
+        // specialize to lfts
+        push_str_list!(specialized_spec, self.generics.get_lfts(), " ", |x| { format!("<LFT> {}", x) });
+        specialized_spec.push_str(" <INST!>");
+
+        coq::term::Gallina::Literal(specialized_spec)
+        //coq::term::Gallina::App(Box::new(coq::term::App::new(coq::term::Gallina::Literal(spec_record.
+        // to_owned()), args)))
+    }
+
+    #[must_use]
+    fn get_base_spec_term(&self) -> coq::term::Gallina {
+        let spec_record = &self.of_trait.base_spec;
+
+        let all_args = self.get_ordered_params_inst();
+
+        let mut specialized_spec = String::new();
+        specialized_spec.push_str(&format!("({spec_record} "));
+        // specialize to rts
+        push_str_list!(specialized_spec, &all_args, " ", |x| { format!("{}", x.get_rfn_type()) });
+        // specialize to sts
+        specialized_spec.push(' ');
+        push_str_list!(specialized_spec, &all_args, " ", |x| { format!("{}", SynType::from(x)) });
+
+        // specialize to further args
+        specialized_spec.push_str(&format!(" {}", self.get_attr_record_term()));
+        for req in self.trait_inst.get_direct_trait_requirements() {
+            // get attrs + spec term
+            specialized_spec.push_str(&format!(" {}", req.get_attr_term()));
         }
 
-        attr_term
+        // specialize to semtys
+        push_str_list!(specialized_spec, &all_args, " ", |x| { format!("<TY> {}", x) });
+        // specialize to lfts
+        push_str_list!(specialized_spec, self.trait_inst.get_lfts(), " ", |x| { format!("<LFT> {}", x) });
+        specialized_spec.push_str(" <INST!>)");
+
+        coq::term::Gallina::Literal(specialized_spec)
     }
 
     /// Get the term for referring to an item of the attr record of this impl.
@@ -4900,10 +4954,7 @@ impl<'def> TraitRefInst<'def> {
     #[must_use]
     pub fn get_attr_record_item_term(&self, attr: &str) -> coq::term::Gallina {
         let item_name = self.of_trait.make_spec_attr_name(attr);
-        coq::term::Gallina::RecordProj(
-            Box::new(coq::term::Gallina::Literal(self.get_attr_record_term())),
-            item_name,
-        )
+        coq::term::Gallina::RecordProj(Box::new(self.get_attr_record_term()), item_name)
     }
 }
 
@@ -4965,17 +5016,90 @@ impl<'def> TraitImplSpec<'def> {
             body: attr_record_term,
         }
     }
+
+    #[must_use]
+    pub fn generate_lemma_statement(&self) -> coq::Document {
+        let mut doc = coq::Document::default();
+
+        let spec_name = &self.trait_ref.impl_ref.spec_subsumption_statement;
+
+        // generate the lemma statement
+        // get parameters
+        // this is parametric in the rts, sts, semtys attrs of all trait deps.
+        let ty_params = self.trait_ref.generics.get_all_ty_params_with_assocs();
+        let mut params = ty_params.get_coq_ty_params();
+        params.append(self.trait_ref.generics.get_all_attr_trait_parameters(false).0);
+
+        // instantiation of the trait
+        let params_inst = self.trait_ref.get_ordered_params_inst();
+
+        let incl_name = self.trait_ref.of_trait.spec_incl_name();
+        let own_spec = self.trait_ref.get_spec_record_term();
+        let base_spec = self.trait_ref.get_base_spec_term();
+
+        let scope = &self.trait_ref.generics;
+        let mut ty_term = format!("trait_incl_marker (lift_trait_incl {incl_name} (");
+        scope.format(&mut ty_term, false, false, &[], &[], &[]).unwrap();
+        ty_term.push_str(&format!(" {own_spec}) ("));
+        scope.format(&mut ty_term, false, false, &[], &[], &[]).unwrap();
+        ty_term.push_str(&format!(" {base_spec}))"));
+
+        let lem = coq::command::Definition {
+            name: spec_name.to_owned(),
+            params,
+            ty: None,
+            body: coq::term::Gallina::Literal(ty_term),
+        };
+        doc.push(coq::command::Command::Definition(lem));
+
+        doc
+    }
+
+    #[must_use]
+    pub fn generate_proof(&self) -> coq::Document {
+        let mut doc = coq::Document::default();
+
+        let lemma_name = &self.trait_ref.impl_ref.spec_subsumption_proof;
+
+        // generate the lemma statement
+        // get parameters
+        // this is parametric in the rts, sts, semtys attrs of all trait deps.
+        let ty_params = self.trait_ref.generics.get_all_ty_params_with_assocs();
+        let mut params = ty_params.get_coq_ty_params();
+        params.append(self.trait_ref.generics.get_all_attr_trait_parameters(false).0);
+
+        let mut ty_term = format!("{} ", self.trait_ref.impl_ref.spec_subsumption_statement);
+        push_str_list!(ty_term, &params.make_using_terms(), " ");
+
+        let lem = coq::command::Lemma {
+            name: lemma_name.to_owned(),
+            params,
+            ty: coq::term::Type::Literal(ty_term),
+        };
+        doc.push(coq::command::Command::Lemma(lem));
+
+        doc.push(coq::command::Command::Proof);
+        let prelude_tac = format!(
+            "unfold {}; solve_trait_incl_prelude",
+            self.trait_ref.impl_ref.spec_subsumption_statement
+        );
+        doc.push(coq::ltac::LTac::Literal(prelude_tac));
+        doc.push(coq::ltac::LTac::Literal("all: repeat liRStep; liShow".to_owned()));
+        doc.push(coq::ltac::LTac::Literal("all: print_remaining_trait_goal".to_owned()));
+        doc.push(coq::ltac::LTac::Literal("Unshelve".to_owned()));
+        doc.push(coq::ltac::LTac::Literal("all: sidecond_solver".to_owned()));
+        doc.push(coq::ltac::LTac::Literal("Unshelve".to_owned()));
+        doc.push(coq::ltac::LTac::Literal("all: sidecond_hammer".to_owned()));
+        doc.push(coq::command::Command::Qed);
+
+        doc
+    }
 }
 
 impl<'def> Display for TraitImplSpec<'def> {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
-        // TODO: what's the right choice here?
-        //let assoc_types = self.trait_ref.generics.get_ass();
         let assoc_types = Vec::new();
 
-        // TODO: figure out how we deal with the surrounding scope and its trait requiremeents
-        // here.
-
         // instantiate with the parameter and associated types
         let params_inst = self.trait_ref.get_ordered_params_inst();
 
@@ -4997,6 +5121,9 @@ impl<'def> Display for TraitImplSpec<'def> {
             ),
         );
 
+        instance.append(&mut self.generate_lemma_statement().0);
+        //instance.append(&mut self.generate_proof().0);
+
         write!(
             f,
             "{}\n",
@@ -5049,7 +5176,7 @@ impl<'def> InstantiatedTraitFunctionSpec<'def> {
 
         // also instantiate with the attrs that are quantified on the outside
         let attr_term = self.trait_ref.get_attr_record_term();
-        params.push(attr_term);
+        params.push(attr_term.to_string());
 
         // instantiate with the attrs of trait requirements
         for trait_req in self.trait_ref.trait_inst.get_direct_trait_requirements() {
diff --git a/rr_frontend/translation/src/lib.rs b/rr_frontend/translation/src/lib.rs
index b8729ea..7c2b116 100644
--- a/rr_frontend/translation/src/lib.rs
+++ b/rr_frontend/translation/src/lib.rs
@@ -286,6 +286,7 @@ impl<'tcx, 'rcx> VerificationCtxt<'tcx, 'rcx> {
             spec_attrs_record: decl.trait_ref.impl_ref.spec_attrs_record.clone(),
             spec_record: decl.trait_ref.impl_ref.spec_record.clone(),
             spec_subsumption_proof: decl.trait_ref.impl_ref.spec_subsumption_proof.clone(),
+            spec_subsumption_statement: decl.trait_ref.impl_ref.spec_subsumption_statement.clone(),
         };
 
         Some(a)
@@ -681,26 +682,38 @@ impl<'tcx, 'rcx> VerificationCtxt<'tcx, 'rcx> {
     }
 
     /// Write proofs for a verification unit.
-    fn write_proofs<F>(&self, file_path: F, stem: &str)
+    fn write_proofs<F, G>(
+        &self,
+        proof_dir_path: &Path,
+        file_path: F,
+        trait_file_path: G,
+        stem: &str,
+    ) -> Vec<String>
     where
-        F: Fn(&str) -> PathBuf,
+        F: Fn(&str) -> String,
+        G: Fn(&str) -> String,
     {
         let common_imports = vec![
             coq::module::Import::new(vec!["lang", "notation"]).from(vec!["caesium"]),
             coq::module::Import::new(vec!["typing", "shims"]).from(vec!["refinedrust"]),
         ];
 
+        let mut proof_modules = Vec::new();
+
         // write proofs
         // each function gets a separate file in order to parallelize
         for (did, fun) in self.procedure_registry.iter_code() {
-            let path = file_path(fun.name());
+            let module_path = file_path(fun.name());
+            let path = proof_dir_path.join(format!("{module_path}.v"));
 
-            if path.exists() {
-                info!("Proof file for function {} already exists, skipping creation", fun.name());
+            if !self.check_function_needs_proof(*did, fun) {
                 continue;
             }
 
-            if !self.check_function_needs_proof(*did, fun) {
+            proof_modules.push(module_path);
+
+            if path.exists() {
+                info!("Proof file for function {} already exists, skipping creation", fun.name());
                 continue;
             }
 
@@ -736,6 +749,55 @@ impl<'tcx, 'rcx> VerificationCtxt<'tcx, 'rcx> {
 
             writeln!(proof_file, "End proof.").unwrap();
         }
+
+        // write file for trait incl
+        for spec in self.trait_impls.values() {
+            //writeln!(spec_file, "{spec}").unwrap();
+
+            let name = &spec.trait_ref.impl_ref.spec_subsumption_proof;
+            let module_path = trait_file_path(name.as_str());
+            let path = proof_dir_path.join(format!("{module_path}.v"));
+
+            proof_modules.push(module_path);
+
+            if path.exists() {
+                info!("Proof file for trait impl {} already exists, skipping creation", name);
+                continue;
+            }
+
+            info!("Proof file for trait impl {} does not yet exist, creating", name);
+
+            let mut proof_file = io::BufWriter::new(File::create(path.as_path()).unwrap());
+
+            let mut imports = common_imports.clone();
+
+            imports.append(&mut vec![
+                coq::module::Import::new(vec![&format!("generated_specs_{stem}")]).from(vec![
+                    &self.coq_path_prefix,
+                    stem,
+                    "generated",
+                ]),
+            ]);
+
+            writeln!(proof_file, "{}", coq::module::ImportList(&imports)).unwrap();
+
+            // Note: we do not export the self.extra_exports explicitly, as we rely on them
+            // being re-exported from the template -- we want to be stable under changes of the
+            // extras
+
+            writeln!(proof_file, "Set Default Proof Using \"Type\".").unwrap();
+            writeln!(proof_file).unwrap();
+
+            writeln!(proof_file, "Section proof.").unwrap();
+            writeln!(proof_file, "Context `{{RRGS : !refinedrustGS Σ}}.").unwrap();
+            writeln!(proof_file).unwrap();
+
+            write!(proof_file, "{}", spec.generate_proof()).unwrap();
+
+            writeln!(proof_file, "End proof.").unwrap();
+        }
+
+        proof_modules
     }
 
     /// Write Coq files for this verification unit.
@@ -909,16 +971,14 @@ impl<'tcx, 'rcx> VerificationCtxt<'tcx, 'rcx> {
             fs::create_dir_all(proof_dir_path).unwrap();
         }
 
-        self.write_proofs(|name| proof_dir_path.join(format!("proof_{name}.v")), stem);
-
         // explicitly spell out the proof modules we want to compile so we don't choke on stale
         // proof files
-        let mut proof_modules = Vec::new();
-        for (did, fun) in self.procedure_registry.iter_code() {
-            if self.check_function_needs_proof(*did, fun) {
-                proof_modules.push(format!("proof_{}", fun.name()));
-            }
-        }
+        let proof_modules = self.write_proofs(
+            proof_dir_path,
+            |name| format!("proof_{name}"),
+            |name| format!("trait_incl_{name}"),
+            stem,
+        );
 
         // write proof dune file
         let proof_dune_path = proof_dir_path.join("dune");
@@ -1036,6 +1096,7 @@ fn register_shims<'tcx>(vcx: &mut VerificationCtxt<'tcx, '_>) -> Result<(), base
             shim.spec_params_record.clone(),
             shim.spec_attrs_record.clone(),
             shim.spec_subsumption_proof.clone(),
+            shim.spec_subsumption_statement.clone(),
         );
         vcx.trait_registry.register_impl_shim(did, impl_lit)?;
 
@@ -1408,13 +1469,15 @@ fn register_trait_impls(vcx: &VerificationCtxt<'_, '_>) -> Result<(), String> {
             let spec_name = format!("{base_name}_spec");
             let spec_params_name = format!("{base_name}_spec_params");
             let spec_attrs_name = format!("{base_name}_spec_attrs");
-            let proof_name = format!("{base_name}_spec_subsumption");
+            let proof_name = format!("{base_name}_spec_subsumption_correct");
+            let proof_statement = format!("{base_name}_spec_subsumption");
 
             let impl_lit = radium::LiteralTraitImpl {
                 spec_record: spec_name,
                 spec_params_record: spec_params_name,
                 spec_attrs_record: spec_attrs_name,
                 spec_subsumption_proof: proof_name,
+                spec_subsumption_statement: proof_statement,
             };
             vcx.trait_registry
                 .register_impl_shim(did, impl_lit)
diff --git a/rr_frontend/translation/src/shims/registry.rs b/rr_frontend/translation/src/shims/registry.rs
index e693198..81d1b18 100644
--- a/rr_frontend/translation/src/shims/registry.rs
+++ b/rr_frontend/translation/src/shims/registry.rs
@@ -85,6 +85,8 @@ struct ShimTraitImplEntry {
     spec_attrs_record: String,
     /// the Coq lemma name of the spec subsumption proof
     spec_subsumption_proof: String,
+    /// the Coq definition giving the lemma statement for the subsumption
+    spec_subsumption_statement: String,
 }
 
 /// A file entry for a trait method implementation.
@@ -162,6 +164,7 @@ pub struct TraitImplShim {
     pub spec_params_record: String,
     pub spec_attrs_record: String,
     pub spec_subsumption_proof: String,
+    pub spec_subsumption_statement: String,
 }
 impl From<TraitImplShim> for ShimTraitImplEntry {
     fn from(shim: TraitImplShim) -> Self {
@@ -174,6 +177,7 @@ impl From<TraitImplShim> for ShimTraitImplEntry {
             spec_params_record: shim.spec_params_record,
             spec_attrs_record: shim.spec_attrs_record,
             spec_subsumption_proof: shim.spec_subsumption_proof,
+            spec_subsumption_statement: shim.spec_subsumption_statement,
         }
     }
 }
@@ -437,6 +441,7 @@ impl<'a> SR<'a> {
                         spec_params_record: b.spec_params_record,
                         spec_attrs_record: b.spec_attrs_record,
                         spec_subsumption_proof: b.spec_subsumption_proof,
+                        spec_subsumption_statement: b.spec_subsumption_statement,
                     };
 
                     self.trait_impl_shims.push(entry);
diff --git a/theories/rust_typing/automation.v b/theories/rust_typing/automation.v
index afb7df7..9bce5a6 100644
--- a/theories/rust_typing/automation.v
+++ b/theories/rust_typing/automation.v
@@ -1083,7 +1083,7 @@ Ltac sidecond_solver :=
 Ltac solve_function_subtype_hook ::=
   rewrite /function_subtype;
   iStartProof;
-  unshelve (repeat liRStep);
+  unshelve (repeat liRStep; solve[fail]);
   unshelve (sidecond_solver);
   sidecond_hammer
 .
@@ -1104,3 +1104,17 @@ Ltac print_remaining_sidecond :=
   | H := FUNCTION_NAME ?s |- _ =>
     print_remaining_shelved_goal s
   end.
+
+(* Prelude for trait incl files *)
+Ltac solve_trait_incl_prelude :=
+  solve_trait_incl_prepare;
+  solve_trait_incl_core;
+  first [
+    rewrite /function_subtype;
+    iStartProof
+  | fast_done].
+Ltac print_remaining_trait_goal :=
+  match goal with
+  | |- _ =>
+  idtac "Type system got stuck while proving trait inclusion"; print_goal; admit
+  end.
diff --git a/theories/rust_typing/automation/solvers.v b/theories/rust_typing/automation/solvers.v
index 4025639..f09b374 100644
--- a/theories/rust_typing/automation/solvers.v
+++ b/theories/rust_typing/automation/solvers.v
@@ -3614,34 +3614,41 @@ Ltac strip_all_applied_params a acc cont :=
       strip_all_applied_params a1 uconstr:(a2 +:: acc) cont
   | _ => cont a acc
   end.
-Ltac solve_trait_incl :=
+Ltac solve_trait_incl_prepare :=
   lazymatch goal with
   | |- trait_incl_marker ?P =>
       rewrite trait_incl_marker_unfold;
       let κs := fresh in let tys := fresh in
       intros κs tys;
       destruct_product_hypothesis κs κs;
-      destruct_product_hypothesis tys tys;
-      (* check if we can decompose the first term *)
-      lazymatch goal with
-      | |- ?incl ?spec1 ?spec2 =>
-        first [
-            decompose_instantiated_spec constr:(spec1) ltac:(fun spec1 spec1_tys spec1_lfts =>
-            (* look for an assumption we can specialize *)
-            is_var spec1;
-            prove_trait_incl_for spec1 spec1_tys spec1_lfts ltac:(fun t1 t2 H2 =>
-              (* TODO: ideally, we should use transitivity instead and then go on *)
-              apply H2
-            ))
-          | (* directly solve the inclusion *)
-            (* first unfold the inclusion *)
-            strip_all_applied_params incl (hnil id) ltac:(fun a _ =>
-              unfold a;
-              intros;
-              split_and?;
-              intros;
-              first [solve_function_subtype | done ]
-            )
-          ]
-      end
-end.
+      destruct_product_hypothesis tys tys
+  end.
+Ltac solve_trait_incl_core :=
+  lazymatch goal with
+    | |- ?incl ?spec1 ?spec2 =>
+      first [
+          decompose_instantiated_spec constr:(spec1) ltac:(fun spec1 spec1_tys spec1_lfts =>
+          (* look for an assumption we can specialize *)
+          is_var spec1;
+          prove_trait_incl_for spec1 spec1_tys spec1_lfts ltac:(fun t1 t2 H2 =>
+            (* TODO: ideally, we should use transitivity instead and then go on *)
+            apply H2
+          ))
+        | (* directly solve the inclusion *)
+          (* first unfold the inclusion *)
+          strip_all_applied_params incl (hnil id) ltac:(fun a _ =>
+            unfold a;
+            intros;
+            split_and?;
+            intros
+          )
+        ]
+    end.
+Ltac solve_trait_incl_fn :=
+  first [solve_function_subtype | done ].
+
+Ltac solve_trait_incl :=
+  solve_trait_incl_prepare;
+  solve_trait_incl_core;
+  try solve_trait_incl_fn.
+
diff --git a/theories/rust_typing/functions.v b/theories/rust_typing/functions.v
index 58b0803..8a6b84b 100644
--- a/theories/rust_typing/functions.v
+++ b/theories/rust_typing/functions.v
@@ -668,8 +668,8 @@ Section function_subsume.
         inhale (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_R) π;
         inhale (vr ◁ᵥ{π} (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_ref) @ (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_ty));
         ∃ b2,
-        exhale (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_R) π;
         exhale (vr ◁ᵥ{π} (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_ref) @ (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_ty));
+        exhale (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_R) π;
         done
       | return T
     .
@@ -817,8 +817,8 @@ Section function_subsume.
         inhale (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_R) π;
         inhale (vr ◁ᵥ{π} (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_ref) @ (((F1 κs tys).(fn_p) a).(fp_fr) a2).(fr_ty));
         ∃ b2,
-        exhale (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_R) π;
         exhale (vr ◁ᵥ{π} (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_ref) @ (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_ty));
+        exhale (((F2 κs tys).(fn_p) b).(fp_fr) b2).(fr_R) π;
         done
     | exhale ⌜l1 = l2⌝; return T.
   Proof.
-- 
GitLab