Theory Monomorphic_Monad

(*  Title:      Monomorphic_Monad.thy
    Author:     Andreas Lochbihler, ETH Zurich *)

theory Monomorphic_Monad imports
  "HOL-Probability.Probability"
  "HOL-Library.Multiset"
  "HOL-Library.Countable_Set_Type"
begin

section ‹Preliminaries›

lemma (in comp_fun_idem) fold_set_union:
  " finite A; finite B   Finite_Set.fold f x (A  B) = Finite_Set.fold f (Finite_Set.fold f x A) B"
by(induction A arbitrary: x rule: finite_induct)(simp_all add: fold_insert_idem2 del: fold_insert_idem)

lemma (in comp_fun_idem) ffold_set_union: "ffold f x (A |∪| B) = ffold f (ffold f x A) B"
including fset.lifting by(transfer fixing: f)(rule fold_set_union)

lemma relcompp_top_top [simp]: "top OO top = top"
by(auto simp add: fun_eq_iff)

attribute_setup locale_witness = Scan.succeed Locale.witness_add

named_theorems monad_unfold "Defining equations for overloaded monad operations"

context includes lifting_syntax begin

inductive rel_itself :: "'a itself  'b itself  bool"
where "rel_itself TYPE(_) TYPE(_)"

lemma type_parametric [transfer_rule]: "rel_itself TYPE('a) TYPE('b)"
by(simp add: rel_itself.simps)
lemma plus_multiset_parametric [transfer_rule]:
  "(rel_mset A ===> rel_mset A ===> rel_mset A) (+) (+)"
  apply(rule rel_funI)+
  subgoal premises prems using prems by induction(auto intro: rel_mset_Plus)
  done

lemma Mempty_parametric [transfer_rule]: "rel_mset A {#} {#}"
  by(fact rel_mset_Zero)

lemma fold_mset_parametric:
  assumes 12: "(A ===> B ===> B) f1 f2"
  and "comp_fun_commute f1" "comp_fun_commute f2"
  shows "(B ===> rel_mset A ===> B) (fold_mset f1) (fold_mset f2)"
proof(rule rel_funI)+
  interpret f1: comp_fun_commute f1 by fact
  interpret f2: comp_fun_commute f2 by fact

  show "B (fold_mset f1 z1 X) (fold_mset f2 z2 Y)" 
    if "rel_mset A X Y" "B z1 z2" for z1 z2 X Y
    using that(1) by(induction RA X Y)(simp_all add: that(2) 12[THEN rel_funD, THEN rel_funD])
qed

lemma rel_fset_induct [consumes 1, case_names empty step, induct pred: rel_fset]:
  assumes XY: "rel_fset A X Y"
    and empty: "P {||} {||}"
    and step: "X Y x y.  rel_fset A X Y; P X Y; A x y; x |∉| X  y |∉| Y   P (finsert x X) (finsert y Y)"
  shows "P X Y"
proof -
  from XY obtain Z where X: "X = fst |`| Z" and Y: "Y = snd |`| Z" and Z: "fBall Z (λ(x, y). A x y)"
    unfolding fset.in_rel by auto
  from Z show ?thesis unfolding X Y
  proof(induction Z)
    case (insert xy Z)
    obtain x y where [simp]: "xy = (x, y)" by(cases xy)
    show ?case using insert
      apply(cases "x |∈| fst |`| Z  y |∈| snd |`| Z")
       apply(simp add: finsert_absorb)
      apply(auto intro!: step simp add: fset.in_rel; blast)
      done
  qed(simp add: assms)
qed

lemma ffold_parametric:
  assumes 12: "(A ===> B ===> B) f1 f2"
  and "comp_fun_idem f1" "comp_fun_idem f2"
  shows "(B ===> rel_fset A ===> B) (ffold f1) (ffold f2)"
proof(rule rel_funI)+
  interpret f1: comp_fun_idem f1 by fact
  interpret f2: comp_fun_idem f2 by fact

  show "B (ffold f1 z1 X) (ffold f2 z2 Y)" 
    if "rel_fset A X Y" "B z1 z2" for z1 z2 X Y
    using that(1) by(induction)(simp_all add: that(2) 12[THEN rel_funD, THEN rel_funD])
qed

end

lemma rel_set_Grp: "rel_set (BNF_Def.Grp A f) = BNF_Def.Grp {X. X  A} (image f)"
  by(auto simp add: fun_eq_iff Grp_def rel_set_def)

context includes cset.lifting begin

lemma cUNION_assoc: "cUNION (cUNION A f) g = cUNION A (λx. cUNION (f x) g)"
  by transfer auto

lemma cUnion_cempty [simp]: "cUnion cempty = cempty"
  by transfer simp

lemma cUNION_cempty [simp]: "cUNION cempty f = cempty"
  by simp

lemma cUnion_cinsert: "cUnion (cinsert x A) = cUn x (cUnion A)"
  by transfer simp

lemma cUNION_cinsert: "cUNION (cinsert x A) f = cUn (f x) (cUNION A f)"
  by (simp add: cUnion_cinsert)

lemma cUnion_csingle [simp]: "cUnion (csingle x) = x"
  by (simp add: cUnion_cinsert)

lemma cUNION_csingle [simp]: "cUNION (csingle x) f = f x"
  by simp

lemma cUNION_csingle2 [simp]: "cUNION A csingle = A"
  by (fact cUN_csingleton)

lemma cUNION_cUn: "cUNION (cUn A B) f = cUn (cUNION A f) (cUNION B f)"
  by simp

lemma cUNION_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_cset A ===> (A ===> rel_cset B) ===> rel_cset B) cUNION cUNION"
  unfolding rel_fun_def by transfer(blast intro: rel_set_UNION)

end

locale three =
  fixes tytok :: "'a itself"
  assumes ex_three: "x y z :: 'a. x  y  x  z  y  z"
begin

definition threes :: "'a × 'a × 'a" where
  "threes = (SOME (x, y, z). x  y  x  z  y  z)"
definition three1 :: 'a ("1") where "1 = fst threes"
definition three2 :: 'a ("2") where "2 = fst (snd threes)"
definition three3 :: 'a ("3") where "3 = snd (snd (threes))"

lemma three_neq_aux: "1  2" "1  3" "2  3"
proof -
  have "1  2  1  3  2  3"
    unfolding three1_def three2_def three3_def threes_def split_def
    by(rule someI_ex)(use ex_three in auto)
  then show "1  2" "1  3" "2  3" by simp_all
qed

lemmas three_neq [simp] = three_neq_aux three_neq_aux[symmetric]

inductive rel_12_23 :: "'a  'a  bool" where
  "rel_12_23 1 2"
| "rel_12_23 2 3"

lemma bi_unique_rel_12_23 [simp, transfer_rule]: "bi_unique rel_12_23"
  by(auto simp add: bi_unique_def rel_12_23.simps)

inductive rel_12_21 :: "'a  'a  bool" where
  "rel_12_21 1 2"
| "rel_12_21 2 1"

lemma bi_unique_rel_12_21 [simp, transfer_rule]: "bi_unique rel_12_21"
  by(auto simp add: bi_unique_def rel_12_21.simps)

end

lemma bernoulli_pmf_0: "bernoulli_pmf 0 = return_pmf False"
  by(rule pmf_eqI)(simp split: split_indicator)

lemma bernoulli_pmf_1: "bernoulli_pmf 1 = return_pmf True"
  by(rule pmf_eqI)(simp split: split_indicator)

lemma bernoulli_Not: "map_pmf Not (bernoulli_pmf r) = bernoulli_pmf (1 - r)"
  apply(rule pmf_eqI)
  apply(rewrite in "pmf _  = _" not_not[symmetric])
  apply(subst pmf_map_inj')
  apply(simp_all add: inj_on_def bernoulli_pmf.rep_eq min_def max_def)
  done

lemma pmf_eqI_avoid: "p = q" if "i. i  x  pmf p i = pmf q i"
proof(rule pmf_eqI)
  show "pmf p i = pmf q i" for i
  proof(cases "i = x")
    case [simp]: True
    have "pmf p i = measure_pmf.prob p {i}" by(simp add: measure_pmf_single)
    also have " = 1 - measure_pmf.prob p (UNIV - {i})"
      by(subst measure_pmf.prob_compl[unfolded space_measure_pmf]) simp_all
    also have "measure_pmf.prob p (UNIV - {i}) = measure_pmf.prob q (UNIV - {i})"
      unfolding integral_pmf[symmetric] by(rule Bochner_Integration.integral_cong)(auto intro: that)
    also have "1 -  = measure_pmf.prob q {i}"
      by(subst measure_pmf.prob_compl[unfolded space_measure_pmf]) simp_all
    also have " = pmf q i" by(simp add: measure_pmf_single)
    finally show ?thesis .
  next
    case False
    then show ?thesis by(rule that)
  qed
qed

section ‹Locales for monomorphic monads›

subsection ‹Plain monad›

type_synonym ('a, 'm) bind = "'m  ('a  'm)  'm"
type_synonym ('a, 'm) return = "'a  'm"

locale monad_base =
  fixes return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
begin

primrec sequence :: "'m list  ('a list  'm)  'm"
where
  "sequence [] f = f []"
| "sequence (x # xs) f = bind x (λa. sequence xs (f  (#) a))"

definition lift :: "('a  'a)  'm  'm"
where "lift f x = bind x (λx. return (f x))"

end

declare
  monad_base.sequence.simps [code]
  monad_base.lift_def [code]

context includes lifting_syntax begin

lemma sequence_parametric [transfer_rule]:
  "((M ===> (A ===> M) ===> M) ===> list_all2 M ===> (list_all2 A ===> M) ===> M) monad_base.sequence monad_base.sequence"
unfolding monad_base.sequence_def[abs_def] by transfer_prover

lemma lift_parametric [transfer_rule]:
  "((A ===> M) ===> (M ===> (A ===> M) ===> M) ===> (A ===> A) ===> M ===> M) monad_base.lift monad_base.lift"
unfolding monad_base.lift_def by transfer_prover

end

locale monad = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  assumes bind_assoc: "(x :: 'm) f g. bind (bind x f) g = bind x (λy. bind (f y) g)" 
  and return_bind: "x f. bind (return x) f = f x"
  and bind_return: "x. bind x return = x"
begin

lemma bind_lift [simp]: "bind (lift f x) g = bind x (g  f)"
by(simp add: lift_def bind_assoc return_bind o_def)

lemma lift_bind [simp]: "lift f (bind m g) = bind m (λx. lift f (g x))"
by(simp add: lift_def bind_assoc)

end

subsection ‹State›

type_synonym ('s, 'm) get = "('s  'm)  'm"
type_synonym ('s, 'm) put = "'s  'm  'm"

locale monad_state_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
begin

definition update :: "('s  's)  'm  'm"
where "update f m = get (λs. put (f s) m)"

end

declare monad_state_base.update_def [code]

lemma update_parametric [transfer_rule]: includes lifting_syntax shows  
  "(((S ===> M) ===> M) ===> (S ===> M ===> M) ===> (S ===> S) ===> M ===> M)
   monad_state_base.update monad_state_base.update"
unfolding monad_state_base.update_def by transfer_prover

locale monad_state = monad_state_base return bind get put + monad return bind 
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  +
  assumes put_get: "f. put s (get f) = put s (f s)"
  and get_get: "f. get (λs. get (f s)) = get (λs. f s s)"
  and put_put: "put s (put s' m) = put s' m"
  and get_put: "get (λs. put s m) = m"
  and get_const: "m. get (λ_. m) = m"
  and bind_get: "f g. bind (get f) g = get (λs. bind (f s) g)"
  and bind_put: "f. bind (put s m) f = put s (bind m f)"
begin

lemma put_update: "put s (update f m) = put (f s) m"
by(simp add: update_def put_get put_put)

lemma update_put: "update f (put s m) = put s m"
by(simp add: update_def put_put get_const)

lemma bind_update: "bind (update f m) g = update f (bind m g)"
by(simp add: update_def bind_get bind_put)

lemma update_get: "update f (get g) = get (update f  g  f)"
by(simp add: update_def put_get get_get o_def) 
 
lemma update_const: "update (λ_. s) m = put s m"
by(simp add: update_def get_const)

lemma update_update: "update f (update g m) = update (g  f) m"
by(simp add: update_def put_get put_put)

lemma update_id: "update id m = m"
by(simp add: update_def get_put)

end

subsection ‹Failure›

type_synonym 'm fail = "'m"

locale monad_fail_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes fail :: "'m fail"
begin

definition assert :: "('a  bool)  'm  'm"
where "assert P m = bind m (λx. if P x then return x else fail)"

end

declare monad_fail_base.assert_def [code]

lemma assert_parametric [transfer_rule]: includes lifting_syntax shows
  "((A ===> M) ===> (M ===> (A ===> M) ===> M) ===> M ===> (A ===> (=)) ===> M ===> M)
   monad_fail_base.assert monad_fail_base.assert"
unfolding monad_fail_base.assert_def by transfer_prover

locale monad_fail = monad_fail_base return bind fail + monad return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and fail :: "'m fail"
  +
  assumes fail_bind: "f. bind fail f = fail"
begin

lemma assert_fail: "assert P fail = fail"
by(simp add: assert_def fail_bind)

end

subsection ‹Exception›

type_synonym 'm catch = "'m  'm  'm"

locale monad_catch_base = monad_fail_base return bind fail
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and fail :: "'m fail"
  +
  fixes catch :: "'m catch"

locale monad_catch = monad_catch_base return bind fail catch + monad_fail return bind fail
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and fail :: "'m fail"
  and catch :: "'m catch"
  +
  assumes catch_return: "catch (return x) m = return x"
  and catch_fail: "catch fail m = m"
  and catch_fail2: "catch m fail = m"
  and catch_assoc: "catch (catch m m') m'' = catch m (catch m' m'')"

locale monad_catch_state = monad_catch return bind fail catch + monad_state return bind get put
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and fail :: "'m fail"
  and catch :: "'m catch"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  +
  assumes catch_get: "catch (get f) m = get (λs. catch (f s) m)"
  and catch_put: "catch (put s m) m' = put s (catch m m')"
begin

lemma catch_update: "catch (update f m) m' = update f (catch m m')"
by(simp add: update_def catch_get catch_put)

end

subsection ‹Reader›

text ‹As ask takes a continuation, we have to restate the monad laws for ask›

type_synonym ('r, 'm) ask = "('r  'm)  'm"

locale monad_reader_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes ask :: "('r, 'm) ask"

locale monad_reader = monad_reader_base return bind ask + monad return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and ask :: "('r, 'm) ask"
  +
  assumes ask_ask: "f. ask (λr. ask (f r)) = ask (λr. f r r)"
  and ask_const: "ask (λ_. m) = m"
  and bind_ask: "f g. bind (ask f) g = ask (λr. bind (f r) g)"
  and bind_ask2: "f. bind m (λx. ask (f x)) = ask (λr. bind m (λx. f x r))"
begin

lemma ask_bind: "ask (λr. bind (f r) (g r)) = bind (ask f) (λx. ask (λr. g r x))"
by(simp add: bind_ask bind_ask2 ask_ask)

end

locale monad_reader_state =
  monad_reader return bind ask +
  monad_state return bind get put
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and ask :: "('r, 'm) ask"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  +
  assumes ask_get: "f. ask (λr. get (f r)) = get (λs. ask (λr. f r s))"
  and put_ask: "f. put s (ask f) = ask (λr. put s (f r))"

subsection ‹Probability›

type_synonym ('p, 'm) sample = "'p pmf  ('p  'm)  'm"

locale monad_prob_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes sample :: "('p, 'm) sample"

locale monad_prob = monad return bind + monad_prob_base return bind sample
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and sample :: "('p, 'm) sample"
  +
  assumes sample_const: "p m. sample p (λ_. m) = m"
  and sample_return_pmf: "x f. sample (return_pmf x) f = f x"
  and sample_bind_pmf: "p f g. sample (bind_pmf p f) g = sample p (λx. sample (f x) g)"
  and sample_commute: "p q f. sample p (λx. sample q (f x)) = sample q (λy. sample p (λx. f x y))"
  ― ‹We'd like to state that we can combine independent samples rather than just commute them, but that's not possible with a monomorphic sampling operation›
  and bind_sample1: "p f g. bind (sample p f) g = sample p (λx. bind (f x) g)"
  and bind_sample2: "m f p. bind m (λy. sample p (f y)) = sample p (λx. bind m (λy. f y x))"
  and sample_parametric: "R. bi_unique R  rel_fun (rel_pmf R) (rel_fun (rel_fun R (=)) (=)) sample sample"
begin

lemma sample_cong: "(x. x  set_pmf p  f x = g x)  sample p f = sample q g" if "p = q"
  by(rule sample_parametric[where R="eq_onp (λx. x  set_pmf p)", THEN rel_funD, THEN rel_funD])
    (simp_all add: bi_unique_def eq_onp_def rel_fun_def pmf.rel_refl_strong that)

end

text ‹We can implement binary probabilistic choice using @{term sample} provided that the sample space
  contains at least three elements.›

locale monad_prob3 = monad_prob return bind sample + three "TYPE('p)"
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and sample :: "('p, 'm) sample"
begin

definition pchoose :: "real  'm  'm  'm" where
  "pchoose r m m' = sample (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf r)) (λx. if x = 1 then m else m')"

abbreviation pchoose_syntax :: "'m  real  'm  'm" ("_  _  _" [100, 0, 100] 99) where
  "m  r  m'  pchoose r m m'"

lemma pchoose_0: "m  0  m' = m'"
  by(simp add: pchoose_def bernoulli_pmf_0 sample_return_pmf)

lemma pchoose_1: "m  1  m' = m"
  by(simp add: pchoose_def bernoulli_pmf_1 sample_return_pmf)

lemma pchoose_idemp: "m  r  m = m"
  by(simp add: pchoose_def sample_const)

lemma pchoose_bind1: "bind (m  r  m') f = bind m f  r  bind m' f"
  by(simp add: pchoose_def bind_sample1 if_distrib[where f="λm. bind m _"])

lemma pchoose_bind2: "bind m (λx. f x  p  g x) = bind m f  p  bind m g"
  by(auto simp add: pchoose_def bind_sample2 intro!: arg_cong2[where f=sample])

lemma pchoose_commute: "m  1 - r  m' = m'  r  m"
  apply(simp add: pchoose_def bernoulli_Not[symmetric] pmf.map_comp o_def)
  apply(rule sample_parametric[where R=rel_12_21, THEN rel_funD, THEN rel_funD])
  subgoal by(simp)
  subgoal by(rule pmf.map_transfer[where Rb="(=)", THEN rel_funD, THEN rel_funD])
            (simp_all add: rel_fun_def rel_12_21.simps pmf.rel_eq)
  subgoal by(simp add: rel_fun_def rel_12_21.simps)
  done

lemma pchoose_assoc: "m  p  (m'  q  m'') = (m  r  m')  s  m''" (is "?lhs = ?rhs")
  if "min 1 (max 0 p) = min 1 (max 0 r) * min 1 (max 0 s)"
  and "1 - min 1 (max 0 s) = (1 - min 1 (max 0 p)) * (1 - min 1 (max 0 q))"
proof -
  let ?f = "(λx. if x = 1 then m else if x = 2 then m' else m'')"
  let ?p = "bind_pmf (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf p))
     (λx. if x = 1 then return_pmf 1 else map_pmf (λb. if b then 2 else 3) (bernoulli_pmf q))"
  let ?q = "bind_pmf (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf s))
     (λx. if x = 1 then map_pmf (λb. if b then 1 else 2) (bernoulli_pmf r) else return_pmf 3)"

  have [simp]: "{x. ¬ x} = {False}" "{x. x} = {True}" by auto

  have "?lhs = sample ?p ?f"
    by(auto simp add: pchoose_def sample_bind_pmf if_distrib[where f="λx. sample x _"] sample_return_pmf rel_fun_def rel_12_23.simps pmf.rel_eq cong: if_cong intro!: sample_cong[OF refl] sample_parametric[where R="rel_12_23", THEN rel_funD, THEN rel_funD] pmf.map_transfer[where Rb="(=)", THEN rel_funD, THEN rel_funD])
  also have "?p = ?q"
  proof(rule pmf_eqI_avoid)
    fix i :: "'p"
    assume "i  2"
    then consider (one) "i = 1" | (three) "i = 3" | (other) "i  1" "i  2" "i  3" by metis
    then show "pmf ?p i = pmf ?q i"
    proof cases
      case [simp]: one
      have "pmf ?p i = measure_pmf.expectation (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf p)) (indicator {1})"
        unfolding pmf_bind
        by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
      also have " = min 1 (max 0 p)" 
        by(simp add: vimage_def)(simp add: measure_pmf_single bernoulli_pmf.rep_eq)
      also have " = min 1 (max 0 s) * min 1 (max 0 r)" using that(1) by simp
      also have " = measure_pmf.expectation (bernoulli_pmf s)
            (λx. indicator {True} x * pmf (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf r)) 1)"
        by(simp add: pmf_map vimage_def measure_pmf_single)(simp add:  bernoulli_pmf.rep_eq)
      also have " = pmf ?q i"
        unfolding pmf_bind integral_map_pmf
        by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
      finally show ?thesis .
    next
      case [simp]: three
      have "pmf ?p i = measure_pmf.expectation (bernoulli_pmf p)
            (λx. indicator {False} x * pmf (map_pmf (λb. if b then 2 else 3) (bernoulli_pmf q)) 3)"
        unfolding pmf_bind integral_map_pmf
        by(rule arg_cong2[where f=measure_pmf.expectation, OF refl])(auto simp add: fun_eq_iff pmf_eq_0_set_pmf)
      also have " = (1 - min 1 (max 0 p)) * (1 - min 1 (max 0 q))" 
        by(simp add: pmf_map vimage_def measure_pmf_single)(simp add:  bernoulli_pmf.rep_eq)
      also have " = 1 - min 1 (max 0 s)" using that(2) by simp
      also have " = measure_pmf.expectation (map_pmf (λb. if b then 1 else 2) (bernoulli_pmf s)) (indicator {2})"
        by(simp add: vimage_def)(simp add: measure_pmf_single bernoulli_pmf.rep_eq)
      also have " = pmf ?q i"
        unfolding pmf_bind
        by(rule Bochner_Integration.integral_cong_AE)(auto simp add: fun_eq_iff pmf_eq_0_set_pmf AE_measure_pmf_iff)
      finally show ?thesis .
    next
      case other
      then have "pmf ?p i = 0" "pmf ?q i = 0" by(auto simp add: pmf_eq_0_set_pmf)
      then show ?thesis by simp
    qed
  qed
  also have "sample ?q ?f = ?rhs"
    by(auto simp add: pchoose_def sample_bind_pmf if_distrib[where f="λx. sample x _"] sample_return_pmf cong: if_cong intro!: sample_cong[OF refl])
  finally show ?thesis .
qed

lemma pchoose_assoc': "m  p  (m'  q  m'') = (m  r  m')  s  m''"
  if "p = r * s" and "1 - s = (1 - p) * (1 - q)"
  and "0  p" "p  1" "0  q" "q  1" "0  r" "r  1" "0  s" "s  1"
  by(rule pchoose_assoc; use that in simp add: min_def max_def)

end    

locale monad_state_prob = monad_state return bind get put + monad_prob return bind sample
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  and sample :: "('p, 'm) sample"
  +
  assumes sample_get: "sample p (λx. get (f x)) = get (λs. sample p (λx. f x s))"
begin

lemma sample_put: "sample p (λx. put s (m x)) = put s (sample p m)"
proof -
  fix UU
  have "sample p (λx. put s (m x)) = sample p (λx. bind (put s (return UU)) (λ_. m x))"
    by(simp add: bind_put return_bind)
  also have " = bind (put s (return UU)) (λ_. sample p m)"
    by(simp add: bind_sample2)
  also have " = put s (sample p m)"
    by(simp add: bind_put return_bind)
  finally show ?thesis .
qed

lemma sample_update: "sample p (λx. update f (m x)) = update f (sample p m)"
by(simp add: update_def sample_get sample_put)

end

subsection ‹Nondeterministic choice›

subsubsection ‹Binary choice›

type_synonym 'm alt = "'m  'm  'm"

locale monad_alt_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes alt :: "'m alt"

locale monad_alt = monad return bind + monad_alt_base return bind alt
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and alt :: "'m alt"
  + ― ‹Laws taken from Gibbons, Hinze: Just do it›
  assumes alt_assoc: "alt (alt m1 m2) m3 = alt m1 (alt m2 m3)"
  and bind_alt1: "bind (alt m m') f = alt (bind m f) (bind m' f)"

locale monad_fail_alt = monad_fail return bind fail + monad_alt return bind alt
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and fail :: "'m fail"
  and alt :: "'m alt"
  +
  assumes alt_fail1: "alt fail m = m"
  and alt_fail2: "alt m fail = m"
begin

lemma assert_alt: "assert P (alt m m') = alt (assert P m) (assert P m')"
by(simp add: assert_def bind_alt1)

end

locale monad_state_alt = monad_state return bind get put + monad_alt return bind alt
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  and alt :: "'m alt"
  +
  assumes alt_get: "alt (get f) (get g) = get (λx. alt (f x) (g x))"
  and alt_put: "alt (put s m) (put s m') = put s (alt m m')"
  ― ‹Unlike for @{term sample}, we must require both @{text alt_get} and @{text alt_put} because
  we do not require that @{term bind} right-distributes over @{term alt}.›
begin

lemma alt_update: "alt (update f m) (update f m') = update f (alt m m')"
by(simp add: update_def alt_get alt_put)

end

subsubsection ‹Countable choice›

type_synonym ('c, 'm) altc = "'c cset  ('c  'm)  'm"

locale monad_altc_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes altc :: "('c, 'm) altc"
begin

definition fail :: "'m fail" where "fail = altc cempty (λ_. undefined)"

end

declare monad_altc_base.fail_def [code]

locale monad_altc = monad return bind + monad_altc_base return bind altc
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and altc :: "('c, 'm) altc"
  +
  assumes bind_altc1: "C g f. bind (altc C g) f = altc C (λc. bind (g c) f)"
  and altc_single: "x f. altc (csingle x) f = f x"
  and altc_cUNION: "C f g. altc (cUNION C f) g = altc C (λx. altc (f x) g)"
  ― ‹We do not assume @{text altc_const} like for @{text sample} because the choice set might be empty›
  and altc_parametric: "R. bi_unique R  rel_fun (rel_cset R) (rel_fun (rel_fun R (=)) (=)) altc altc"
begin

lemma altc_cong: "cBall C (λx. f x = g x)  altc C f = altc C g"
  apply(rule altc_parametric[where R="eq_onp (λx. cin x C)", THEN rel_funD, THEN rel_funD])
  subgoal by(simp add: bi_unique_def eq_onp_def)
  subgoal by(simp add: cset.rel_eq_onp eq_onp_same_args pred_cset_def cin_def)
  subgoal by(simp add: rel_fun_def eq_onp_def cBall_def cin_def)
  done

lemma monad_fail [locale_witness]: "monad_fail return bind fail"
proof
  show "bind fail f = fail" for f
    by(simp add: fail_def bind_altc1 cong: altc_cong)
qed

end

text ‹We can implement alt› via altc› only if we know that there are sufficiently
  many elements in the choice type @{typ 'c}. For the associativity law, we need at least
  three elements.›

locale monad_altc3 = monad_altc return bind altc + three "TYPE('c)"
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and altc :: "('c, 'm) altc"
begin

definition alt :: "'m alt"
where "alt m1 m2 = altc (cinsert 1 (csingle 2)) (λc. if c = 1 then m1 else m2)"

lemma monad_alt: "monad_alt return bind alt"
proof
  show "bind (alt m m') f = alt (bind m f) (bind m' f)" for m m' f
    by(simp add: alt_def bind_altc1 if_distrib[where f="λm. bind m _"])

  fix m1 m2 m3 :: 'm
  let ?C = "cUNION (cinsert 1 (csingle 2)) (λc. if c = 1 then cinsert 1 (csingle 2) else csingle 3)"
  let ?D = "cUNION (cinsert 1 (csingle 2)) (λc. if c = 1 then csingle 1 else cinsert 2 (csingle 3))"
  let ?f = "λc. if c = 1 then m1 else if c = 2 then m2 else m3"
  have "alt (alt m1 m2) m3 = altc ?C ?f"
    by (simp only: altc_cUNION) (auto simp add: alt_def altc_single intro!: altc_cong)
  also have "?C = ?D" including cset.lifting by transfer(auto simp add: insert_commute)
  also have "altc ?D ?f = alt m1 (alt m2 m3)"
    apply (simp only: altc_cUNION)
    apply (clarsimp simp add: alt_def altc_single intro!: altc_cong)
    apply (rule altc_parametric [where R="conversep rel_12_23", THEN rel_funD, THEN rel_funD])
    subgoal by simp
    subgoal including cset.lifting by transfer
      (simp add: rel_set_def rel_12_23.simps)
    subgoal by (simp add: rel_fun_def rel_12_23.simps)
    done
  finally show "alt (alt m1 m2) m3 = alt m1 (alt m2 m3)" .
qed

end

locale monad_state_altc =
  monad_state return bind get put +
  monad_altc return bind altc
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
  and altc :: "('c, 'm) altc"
  +
  assumes altc_get: "C f. altc C (λc. get (f c)) = get (λs. altc C (λc. f c s))"
  and altc_put: "C f. altc C (λc. put s (f c)) = put s (altc C f)"

subsection ‹Writer monad›

type_synonym ('w, 'm) tell = "'w  'm  'm"

locale monad_writer_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes tell :: "('w, 'm) tell"

locale monad_writer = monad_writer_base return bind tell + monad return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and tell :: "('w, 'm) tell"
  +
  assumes bind_tell: "w m f. bind (tell w m) f = tell w (bind m f)"

subsection ‹Resumption monad›

type_synonym ('o, 'i, 'm) pause = "'o  ('i  'm)  'm"

locale monad_resumption_base = monad_base return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  fixes pause :: "('o, 'i, 'm) pause"

locale monad_resumption = monad_resumption_base return bind pause + monad return bind 
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  and pause :: "('o, 'i, 'm) pause"
  +
  assumes bind_pause: "bind (pause out c) f = pause out (λi. bind (c i) f)"

subsection ‹Commutative monad›

locale monad_commute = monad return bind 
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  assumes bind_commute: "bind m (λx. bind m' (f x)) = bind m' (λy. bind m (λx. f x y))"

subsection ‹Discardable monad›

locale monad_discard = monad return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  assumes bind_const: "bind m (λ_. m') = m'"

subsection ‹Duplicable monad›

locale monad_duplicate = monad return bind
  for return :: "('a, 'm) return"
  and bind :: "('a, 'm) bind"
  +
  assumes bind_duplicate: "bind m (λx. bind m (f x)) = bind m (λx. f x x)"

section ‹Monad implementations›

subsection ‹Identity monad›

text ‹We need a type constructor such that we can overload the monad operations›

datatype 'a id = return_id ("extract": 'a)

lemmas return_id_parametric = id.ctr_transfer

lemma rel_id_unfold: 
  "rel_id A (return_id x) m'  (x'. m' = return_id x'  A x x')"
  "rel_id A m (return_id x')  (x. m = return_id x  A x x')"
  subgoal by(cases m'; simp)
  subgoal by(cases m; simp)
  done

lemma rel_id_expand: "M (extract m) (extract m')  rel_id M m m'"
  by(cases m; cases m'; simp)

subsubsection ‹Plain monad›

primrec bind_id :: "('a, 'a id) bind"
where "bind_id (return_id x) f = f x"

lemma extract_bind [simp]: "extract (bind_id x f) = extract (f (extract x))"
by(cases x) simp

lemma bind_id_parametric [transfer_rule]: includes lifting_syntax shows
  "(rel_id A ===> (A ===> rel_id A) ===> rel_id A) bind_id bind_id"
unfolding bind_id_def by transfer_prover

lemma monad_id [locale_witness]: "monad return_id bind_id"
proof
  show "bind_id (bind_id x f) g = bind_id x (λx. bind_id (f x) g)" 
    for x :: "'a id" and f :: "'a  'a id" and g :: "'a  'a id"
    by(rule id.expand) simp
  show "bind_id (return_id x) f = f x" for f :: "'a  'a id" and x
    by(rule id.expand) simp
  show "bind_id x return_id = x" for x :: "'a id"
    by(rule id.expand) simp
qed

lemma monad_commute_id [locale_witness]: "monad_commute return_id bind_id"
proof
  show "bind_id m (λx. bind_id m' (f x)) = bind_id m' (λy. bind_id m (λx. f x y))" for m m' :: "'a id" and f
    by(rule id.expand) simp
qed

lemma monad_discard_id [locale_witness]: "monad_discard return_id bind_id"
proof
  show "bind_id m (λ_. m') = m'" for m m' :: "'a id" by(rule id.expand) simp
qed

lemma monad_duplicate_id [locale_witness]: "monad_duplicate return_id bind_id"
proof
  show "bind_id m (λx. bind_id m (f x)) = bind_id m (λx. f x x)" for m :: "'a id" and f
    by(rule id.expand) simp
qed

subsection ‹Probability monad›

text ‹We don't know of a sensible probability monad transformer, so we define the plain probability monad.›

type_synonym 'a prob = "'a pmf"

lemma monad_prob [locale_witness]: "monad return_pmf bind_pmf"
by unfold_locales(simp_all add: bind_assoc_pmf bind_return_pmf bind_return_pmf')

lemma monad_prob_prob [locale_witness]: "monad_prob return_pmf bind_pmf bind_pmf"
  including lifting_syntax
proof
  show "bind_pmf p (λ_. m) = m" for p :: "'b pmf" and m :: "'a prob"
    by(rule bind_pmf_const)
  show "bind_pmf (return_pmf x) f = f x" for f :: "'b  'a prob" and x by(rule bind_return_pmf)
  show "bind_pmf (bind_pmf p f) g = bind_pmf p (λx. bind_pmf (f x) g)"
    for p :: "'b pmf" and f :: "'b  'b pmf" and g :: "'b  'a prob"
    by(rule bind_assoc_pmf)
  show "bind_pmf p (λx. bind_pmf q (f x)) = bind_pmf q (λy. bind_pmf p (λx. f x y))"
    for p q :: "'b pmf" and f :: "'b  'b  'a prob" by(rule bind_commute_pmf)
  show "bind_pmf (bind_pmf p f) g = bind_pmf p (λx. bind_pmf (f x) g)"
    for p :: "'b pmf" and f :: "'b  'a prob" and g :: "'a  'a prob"
    by(simp add: bind_assoc_pmf)
  show "bind_pmf m (λy. bind_pmf p (f y)) = bind_pmf p (λx. bind_pmf m (λy. f y x))"
    for m :: "'a prob" and p :: "'b pmf" and f :: "'a  'b  'a prob"
    by(rule bind_commute_pmf)
  show "(rel_pmf R ===> (R ===> (=)) ===> (=)) bind_pmf bind_pmf" for R :: "'b  'b  bool" 
    by transfer_prover
qed

lemma monad_commute_prob [locale_witness]: "monad_commute return_pmf bind_pmf"
proof
  show "bind_pmf m (λx. bind_pmf m' (f x)) = bind_pmf m' (λy. bind_pmf m (λx. f x y))"
    for m m' :: "'a prob" and f :: "'a  'a  'a prob"
    by(rule bind_commute_pmf)
qed

lemma monad_discard_prob [locale_witness]: "monad_discard return_pmf bind_pmf"
proof
  show "bind_pmf m (λ_. m') = m'" for m m' :: "'a pmf" by(simp)
qed

subsection ‹Resumption›

text ‹
  We cannot define a resumption monad transformer because the codatatype recursion would have to
  go through a type variable. If we plug in something like unbounded non-determinism, then the
  HOL type does not exist.
›

codatatype ('o, 'i, 'a) resumption = is_Done: Done (result: 'a) | Pause ("output": 'o) (resume: "'i  ('o, 'i, 'a) resumption")

subsubsection ‹Plain monad›

definition return_resumption :: "'a  ('o, 'i, 'a) resumption"
where "return_resumption = Done"

primcorec bind_resumption :: "('o, 'i, 'a) resumption  ('a  ('o, 'i, 'a) resumption)  ('o, 'i, 'a) resumption"
where "bind_resumption m f = (if is_Done m then f (result m) else Pause (output m) (λi. bind_resumption (resume m i) f))"

definition pause_resumption :: "'o  ('i  ('o, 'i, 'a) resumption)  ('o, 'i, 'a) resumption"
where "pause_resumption = Pause"

lemma is_Done_return_resumption [simp]: "is_Done (return_resumption x)"
by(simp add: return_resumption_def)

lemma result_return_resumption [simp]: "result (return_resumption x) = x"
by(simp add: return_resumption_def)

lemma monad_resumption [locale_witness]: "monad return_resumption bind_resumption"
proof
  show "bind_resumption (bind_resumption x f) g = bind_resumption x (λy. bind_resumption (f y) g)"
    for x :: "('o, 'i, 'a) resumption" and f g
    by(coinduction arbitrary: x f g rule: resumption.coinduct_strong) auto
  show "bind_resumption (return_resumption x) f = f x" for x and f :: "'a  ('o, 'i, 'a) resumption"
    by(rule resumption.expand)(simp_all add: return_resumption_def)
  show "bind_resumption x return_resumption = x" for x :: "('o, 'i, 'a) resumption"
    by(coinduction arbitrary: x rule: resumption.coinduct_strong) auto
qed

lemma monad_resumption_resumption [locale_witness]:
  "monad_resumption return_resumption bind_resumption pause_resumption"
proof
  show "bind_resumption (pause_resumption out c) f = pause_resumption out (λi. bind_resumption (c i) f)"
    for out c and f :: "'a  ('o, 'i, 'a) resumption"
    by(rule resumption.expand)(simp_all add: pause_resumption_def)
qed

subsection ‹Failure and exception monad transformer›

text ‹
  The phantom type variable @{typ 'a} is needed to avoid hidden polymorphism when overloading the
  monad operations for the failure monad transformer.
›

datatype (plugins del: transfer) (phantom_optionT: 'a, set_optionT: 'm) optionT =
  OptionT (run_option: 'm)
  for rel: rel_optionT' 
      map: map_optionT'

text ‹
  We define our own relator and mapper such that the phantom variable does not need any relation.
›

lemma phantom_optionT [simp]: "phantom_optionT x = {}"
by(cases x) simp

context includes lifting_syntax begin

lemma rel_optionT'_phantom: "rel_optionT' A = rel_optionT' top"
by(auto 4 4 intro: optionT.rel_mono antisym optionT.rel_mono_strong)

lemma map_optionT'_phantom: "map_optionT' f = map_optionT' undefined"
by(auto 4 4 intro: optionT.map_cong)

definition map_optionT :: "('m  'm')  ('a, 'm) optionT  ('b, 'm') optionT"
where "map_optionT = map_optionT' undefined"

definition rel_optionT :: "('m  'm'  bool)  ('a, 'm) optionT  ('b, 'm') optionT  bool"
where "rel_optionT = rel_optionT' top"

lemma rel_optionTE:
  assumes "rel_optionT M m m'"
  obtains x y where "m = OptionT x" "m' = OptionT y" "M x y"
using assms by(cases m; cases m'; simp add: rel_optionT_def)

lemma rel_optionT_simps [simp]: "rel_optionT M (OptionT m) (OptionT m')  M m m'"
by(simp add: rel_optionT_def)

lemma rel_optionT_eq [relator_eq]: "rel_optionT (=) = (=)"
by(auto simp add: fun_eq_iff rel_optionT_def intro: optionT.rel_refl_strong elim: optionT.rel_cases)

lemma rel_optionT_mono [relator_mono]: "rel_optionT A  rel_optionT B" if "A  B"
by(simp add: rel_optionT_def optionT.rel_mono that)

lemma rel_optionT_distr [relator_distr]: "rel_optionT A OO rel_optionT B = rel_optionT (A OO B)"
by(simp add: rel_optionT_def optionT.rel_compp[symmetric])

lemma rel_optionT_Grp: "rel_optionT (BNF_Def.Grp A f) = BNF_Def.Grp {x. set_optionT x  A} (map_optionT f)"
by(simp add: rel_optionT_def rel_optionT'_phantom[of "BNF_Def.Grp UNIV undefined", symmetric] optionT.rel_Grp map_optionT_def)

lemma OptionT_parametric [transfer_rule]: "(M ===> rel_optionT M) OptionT OptionT"
by(simp add: rel_fun_def rel_optionT_def)

lemma run_option_parametric [transfer_rule]: "(rel_optionT M ===> M) run_option run_option"
by(auto simp add: rel_fun_def rel_optionT_def elim: optionT.rel_cases)

lemma case_optionT_parametric [transfer_rule]:
  "((M ===> X) ===> rel_optionT M ===> X) case_optionT case_optionT"
by(auto simp add: rel_fun_def rel_optionT_def split: optionT.split)

lemma rec_optionT_parametric [transfer_rule]:
  "((M ===> X) ===> rel_optionT M ===> X) rec_optionT rec_optionT"
by(auto simp add: rel_fun_def elim: rel_optionTE)

end

subsubsection ‹Plain monad, failure, and exceptions›

context
  fixes return :: "('a option, 'm) return"
  and bind :: "('a option, 'm) bind"
begin

definition return_option :: "('a, ('a, 'm) optionT) return"
where "return_option x = OptionT (return (Some x))"

primrec bind_option :: "('a, ('a, 'm) optionT) bind"
where [code_unfold, monad_unfold]:
  "bind_option (OptionT x) f = 
   OptionT (bind x (λx. case x of None  return (None :: 'a option) | Some y  run_option (f y)))"

definition fail_option :: "('a, 'm) optionT fail"
where [code_unfold, monad_unfold]: "fail_option = OptionT (return None)"

definition catch_option :: "('a, 'm) optionT catch"
where "catch_option m h = OptionT (bind (run_option m) (λx. if x = None then run_option h else return x))"

lemma run_bind_option:
  "run_option (bind_option x f) = bind (run_option x) (λx. case x of None  return None | Some y  run_option (f y))"
by(cases x) simp

lemma run_return_option [simp]: "run_option (return_option x) = return (Some x)"
by(simp add: return_option_def)

lemma run_fail_option [simp]: "run_option fail_option = return None"
by(simp add: fail_option_def)

lemma run_catch_option [simp]: 
  "run_option (catch_option m1 m2) = bind (run_option m1) (λx. if x = None then run_option m2 else return x)"
by(simp add: catch_option_def)

context
  assumes monad: "monad return bind"
begin

interpretation monad return bind by(rule monad)

lemma monad_optionT [locale_witness]: "monad return_option bind_option" (is "monad ?return ?bind")
proof
  show "?bind (?bind x f) g = ?bind x (λx. ?bind (f x) g)"  for x f g
    by(rule optionT.expand)(auto simp add: bind_assoc run_bind_option return_bind intro!: arg_cong2[where f=bind] split: option.split)
  show "?bind (?return x) f = f x" for f x
    by(rule optionT.expand)(simp add: run_bind_option return_bind return_option_def)
  show "?bind x ?return = x" for x
    by(rule optionT.expand)(simp add: run_bind_option option.case_distrib[symmetric] case_option_id bind_return cong del: option.case_cong)
qed

lemma monad_fail_optionT [locale_witness]:
  "monad_fail return_option bind_option fail_option"
proof
  show "bind_option fail_option f = fail_option" for f
    by(rule optionT.expand)(simp add: run_bind_option return_bind)
qed

lemma monad_catch_optionT [locale_witness]:
  "monad_catch return_option bind_option fail_option catch_option"
proof
  show "catch_option (return_option x) m = return_option x" for x m
    by(rule optionT.expand)(simp add: return_bind)
  show "catch_option fail_option m = m" for m
    by(rule optionT.expand)(simp add: return_bind)
  show "catch_option m fail_option = m" for m
    by(rule optionT.expand)(simp add: bind_return if_distrib[where f="return", symmetric] cong del: if_weak_cong)
   show "catch_option (catch_option m m') m'' = catch_option m (catch_option m' m'')" for m m' m''
    by(rule optionT.expand)(auto simp add: bind_assoc fun_eq_iff return_bind intro!: arg_cong2[where f=bind])
qed

end

subsubsection ‹Reader›

context
  fixes ask :: "('r, 'm) ask"
begin

definition ask_option :: "('r, ('a, 'm) optionT) ask" 
where [code_unfold, monad_unfold]: "ask_option f = OptionT (ask (λr. run_option (f r)))"

lemma run_ask_option [simp]: "run_option (ask_option f) = ask (λr. run_option (f r))"
by(simp add: ask_option_def)

lemma monad_reader_optionT [locale_witness]:
  assumes "monad_reader return bind ask"
  shows "monad_reader return_option bind_option ask_option"
proof -
  interpret monad_reader return bind ask by(fact assms)
  show ?thesis
  proof
    show "ask_option (λr. ask_option (f r)) = ask_option (λr. f r r)" for f
      by(rule optionT.expand)(simp add: ask_ask)
    show "ask_option (λ_. x) = x" for x
      by(rule optionT.expand)(simp add: ask_const)
    show "bind_option (ask_option f) g = ask_option (λr. bind_option (f r) g)" for f g
      by(rule optionT.expand)(simp add: bind_ask run_bind_option)
    show "bind_option m (λx. ask_option (f x)) = ask_option (λr. bind_option m (λx. f x r))" for m f
      by(rule optionT.expand)(auto simp add: bind_ask2[symmetric] run_bind_option ask_const del: ext intro!: arg_cong2[where f=bind] ext split: option.split)
  qed
qed

end

subsubsection ‹State›

context
  fixes get :: "('s, 'm) get"
  and put :: "('s, 'm) put"
begin

definition get_option :: "('s, ('a, 'm) optionT) get"
where "get_option f = OptionT (get (λs. run_option (f s)))"

primrec put_option :: "('s, ('a, 'm) optionT) put"
where "put_option s (OptionT m) = OptionT (put s m)"

lemma run_get_option [simp]:
  "run_option (get_option f) = get (λs. run_option (f s))"
by(simp add: get_option_def)

lemma run_put_option [simp]:
  "run_option (put_option s m) = put s (run_option m)"
by(cases m)(simp)

context
  assumes state: "monad_state return bind get put"
begin

interpretation monad_state return bind get put by(fact state)

lemma monad_state_optionT [locale_witness]:
  "monad_state return_option bind_option get_option put_option"
proof
  show "put_option s (get_option f) = put_option s (f s)" for s f
    by(rule optionT.expand)(simp add: put_get)
  show "get_option (λs. get_option (f s)) = get_option (λs. f s s)" for f
    by(rule optionT.expand)(simp add: get_get)
  show "put_option s (put_option s' m) = put_option s' m" for s s' m
    by(rule optionT.expand)(simp add: put_put)
  show "get_option (λs. put_option s m) = m" for m
    by(rule optionT.expand)(simp add: get_put)
  show "get_option (λ_. m) = m" for m
    by(rule optionT.expand)(simp add: get_const)
  show "bind_option (get_option f) g = get_option (λs. bind_option (f s) g)" for f g
    by(rule optionT.expand)(simp add: bind_get run_bind_option)
  show "bind_option (put_option s m) f = put_option s (bind_option m f)" for s m f
    by(rule optionT.expand)(simp add: bind_put run_bind_option)
qed

lemma monad_catch_state_optionT [locale_witness]:
  "monad_catch_state return_option bind_option fail_option catch_option get_option put_option"
proof
  show "catch_option (get_option f) m = get_option (λs. catch_option (f s) m)" for f m
    by(rule optionT.expand)(simp add: bind_get)
  show "catch_option (put_option s m) m' = put_option s (catch_option m m')" for s m m'
    by(rule optionT.expand)(simp add: bind_put)
qed

end

subsubsection ‹Probability›

definition altc_sample_option :: "('x  ('b  'm)  'm)  'x  ('b  ('a, 'm) optionT)  ('a, 'm) optionT"
  where "altc_sample_option altc_sample p f = OptionT (altc_sample p (λx. run_option (f x)))"

lemma run_altc_sample_option [simp]: "run_option (altc_sample_option altc_sample p f) = altc_sample p (λx. run_option (f x))"
by(simp add: altc_sample_option_def)

context
  fixes sample :: "('p, 'm) sample"
begin

abbreviation sample_option :: "('p, ('a, 'm) optionT) sample"
where "sample_option  altc_sample_option sample"

lemma monad_prob_optionT [locale_witness]:
  assumes "monad_prob return bind sample"
  shows "monad_prob return_option bind_option sample_option"
proof -
  interpret monad_prob return bind sample by(fact assms)
  note sample_parametric[transfer_rule]
  show ?thesis including lifting_syntax
  proof
    show "sample_option p (λ_. x) = x" for p x
      by(rule optionT.expand)(simp add: sample_const)
    show "sample_option (return_pmf x) f = f x" for f x
      by(rule optionT.expand)(simp add: sample_return_pmf)
    show "sample_option (bind_pmf p f) g = sample_option p (λx. sample_option (f x) g)" for p f g
      by(rule optionT.expand)(simp add: sample_bind_pmf)
    show "sample_option p (λx. sample_option q (f x)) = sample_option q (λy. sample_option p (λx. f x y))" for p q f
      by(rule optionT.expand)(auto intro!: sample_commute)
    show "bind_option (sample_option p f) g = sample_option p (λx. bind_option (f x) g)" for p f g
      by(rule optionT.expand)(auto simp add: bind_sample1 run_bind_option)
    show "bind_option m (λy. sample_option p (f y)) = sample_option p (λx. bind_option m (λy. f y x))" for m p f
      by(rule optionT.expand)(auto simp add: bind_sample2[symmetric] run_bind_option sample_const del: ext intro!: arg_cong2[where f=bind] ext split: option.split)
    show  "(rel_pmf R ===> (R ===> (=)) ===> (=)) sample_option sample_option" 
      if [transfer_rule]: "bi_unique R" for R
      unfolding altc_sample_option_def by transfer_prover
  qed
qed

lemma monad_state_prob_optionT [locale_witness]:
  assumes "monad_state_prob return bind get put sample"
  shows "monad_state_prob return_option bind_option get_option put_option sample_option"
proof -
  interpret monad_state_prob return bind get put sample by fact
  show ?thesis
  proof
    show "sample_option p (λx. get_option (f x)) = get_option (λs. sample_option p (λx. f x s))" for p f
      by(rule optionT.expand)(simp add: sample_get)
  qed
qed

end

subsubsection ‹Writer›

context
  fixes tell :: "('w, 'm) tell"
begin

definition tell_option :: "('w, ('a, 'm) optionT) tell" 
where "tell_option w m = OptionT (tell w (run_option m))"

lemma run_tell_option [simp]: "run_option (tell_option w m) = tell w (run_option m)"
by(simp add: tell_option_def)

lemma monad_writer_optionT [locale_witness]:
  assumes "monad_writer return bind tell"
  shows "monad_writer return_option bind_option tell_option"
proof -
  interpret monad_writer return bind tell by fact
  show ?thesis
  proof
    show "bind_option (tell_option w m) f = tell_option w (bind_option m f)" for w m f
      by(rule optionT.expand)(simp add: run_bind_option bind_tell)
  qed
qed

end

subsubsection ‹Binary Non-determinism›

context
  fixes alt :: "'m alt"
begin

definition alt_option :: "('a, 'm) optionT alt"
where "alt_option m1 m2 = OptionT (alt (run_option m1) (run_option m2))"

lemma run_alt_option [simp]: "run_option (alt_option m1 m2) = alt (run_option m1) (run_option m2)"
by(simp add: alt_option_def)

lemma monad_alt_optionT [locale_witness]:
  assumes "monad_alt return bind alt"
  shows "monad_alt return_option bind_option alt_option"
proof -
  interpret monad_alt return bind alt by fact
  show ?thesis
  proof
    show "alt_option (alt_option m1 m2) m3 = alt_option m1 (alt_option m2 m3)" for m1 m2 m3
      by(rule optionT.expand)(simp add: alt_assoc)
    show "bind_option (alt_option m m') f = alt_option (bind_option m f) (bind_option m' f)" for m m' f
      by(rule optionT.expand)(simp add: bind_alt1 run_bind_option)
  qed
qed

text ‹
  The @{term fail} of @{typ "(_, _) optionT"} does not combine with @{term "alt"} of the inner monad
  because @{typ "(_, _) optionT"} injects failures with @{term "return None"} into the inner monad.
›

lemma monad_state_alt_optionT [locale_witness]:
  assumes "monad_state_alt return bind get put alt"
  shows "monad_state_alt return_option bind_option get_option put_option alt_option"
proof -
  interpret monad_state_alt return bind get put alt by fact
  show ?thesis
  proof
    show "alt_option (get_option f) (get_option g) = get_option (λx. alt_option (f x) (g x))"
      for f g by(rule optionT.expand)(simp add: alt_get)
    show "alt_option (put_option s m) (put_option s m') = put_option s (alt_option m m')"
      for s m m' by(rule optionT.expand)(simp add: alt_put)
  qed
qed

end

subsubsection ‹Countable Non-determinism›

context
  fixes altc :: "('c, 'm) altc"
begin

abbreviation altc_option :: "('c, ('a, 'm) optionT) altc"
where "altc_option  altc_sample_option altc"

lemma monad_altc_optionT [locale_witness]:
  assumes "monad_altc return bind altc"
  shows "monad_altc return_option bind_option altc_option"
proof -
  interpret monad_altc return bind altc by fact
  note altc_parametric[transfer_rule]
  show ?thesis including lifting_syntax
  proof
    show "bind_option (altc_option C g) f = altc_option C (λc</