Theory Writer_Transformer

section ‹Writer monad transformer›

theory Writer_Transformer
imports Writer_Monad
begin

subsection ‹Type definition›

text ‹Below is the standard Haskell definition of a writer monad
transformer:›

text_raw ‹
\begin{verbatim}
newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }
\end{verbatim}
›

text ‹In this development, since a lazy pair type is not pre-defined
in HOLCF, we will use an equivalent formulation in terms of our
previous \texttt{Writer} type:›

text_raw ‹
\begin{verbatim}
data Writer w a = Writer w a
newtype WriterT w m a = WriterT { runWriterT :: m (Writer w a) }
\end{verbatim}
›

text ‹We can translate this definition directly into HOLCF using
tycondef›. \medskip›

tycondef 'a('m::"functor",'w) writerT =
  WriterT (runWriterT :: "('a'w writer)'m")

lemma coerce_writerT_abs [simp]:
  "coerce(writerT_absx) = writerT_abs(coercex)"
apply (simp add: writerT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_writerT)
done

lemma coerce_WriterT [simp]: "coerce(WriterTk) = WriterT(coercek)"
unfolding WriterT_def by simp

lemma writerT_cases [case_names WriterT]:
  obtains k where "y = WriterTk"
proof
  show "y = WriterT(runWriterTy)"
    by (cases y, simp_all)
qed

lemma WriterT_runWriterT [simp]: "WriterT(runWriterTm) = m"
by (cases m rule: writerT_cases, simp)

lemma writerT_induct [case_names WriterT]:
  fixes P :: "'a('f::functor,'e) writerT  bool"
  assumes "k. P (WriterTk)"
  shows "P y"
by (cases y rule: writerT_cases, simp add: assms)

lemma writerT_eq_iff:
  "a = b  runWriterTa = runWriterTb"
apply (cases a rule: writerT_cases)
apply (cases b rule: writerT_cases)
apply simp
done

lemma writerT_below_iff:
  "a  b  runWriterTa  runWriterTb"
apply (cases a rule: writerT_cases)
apply (cases b rule: writerT_cases)
apply simp
done

lemma writerT_eqI:
  "runWriterTa = runWriterTb  a = b"
by (simp add: writerT_eq_iff)

lemma writerT_belowI:
  "runWriterTa  runWriterTb  a  b"
by (simp add: writerT_below_iff)

lemma runWriterT_coerce [simp]:
  "runWriterT(coercek) = coerce(runWriterTk)"
by (induct k rule: writerT_induct, simp)

subsection ‹Functor class instance›

lemma fmap_writer_def: "fmap = writer_mapID"
apply (rule cfun_eqI, rename_tac f)
apply (rule cfun_eqI, rename_tac x)
apply (case_tac x rule: writer.exhaust, simp_all)
apply (simp add: writer_map_def fix_const)
apply (simp add: writer_map_def fix_const Writer_def)
done

lemma fmapU_WriterT [simp]:
  "fmapUf(WriterTm) = WriterT(fmap(fmapf)m)"
unfolding fmapU_writerT_def writerT_map_def fmap_writer_def fix_const
  WriterT_def by simp

lemma runWriterT_fmapU [simp]:
  "runWriterT(fmapUfm) = fmap(fmapf)(runWriterTm)"
by (induct m rule: writerT_induct) simp

instance writerT :: ("functor", "domain") "functor"
proof
  fix f g :: "udom  udom" and xs :: "udom('a,'b) writerT"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    apply (induct xs rule: writerT_induct)
    apply (simp add: fmap_fmap eta_cfun)
    done
qed

subsection ‹Monad operations›

text ‹The writer monad transformer does not yield a monad in the
usual sense: We cannot prove a monad› class instance, because
type 'a⋅('m,'w) writerT› contains values that break the monad
laws. However, it turns out that such values are inaccessible: The
monad laws are satisfied by all values constructible from the abstract
operations.›

text ‹To explore the properties of the writer monad transformer
operations, we define them all as non-overloaded functions. \medskip
›

definition unitWT :: "'a  'a('m::monad,'w::monoid) writerT"
  where "unitWT = (Λ x. WriterT(return(Writermemptyx)))"

definition bindWT :: "'a('m::monad,'w::monoid) writerT  ('a  'b('m,'w) writerT)  'b('m,'w) writerT"
  where "bindWT = (Λ m k. WriterT(bind(runWriterTm)
    (Λ(Writerwx). bind(runWriterT(kx))(Λ(Writerw'y).
      return(Writer(mappendww')y)))))"

definition liftWT :: "'a'm  'a('m::monad,'w::monoid) writerT"
  where "liftWT = (Λ m. WriterT(fmap(Writermempty)m))"

definition tellWT :: "'a  'w  'a('m::monad,'w::monoid) writerT"
  where "tellWT = (Λ x w. WriterT(return(Writerwx)))"

definition fmapWT :: "('a  'b)  'a('m::monad,'w::monoid) writerT  'b('m,'w) writerT"
  where "fmapWT = (Λ f m. bindWTm(Λ x. unitWT(fx)))"

lemma runWriterT_fmap [simp]:
  "runWriterT(fmapfm) = fmap(fmapf)(runWriterTm)"
by (subst fmap_def, simp add: coerce_simp eta_cfun)

lemma runWriterT_unitWT [simp]:
  "runWriterT(unitWTx) = return(Writermemptyx)"
unfolding unitWT_def by simp

lemma runWriterT_bindWT [simp]:
  "runWriterT(bindWTmk) = bind(runWriterTm)
    (Λ(Writerwx). bind(runWriterT(kx))(Λ(Writerw'y).
      return(Writer(mappendww')y)))"
unfolding bindWT_def by simp

lemma runWriterT_liftWT [simp]:
  "runWriterT(liftWTm) = fmap(Writermempty)m"
unfolding liftWT_def by simp

lemma runWriterT_tellWT [simp]:
  "runWriterT(tellWTxw) = return(Writerwx)"
unfolding tellWT_def by simp

lemma runWriterT_fmapWT [simp]:
  "runWriterT(fmapWTfm) =
    runWriterTm  (Λ (Writerwx). return(Writerw(fx)))"
by (simp add: fmapWT_def bindWT_def mempty_right)

subsection ‹Laws›

text ‹The liftWT› function maps return› and
bind› on the inner monad to unitWT› and bindWT›, as expected. \medskip›

lemma liftWT_return:
  "liftWT(returnx) = unitWTx"
by (rule writerT_eqI, simp add: fmap_return)

lemma liftWT_bind:
  "liftWT(bindmk) = bindWT(liftWTm)(liftWT oo k)"
by (rule writerT_eqI)
   (simp add: monad_fmap bind_bind mempty_left)

text ‹The composition rule holds unconditionally for fmap. The fmap
function also interacts as expected with unit and bind. \medskip›

lemma fmapWT_fmapWT:
  "fmapWTf(fmapWTgm) = fmapWT(Λ x. f(gx))m"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp add: mempty_right)
done

lemma fmapWT_unitWT:
  "fmapWTf(unitWTx) = unitWT(fx)"
by (simp add: writerT_eq_iff mempty_right)

lemma fmapWT_bindWT:
  "fmapWTf(bindWTmk) = bindWTm(Λ x. fmapWTf(kx))"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac x, simp)
apply (case_tac x, simp add: bind_strict, simp add: bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac y, simp)
apply (case_tac y, simp add: bind_strict, simp add: mempty_right)
done

lemma bindWT_fmapWT:
  "bindWT(fmapWTfm)k = bindWTm(Λ x. k(fx))"
apply (simp add: writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, rename_tac x, simp)
apply (case_tac x, simp add: bind_strict, simp add: mempty_right)
done

text ‹The left unit monad law is not satisfied in general. \medskip›

lemma bindWT_unitWT_counterexample:
  fixes k :: "'a  'b('m::monad,'w::monoid) writerT"
  assumes 1: "kx = WriterT(return)"
  assumes 2: "return  ( :: ('b'w writer)'m::monad)"
  shows "bindWT(unitWTx)k  kx"
by (simp add: writerT_eq_iff mempty_left assms)

text ‹However, left unit is satisfied for inner monads with a strict
return› function.›

lemma bindWT_unitWT_restricted:
  fixes k :: "'a  'b('m::monad,'w::monoid) writerT"
  assumes "return = ( :: ('b'w writer)'m)"
  shows "bindWT(unitWTx)k = kx"
unfolding writerT_eq_iff
apply (simp add: mempty_left)
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms)
done

text ‹The associativity of bindWT› holds
unconditionally. \medskip›

lemma bindWT_bindWT:
  "bindWT(bindWTmh)k = bindWTm(Λ x. bindWT(hx)k)"
apply (rule writerT_eqI)
apply simp
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict)
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp, rename_tac y)
apply (case_tac y)
apply (simp add: bind_strict)
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp, rename_tac z)
apply (case_tac z)
apply (simp add: bind_strict)
apply (simp add: mappend_assoc)
done

text ‹The right unit monad law is not satisfied in general. \medskip›

lemma bindWT_unitWT_right_counterexample:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "m = WriterT(return)"
  assumes "return  ( :: ('a'w writer)'m)"
  shows "bindWTmunitWT  m"
by (simp add: writerT_eq_iff assms)

text ‹Right unit is satisfied for inner monads with a strict return› function. \medskip›

lemma bindWT_unitWT_right_restricted:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "return = ( :: ('a'w writer)'m)"
  shows "bindWTmunitWT = m"
unfolding writerT_eq_iff
apply simp
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms mempty_right)
done

subsection ‹Writer monad transformer invariant›

text ‹We inductively define a predicate that includes all values
that can be constructed from the standard writerT› operations.
\medskip›

inductive invar :: "'a('m::monad, 'w::monoid) writerT  bool"
  where invar_bottom: "invar "
  | invar_lub: "Y. chain Y; i. invar (Y i)  invar (i. Y i)"
  | invar_unitWT: "x. invar (unitWTx)"
  | invar_bindWT: "m k. invar m; x. invar (kx)  invar (bindWTmk)"
  | invar_tellWT: "x w. invar (tellWTxw)"
  | invar_liftWT: "m. invar (liftWTm)"

text ‹Right unit is satisfied for arguments built from standard
functions. \medskip›

lemma bindWT_unitWT_right_invar:
  fixes m :: "'a('m::monad,'w::monoid) writerT"
  assumes "invar m"
  shows "bindWTmunitWT = m"
using assms proof (induct set: invar)
  case invar_bottom thus ?case
    by (rule writerT_eqI, simp add: bind_strict)
next
  case invar_lub thus ?case
    by - (rule admD, simp, assumption, assumption)
next
  case invar_unitWT thus ?case
    by (rule writerT_eqI, simp add: bind_bind mempty_left)
next
  case invar_bindWT thus ?case
    apply (simp add: writerT_eq_iff bind_bind)
    apply (rule cfun_arg_cong, rule cfun_eqI, simp)
    apply (case_tac x, simp add: bind_strict, simp add: bind_bind)
    apply (rule cfun_arg_cong, rule cfun_eqI, simp, rename_tac y)
    apply (case_tac y, simp add: bind_strict, simp add: mempty_right)
    done
next
  case invar_tellWT thus ?case
    by (simp add: writerT_eq_iff mempty_right)
next
  case invar_liftWT thus ?case
    by (rule writerT_eqI, simp add: monad_fmap bind_bind mempty_right)
qed

text ‹Left unit is also satisfied for arguments built from standard
functions. \medskip›

lemma writerT_left_unit_invar_lemma:
  assumes "invar m"
  shows "runWriterTm  (Λ (Writerwx). return(Writerwx)) = runWriterTm"
using assms proof (induct m set: invar)
  case invar_bottom thus ?case
    by (simp add: bind_strict)
next
  case invar_lub thus ?case
    by - (rule admD, simp, assumption, assumption)
next
  case invar_unitWT thus ?case
    by simp
next
  case invar_bindWT thus ?case
    apply (simp add: bind_bind)
    apply (rule cfun_arg_cong)
    apply (rule cfun_eqI, simp, rename_tac n)
    apply (case_tac n, simp add: bind_strict)
    apply (simp add: bind_bind)
    apply (rule cfun_arg_cong)
    apply (rule cfun_eqI, simp, rename_tac p)
    apply (case_tac p, simp add: bind_strict)
    apply simp
    done
next
  case invar_tellWT thus ?case
    by simp
next
  case invar_liftWT thus ?case
    by (simp add: monad_fmap bind_bind)
qed

lemma bindWT_unitWT_invar:
  assumes "invar (kx)"
  shows "bindWT(unitWTx)k = kx"
apply (simp add: writerT_eq_iff mempty_left)
apply (rule writerT_left_unit_invar_lemma [OF assms])
done

subsection ‹Invariant expressed as a deflation›

definition invar' :: "'a('m::monad, 'w::monoid) writerT  bool"
  where "invar' m  fmapWTIDm = m"

text ‹All standard operations preserve the invariant.›

lemma invar'_bottom: "invar' "
  unfolding invar'_def by (simp add: writerT_eq_iff bind_strict)

lemma adm_invar': "adm invar'"
  unfolding invar'_def [abs_def] by simp

lemma invar'_unitWT: "invar' (unitWTx)"
  unfolding invar'_def by (simp add: writerT_eq_iff)

lemma invar'_bindWT: "invar' m; x. invar' (kx)  invar' (bindWTmk)"
  unfolding invar'_def
  apply (erule subst)
  apply (simp add: writerT_eq_iff)
  apply (simp add: bind_bind)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI, case_tac x)
  apply (simp add: bind_strict)
  apply simp
  apply (simp add: bind_bind)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI, rename_tac x, case_tac x)
  apply (simp add: bind_strict)
  apply simp
  done

lemma invar'_tellWT: "invar' (tellWTxw)"
  unfolding invar'_def by (simp add: writerT_eq_iff)

lemma invar'_liftWT: "invar' (liftWTm)"
  unfolding invar'_def by (simp add: writerT_eq_iff monad_fmap bind_bind)

text ‹Left unit is satisfied for arguments built from fmap.›

lemma bindWT_unitWT_fmapWT:
  "bindWT(unitWTx)(Λ x. fmapWTf(kx))
    = fmapWTf(kx)"
apply (simp add: fmapWT_def writerT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp_all add: bind_strict mempty_left)
done

text ‹Right unit is satisfied for arguments built from fmap.›

lemma bindWT_fmapWT_unitWT:
  shows "bindWT(fmapWTfm)unitWT = fmapWTfm"
apply (simp add: bindWT_fmapWT)
apply (simp add: fmapWT_def)
done

text ‹All monad laws are preserved by values satisfying the invariant.›

lemma invar'_right_unit: "invar' m  bindWTmunitWT = m"
unfolding invar'_def by (erule subst, rule bindWT_fmapWT_unitWT)

lemma invar'_monad_fmap:
  "invar' m  fmapWTfm = bindWTm(Λ x. unitWT(fx))"
  unfolding invar'_def
  by (erule subst, simp add: writerT_eq_iff mempty_right)

lemma invar'_bind_assoc:
  "invar' m; x. invar' (fx); y. invar' (gy)
     bindWT(bindWTmf)g = bindWTm(Λ x. bindWT(fx)g)"
  by (rule bindWT_bindWT)

end