Theory DetMonadLemmas

(*
 * Copyright 2014, NICTA
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(NICTA_BSD)
 *)

(* Zhe Hou: I removed the lemmas related to the definitions which are 
*  removed from DetMonad. That is, only those lemmas for deterministic
*  monads are kept here. 
*)

theory DetMonadLemmas
imports DetMonad
begin

section "General Lemmas Regarding the Deterministic State Monad"

subsection "Congruence Rules for the Function Package"

lemma bind_cong[fundef_cong]:
  " f = f'; v s s'. (v, s') = fst (f' s)  g v s' = g' v s'   f >>= g = f' >>= g'"
  apply (rule ext) 
  apply (auto simp: bind_def h1_def h2_def Let_def split_def intro: rev_image_eqI)
  done

lemma bind_apply_cong [fundef_cong]:
  " f s = f' s'; rv st. (rv, st) = fst (f' s')  g rv st = g' rv st 
        (f >>= g) s = (f' >>= g') s'"
  apply (simp add: bind_def h1_def h2_def)
  apply (auto simp: split_def intro: SUP_cong [OF refl] intro: rev_image_eqI)
  done

lemma bindE_cong[fundef_cong]:
  " M = M' ; v s s'. (Inr v, s') = fst (M' s)  N v s' = N' v s'   bindE M N = bindE M' N'"
  apply (simp add: bindE_def)
  apply (rule bind_cong)
   apply (rule refl)
  apply (unfold lift_def)
  apply (case_tac v, simp_all)
  done

lemma bindE_apply_cong[fundef_cong]:
  " f s = f' s'; rv st. (Inr rv, st) = fst (f' s')  g rv st = g' rv st  
   (f >>=E g) s = (f' >>=E g') s'"
  apply (simp add: bindE_def)
  apply (rule bind_apply_cong)
   apply assumption
  apply (case_tac rv, simp_all add: lift_def)
  done

lemma K_bind_apply_cong[fundef_cong]:
  " f st = f' st'   K_bind f arg st = K_bind f' arg' st'"
  by simp

lemma when_apply_cong[fundef_cong]:
  " C = C'; s = s'; C'  m s' = m' s'   whenE C m s = whenE C' m' s'"
  by (simp add: whenE_def)

lemma unless_apply_cong[fundef_cong]:
  " C = C'; s = s'; ¬ C'  m s' = m' s'   unlessE C m s = unlessE C' m' s'"
  by (simp add: unlessE_def)

lemma whenE_apply_cong[fundef_cong]:
  " C = C'; s = s'; C'  m s' = m' s'   whenE C m s = whenE C' m' s'"
  by (simp add: whenE_def)

lemma unlessE_apply_cong[fundef_cong]:
  " C = C'; s = s'; ¬ C'  m s' = m' s'   unlessE C m s = unlessE C' m' s'"
  by (simp add: unlessE_def)

subsection "Simplifying Monads"

lemma nested_bind [simp]:
  "do x  do y  f; return (g y) od; h x od =
   do y  f; h (g y) od"
  apply (clarsimp simp add: bind_def h1_def h2_def)
  apply (rule ext)
  apply (clarsimp simp add: Let_def split_def return_def)
  done

lemma assert_True [simp]:
  "assert True >>= f = f ()"
  by (simp add: assert_def)

lemma when_True_bind [simp]:
  "when1 True g >>= f = g >>= f"
  by (simp add: when1_def bind_def return_def)

lemma whenE_False_bind [simp]:
  "whenE False g >>=E f = f ()"
  by (simp add: whenE_def bindE_def returnOk_def lift_def)

lemma whenE_True_bind [simp]:
  "whenE True g >>=E f = g >>=E f"
  by (simp add: whenE_def bindE_def returnOk_def lift_def)

lemma when_True [simp]: "when1 True X = X"
  by (clarsimp simp: when1_def)

lemma when_False [simp]: "when1 False X = return ()"
  by (clarsimp simp: when1_def)

lemma unless_False [simp]: "unless False X = X"
  by (clarsimp simp: unless_def)

lemma unless_True [simp]: "unless True X = return ()"
  by (clarsimp simp: unless_def)

lemma unlessE_whenE:
  "unlessE P = whenE (~P)"
  by (rule ext)+ (simp add: unlessE_def whenE_def)

lemma unless_when:
  "unless P = when1 (~P)"
  by (rule ext)+ (simp add: unless_def when1_def)

lemma gets_to_return [simp]: "gets (λs. v) = return v"
  by (clarsimp simp: gets_def put_def get_def bind_def h1_def h2_def return_def)

lemma liftE_handleE' [simp]: "((liftE a) <handle2> b) = liftE a"
  apply (clarsimp simp: liftE_def handleE'_def)
  done

lemma liftE_handleE [simp]: "((liftE a) <handle> b) = liftE a"
  apply (unfold handleE_def)
  apply simp
  done

lemma condition_split:
  "P (condition C a b s) = ((((C s)  P (a s))  (¬ (C s)  P (b s))))"
  apply (clarsimp simp: condition_def)
  done

lemma condition_split_asm:
  "P (condition C a b s) = (¬ (C s  ¬ P (a s)  ¬ C s  ¬ P (b s)))"
  apply (clarsimp simp: condition_def)
  done

lemmas condition_splits = condition_split condition_split_asm

lemma condition_true_triv [simp]:
  "condition (λ_. True) A B = A"
  apply (rule ext)
  apply (clarsimp split: condition_splits)
  done

lemma condition_false_triv [simp]:
  "condition (λ_. False) A B = B"
  apply (rule ext)
  apply (clarsimp split: condition_splits)
  done

lemma condition_true: " P s   condition P A B s = A s"
  apply (clarsimp simp: condition_def)
  done

lemma condition_false: " ¬ P s   condition P A B s = B s"
  apply (clarsimp simp: condition_def)
  done

section "Low-level monadic reasoning"

lemma valid_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')  f  Q' "
  by (auto simp add: valid_def no_fail_def split: prod.splits)

lemma validNF_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄!) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')  f  Q' ⦄!"
  by (auto simp add: valid_def validNF_def no_fail_def split: prod.splits)

lemma validE_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄,  λrv s. E s0 rv s ) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')
         (rv s'. E s0 rv s'  E' rv s')  f  Q' ⦄,  E' "
  by (auto simp add: validE_def valid_def no_fail_def split: prod.splits sum.splits)

lemma validE_NF_make_schematic_post:
  "(s0.  λs. P s0 s  f  λrv s. Q s0 rv s ⦄,  λrv s. E s0 rv s ⦄!) 
    λs. s0. P s0 s  (rv s'. Q s0 rv s'  Q' rv s')
         (rv s'. E s0 rv s'  E' rv s')  f  Q' ⦄,  E' ⦄!"
  by (auto simp add: validE_NF_def validE_def valid_def no_fail_def split: prod.splits sum.splits)

lemma validNF_conjD1: " P  f  λrv s. Q rv s  Q' rv s ⦄!   P  f  Q ⦄!"
  by (fastforce simp: validNF_def valid_def no_fail_def)

lemma validNF_conjD2: " P  f  λrv s. Q rv s  Q' rv s ⦄!   P  f  Q' ⦄!"
  by (fastforce simp: validNF_def valid_def no_fail_def)

lemma exec_gets:
  "(gets f >>= m) s = m (f s) s"
  by (simp add: simpler_gets_def bind_def h1_def h2_def)

lemma in_gets:
  "(r, s') = fst (gets f s) = (r = f s  s' = s)"
  by (simp add: simpler_gets_def)

end