Theory MutableRef

section "Semantics of mutable references"

theory MutableRef
  imports Main "HOL-Library.FSet" 

datatype ty = TNat | TFun ty ty (infix "" 60) | TPair ty ty | TRef ty

type_synonym name = nat

datatype exp = EVar name | ENat nat | ELam ty exp | EApp exp exp 
  | EPrim "nat  nat  nat" exp exp | EIf exp exp exp
  | EPair exp exp | EFst exp | ESnd exp
  | ERef exp | ERead exp | EWrite exp exp

subsection "Denotations (values)"
datatype val = VNat nat | VFun "(val × val) fset" | VPair val val | VAddr nat | Wrong

type_synonym func = "(val × val) fset"
type_synonym store = "func"

inductive val_le :: "val  val  bool" (infix "" 52) where
  vnat_le[intro!]: "(VNat n)  (VNat n)" |
  vaddr_le[intro!]: "(VAddr a)  (VAddr a)" | 
  wrong_le[intro!]: "Wrong  Wrong" |
  vfun_le[intro!]: "t1 |⊆| t2  (VFun t1)  (VFun t2)" |
  vpair_le[intro!]: " v1  v1'; v2  v2'   (VPair v1 v2)  (VPair v1' v2')" 

primrec vsize :: "val  nat" where
"vsize (VNat n) = 1" |
"vsize (VFun t) = 1 + ffold (λ((_,v), (_,u)).λr. v + u + r) 0
                            (fimage (map_prod (λ v. (v,vsize v)) (λ v. (v,vsize v))) t)" |
"vsize (VPair v1 v2) = 1 + vsize v1 + vsize v2" |
"vsize (VAddr a) = 1" |
"vsize Wrong = 1" 

subsection "Non-deterministic state monad"

type_synonym 'a M = "store  ('a × store) set"

definition bind :: "'a M  ('a  'b M)  'b M" where
  "bind m f μ1  { (v,μ3).  v' μ2. (v',μ2)  m μ1  (v,μ3)  f v' μ2 }"
declare bind_def[simp]

syntax "_bind" :: "[pttrns,'a M,'b]  'c" ("(_  _;//_)" 0)
translations "P  E; F"  "CONST bind E (λP. F)"

no_notation "binomial" (infixl "choose" 65)

definition choose :: "'a set  'a M" where
  "choose S μ  {(a,μ1). a  S  μ1=μ}"
declare choose_def[simp]
definition return :: "'a  'a M" where
  "return v μ  { (v,μ) }"
declare return_def[simp]

definition zero :: "'a M" where
  "zero μ  {}"
declare zero_def[simp]
definition err_bind :: "val M  (val  val M)  val M" where
  "err_bind m f  (x  m; if x = Wrong then return Wrong else f x)"
declare err_bind_def[simp]

syntax "_errset_bind" :: "[pttrns,val M,val]  'c" ("(_ := _;//_)" 0)
translations "P := E; F"  "CONST err_bind E (λP. F)"

definition down :: "val  val M" where
  "down v μ1  {(v',μ). v'  v  μ = μ1 }"
declare down_def[simp]

definition get_store :: "store M" where
  "get_store μ  { (μ,μ) }"
declare get_store_def[simp]
definition put_store :: "store  unit M" where
  "put_store μ  λ_. { ((),μ) }"
declare put_store_def[simp]

definition mapM :: "'a fset  ('a  'b M)  ('b fset) M" where
  "mapM as f  ffold (λa. λr. (b  f a; bs  r; return (finsert b bs))) (return {||}) as"

definition run :: "store  val M  (val × store) set" where
  "run σ m  m σ"
declare run_def[simp]

definition sdom :: "store  nat set" where
  "sdom μ  {a.  v. (VAddr a,v)  fset μ }"

definition max_addr :: "store  nat" where
  "max_addr μ = ffold (λa.λr. case a of (VAddr n,_)  max n r | _  r) 0 μ"
subsection "Denotational semantics"

abbreviation apply_fun :: "val M  val M  val M" where
  "apply_fun V1 V2  (v1 := V1; v2 := V2;
                       case v1 of VFun f  
                          (p, p')  choose (fset f); μ0  get_store;
                          (case (p,p') of (VPair v (VFun μ), VPair v' (VFun μ')) 
                            if v  v2  (VFun μ)  (VFun μ0) then (_  put_store μ'; down v') 
                            else zero
                          | _  zero)
                       | _  return Wrong)"  

fun nvals :: "nat  (val fset) M" where
  "nvals 0 = return {||}" |
  "nvals (Suc k) = (v  choose UNIV; L  nvals k; return (finsert v L))"

definition vals :: "(val fset) M" where
  "vals  (n  choose UNIV; nvals n)"
declare vals_def[simp]

fun npairs :: "nat  func M" where
  "npairs 0 = return {||}" |
  "npairs (Suc k) = (v  choose UNIV; v'  choose {v::val. True};
                     P  npairs k; return (finsert (v,v') P))"  

definition tables :: "func M" where
  "tables  (n  choose {k::nat. True}; npairs n)"
declare tables_def[simp]

definition read :: "nat  val M" where
  "read a  (μ  get_store; if a  sdom μ then
                           ((v1,v2)  choose (fset μ); if v1 = VAddr a then return v2 else zero)
                       else return Wrong)"
declare read_def[simp]

definition update :: "nat  val  val M" where
  "update a v  (μ  get_store;
                 _  put_store (finsert (VAddr a,v) (ffilter (λ(v,v'). v  VAddr a) μ));
                 return (VAddr a))"
declare update_def[simp]

type_synonym env = "val list"    
fun E :: "exp  env  val M" where
  Enat: "E (ENat n) ρ = return (VNat n)" |
  Evar: "E (EVar n) ρ = (if n < length ρ then down (ρ!n) else return Wrong)" |
  Elam: "E (ELam A e) ρ = (L  vals; 
                           t  mapM L (λ v. (μ  tables;  (v',μ')  choose (run μ (E e (v#ρ))); 
                                     return (VPair v (VFun μ),VPair v' (VFun μ'))));
                           return (VFun t))" |
  Eapp: "E (EApp e1 e2) ρ = apply_fun (E e1 ρ) (E e2 ρ)" |
  Eprim: "E (EPrim f e1 e2) ρ = (v1 := E e1 ρ; v2 := E e2 ρ;
                                 case (v1, v2) of (VNat n1,VNat n2)  return (VNat (f n1 n2))
                                 | _  return Wrong)" |
  Eif: "E (EIf e1 e2 e3) ρ = (v1 := E e1 ρ; case v1 of VNat n  (if n = 0 then E e3 ρ else E e2 ρ)
                                            | _  return Wrong)" |
  Epair: "E (EPair e1 e2) ρ = (v1 := E e1 ρ; v2 := E e2 ρ; return (VPair v1 v2))" |
  Efst: "E (EFst e) ρ = (v:=E e ρ; case v of VPair v1 v2  return v1 | _  return Wrong)" |
  Esnd: "E (ESnd e) ρ = (v:=E e ρ; case v of VPair v1 v2  return v2 | _  return Wrong)" |
  Eref: "E (ERef e) ρ = (v:=E e ρ; μ  get_store; a  choose UNIV;
                         if a  sdom μ then zero
                         else (_  put_store (finsert (VAddr a,v) μ);
                               return (VAddr a)))" |
  Eread: "E (ERead e) ρ = (v := E e ρ; case v of VAddr a  read a | _  return Wrong)" |
  Ewrite: "E (EWrite e1 e2) ρ = (v1 := E e1 ρ; v2 := E e2 ρ;
                                 case v1 of VAddr a  update a v2 | _  return Wrong)"