Theory Monad

section ‹Monad Class›

theory Monad
imports Functor
begin

subsection ‹Class definition›

text ‹In Haskell, class \emph{Monad} is defined as follows:›

text_raw ‹
\begin{verbatim}
class Monad m where
  return :: a -> m a
  (>>=) :: m a -> (a -> m b) -> m b
\end{verbatim}
›

text ‹We formalize class monad› in a manner similar to the
functor› class: We fix monomorphic versions of the class
constants, replacing type variables with udom›, and assume
monomorphic versions of the class axioms.›

text ‹Because the monad laws imply the composition rule for fmap›, we declare prefunctor› as the superclass, and separately
prove a subclass relationship with functor›.›

class monad = prefunctor +
  fixes returnU :: "udom  udom'a::tycon"
  fixes bindU :: "udom'a  (udom  udom'a)  udom'a"
  assumes fmapU_eq_bindU:
    "f xs. fmapUfxs = bindUxs(Λ x. returnU(fx))"
  assumes bindU_returnU:
    "f x. bindU(returnUx)f = fx"
  assumes bindU_bindU:
    "xs f g. bindU(bindUxsf)g = bindUxs(Λ x. bindU(fx)g)"

instance monad  "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom'a"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    by (simp add: fmapU_eq_bindU bindU_bindU bindU_returnU)
qed

text ‹As with fmap›, we define the polymorphic return›
and bind› by coercion from the monomorphic returnU› and
bindU›.›

definition return :: "'a  'a'm::monad"
  where "return = coerce(returnU :: udom  udom'm)"

definition bind :: "'a'm::monad  ('a  'b'm)  'b'm"
  where "bind = coerce(bindU :: udom'm  _)"

abbreviation bind_syn :: "'a'm::monad  ('a  'b'm)  'b'm" (infixl "" 55)
  where "m  f  bindmf"

subsection ‹Naturality of bind and return›

text ‹The three class axioms imply naturality properties of returnU› and bindU›, i.e., that both commute with fmapU›.›

lemma fmapU_returnU [coerce_simp]:
  "fmapUf(returnUx) = returnU(fx)"
by (simp add: fmapU_eq_bindU bindU_returnU)

lemma fmapU_bindU [coerce_simp]:
  "fmapUf(bindUmk) = bindUm(Λ x. fmapUf(kx))"
by (simp add: fmapU_eq_bindU bindU_bindU)

lemma bindU_fmapU:
  "bindU(fmapUfxs)k = bindUxs(Λ x. k(fx))"
by (simp add: fmapU_eq_bindU bindU_returnU bindU_bindU)

subsection ‹Polymorphic versions of class assumptions›

lemma monad_fmap:
  fixes xs :: "'a'm::monad" and f :: "'a  'b"
  shows "fmapfxs = xs  (Λ x. return(fx))"
unfolding bind_def return_def fmap_def
by (simp add: coerce_simp fmapU_eq_bindU bindU_returnU)

lemma monad_left_unit [simp]: "(returnx  f) = (fx)"
unfolding bind_def return_def
by (simp add: coerce_simp bindU_returnU)

lemma bind_bind:
  fixes m :: "'a'm::monad"
  shows "((m  f)  g) = (m  (Λ x. fx  g))"
unfolding bind_def
by (simp add: coerce_simp bindU_bindU)

subsection ‹Derived rules›

text ‹The following properties can be derived using only the
abstract monad laws.›

lemma monad_right_unit [simp]: "(m  return) = m"
 apply (subgoal_tac "fmapIDm = m")
  apply (simp only: monad_fmap)
  apply (simp add: eta_cfun)
 apply simp
done

lemma fmap_return: "fmapf(returnx) = return(fx)"
by (simp add: monad_fmap)

lemma fmap_bind: "fmapf(bindxsk) = bindxs(Λ x. fmapf(kx))"
by (simp add: monad_fmap bind_bind)

lemma bind_fmap: "bind(fmapfxs)k = bindxs(Λ x. k(fx))"
by (simp add: monad_fmap bind_bind)

text ‹Bind is strict in its first argument, if its second argument
is a strict function.›

lemma bind_strict:
  assumes "k = " shows "  k = "
proof -
  have "  k  return  k"
    by (intro monofun_cfun below_refl minimal)
  thus "  k = "
    by (simp add: assms)
qed

lemma congruent_bind:
  "(m. m  k1 = m  k2) = (k1 = k2)"
 apply (safe, rule cfun_eqI)
 apply (drule_tac x="returnx" in spec, simp)
done

subsection ‹Laws for join›

definition join :: "('a'm)'m  'a'm::monad"
  where "join  Λ m. m  (Λ x. x)"

lemma join_fmap_fmap: "join(fmap(fmapf)xss) = fmapf(joinxss)"
by (simp add: join_def monad_fmap bind_bind)

lemma join_return: "join(returnxs) = xs"
by (simp add: join_def)

lemma join_fmap_return: "join(fmapreturnxs) = xs"
by (simp add: join_def monad_fmap eta_cfun bind_bind)

lemma join_fmap_join: "join(fmapjoinxsss) = join(joinxsss)"
by (simp add: join_def monad_fmap bind_bind)

lemma bind_def2: "m  k = join(fmapkm)"
by (simp add: join_def monad_fmap eta_cfun bind_bind)

subsection ‹Equivalence of monad laws and fmap/join laws›

lemma "(returnx  f) = (fx)"
by (simp only: bind_def2 fmap_return join_return)

lemma "(m  return) = m"
by (simp only: bind_def2 join_fmap_return)

lemma "((m  f)  g) = (m  (Λ x. fx  g))"
 apply (simp only: bind_def2)
 apply (subgoal_tac "join(fmapg(join(fmapfm))) =
    join(fmapjoin(fmap(fmapg)(fmapfm)))")
  apply (simp add: fmap_fmap)
 apply (simp add: join_fmap_join join_fmap_fmap)
done

subsection ‹Simplification of coercions›

text ‹We configure rewrite rules that push coercions inwards, and
reduce them to coercions on simpler types.›

lemma coerce_return [coerce_simp]:
  "COERCE('a'm,'b'm::monad)(returnx) = return(COERCE('a,'b)x)"
by (simp add: coerce_functor fmap_return)

lemma coerce_bind [coerce_simp]:
  fixes m :: "'a'm::monad" and k :: "'a  'b'm"
  shows "COERCE('b'm,'c'm)(m  k) = m  (Λ x. COERCE('b'm,'c'm)(kx))"
by (simp add: coerce_functor fmap_bind)

lemma bind_coerce [coerce_simp]:
  fixes m :: "'a'm::monad" and k :: "'b  'c'm"
  shows "COERCE('a'm,'b'm)m  k = m  (Λ x. k(COERCE('a,'b)x))"
by (simp add: coerce_functor bind_fmap)

end