Theory Refine_Monadic.RefineG_Recursion

section ‹Generic Recursion Combinator for Complete Lattice Structured Domains›
theory RefineG_Recursion
imports "../Refine_Misc" RefineG_Transfer RefineG_Domain
begin

text ‹
  We define a recursion combinator that asserts monotonicity.
›


(* TODO: Move to Domain.*)
text ‹
  The following lemma allows to compare least fixed points wrt.\ different flat
  orderings. At any point, the fixed points are either equal or have their 
  orderings bottom values.
›
lemma fp_compare:
  ― ‹At any point, fixed points wrt.\ different orderings are either equal, 
    or both bottom.›
  assumes M1: "flatf_mono b1 B" and M2: "flatf_mono b2 B"
  shows "flatf_fp b1 B x = flatf_fp b2 B x 
     (flatf_fp b1 B x = b1  flatf_fp b2 B x = b2)"
proof -
  note UNF1 = flatf_ord.fixp_unfold[OF M1, symmetric]
  note UNF2 = flatf_ord.fixp_unfold[OF M2, symmetric]

  from UNF1 have 1: "flatf_ord b2 (B (flatf_fp b1 B)) (flatf_fp b1 B)" by simp
  from UNF2 have 2: "flatf_ord b1 (B (flatf_fp b2 B)) (flatf_fp b2 B)" by simp

  from flatf_ord.fixp_lowerbound[OF M2 1] flatf_ord.fixp_lowerbound[OF M1 2]
    show ?thesis unfolding fun_ord_def flat_ord_def by auto
qed

(* TODO: Move *)
lemma flat_ord_top[simp]: "flat_ord b b x" by (simp add: flat_ord_def)


(* TODO: Move to Domain.*)
lemma lfp_gfp_compare:
  ― ‹Least and greatest fixed point are either equal, or bot and top›
  assumes MLE: "flatf_mono_le B" and MGE: "flatf_mono_ge B"
  shows "flatf_lfp B x = flatf_gfp B x 
     (flatf_lfp B x = bot  flatf_gfp B x = top)"
  using fp_compare[OF MLE MGE] .


(* TODO: Move to Domain *)
definition trimono :: "(('a  'b)  'a  ('b::{bot,order,top}))  bool" 
  where "trimono B  ⌦‹flatf_mono_le B ∧› flatf_mono_ge B  mono B"
lemma trimonoI[refine_mono]: 
  "flatf_mono_ge B; mono B  trimono B"
  unfolding trimono_def by auto

lemma trimono_trigger: "trimono B  trimono B" .

declaration Refine_Mono_Prover.declare_mono_triggers @{thms trimono_trigger}

(*lemma trimonoD_flatf_le: "trimono B ⟹ flatf_mono_le B"
  unfolding trimono_def by auto*)

lemma trimonoD_flatf_ge: "trimono B  flatf_mono_ge B"
  unfolding trimono_def by auto

lemma trimonoD_mono: "trimono B  mono B"
  unfolding trimono_def by auto

lemmas trimonoD = trimonoD_flatf_ge trimonoD_mono

(* TODO: Optimize mono-prover to only do derivations once. 
  Will cause problem with higher-order unification on ord - variable! *)
definition "triords  {flat_ge,(≤)}"
lemma trimono_alt: 
  "trimono B  (ordfun_ord`triords. monotone ord ord B)"
  unfolding trimono_def
  by (auto simp: triords_def fun_ord_def[abs_def] le_fun_def[abs_def])

lemma trimonoI': 
  assumes "ord. ordtriords  monotone (fun_ord ord) (fun_ord ord) B"
  shows "trimono B"
  unfolding trimono_alt using assms by blast


(* TODO: Once complete_lattice and ccpo typeclass are unified,
  we should also define a REC-combinator for ccpos! *)

definition REC where "REC B x  
  if (trimono B) then (lfp B x) else (top::'a::complete_lattice)"
definition RECT ("RECT") where "RECT B x  
  if (trimono B) then (flatf_gfp B x) else (top::'a::complete_lattice)"

lemma RECT_gfp_def: "RECT B x = 
  (if (trimono B) then (gfp B x) else (top::'a::complete_lattice))"
  unfolding RECT_def
  by (auto simp: gfp_eq_flatf_gfp[OF trimonoD_flatf_ge trimonoD_mono])

lemma REC_unfold: "trimono B  REC B = B (REC B)"
  unfolding REC_def [abs_def]
  by (simp add: lfp_unfold[OF trimonoD_mono, symmetric])

lemma RECT_unfold: "trimono B  RECT B = B (RECT B)"
  unfolding RECT_def [abs_def]
  by (simp add: flatf_ord.fixp_unfold[OF trimonoD_flatf_ge, symmetric])

lemma REC_mono[refine_mono]:
  assumes [simp]: "trimono B"
  assumes LE: "F x. (B F x)  (B' F x)"
  shows "(REC B x)  (REC B' x)"
  unfolding REC_def
  apply clarsimp
  apply (rule lfp_mono[THEN le_funD])
  apply (rule LE[THEN le_funI])
  done

lemma RECT_mono[refine_mono]:
  assumes [simp]: "trimono B'"
  assumes LE: "F x. flat_ge (B F x) (B' F x)"
  shows "flat_ge (RECT B x) (RECT B' x)"
  unfolding RECT_def
  apply clarsimp
  apply (rule flatf_fp_mono, (simp_all add: trimonoD) [2])
  apply (rule LE)
  done

lemma REC_le_RECT: "REC body x  RECT body x"
  unfolding REC_def RECT_gfp_def
  apply (cases "trimono body")
  apply clarsimp
  apply (rule lfp_le_gfp[THEN le_funD])
  apply (simp add: trimonoD)
  apply simp
  done

print_statement flatf_fp_induct_pointwise
theorem lfp_induct_pointwise:
  fixes a::'a
  assumes ADM1: "a x. chain_admissible (λb. a x. pre a x  post a x (b x))"
  assumes ADM2: "a x. pre a x  post a x bot"
  assumes MONO: "mono B"
  assumes P0: "pre a x"
  assumes IS: 
    "f a x.
        a' x'. pre a' x'  post a' x' (f x'); pre a x;
         f  (lfp B)
         post a x (B f x)"
  shows "post a x (lfp B x)"
proof -
  define u where "u = lfp B"

  have [simp]: "f. flfp B  B f  lfp B"
    by (metis (poly_guards_query) MONO lfp_unfold monoD)

  have "(a x. pre a x  post a x (lfp B x))  lfp B  u"
    apply (rule lfp_cadm_induct[where f=B])
    apply (rule admissible_conj)
    apply (rule ADM1)
    apply (rule)
    apply (blast intro: Sup_least)
    apply (simp add: le_fun_def ADM2) []
    apply fact
    apply (intro conjI allI impI)
    unfolding u_def
    apply (blast intro: IS)
    apply simp
    done
  with P0 show ?thesis by blast
qed


lemma REC_rule_arb:
  fixes x::"'x" and arb::'arb
  assumes M: "trimono body"
  assumes I0: "pre arb x"
  assumes IS: "f arb x. 
    arb' x. pre arb' x  f x  M arb' x; pre arb x; f  REC body
    body f x  M arb x"
  shows "REC body x  M arb x"
  unfolding REC_def
  apply (clarsimp simp: M)
  apply (rule lfp_induct_pointwise[where pre=pre])
  apply (auto intro!: chain_admissibleI SUP_least) [2]
  apply (simp add: trimonoD[OF M])
  apply (rule I0)
  apply (rule IS, assumption+)
  apply (auto simp: REC_def[abs_def] intro!: le_funI dest: le_funD) []
  done

lemma RECT_rule_arb:
  assumes M: "trimono body"
  assumes WF: "wf (V::('x×'x) set)"
  assumes I0: "pre (arb::'arb) (x::'x)"
  assumes IS: "f arb x.  
      arb' x'. pre arb' x'; (x',x)V  f x'  M arb' x'; 
      pre arb x;
      RECT body = f
       body f x  M arb x"
  shows "RECT body x  M arb x"
  apply (rule wf_fixp_induct[where fp=RECT and pre=pre and B=body])
  apply (rule RECT_unfold)
  apply (simp_all add: M) [2]
  apply (rule WF)
  apply fact
  apply (rule IS)
  apply assumption
  apply assumption
  apply assumption
  done

lemma REC_rule:
  fixes x::"'x"
  assumes M: "trimono body"
  assumes I0: "pre x"
  assumes IS: "f x.  x. pre x  f x  M x; pre x; f  REC body  
     body f x  M x"
  shows "REC body x  M x"
  by (rule REC_rule_arb[where pre="λ_. pre" and M="λ_. M", OF assms])
    
    
lemma RECT_rule:
  assumes M: "trimono body"
  assumes WF: "wf (V::('x×'x) set)"
  assumes I0: "pre (x::'x)"
  assumes IS: "f x.  x'. pre x'; (x',x)V  f x'  M x'; pre x; 
                        RECT body = f
      body f x  M x"
  shows "RECT body x  M x"
  by (rule RECT_rule_arb[where pre="λ_. pre" and M="λ_. M", OF assms])




(* TODO: Can we set-up induction method to work with such goals? *)
lemma REC_rule_arb2:
  assumes M: "trimono body"
  assumes I0: "pre (arb::'arb) (arc::'arc) (x::'x)"
  assumes IS: "f arb arc x.  
      arb' arc' x'. pre arb' arc' x'   f x'  M arb' arc' x'; 
      pre arb arc x
       body f x  M arb arc x"
  shows "REC body x  M arb arc x"
  apply (rule order_trans)
  apply (rule REC_rule_arb[
    where pre="case_prod pre" and M="case_prod M" and arb="(arb, arc)", 
    OF M])
  by (auto intro: assms)

lemma REC_rule_arb3:
  assumes M: "trimono body"
  assumes I0: "pre (arb::'arb) (arc::'arc) (ard::'ard) (x::'x)"
  assumes IS: "f arb arc ard x.  
      arb' arc' ard' x'. pre arb' arc' ard' x'  f x'  M arb' arc' ard' x';
      pre arb arc ard x 
       body f x  M arb arc ard x"
  shows "REC body x  M arb arc ard x"
  apply (rule order_trans)
  apply (rule REC_rule_arb2[
    where pre="case_prod pre" and M="case_prod M" and arb="(arb, arc)" and arc="ard", 
    OF M])
  by (auto intro: assms)

lemma RECT_rule_arb2:
  assumes M: "trimono body"
  assumes WF: "wf (V::'x rel)"
  assumes I0: "pre (arb::'arb) (arc::'arc) (x::'x)"
  assumes IS: "f arb arc x.  
      arb' arc' x'. pre arb' arc' x'; (x',x)V  f x'  M arb' arc' x'; 
      pre arb arc x;
      f  RECT body
       body f x  M arb arc x"
  shows "RECT body x  M arb arc x"
  apply (rule order_trans)
  apply (rule RECT_rule_arb[
    where pre="case_prod pre" and M="case_prod M" and arb="(arb, arc)", 
    OF M WF])
  by (auto intro: assms)

lemma RECT_rule_arb3:
  assumes M: "trimono body"
  assumes WF: "wf (V::'x rel)"
  assumes I0: "pre (arb::'arb) (arc::'arc) (ard::'ard) (x::'x)"
  assumes IS: "f arb arc ard x.  
      arb' arc' ard' x'. pre arb' arc' ard' x'; (x',x)V  f x'  M arb' arc' ard' x'; 
    pre arb arc ard x;
    f  RECT body
     body f x  M arb arc ard x"
  shows "RECT body x  M arb arc ard x"
  apply (rule order_trans)
  apply (rule RECT_rule_arb2[
    where pre="case_prod pre" and M="case_prod M" and arb="(arb, arc)" and arc="ard", 
    OF M WF])
  by (auto intro: assms)



(* Obsolete, provide a variant to show nofail.
text {* The following lemma shows that greatest and least fixed point are equal,
  if we can provide a variant. *}
lemma RECT_eq_REC:
  assumes MONO: "flatf_mono_le body"
  assumes MONO_GE: "flatf_mono_ge body"
  assumes WF: "wf V"
  assumes I0: "I x"
  assumes IS: "⋀f x. I x ⟹ 
    body (λx'. if (I x' ∧ (x',x)∈V) then f x' else top) x ≤ body f x"
  shows "RECT body x = REC body x"
  unfolding RECT_def REC_def 
proof (simp add: MONO MONO_GE)
  have "I x ⟶ flatf_gfp body x ≤ flatf_lfp body x"
    using WF
    apply (induct rule: wf_induct_rule)
    apply (rule impI)
    apply (subst flatf_ord.fixp_unfold[OF MONO])
    apply (subst flatf_ord.fixp_unfold[OF MONO_GE])
    apply (rule order_trans[OF _ IS])
    apply (rule monoD[OF MONO,THEN le_funD])
    apply (rule le_funI)
    apply simp
    apply simp
    done
  


  from lfp_le_gfp' MONO have "lfp body x ≤ gfp body x" .
  moreover have "I x ⟶ gfp body x ≤ lfp body x"
    using WF
    apply (induct rule: wf_induct[consumes 1])
    apply (rule impI)
    apply (subst lfp_unfold[OF MONO])
    apply (subst gfp_unfold[OF MONO])
    apply (rule order_trans[OF _ IS])
    apply (rule monoD[OF MONO,THEN le_funD])
    apply (rule le_funI)
    apply simp
    apply simp
    done
  ultimately show ?thesis
    unfolding REC_def RECT_def gfp_eq_flatf_gfp[OF MONO_GE MONO, symmetric]
    apply (rule_tac antisym)
    using I0 MONO MONO_GE by auto
qed
*)

lemma RECT_eq_REC: 
  ― ‹Partial and total correct recursion are equal if total 
    recursion does not fail.›
  assumes NT: "RECT body x  top"
  shows "RECT body x = REC body x"
proof (cases "trimono body")
  case M: True 
  show ?thesis
    using NT M
    unfolding RECT_def REC_def
  proof clarsimp
    from lfp_unfold[OF trimonoD_mono[OF M], symmetric]
    have "flatf_ge (body (lfp body)) (lfp body)" by simp
    note flatf_ord.fixp_lowerbound[
      OF trimonoD_flatf_ge[OF M], of "lfp body", OF this]
    moreover assume "flatf_gfp body x  top"
    ultimately show "flatf_gfp body x = lfp body x"
      by (auto simp add: fun_ord_def flat_ord_def)
  qed
next
  case False thus ?thesis unfolding RECT_def REC_def by auto
qed

lemma RECT_eq_REC_tproof:
  ― ‹Partial and total correct recursion are equal if we can provide a 
    termination proof.›
  fixes a :: 'a
  assumes M: "trimono body"
  assumes WF: "wf V"
  assumes I0: "pre a x"
  assumes IS: "f arb x.
          arb' x'. pre arb' x'; (x', x)  V  f x'  M arb' x'; 
            pre arb x; RECT body = f
           body f x  M arb x"
  assumes NT: "M a x  top"
  shows "RECT body x = REC body x  RECT body x  M a x"
proof
  show "RECT body x  M a x"
    by (rule RECT_rule_arb[OF M WF, where pre=pre, OF I0 IS])
  
  with NT have "RECT body x  top" by (metis top.extremum_unique)
  thus "RECT body x = REC body x" by (rule RECT_eq_REC)
qed


subsection ‹Transfer›

lemma (in transfer) transfer_RECT'[refine_transfer]:
  assumes REC_EQ: "x. fr x = b fr x"
  assumes REF: "F f x. x. α (f x)  F x   α (b f x)  B F x"
  shows "α (fr x)  RECT B x"
  unfolding RECT_def
proof clarsimp
  assume MONO: "trimono B"
  show "α (fr x)  flatf_gfp B x"
    apply (rule flatf_fixp_transfer[where B=B and fp'=fr and P="(=)", 
        OF _ trimonoD_flatf_ge[OF MONO]])
    apply simp
    apply (rule ext, fact)
    apply (simp)
    apply (simp,rule REF, blast)    
    done
qed

lemma (in ordered_transfer) transfer_RECT[refine_transfer]:
  assumes REF: "F f x. x. α (f x)  F x   α (b f x)  B F x"
  assumes M: "trimono b"
  shows "α (RECT b x)  RECT B x"
  apply (rule transfer_RECT')
  apply (rule RECT_unfold[OF M, THEN fun_cong])
  by fact

lemma (in dist_transfer) transfer_REC[refine_transfer]:
  assumes REF: "F f x. x. α (f x)  F x   α (b f x)  B F x"
  assumes M: "trimono b"
  shows "α (REC b x)  REC B x"
  unfolding REC_def
  (* TODO: Clean up *)
  apply (clarsimp simp: M)
  apply (rule lfp_induct_pointwise[where B=b and pre="(=)"])
  apply (rule)
  apply clarsimp
  apply (subst α_dist)
  apply (auto simp add: chain_def le_fun_def) []
  apply (rule Sup_least)
  apply auto []
  apply simp
  apply (simp add: trimonoD[OF M])
  apply (rule refl)
  apply (subst lfp_unfold)
  apply (simp add: trimonoD)
  apply (rule REF)
  apply blast
  done

(* TODO: Could we base the whole refine_transfer-stuff on arbitrary relations *)
(* TODO: For enres-breakdown, we had to do antisymmetry, in order to get TR_top.
  What is the general shape of tr-relations for that, such that we could show equality directly?
*)
lemma RECT_transfer_rel:
  assumes [simp]: "trimono F" "trimono F'"
  assumes TR_top[simp]: "x. tr x top"
  assumes P_start[simp]: "P x x'"
  assumes IS: "D D' x x'.  x x'. P x x'  tr (D x) (D' x'); P x x'; RECT F = D   tr (F D x) (F' D' x')"
  shows "tr (RECT F x) (RECT F' x')"
  unfolding RECT_def 
  apply auto
  apply (rule flatf_gfp_transfer[where tr=tr and P=P])
  apply (auto simp: trimonoD_flatf_ge)  
  apply (rule IS)
  apply (auto simp: RECT_def)
  done
  
lemma RECT_transfer_rel':
  assumes [simp]: "trimono F" "trimono F'"
  assumes TR_top[simp]: "x. tr x top"
  assumes P_start[simp]: "P x x'"
  assumes IS: "D D' x x'.  x x'. P x x'  tr (D x) (D' x'); P x x'   tr (F D x) (F' D' x')"
  shows "tr (RECT F x) (RECT F' x')"
  using RECT_transfer_rel[where tr=tr and P=P,OF assms(1,2,3,4)] IS by blast

end