Theory Error_Transformer

section ‹Error monad transformer›

theory Error_Transformer
imports Error_Monad
begin

subsection ‹Type definition›

text ‹The error monad transformer is defined in Haskell by composing
the given monad with a standard error monad:›

text_raw ‹
\begin{verbatim}
data Error e a = Err e | Ok a
newtype ErrorT e m a = ErrorT { runErrorT :: m (Error e a) }
\end{verbatim}
›

text ‹We can formalize this definition directly using tycondef›. \medskip›

tycondef 'a('f::"functor",'e::"domain") errorT =
  ErrorT (runErrorT :: "('a'e error)'f")

lemma coerce_errorT_abs [simp]: "coerce(errorT_absx) = errorT_abs(coercex)"
apply (simp add: errorT_abs_def coerce_def)
apply (simp add: emb_prj_emb prj_emb_prj DEFL_eq_errorT)
done

lemma coerce_ErrorT [simp]: "coerce(ErrorTk) = ErrorT(coercek)"
unfolding ErrorT_def by simp

lemma errorT_cases [case_names ErrorT]:
  obtains k where "y = ErrorTk"
proof
  show "y = ErrorT(runErrorTy)"
    by (cases y, simp_all)
qed

lemma ErrorT_runErrorT [simp]: "ErrorT(runErrorTm) = m"
by (cases m rule: errorT_cases, simp)

lemma errorT_induct [case_names ErrorT]:
  fixes P :: "'a('f::functor,'e) errorT  bool"
  assumes "k. P (ErrorTk)"
  shows "P y"
by (cases y rule: errorT_cases, simp add: assms)

lemma errorT_eq_iff:
  "a = b  runErrorTa = runErrorTb"
apply (cases a rule: errorT_cases)
apply (cases b rule: errorT_cases)
apply simp
done

lemma errorT_eqI:
  "runErrorTa = runErrorTb  a = b"
by (simp add: errorT_eq_iff)

lemma runErrorT_coerce [simp]:
  "runErrorT(coercek) = coerce(runErrorTk)"
by (induct k rule: errorT_induct, simp)

subsection ‹Functor class instance›

lemma fmap_error_def: "fmap = error_mapID"
apply (rule cfun_eqI, rename_tac f)
apply (rule cfun_eqI, rename_tac x)
apply (case_tac x rule: error.exhaust, simp_all)
apply (simp add: error_map_def fix_const)
apply (simp add: error_map_def fix_const Err_def)
apply (simp add: error_map_def fix_const Ok_def)
done

lemma fmapU_ErrorT [simp]:
  "fmapUf(ErrorTm) = ErrorT(fmap(fmapf)m)"
unfolding fmapU_errorT_def errorT_map_def fmap_error_def fix_const ErrorT_def
by simp

lemma runErrorT_fmapU [simp]:
  "runErrorT(fmapUfm) = fmap(fmapf)(runErrorTm)"
by (induct m rule: errorT_induct) simp

instance errorT :: ("functor", "domain") "functor"
proof
  fix f g and xs :: "udom('a, 'b) errorT"
  show "fmapUf(fmapUgxs) = fmapU(Λ x. f(gx))xs"
    apply (induct xs rule: errorT_induct)
    apply (simp add: fmap_fmap eta_cfun)
    done
qed

subsection ‹Transfer properties to polymorphic versions›

lemma fmap_ErrorT [simp]:
  fixes f :: "'a  'b" and m :: "'a'e error('m::functor)"
  shows "fmapf(ErrorTm) = ErrorT(fmap(fmapf)m)"
unfolding fmap_def [where 'f="('m,'e) errorT"]
by (simp_all add: coerce_simp eta_cfun)

lemma runErrorT_fmap [simp]:
  fixes f :: "'a  'b" and m :: "'a('m::functor,'e) errorT"
  shows "runErrorT(fmapfm) = fmap(fmapf)(runErrorTm)"
using fmap_ErrorT [of f "runErrorTm"]
by simp

lemma errorT_fmap_strict [simp]:
  shows "fmapf(::'a('m::monad,'e) errorT) = "
by (simp add: errorT_eq_iff fmap_strict)

subsection ‹Monad operations›

text ‹The error monad transformer does not yield a monad in the
usual sense: We cannot prove a monad› class instance, because
type 'a⋅('m,'e) errorT› contains values that break the monad
laws. However, it turns out that such values are inaccessible: The
monad laws are satisfied by all values constructible from the abstract
operations.›

text ‹To explore the properties of the error monad transformer
operations, we define them all as non-overloaded functions. \medskip
›

definition unitET :: "'a  'a('m::monad,'e) errorT"
  where "unitET = (Λ x. ErrorT(return(Okx)))"

definition bindET :: "'a('m::monad,'e) errorT 
    ('a  'b('m,'e) errorT)  'b('m,'e) errorT"
  where "bindET = (Λ m k. ErrorT(bind(runErrorTm)
    (Λ n. case n of Erre  return(Erre) | Okx  runErrorT(kx))))"

definition liftET :: "'a'm::monad  'a('m,'e) errorT"
  where "liftET = (Λ m. ErrorT(fmapOkm))"

definition throwET :: "'e  'a('m::monad,'e) errorT"
  where "throwET = (Λ e. ErrorT(return(Erre)))"

definition catchET :: "'a('m::monad,'e) errorT 
    ('e  'a('m,'e) errorT)  'a('m,'e) errorT"
  where "catchET = (Λ m h. ErrorT(bind(runErrorTm)(Λ n. case n of
    Erre  runErrorT(he) | Okx  return(Okx))))"

definition fmapET :: "('a  'b) 
    'a('m::monad,'e) errorT  'b('m,'e) errorT"
  where "fmapET = (Λ f m. bindETm(Λ x. unitET(fx)))"

lemma runErrorT_unitET [simp]:
  "runErrorT(unitETx) = return(Okx)"
unfolding unitET_def by simp

lemma runErrorT_bindET [simp]:
  "runErrorT(bindETmk) = bind(runErrorTm)
    (Λ n. case n of Erre  return(Erre) | Okx  runErrorT(kx))"
unfolding bindET_def by simp

lemma runErrorT_liftET [simp]:
  "runErrorT(liftETm) = fmapOkm"
unfolding liftET_def by simp

lemma runErrorT_throwET [simp]:
  "runErrorT(throwETe) = return(Erre)"
unfolding throwET_def by simp

lemma runErrorT_catchET [simp]:
  "runErrorT(catchETmh) =
    bind(runErrorTm)(Λ n. case n of
      Erre  runErrorT(he) | Okx  return(Okx))"
unfolding catchET_def by simp

lemma runErrorT_fmapET [simp]:
  "runErrorT(fmapETfm) =
    bind(runErrorTm)(Λ n. case n of
      Erre  return(Erre) | Okx  return(Ok(fx)))"
unfolding fmapET_def by simp

subsection ‹Laws›

lemma bindET_unitET [simp]:
  "bindET(unitETx)k = kx"
by (rule errorT_eqI, simp)

lemma catchET_unitET [simp]:
  "catchET(unitETx)h = unitETx"
by (rule errorT_eqI, simp)

lemma catchET_throwET [simp]:
  "catchET(throwETe)h = he"
by (rule errorT_eqI, simp)

lemma liftET_return:
  "liftET(returnx) = unitETx"
by (rule errorT_eqI, simp add: fmap_return)

lemma liftET_bind:
  "liftET(bindmk) = bindET(liftETm)(liftET oo k)"
by (rule errorT_eqI, simp add: fmap_bind bind_fmap)

lemma bindET_throwET:
  "bindET(throwETe)k = throwETe"
by (rule errorT_eqI, simp)

lemma bindET_bindET:
  "bindET(bindETmh)k = bindETm(Λ x. bindET(hx)k)"
apply (rule errorT_eqI)
apply simp
apply (simp add: bind_bind)
apply (rule cfun_arg_cong)
apply (rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict)
apply simp
apply simp
done

lemma fmapET_fmapET:
  "fmapETf(fmapETgm) = fmapET(Λ x. f(gx))m"
by (simp add: fmapET_def bindET_bindET)

text ‹Right unit monad law is not satisfied in general.›

lemma bindET_unitET_right_counterexample:
  fixes m :: "'a('m::monad,'e) errorT"
  assumes "m = ErrorT(return)"
  assumes "return  ( :: ('a'e error)'m)"
  shows "bindETmunitET  m"
by (simp add: errorT_eq_iff assms)

text ‹Right unit is satisfied for inner monads with strict return.›

lemma bindET_unitET_right_restricted:
  fixes m :: "'a('m::monad,'e) errorT"
  assumes "return = ( :: ('a'e error)'m)"
  shows "bindETmunitET = m"
unfolding errorT_eq_iff
apply simp
apply (rule trans [OF _ monad_right_unit])
apply (rule cfun_arg_cong)
apply (rule cfun_eqI)
apply (case_tac x, simp_all add: assms)
done

subsection ‹Error monad transformer invariant›

text ‹This inductively-defined invariant is supposed to represent
the set of all values constructible using the standard errorT›
operations.›

inductive invar :: "'a('m::monad, 'e) errorT  bool"
  where invar_bottom: "invar "
  | invar_lub: "Y. chain Y; i. invar (Y i)  invar (i. Y i)"
  | invar_unitET: "x. invar (unitETx)"
  | invar_bindET: "m k. invar m; x. invar (kx)  invar (bindETmk)"
  | invar_throwET: "e. invar (throwETe)"
  | invar_catchET: "m h. invar m; e. invar (he)  invar (catchETmh)"
  | invar_liftET: "m. invar (liftETm)"

text ‹Right unit is satisfied for arguments built from standard functions.›

lemma bindET_unitET_right_invar:
  assumes "invar m"
  shows "bindETmunitET = m"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp)
apply (simp add: errorT_eq_iff bind_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x, simp add: bind_strict, simp, simp)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind)
done

text ‹Monad-fmap is satisfied for arguments built from standard functions.›

lemma errorT_monad_fmap_invar:
  fixes f :: "'a  'b" and m :: "'a('m::monad,'e) errorT"
  assumes "invar m"
  shows "fmapfm = bindETm(Λ x. unitET(fx))"
using assms
apply (induct set: invar)
apply (rule errorT_eqI, simp add: bind_strict fmap_strict)
apply (rule admD, simp, assumption, assumption)
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply (simp add: fmap_return)
apply simp
apply (rule errorT_eqI, simp add: fmap_return)
apply (simp add: errorT_eq_iff bind_bind fmap_bind)
apply (rule cfun_arg_cong, rule cfun_eqI, simp)
apply (case_tac x)
apply (simp add: bind_strict fmap_strict)
apply simp
apply (simp add: fmap_return)
apply (rule errorT_eqI, simp add: monad_fmap bind_bind return_error_def)
done

subsection ‹Invariant expressed as a deflation›

text ‹We can also define an invariant in a more semantic way, as the
set of fixed-points of a deflation.›

definition invar' :: "'a('m::monad, 'e) errorT  bool"
  where "invar' m  fmapETIDm = m"

text ‹All standard operations preserve the invariant.›

lemma invar'_unitET: "invar' (unitETx)"
  unfolding invar'_def by (simp add: fmapET_def)

lemma invar'_fmapET: "invar' m  invar' (fmapETfm)"
  unfolding invar'_def
  by (erule subst, simp add: fmapET_def bindET_bindET eta_cfun)

lemma invar'_bindET: "invar' m; x. invar' (kx)  invar' (bindETmk)"
  unfolding invar'_def
  by (simp add: fmapET_def bindET_bindET eta_cfun)

lemma invar'_throwET: "invar' (throwETe)"
  unfolding invar'_def by (simp add: fmapET_def bindET_throwET eta_cfun)

lemma invar'_catchET: "invar' m; e. invar' (he)  invar' (catchETmh)"
  unfolding invar'_def
  apply (simp add: fmapET_def eta_cfun)
  apply (rule errorT_eqI)
  apply (simp add: bind_bind eta_cfun)
  apply (rule cfun_arg_cong)
  apply (rule cfun_eqI)
  apply (case_tac x)
  apply (simp add: bind_strict)
  apply simp
  apply (drule_tac x=e in meta_spec)
  apply (erule_tac t="he" in subst) back
  apply (simp add: eta_cfun)
  apply simp
  done

lemma invar'_liftET: "invar' (liftETm)"
  unfolding invar'_def
  apply (simp add: fmapET_def errorT_eq_iff)
  apply (simp add: monad_fmap bind_bind)
  done

lemma invar'_bottom: "invar' "
  unfolding invar'_def fmapET_def
  by (simp add: errorT_eq_iff bind_strict)

lemma adm_invar': "adm invar'"
  unfolding invar'_def [abs_def] by simp

text ‹All monad laws are preserved by values satisfying the invariant.›

lemma bindET_fmapET_unitET:
  shows "bindET(fmapETfm)unitET = fmapETfm"
by (simp add: fmapET_def bindET_bindET)

lemma invar'_right_unit: "invar' m  bindETmunitET = m"
unfolding invar'_def by (erule subst, rule bindET_fmapET_unitET)

lemma invar'_monad_fmap:
  "invar' m  fmapETfm = bindETm(Λ x. unitET(fx))"
  unfolding invar'_def by (erule subst, simp add: errorT_eq_iff)

lemma invar'_bind_assoc:
  "invar' m; x. invar' (fx); y. invar' (gy)
     bindET(bindETmf)g = bindETm(Λ x. bindET(fx)g)"
  by (rule bindET_bindET)

end