Theory Binary_Tree_Monad

section ‹Binary tree monad›

theory Binary_Tree_Monad
imports Monad
begin

subsection ‹Type definition›

tycondef 'abtree =
  Leaf (lazy "'a") | Node (lazy "'abtree") (lazy "'abtree")

lemma coerce_btree_abs [simp]: "coerce(btree_absx) = btree_abs(coercex)"
apply (simp add: btree_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_btree)
done

lemma coerce_Leaf [simp]: "coerce(Leafx) = Leaf(coercex)"
unfolding Leaf_def by simp

lemma coerce_Node [simp]: "coerce(Nodexsys) = Node(coercexs)(coerceys)"
unfolding Node_def by simp

lemma fmapU_btree_simps [simp]:
  "fmapUf(::udombtree) = "
  "fmapUf(Leafx) = Leaf(fx)"
  "fmapUf(Nodexsys) = Node(fmapUfxs)(fmapUfys)"
unfolding fmapU_btree_def btree_map_def
apply (subst fix_eq, simp)
apply (subst fix_eq, simp add: Leaf_def)
apply (subst fix_eq, simp add: Node_def)
done

subsection ‹Class instance proofs›

instance btree :: "functor"
apply standard
apply (induct_tac xs rule: btree.induct, simp_all)
done

instantiation btree :: monad
begin

definition
  "returnU = Leaf"

fixrec bindU_btree :: "udombtree  (udom  udombtree)  udombtree"
  where "bindU_btree(Leafx)k = kx"
  | "bindU_btree(Nodexsys)k =
      Node(bindU_btreexsk)(bindU_btreeysk)"

lemma bindU_btree_strict [simp]: "bindUk = (::udombtree)"
by fixrec_simp

instance proof
  fix x :: "udom"
  fix f :: "udom  udom"
  fix h k :: "udom  udombtree"
  fix xs :: "udombtree"
  show "fmapUfxs = bindUxs(Λ x. returnU(fx))"
    by (induct xs rule: btree.induct, simp_all add: returnU_btree_def)
  show "bindU(returnUx)k = kx"
    by (simp add: returnU_btree_def)
  show "bindU(bindUxsh)k = bindUxs(Λ x. bindU(hx)k)"
    by (induct xs rule: btree.induct) simp_all
qed

end

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_btree_simps [simp]:
  "fmapf(::'abtree) = "
  "fmapf(Leafx) = Leaf(fx)"
  "fmapf(Nodexsys) = Node(fmapfxs)(fmapfys)"
unfolding fmap_def by simp_all

lemma bind_btree_simps [simp]:
  "bind(::'abtree)k = "
  "bind(Leafx)k = kx"
  "bind(Nodexsys)k = Node(bindxsk)(bindysk)"
unfolding bind_def
by (simp_all add: coerce_simp)

lemma return_btree_def:
  "return = Leaf"
unfolding return_def returnU_btree_def
by (simp add: coerce_simp eta_cfun)

lemma join_btree_simps [simp]:
  "join(::'abtreebtree) = "
  "join(Leafxs) = xs"
  "join(Nodexssyss) = Node(joinxss)(joinyss)"
unfolding join_def by simp_all

end