Theory State_Monad

theory State_Monad
imports State "HOL-Library.Monad_Syntax" Utils
begin

section "state Monad with Exceptions"

datatype ('n, 'e) result =
  Normal (normal: 'n)
| Exception (ex: 'e)
| NT

lemma result_cases[cases type: result]:
  fixes x :: "('a × 's, 'e × 's) result"
  obtains (n) a s where "x = Normal (a, s)"
        | (e) e s where "x = Exception (e, s)"
        | (t) "x = NT"
proof (cases x)
  case (Normal n)
  then show ?thesis using n by force
next
  case (Exception e)
  then show ?thesis using e by force
next
  case NT
  then show ?thesis using t by simp
qed

typedef ('a, 'e, 's) state_monad = "UNIV::('s  ('a × 's, 'e × 's) result) set"
  morphisms execute create
  by simp

named_theorems execute_simps "simplification rules for execute"

lemma execute_Let [execute_simps]:
  "execute (let x = t in f x) = (let x = t in execute (f x))"
  by (simp add: Let_def)

subsection ‹Code Generator Setup›

code_datatype create

lemma execute_create[execute_simps, code]: "execute (create f) = f" using create_inverse by simp

declare execute_inverse[simp]

lemma execute_ext[intro]: "(x. (execute m1 x = execute m2 x))  m1 = m2" using HOL.ext
  by (metis execute_inverse)

subsection ‹Fundamental Definitions›

definition return :: "'a  ('a, 'e, 's) state_monad"
  where "return a = create (λs. Normal (a, s))"

lemma execute_return [execute_simps]:
  "execute (return x) = Normal  Pair x"
  unfolding return_def by (auto simp add:execute_simps)

lemma execute_returnE:
  assumes "execute (return x) s = Normal (a, s')"
  shows "x = a" and "s = s'"
  using assms unfolding return_def execute_create by auto

definition throw :: "'e  ('a, 'e, 's) state_monad"
  where "throw e = create (λs. Exception (e, s))"

lemma execute_throw [execute_simps]:
  "execute (throw x) s = Exception (x, s)"
  unfolding throw_def by (auto simp add:execute_simps)

definition bind :: "('a, 'e, 's) state_monad  ('a  ('b, 'e, 's) state_monad)  ('b, 'e, 's) state_monad" (infixl ">>=" 60)
where "bind f g = create (λs. (case (execute f) s of
                      Normal (a, s')  execute (g a) s'
                    | Exception e  Exception e
                    | NT  NT))"

adhoc_overloading Monad_Syntax.bind  bind

lemma execute_bind [execute_simps]:
  "execute f s = Normal (x, s')  execute (f  g) s = execute (g x) s'"
  "execute f s = Exception e  execute (f  g) s = Exception e"
  "execute f s = NT  execute (f  g) s = NT"
  unfolding bind_def execute_create by simp_all

lemma execute_bind_normal_E:
  assumes "execute (f  g) s = Normal (a, s')"
  obtains (1) s'' x where "execute f s = Normal (x, s'')" and "execute (g x) s'' = Normal (a, s')"
  using assms unfolding bind_def execute_create apply (cases "execute f s") using that by auto

lemma execute_bind_exception_E:
  assumes "execute (f  g) s = Exception (x, s')"
  obtains (1) "execute f s = Exception (x, s')"
        | (2) a s'' where "execute f s = Normal (a,s'')" and "execute (g a) s'' = Exception (x,s')"
  using assms unfolding bind_def execute_create apply (cases "execute f s") using that by auto

(*
  This lemma is needed for termination proofs.
*)
lemma monad_cong[cong]:
  fixes m1 m2 m3 m4
  assumes "m1 = m2"
      and "s v s'. execute m2 s = Normal (v, s')  execute (m3 v) s' = execute (m4 v) s'"
    shows "(bind m1 m3) = (bind m2 m4)"
  unfolding bind_def
proof -
  have "(λs. case execute m1 s of Normal (a, xa)  execute (m3 a) xa | Exception x  Exception x | NT  NT) =
        (λs. case execute m2 s of Normal (a, xa)  execute (m4 a) xa | Exception x  Exception x | NT  NT)"
  (is  "(λs. ?L s) = (λs. ?R s)")
  proof
    fix s
    show "?L s = ?R s"
      using assms by (cases "execute m1 s"; simp)
  qed
  then show "create (λs. ?L s) = create (λs. ?R s)" by simp
qed

lemma throw_left[simp]: "throw x  y = throw x" unfolding throw_def bind_def by (simp add: execute_simps)

subsection ‹The Monad Laws›

text @{term return} is absorbed at the left of a @{term bind},
  applying the return value directly:›
lemma return_bind [simp]: "(return x  f) = f x"
unfolding return_def bind_def by (simp add: execute_simps)

text @{term return} is absorbed on the right of a @{term bind}
lemma bind_return [simp]: "(m  return) = m"
proof (rule execute_ext)
  fix s
  show "execute (m  return) s = execute m s"
  proof (cases "execute m s" rule: result_cases)
    case (n a s)
    then show ?thesis by (simp add: execute_simps)
  next
    case (e e)
    then show ?thesis by (simp add: execute_simps)
  next
    case t
    then show ?thesis by (simp add: execute_simps)
  qed
qed

text @{term bind} is associative›
lemma bind_assoc:
  fixes m :: "('a,'e,'s) state_monad"
  fixes f :: "'a  ('b,'e,'s) state_monad"
  fixes g :: "'b  ('c,'e,'s) state_monad"
  shows "(m  f)  g  =  m  (λx. f x g)"
proof
  fix s
  show "execute (m  f  g) s = execute (m  (λx. f x  g)) s"
  unfolding bind_def by (cases "execute m s" rule: result_cases; simp add: execute_simps)
qed

subsection ‹Basic Conguruence Rules›

(*
  Lemma bind_case_nat_cong is required if a bind operand is a case analysis over nat.
*)
lemma bind_case_nat_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Suc a  f a h = f' a h"
  shows "(case x of Suc a  f a | 0  g) h = (case x' of Suc a  f' a | 0  g) h"
  by (metis assms(1) assms(2) old.nat.exhaust old.nat.simps(4) old.nat.simps(5))

lemma if_cong[fundef_cong]:
  assumes "b = b'"
    and "b'  m1 s = m1' s"
    and "¬ b'  m2 s = m2' s"
  shows "(if b then m1 else m2) s = (if b' then m1' else m2') s"
  using assms(1) assms(2) assms(3) by auto

lemma bind_case_pair_cong [fundef_cong]:
  assumes "x = x'" and "a b. x = (a,b)  f a b s = f' a b s"
  shows "(case x of (a,b)  f a b) s = (case x' of (a,b)  f' a b) s"
  by (simp add: assms(1) assms(2) prod.case_eq_if)

lemma bind_case_let_cong [fundef_cong]:
  assumes "M = N"
      and "(x. x = N  f x s = g x s)"
    shows "(Let M f) s = (Let N g) s"
  by (simp add: assms(1) assms(2))

lemma bind_case_some_cong [fundef_cong]:
  assumes "x = x'" and "a. x = Some a  f a s = f' a s" and "x = None  g s = g' s"
  shows "(case x of Some a  f a | None  g) s = (case x' of Some a  f' a | None  g') s"
  by (simp add: assms(1) assms(2) assms(3) option.case_eq_if)

lemma bind_case_bool_cong [fundef_cong]:
  assumes "x = x'" and "x = True  f s = f' s" and "x = False  g s = g' s"
  shows "(case x of True  f | False  g) s = (case x' of True  f' | False  g') s"
  using assms(1) assms(2) assms(3) by auto

subsection ‹Other functions›

text ‹
  The basic accessor functions of the state monad. get› returns
  the current state as result, does not fail, and does not change the state.
  put s› returns unit, changes the current state to s› and does not fail.
›
definition get :: "('s, 'e, 's) state_monad" where
  "get = create (λs. Normal (s, s))"

lemma execute_get [execute_simps]:
  "execute get = (λs. Normal (s, s))"
  unfolding get_def by (auto simp add:execute_simps)

definition put :: "'s  (unit, 'e, 's) state_monad" where
  "put s = create (K (Normal ((), s)))"

lemma execute_put [execute_simps]:
  "execute (put s) = K (Normal ((), s))"
  unfolding put_def by (auto simp add:execute_simps)

definition update :: "('s  'a × 's)  ('a, 'e, 's) state_monad" where
  "update f = create (λs. Normal (f s))"

lemma execute_update [execute_simps]:
  "execute (update f) = (λs. Normal (f s))"
  unfolding update_def by (auto simp add:execute_simps)

text ‹Apply a function to the current state and return the result
without changing the state.›
definition applyf :: "('s  'a)  ('a, 'e, 's) state_monad" where
 "applyf f = get  (λs. return (f s))"

text ‹Modify the current state using the function passed in.›
definition modify :: "('s  's)  (unit, 'e, 's) state_monad" where
"modify f = get  (λs::'s. put (f s))"

lemma execute_modify [execute_simps]:
  "execute (modify f) s = Normal ((), f s)"
  unfolding modify_def by (auto simp add:execute_simps)

primrec mfold :: "('a,'e,'s) state_monad  nat  ('a list,'e,'s) state_monad"
  where
    "mfold m 0 = return []"
  | "mfold m (Suc n) =
      do {
        l  m;
        ls  mfold m n;
        return (l # ls)
      }"

subsection ‹Some basic examples›

lemma "do {
        x  return 1;
        return (2::nat);
        return x
       } =
       return 1  (λx. return (2::nat)  (λ_. (return x)))" ..

lemma "do {
        x  return 1;
          return 2;
          return x
       } = return 1"
  by auto

subsection ‹Conditional Monad›

fun cond_monad:: "('s  bool)  ('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad" where
"cond_monad c mt mf = 
  do {
    s  get;
    if (c s) then mt else mf
  }"

definition option :: "'e  ('s  'a option)  ('a, 'e, 's) state_monad" where
 "option x f = create (λs. (case f s of
    Some y  execute (return y) s
  | None  execute (throw x) s))"

lemma execute_option [execute_simps]:
  "y. f s = Some y  execute (option e f) s = execute (return y) s"
  "f s = None  execute (option e f) s = execute (throw e) s"
  unfolding option_def by (auto simp add:execute_simps)

definition assert :: "'e  ('s  bool)  (unit, 'e, 's) state_monad" where
 "assert x t = create (λs. if (t s) then execute (return ()) s else execute (throw x) s)"

lemma execute_assert [execute_simps]:
  "t s  execute (assert e t) s = execute (return ()) s"
  "¬ t s  execute (assert e t) s = execute (throw e) s"
  unfolding assert_def by (auto simp add:execute_simps)

subsection ‹Setup for Partial Function Package›

text ‹
  We can make result into a pointed cpo:
   The order is obtained by combinin function order with result order
   The least element is NT
›

definition effect :: "('a, 'b, 'c) state_monad  'c  'a × 'c + 'b × 'c  bool" where
  effect_def: "effect c h r  is_Normal (execute c h)  r = Inl (normal (execute c h))  is_Exception (execute c h)  r = Inr (ex (execute c h))"

lemma effectE:
  assumes "effect c h r"
  obtains (normal) "is_Normal (execute c h)  r = Inl (normal (execute c h))"
  | (exception) "is_Exception (execute c h)  r = Inr (ex (execute c h))"
  using assms unfolding effect_def by auto

abbreviation "empty_result  create (λs. NT)"
abbreviation "result_ord  flat_ord NT"
abbreviation "result_lub  flat_lub NT"

definition sm_ord :: "('a, 'e, 's) state_monad  ('a, 'e, 's) state_monad  bool" where
  "sm_ord = img_ord execute (fun_ord result_ord)"

definition sm_lub :: "('a, 'e, 's) state_monad set  ('a, 'e, 's) state_monad" where
  "sm_lub = img_lub execute create (fun_lub result_lub)"

lemma sm_lub_empty: "sm_lub {} = empty_result"
  by(simp add: sm_lub_def img_lub_def fun_lub_def flat_lub_def)

lemma sm_ordI:
  assumes "h. execute x h = NT  execute x h = execute y h"
  shows "sm_ord x y"
  using assms unfolding sm_ord_def img_ord_def fun_ord_def flat_ord_def
  by blast

lemma sm_ordE:
  assumes "sm_ord x y"
  obtains "execute x h = NT" | "execute x h = execute y h"
  using assms unfolding sm_ord_def img_ord_def fun_ord_def flat_ord_def
  by atomize_elim blast

lemma sm_interpretation: "partial_function_definitions sm_ord sm_lub"
proof -
  have "partial_function_definitions (fun_ord result_ord) (fun_lub result_lub)"
    by (rule partial_function_lift) (rule flat_interpretation)
  then have "partial_function_definitions (img_ord execute (fun_ord result_ord))
      (img_lub execute create (fun_lub result_lub))"
    by (rule partial_function_image) (auto simp add: execute_simps)
  then show "partial_function_definitions sm_ord sm_lub"
    by (simp only: sm_ord_def sm_lub_def)
qed

interpretation sm: partial_function_definitions sm_ord sm_lub
  rewrites "sm_lub {}  empty_result"
by (fact sm_interpretation)(simp add: sm_lub_empty)

named_theorems mono

declare sm.const_mono[mono]
declare Partial_Function.call_mono[mono]

text ‹The success predicate requires a state monad sm starting in state s to terminate successfully in state s' with return value a.›

definition success :: "('a, 'e, 's) state_monad  's  's  'a  bool" where
  success_def: "success sm s s' a  execute sm s  NT"

text ‹We can show that every predicate P is admissible if we assume successful termination.›
lemma sm_step_admissible: 
  "ccpo.admissible (fun_lub result_lub) (fun_ord result_ord) (λxa. h r. is_Normal (xa h)  r = Inl (normal (xa h))  is_Exception (xa h)  r = Inr (ex (xa h))  P h r)"
proof (rule ccpo.admissibleI)
  fix A :: "('a  ('b, 'c) result) set"
  assume ch: "Complete_Partial_Order.chain (fun_ord result_ord) A"
    and IH: "xaA. h r. is_Normal (xa h)  r = Inl (normal (xa h))  is_Exception (xa h)  r = Inr (ex (xa h))  P h r"
  from ch have ch': "x. Complete_Partial_Order.chain result_ord {y. fA. y = f x}" by (rule chain_fun)
  show "h r. is_Normal (fun_lub result_lub A h)  r = Inl (normal (fun_lub result_lub A h))  is_Exception (fun_lub result_lub A h)  r = Inr (ex (fun_lub result_lub A h))  P h r"
  proof (intro allI impI)
    fix h r assume "is_Normal (fun_lub result_lub A h)  r = Inl (normal (fun_lub result_lub A h))  is_Exception (fun_lub result_lub A h)  r = Inr (ex (fun_lub result_lub A h))"
    then show "P h r"
    proof
      assume "is_Normal (fun_lub result_lub A h)  r = Inl (normal (fun_lub result_lub A h))"
      with flat_lub_in_chain[OF ch'] this[unfolded fun_lub_def]
      show "P h r" using IH
        by (smt (verit) Collect_cong mem_Collect_eq result.case_eq_if result.disc_eq_case(3))
    next
      assume "is_Exception (fun_lub result_lub A h)  r = Inr (ex (fun_lub result_lub A h))"
      with flat_lub_in_chain[OF ch'] this[unfolded fun_lub_def]
      show "P h r" using IH
        by (smt (verit) Collect_cong mem_Collect_eq result.case_eq_if result.disc_eq_case(3))
    qed
  qed
qed

lemma admissible_sm: 
  "sm.admissible (λf. x h r. effect (f x) h r  P x h r)"
proof (rule admissible_fun[OF sm_interpretation])
  fix x
  show "ccpo.admissible sm_lub sm_ord (λa. h r. effect a h r  P x h r)"
    unfolding sm_ord_def sm_lub_def
  proof (intro admissible_image partial_function_lift flat_interpretation)
    show "ccpo.admissible (fun_lub result_lub) (fun_ord result_ord) ((λa. h r. effect a h r  P x h r)  create)"
      unfolding comp_def effect_def execute_create
      by (rule sm_step_admissible)
  qed (auto simp add: execute_simps)
qed

text ‹
  Now we can derive an induction rule for proving partial correctness properties.
  Note that this rule requires successful termination.
›

lemma fixp_induct_sm:
  fixes F :: "'c  'c" and
        U :: "'c  'b  ('a, 'e, 's) state_monad" and
        C :: "('b  ('a, 'e, 's) state_monad)  'c" and
        P :: "'b  's  'a × 's + 'e × 's  bool"
  assumes mono: "x. monotone (fun_ord sm_ord) sm_ord (λf. U (F (C f)) x)"
  assumes eq: "f  C (ccpo.fixp (fun_lub sm_lub) (fun_ord sm_ord) (λf. U (F (C f))))"
  assumes inverse2: "f. U (C f) = f"
  assumes step: "f x h r. (x h r. effect (U f x) h r  P x h r) 
     effect (U (F f) x) h r  P x h r"
  assumes defined: "effect (U f x) h r"
  shows "P x h r"
  using step defined sm.fixp_induct_uc[of U F C, OF mono eq inverse2 admissible_sm, of P]
  unfolding effect_def execute_create by blast

text ‹We now need to setup the new sm mode for the partial function package.›

declaration Partial_Function.init "sm"
    termsm.fixp_fun
    termsm.mono_body
    @{thm sm.fixp_rule_uc}
    @{thm sm.fixp_induct_uc}
    (SOME @{thm fixp_induct_sm})

subsection ‹Monotonicity Results›

abbreviation "mono_sm  monotone (fun_ord sm_ord) sm_ord"

lemma execute_bind_case:
  "execute (f  g) h = (case (execute f h) of
    Normal (x, h')  execute (g x) h' | Exception e  Exception e | NT  NT)"
  by (simp add: bind_def execute_simps)

lemma bind_mono [partial_function_mono,mono]:
  assumes mf: "mono_sm B" and mg: "y. mono_sm (λf. C y f)"
  shows "mono_sm (λf. B f  (λy. C y f))"
proof (rule monotoneI)
  fix f g :: "'a  ('b, 'c, 'd) state_monad" assume fg: "sm.le_fun f g"
  from mf
  have 1: "sm_ord (B f) (B g)" by (rule monotoneD) (rule fg)
  from mg
  have 2: "y'. sm_ord (C y' f) (C y' g)" by (rule monotoneD) (rule fg)

  have "sm_ord (B f  (λy. C y f)) (B g  (λy. C y f))" (is "sm_ord ?L ?R")
  proof (rule sm_ordI)
    fix h
    from 1 show "execute ?L h = NT  execute ?L h = execute ?R h"
      by (rule sm_ordE[where h = h]) (auto simp: execute_bind_case)
  qed
  also
  have "sm_ord (B g  (λy'. C y' f)) (B g  (λy'. C y' g))" (is "sm_ord ?L ?R")
  proof (rule sm_ordI)
    fix h
    show "execute ?L h = NT  execute ?L h = execute ?R h"
    proof (cases "execute (B g) h")
      case (n a s)
      then have "execute ?L h = execute (C a f) s" "execute ?R h = execute (C a g) s"
        by (auto simp: execute_bind_case)
      with 2[of a] show ?thesis by (auto elim: sm_ordE)
    next
      case (e e)
      then show ?thesis by (simp add: execute_bind_case)
    next
      case t
      then have "execute ?L h = NT" by (auto simp: execute_bind_case)
      thus ?thesis ..
    qed
  qed
  finally (sm.leq_trans)
  show "sm_ord (B f  (λy. C y f)) (B g  (λy'. C y' g))" .
qed

lemma throw_monad_mono[mono]: "mono_sm (λ_. throw e)"
  by (simp add: monotoneI sm_ordI)

lemma return_monad_mono[mono]: "mono_sm (λ_. return x)"
  by (simp add: monotoneI sm_ordI)

lemma option_monad_mono[mono]: "mono_sm (λ_. option E x)"
  by (simp add: monotoneI sm_ordI)

definition exc:: "('a, 'b, 'c) state_monad  ('a, 'b, 'c) state_monad"
  where "exc m  create (λs. case execute m s of Normal (v,s')  Normal (v, s')
                                               | Exception (e, s')  Exception (e, s)
                                               | NT  NT)"

lemma exc_mono[mono]:
  fixes m::"('b  ('c, 'e, 'f) state_monad)  ('x, 'y, 'z) state_monad"
  assumes mf: "mono_sm (λcall. (m call))"
  shows "mono_sm (λcall. (exc (m call)))"

proof (rule monotoneI)
  fix f g :: "'b  ('c, 'e, 'f) state_monad"
  assume fg: "sm.le_fun f g"
  then have 1: "sm_ord (m f) (m g)" using mf by (auto dest: monotoneD)
  show "sm_ord (exc (m f)) (exc (m g))"
  proof (rule sm_ordI)
    fix h
    show "execute (exc (m f)) h = NT  execute (exc (m f)) h = execute (exc (m g)) h"
    proof (rule sm_ordE[OF 1, of h])
      assume "execute (m f) h = NT"
      then show "execute (exc (m f)) h = NT  execute (exc (m f)) h = execute (exc (m g)) h" unfolding exc_def by (simp add:execute_simps)
    next
      assume "execute (m f) h = execute (m g) h"
      then show "execute (exc (m f)) h = NT  execute (exc (m f)) h = execute (exc (m g)) h" unfolding exc_def by (simp add:execute_simps)
    qed
  qed
qed


end