Theory Sepref_Acconstraint

theory Sepref_Acconstraint
  imports Refine_Imperative_HOL.IICF Timed_Automata.Timed_Automata
begin

  subsection ‹Refinement Assertion›
  fun acconstraint_assn where
(*
    "acconstraint_assn A (E1 x) (E1 x') = A x x'"
  | "acconstraint_assn A (E2 x) (E2 x') = A x x'"
  | "acconstraint_assn A (E3) (E3) = emp"
*)
  "acconstraint_assn A B (LT x y) (LT x' y') = A x x' * B y y'"
  | "acconstraint_assn A B (LE x y) (LE x' y') = A x x' * B y y'"
  | "acconstraint_assn A B (EQ x y) (EQ x' y') = A x x' * B y y'"
  | "acconstraint_assn A B (GE x y) (GE x' y') = A x x' * B y y'"
  | "acconstraint_assn A B (GT x y) (GT x' y') = A x x' * B y y'"
(*
  | "acconstraint_assn A (E5 x y) (E5 x' y') = bool_assn x x' * A y y'"
*)
  | "acconstraint_assn _ _ _ _ = false"


  fun acconstraint_relp where
  "acconstraint_relp A B (LT x y) (LT x' y')  A x x'  B y y'"
  | "acconstraint_relp A B (LE x y) (LE x' y')  A x x'  B y y'"
  | "acconstraint_relp A B (EQ x y) (EQ x' y')  A x x'  B y y'"
  | "acconstraint_relp A B (GE x y) (GE x' y')  A x x'  B y y'"
  | "acconstraint_relp A B (GT x y) (GT x' y')  A x x'  B y y'"
  | "acconstraint_relp _ _ _ _  False"

  definition [to_relAPP]: "acconstraint_rel A B  p2rel (acconstraint_relp (rel2p A) (rel2p B))"
  
  lemma aconstraint_assn_pure_conv[constraint_simps]:
    "acconstraint_assn (pure A) (pure B)  pure (A,B acconstraint_rel)"
  apply (rule eq_reflection)
  apply (intro ext)
  subgoal for a b
  apply (cases a; cases b; simp add: acconstraint_rel_def pure_def p2rel_def rel2p_def)
  done
  done

  lemmas [sepref_import_rewrite, sepref_frame_normrel_eqs, fcomp_norm_unfold] =
    aconstraint_assn_pure_conv[symmetric]

  text ‹You might want to prove some properties›

  text ‹A pure-rule is required to enable recovering of invalidated data that was not stored on the heap›
  lemma acconstraint_assn_pure[constraint_rules]: "is_pure A  is_pure B  is_pure (acconstraint_assn A B)"
    apply (auto simp: is_pure_iff_pure_assn)
    apply (rename_tac x x')
    apply (case_tac x; case_tac x'; simp add: pure_def)
    done

  text ‹An identitiy rule is required to easily prove trivial refinement theorems›    
  lemma acconstraint_assn_id[simp]: "acconstraint_assn id_assn id_assn = id_assn"
    apply (intro ext)
    subgoal for x y by (cases x; cases y; simp add: pure_def)
    done

  text ‹With congruence condition›  
  lemma acconstraint_match_cong[sepref_frame_match_rules]: 
    "x y. xset1_acconstraint e; yset1_acconstraint e'  hn_ctxt A x y t hn_ctxt A' x y
     x y. xset2_acconstraint e; yset2_acconstraint e'  hn_ctxt B x y t hn_ctxt B' x y
     hn_ctxt (acconstraint_assn A B) e e' t hn_ctxt (acconstraint_assn A' B') e e'"
    by (cases e; cases e'; simp add: hn_ctxt_def entt_star_mono)
      

  lemma acconstraint_merge_cong[sepref_frame_merge_rules]:
    assumes "x y. xset1_acconstraint e; yset1_acconstraint e'  hn_ctxt A x y A hn_ctxt A' x y t hn_ctxt Am x y"
    assumes "x y. xset2_acconstraint e; yset2_acconstraint e'  hn_ctxt B x y A hn_ctxt B' x y t hn_ctxt Bm x y"
    shows "hn_ctxt (acconstraint_assn A B) e e' A hn_ctxt (acconstraint_assn A' B') e e' t hn_ctxt (acconstraint_assn Am Bm) e e'"
    apply (blast intro: entt_disjE acconstraint_match_cong entt_disjD1[OF assms(1)] entt_disjD2[OF assms(1)] entt_disjD1[OF assms(2)] entt_disjD2[OF assms(2)])
    done

  text ‹Propagating invalid›  
  lemma entt_invalid_acconstraint: "hn_invalid (acconstraint_assn A B) e e' t hn_ctxt (acconstraint_assn (invalid_assn A) (invalid_assn B)) e e'"
    apply (simp add: hn_ctxt_def invalid_assn_def[abs_def])
    apply (rule enttI)
    apply clarsimp
    apply (cases e; cases e'; auto simp: mod_star_conv pure_def) 
    done

  lemmas invalid_acconstraint_merge[sepref_frame_merge_rules] = gen_merge_cons[OF entt_invalid_acconstraint]

  subsection ‹Constructors›  
  text ‹Constructors need to be registered›
  sepref_register LT LE EQ GE GT
  
  text ‹Refinement rules can be proven straightforwardly on the separation logic level (method @{method sepref_to_hoare})›
  (*
  lemma [sepref_fr_rules]: "(return o E1,RETURN o E1) ∈ Ada acconstraint_assn A"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(return o E2,RETURN o E2) ∈ Ada acconstraint_assn A"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(uncurry0 (return E3),uncurry0 (RETURN E3)) ∈ unit_assnka acconstraint_assn A"
    by sepref_to_hoare sep_auto
  *)
  lemma [sepref_fr_rules]: "(uncurry (return oo LT),uncurry (RETURN oo LT))  Ad*aBd a acconstraint_assn A B"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(uncurry (return oo LE),uncurry (RETURN oo LE))  Ad*aBd a acconstraint_assn A B"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(uncurry (return oo EQ),uncurry (RETURN oo EQ))  Ad*aBd a acconstraint_assn A B"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(uncurry (return oo GE),uncurry (RETURN oo GE))  Ad*aBd a acconstraint_assn A B"
    by sepref_to_hoare sep_auto
  lemma [sepref_fr_rules]: "(uncurry (return oo GT),uncurry (RETURN oo GT))  Ad*aBd a acconstraint_assn A B"
    by sepref_to_hoare sep_auto
  (*
  lemma [sepref_fr_rules]: "(uncurry (return oo E4),uncurry (RETURN oo E4)) ∈ Ad*aAda acconstraint_assn A"
    by sepref_to_hoare sep_auto
  *)
  (*
  lemma [sepref_fr_rules]: "(uncurry (return oo E5),uncurry (RETURN oo E5)) ∈ bool_assnk*aAda acconstraint_assn A"
    by sepref_to_hoare (sep_auto simp: pure_def)
  *)

  subsection ‹Destructor›  
  text ‹There is currently no automation for destructors, so all the registration boilerplate 
    needs to be done manually›

  text ‹Set ups operation identification heuristics›
  sepref_register case_acconstraint

  text ‹In the monadify phase, this eta-expands to make visible all required arguments›
  lemma [sepref_monadify_arity]: "case_acconstraint  λ2f1 f2 f3 f4 f5 x. SP case_acconstraint$(λ2x. f1$x)$(λ2x. f2$x)$f3$(λ2x y. f4$x$y)$(λ2x y. f5$x$y)$x"
    by simp

  text ‹This determines an evaluation order for the first-order operands›  
  lemma [sepref_monadify_comb]: "case_acconstraint$f1$f2$f3$f4$f5$x  (⤜)$(EVAL$x)$(λ2x. SP case_acconstraint$f1$f2$f3$f4$f5$x)" by simp

  text ‹This enables translation of the case-distinction in a non-monadic context.›  
  lemma [sepref_monadify_comb]: "EVAL$(case_acconstraint$(λ2x y. f1 x y)$(λ2x y. f2 x y)$(λ2x y. f3 x y)$(λ2x y. f4 x y)$(λ2x y. f5 x y)$x) 
     (⤜)$(EVAL$x)$(λ2x. SP case_acconstraint$(λ2x y. EVAL $ f1 x y)$(λ2x y. EVAL $ f2 x y)$(λ2x y. EVAL $ f3 x y)$(λ2x y. EVAL $ f4 x y)$(λ2x y. EVAL $ f5 x y)$x)"
    apply (rule eq_reflection)
    by (simp split: acconstraint.splits)

  text ‹Auxiliary lemma, to lift simp-rule over hn_ctxt›  
  lemma acconstraint_assn_ctxt: "acconstraint_assn A B x y = z  hn_ctxt (acconstraint_assn A B) x y = z"
    by (simp add: hn_ctxt_def)

  text ‹The cases lemma first extracts the refinement for the datatype from the precondition.
    Next, it generate proof obligations to refine the functions for every case. 
    Finally the postconditions of the refinement are merged. 

    Note that we handle the
    destructed values separately, to allow reconstruction of the original datatype after the case-expression.

    Moreover, we provide (invalidated) versions of the original compound value to the cases,
    which allows access to pure compound values from inside the case.
    ›  
  lemma acconstraint_cases_hnr:
    fixes A B and e :: "('a,'b) acconstraint" and e' :: "('ai,'bi) acconstraint"
    defines [simp]: "INVe  hn_invalid (acconstraint_assn A B) e e'"
    assumes FR: "Γ t hn_ctxt (acconstraint_assn A B) e e' * F"
    (*
    assumes E1: "⋀x1 x1a. ⟦e = E1 x1; e' = E1 x1a⟧ ⟹ hn_refine (hn_ctxt A x1 x1a * INVe * F) (f1' x1a) (hn_ctxt A1' x1 x1a * hn_ctxt XX1 e e' * Γ1') R (f1 x1)"
    assumes E2: "⋀x2 x2a. ⟦e = E2 x2; e' = E2 x2a⟧ ⟹ hn_refine (hn_ctxt A x2 x2a * INVe * F) (f2' x2a) (hn_ctxt A2' x2 x2a * hn_ctxt XX2 e e' * Γ2') R (f2 x2)"
    assumes E3: "⟦e = E3; e' = E3⟧ ⟹ hn_refine F f3' Γ3' R f3"
    *)
    assumes LT: "x41 x42 x41a x42a.
       e = LT x41 x42; e' = LT x41a x42a
        hn_refine
            (hn_ctxt A x41 x41a * hn_ctxt B x42 x42a * INVe * F)
            (f1' x41a x42a)
            (hn_ctxt A1a' x41 x41a * hn_ctxt B1b' x42 x42a * hn_ctxt XX1 e e' * Γ1') R
            (f1 x41 x42)"
    assumes LE: "x41 x42 x41a x42a.
       e = LE x41 x42; e' = LE x41a x42a
        hn_refine
            (hn_ctxt A x41 x41a * hn_ctxt B x42 x42a * INVe * F)
            (f2' x41a x42a)
            (hn_ctxt A2a' x41 x41a * hn_ctxt B2b' x42 x42a * hn_ctxt XX2 e e' * Γ2') R
            (f2 x41 x42)"
    assumes EQ: "x41 x42 x41a x42a.
       e = EQ x41 x42; e' = EQ x41a x42a
        hn_refine
            (hn_ctxt A x41 x41a * hn_ctxt B x42 x42a * INVe * F)
            (f3' x41a x42a)
            (hn_ctxt A3a' x41 x41a * hn_ctxt B3b' x42 x42a * hn_ctxt XX3 e e' * Γ3') R
            (f3 x41 x42)"
    assumes GE: "x41 x42 x41a x42a.
       e = GE x41 x42; e' = GE x41a x42a
        hn_refine
            (hn_ctxt A x41 x41a * hn_ctxt B x42 x42a * INVe * F)
            (f4' x41a x42a)
            (hn_ctxt A4a' x41 x41a * hn_ctxt B4b' x42 x42a * hn_ctxt XX4 e e' * Γ4') R
            (f4 x41 x42)"
    assumes GT: "x41 x42 x41a x42a.
       e = GT x41 x42; e' = GT x41a x42a
        hn_refine
            (hn_ctxt A x41 x41a * hn_ctxt B x42 x42a * INVe * F)
            (f5' x41a x42a)
            (hn_ctxt A5a' x41 x41a * hn_ctxt B5b' x42 x42a * hn_ctxt XX5 e e' * Γ5') R
            (f5 x41 x42)"
    (*
    assumes E5: "⋀x51 x52 x51a x52a.
       ⟦e = E5 x51 x52; e' = E5 x51a x52a⟧
       ⟹ hn_refine (hn_ctxt bool_assn x51 x51a * hn_ctxt A x52 x52a * INVe * F) (f5' x51a x52a)
            (hn_ctxt bool_assn x51 x51a * hn_ctxt A5' x52 x52a * hn_ctxt XX5 e e' * Γ5') R (f5 x51 x52)"
    *)
    assumes MERGE1a[unfolded hn_ctxt_def]:
      "x x'. hn_ctxt A1a' x x' A hn_ctxt A2a' x x' A hn_ctxt A3a' x x' A hn_ctxt A4a' x x' A hn_ctxt A5a' x x' t hn_ctxt A' x x'"
    assumes MERGE1b[unfolded hn_ctxt_def]:
      "x x'. hn_ctxt B1b' x x' A hn_ctxt B2b' x x' A hn_ctxt B3b' x x' A hn_ctxt B4b' x x' A hn_ctxt B5b' x x' t hn_ctxt B' x x'"
    assumes MERGE2[unfolded hn_ctxt_def]: "Γ1' A Γ2' A Γ3' A Γ4' A Γ5' t Γ'"
    shows "hn_refine Γ (case_acconstraint f1' f2' f3' f5' f4' e') (hn_ctxt (acconstraint_assn A' B') e e' * Γ') R (case_acconstraint$(λ2x y. f1 x y)$(λ2x y. f2 x y)$(λ2x y. f3 x y)$(λ2x y. f5 x y)$(λ2x y. f4 x y)$e)"
    
    apply (rule hn_refine_cons_pre[OF FR])
    apply1 extract_hnr_invalids
    apply (cases e; cases e'; simp add: acconstraint_assn.simps[THEN acconstraint_assn_ctxt])
    subgoal 
      apply (rule hn_refine_cons[OF _ LT _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (rule entt_trans[OF _ MERGE1a])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE1b])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp add: entt_disjI1' entt_disjI2')
    done
    subgoal 
      apply (rule hn_refine_cons[OF _ LE _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (rule entt_trans[OF _ MERGE1a])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE1b])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp add: entt_disjI1' entt_disjI2')
    done
    subgoal 
      apply (rule hn_refine_cons[OF _ EQ _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (rule entt_trans[OF _ MERGE1a])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE1b])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp add: entt_disjI1' entt_disjI2')
    done
    subgoal
      apply (rule hn_refine_cons[OF _ GT _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (rule entt_trans[OF _ MERGE1a])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE1b])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp add: entt_disjI1' entt_disjI2')
    done
    subgoal 
      apply (rule hn_refine_cons[OF _ GE _ entt_refl]; assumption?)
      applyS (simp add: hn_ctxt_def)
      apply (rule entt_star_mono)
      apply1 (rule entt_fr_drop)
      apply (rule entt_star_mono)

      apply1 (rule entt_trans[OF _ MERGE1a])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE1b])
      applyS (simp add: hn_ctxt_def entt_disjI1' entt_disjI2')

      apply1 (rule entt_trans[OF _ MERGE2])
      applyS (simp add: entt_disjI1' entt_disjI2')
    done
  done

  text ‹After some more preprocessing (adding extra frame-rules for non-atomic postconditions, 
    and splitting the merge-terms into binary merges), this rule can be registered›
  lemmas [sepref_comb_rules] = acconstraint_cases_hnr[sepref_prep_comb_rule]

  subsection ‹Regression Test›
(*

  definition "test ≡ do {
    let x = E1 True;

    _ ← case x of
      E1 _ ⇒ RETURN x  (* Access compound inside case *)
    | _ ⇒ RETURN E3;  

    (* Now test with non-pure *)
    let a = op_array_replicate 4 (3::nat);
    let x = E5 False a;
    
    _ ← case x of
      E1 _ ⇒ RETURN (0::nat)
    | E2 _ ⇒ RETURN 1
    | E3 ⇒ RETURN 0
    | E4 _ _ ⇒ RETURN 0
    | E5 _ a ⇒ mop_list_get a 0;

    (* Rely on that compound still exists (it's components are only read in the case above) *)
    case x of
      E1 a ⇒ do {mop_list_set a 0 0; RETURN (0::nat)}
    | E2 _ ⇒ RETURN 1
    | E3 ⇒ RETURN 0
    | E4 _ _ ⇒ RETURN 0
    | E5 _ _ ⇒ RETURN 0
  }"

  sepref_definition foo is "SYNTH (uncurry0 test) (unit_assnka nat_assn)"      
    unfolding test_def
    supply [[goals_limit=1]]
    by sepref
*)

end