Theory UnboxedNats

(*<*)
(*
 * The worker/wrapper transformation, following Gill and Hutton.
 * (C)opyright 2009-2011, Peter Gammie, peteg42 at gmail.com.
 * License: BSD
 *)

theory UnboxedNats
imports
  HOLCF
  Nats
  WorkerWrapperNew
begin
(*>*)

section‹Unboxing types.›

text‹The original application of the worker/wrapper transformation
was the unboxing of flat types by citet"SPJ-JL:1991". We can model
the boxed and unboxed types as (respectively) pointed and unpointed
domains in HOLCF. Concretely @{typ "UNat"} denotes the discrete domain
of naturals, @{typ "UNat"} the lifted (flat and pointed) variant, and
@{typ "Nat"} the standard boxed domain, isomorphic to @{typ
"UNat"}. This latter distinction helps us keep the boxed naturals and
lifted function codomains separated; applications of @{term "unbox"}
should be thought of in the same way as Haskell's @{term "newtype"}
constructors, i.e. operationally equivalent to @{term "ID"}.

The divergence monad is used to handle the unboxing, see below.›

subsection‹Factorial example.›

text‹Standard definition of factorial.›

fixrec fac :: "Nat  Nat"
where
  "facn = If n =B 0 then 1 else n * fac(n - 1)"

declare fac.simps[simp del]

lemma fac_strict[simp]: "fac = "
by fixrec_simp

definition
  fac_body :: "(Nat  Nat)  Nat  Nat" where
  "fac_body  Λ r n. If n =B 0 then 1 else n * r(n - 1)"

lemma fac_body_strict[simp]: "fac_bodyr = "
  unfolding fac_body_def by simp

lemma fac_fac_body_eq: "fac = fixfac_body"
  unfolding fac_body_def by (rule cfun_eqI, subst fac_def, simp)

text‹Wrap / unwrap functions. Note the explicit lifting of the
co-domain. For some reason the published version of
citet"GillHutton:2009" does not discuss this point: if we're going to
handle recursive functions, we need a bottom.

@{term "unbox"} simply removes the tag, yielding a possibly-divergent
unboxed value, the result of the function.›

definition
  unwrapB :: "(Nat  Nat)  UNat  UNat" where
  "unwrapB  Λ f. unbox oo f oo box"

text‹Note that the monadic bind operator @{term "(>>=)"} here stands
in for the \textsf{case} construct in the paper.›

definition
  wrapB :: "(UNat  UNat)  Nat  Nat" where
  "wrapB  Λ f x . unboxx >>= f >>= box"

lemma wrapB_unwrapB_body:
  assumes strictF: "f = "
  shows "(wrapB oo unwrapB)f = f" (is "?lhs = ?rhs")
proof(rule cfun_eqI)
  fix x :: Nat
  have "?lhsx = unboxx >>= (Λ x'. unwrapBfx' >>= box)"
    unfolding wrapB_def by simp
  also have " = unboxx >>= (Λ x'. unbox(f(boxx')) >>= box)"
    unfolding unwrapB_def by simp
  also from strictF have " = fx" by (cases x, simp_all)
  finally show "?lhsx = ?rhsx" .
qed

text‹Apply worker/wrapper.›

definition
  fac_work :: "UNat  UNat" where
  "fac_work  fix(unwrapB oo fac_body oo wrapB)"

definition
  fac_wrap :: "Nat  Nat" where
  "fac_wrap  wrapBfac_work"

lemma fac_fac_ww_eq: "fac = fac_wrap" (is "?lhs = ?rhs")
proof -
  have "wrapB oo unwrapB oo fac_body = fac_body"
    using wrapB_unwrapB_body[OF fac_body_strict]
    by - (rule cfun_eqI, simp)
  thus ?thesis
    using worker_wrapper_body[where computation=fac and body=fac_body and wrap=wrapB and unwrap=unwrapB]
    unfolding fac_work_def fac_wrap_def by (simp add: fac_fac_body_eq)
qed

text‹This is not entirely faithful to the paper, as they don't
explicitly handle the lifting of the codomain.›

definition
  fac_body' :: "(UNat  UNat)  UNat  UNat" where
  "fac_body'  Λ r n.
     unbox(If boxn =B 0
              then 1
              else unbox(boxn - 1) >>= r >>= (Λ b. boxn * boxb))"

lemma fac_body'_fac_body: "fac_body' = unwrapB oo fac_body oo wrapB" (is "?lhs = ?rhs")
proof(rule cfun_eqI)+
  fix r x
  show "?lhsrx = ?rhsrx"
    using bbind_case_distr_strict[where f="Λ y. boxx * y" and g="unbox(boxx - 1)"]
          bbind_case_distr_strict[where f="Λ y. boxx * y" and h="box"]
    unfolding fac_body'_def fac_body_def unwrapB_def wrapB_def by simp
qed

text‹The @{term "up"} constructors here again mediate the
isomorphism, operationally doing nothing. Note the switch to the
machine-oriented \emph{if} construct: the test @{term "n = 0"} cannot
diverge.›

definition
  fac_body_final :: "(UNat  UNat)  UNat  UNat" where
  "fac_body_final  Λ r n.
     if n = 0 then up1 else r(n -# 1) >>= (Λ b. up(n *# b))"

lemma fac_body_final_fac_body': "fac_body_final = fac_body'" (is "?lhs = ?rhs")
proof(rule cfun_eqI)+
  fix r x
  show "?lhsrx = ?rhsrx"
    using bbind_case_distr_strict[where f="unbox" and g="r(x -# 1)" and h="(Λ b. box(x *# b))"]
    unfolding fac_body_final_def fac_body'_def uMinus_def uMult_def zero_Nat_def one_Nat_def
    by simp
qed

definition
  fac_work_final :: "UNat  UNat" where
  "fac_work_final  fixfac_body_final"

definition
  fac_final :: "Nat  Nat" where
  "fac_final  Λ n. unboxn >>= fac_work_final >>= box"

lemma fac_fac_final: "fac = fac_final" (is "?lhs=?rhs")
proof -
  have "?lhs = fac_wrap" by (rule fac_fac_ww_eq)
  also have " = wrapBfac_work" by (simp only: fac_wrap_def)
  also have " = wrapB(fix(unwrapB oo fac_body oo wrapB))" by (simp only: fac_work_def)
  also have " = wrapB(fixfac_body')" by (simp only: fac_body'_fac_body)
  also have " = wrapBfac_work_final" by (simp only: fac_body_final_fac_body' fac_work_final_def)
  also have " = fac_final" by (simp add: fac_final_def wrapB_def)
  finally show ?thesis .
qed

(* **************************************** *)

subsection‹Introducing an accumulator.›

text‹

The final version of factorial uses unboxed naturals but is not
tail-recursive. We can apply worker/wrapper once more to introduce an
accumulator, similar to \S\ref{sec:accum}.

The monadic machinery complicates things slightly here. We use
\emph{Kleisli composition}, denoted @{term "(>=>)"}, in the
homomorphism.

Firstly we introduce an ``accumulator'' monoid and show the
homomorphism.

›

type_synonym UNatAcc = "UNat  UNat"

definition
  n2a :: "UNat  UNatAcc" where
  "n2a  Λ m n. up(m *# n)"

definition
  a2n :: "UNatAcc  UNat" where
  "a2n  Λ a. a1"

lemma a2n_strict[simp]: "a2n = "
  unfolding a2n_def by simp

lemma a2n_n2a: "a2n(n2au) = upu"
  unfolding a2n_def n2a_def by (simp add: uMult_arithmetic)

lemma A_hom_mult: "n2a(x *# y) = (n2ax >=> n2ay)"
  unfolding n2a_def bKleisli_def by (simp add: uMult_arithmetic)

definition
  unwrapA :: "(UNat  UNat)  UNat  UNatAcc" where
  "unwrapA  Λ f n. fn >>= n2a"

lemma unwrapA_strict[simp]: "unwrapA = "
  unfolding unwrapA_def by (rule cfun_eqI) simp

definition
  wrapA :: "(UNat  UNatAcc)  UNat  UNat" where
  "wrapA  Λ f. a2n oo f"

lemma wrapA_unwrapA_id: "wrapA oo unwrapA = ID"
  unfolding wrapA_def unwrapA_def
  apply (rule cfun_eqI)+
  apply (case_tac "xxa")
  apply (simp_all add: a2n_n2a)
  done

text‹Some steps along the way.›

definition
  fac_acc_body1 :: "(UNat  UNatAcc)  UNat  UNatAcc" where
  "fac_acc_body1  Λ r n.
     if n = 0 then n2a1 else wrapAr(n -# 1) >>= (Λ res. n2a(n *# res))"

lemma fac_acc_body1_fac_body_final_eq: "fac_acc_body1 = unwrapA oo fac_body_final oo wrapA"
  unfolding fac_acc_body1_def fac_body_final_def wrapA_def unwrapA_def
  by (rule cfun_eqI)+ simp

text‹Use the homomorphism.›

definition
  fac_acc_body2 :: "(UNat  UNatAcc)  UNat  UNatAcc" where
  "fac_acc_body2  Λ r n.
     if n = 0 then n2a1 else wrapAr(n -# 1) >>= (Λ res. n2an >=> n2ares)"

lemma fac_acc_body2_body1_eq: "fac_acc_body2 = fac_acc_body1"
  unfolding fac_acc_body1_def fac_acc_body2_def
  by (rule cfun_eqI)+ (simp add: A_hom_mult)

text‹Apply worker/wrapper.›

definition
  fac_acc_body3 :: "(UNat  UNatAcc)  UNat  UNatAcc" where
  "fac_acc_body3  Λ r n.
     if n = 0 then n2a1 else n2an >=> r(n -# 1)"

lemma fac_acc_body3_body2: "fac_acc_body3 oo (unwrapA oo wrapA) = fac_acc_body2" (is "?lhs=?rhs")
proof(rule cfun_eqI)+
  fix r n acc
  show "((fac_acc_body3 oo (unwrapA oo wrapA))rnacc) = fac_acc_body2rnacc"
    unfolding fac_acc_body2_def fac_acc_body3_def unwrapA_def
    using bbind_case_distr_strict[where f="Λ y. n2an >=> y" and h="n2a", symmetric]
    by simp
qed

lemma fac_work_final_body3_eq: "fac_work_final = wrapA(fixfac_acc_body3)"
  unfolding fac_work_final_def
  by (rule worker_wrapper_fusion_new[OF wrapA_unwrapA_id unwrapA_strict])
     (simp add: fac_acc_body3_body2 fac_acc_body2_body1_eq fac_acc_body1_fac_body_final_eq)

definition
  fac_acc_body_final :: "(UNat  UNatAcc)  UNat  UNatAcc" where
  "fac_acc_body_final  Λ r n acc.
     if n = 0 then upacc else r(n -# 1)(n *# acc)"

definition
  fac_acc_work_final :: "UNat  UNat" where
  "fac_acc_work_final  Λ x. fixfac_acc_body_finalx1"

lemma fac_acc_work_final_fac_acc_work3_eq: "fac_acc_body_final = fac_acc_body3" (is "?lhs=?rhs")
  unfolding fac_acc_body3_def fac_acc_body_final_def n2a_def bKleisli_def
  by (rule cfun_eqI)+
     (simp add: uMult_arithmetic)

lemma fac_acc_work_final_fac_work: "fac_acc_work_final = fac_work_final" (is "?lhs=?rhs")
proof -
  have "?rhs = wrapA(fixfac_acc_body3)" by (rule fac_work_final_body3_eq)
  also have " = wrapA(fixfac_acc_body_final)"
    using fac_acc_work_final_fac_acc_work3_eq by simp
  also have " = ?lhs"
    unfolding fac_acc_work_final_def wrapA_def a2n_def
    by (simp add: cfcomp1)
  finally show ?thesis by simp
qed

(*<*)
end
(*>*)