Theory State

theory State
imports Stores "HOL-Library.Word"
begin

section ‹Value types›

type_synonym bytes = string
type_synonym id = String.literal

datatype ('a::address) valtype =
  Bool (bool: bool)
| Uint (uint: "256 word")
| Address (ad: 'a)
| Bytes bytes ―‹bytes1, ..., bytes32›

instantiation valtype :: (address) vtype
begin

fun to_nat_valtype::"'a valtype  nat option" where
  "to_nat_valtype (Uint x) = Some (unat x)"
| "to_nat_valtype _ = None"

instance ..

end

section ‹Common functions›

fun lift_bool_unary::"(bool  bool)  ('a::address) valtype  ('a::address) valtype option" where
  "lift_bool_unary op (Bool b) = Some (Bool (op b))"
| "lift_bool_unary _ _ = None"

definition vtnot where
  "vtnot = lift_bool_unary Not"

fun lift_bool_binary::"(bool  bool  bool)  ('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "lift_bool_binary op (Bool l) (Bool r) = Some (Bool (op l r))"
| "lift_bool_binary _ _ _ = None"

definition vtand where
  "vtand = lift_bool_binary (∧)"

definition vtor where
  "vtor = lift_bool_binary (∨)"

fun vtequals where
  "vtequals (Uint l) (Uint r) = Some (Bool (l = r))"
| "vtequals (Address l) (Address r) = Some (Bool (l = r))"
| "vtequals (Bool l) (Bool r) = Some (Bool (l = r))"
| "vtequals (Bytes l) (Bytes r) = Some (Bool (l = r))"
| "vtequals _ _ = None"

fun lift_int_comp::"(256 word  256 word  bool)  ('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "lift_int_comp op (Uint l) (Uint r) = Some (Bool (op l r))"
| "lift_int_comp _ _ _ = None"

definition vtless where
  "vtless = lift_int_comp (<)"

fun lift_int_binary::"(256 word  256 word  256 word)  ('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "lift_int_binary op (Uint l) (Uint r) = Some (Uint (op l r))"
| "lift_int_binary _ _ _ = None"

definition vtplus where
  "vtplus = lift_int_binary (+)"

fun vtplus_safe::"('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "vtplus_safe (Uint l) (Uint r) = (if unat l + unat r < 2^256 then Some (Uint (l + r)) else None)"
| "vtplus_safe _ _ = None"

declare vtplus_safe.simps[simp del]

definition vtminus where
  "vtminus = lift_int_binary (-)"

fun vtminus_safe::"('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "vtminus_safe (Uint l) (Uint r) = (if r  l then Some (Uint (l - r)) else None)"
| "vtminus_safe _ _ = None"

declare vtminus_safe.simps[simp del]

definition vtmult where
  "vtmult = lift_int_binary (*)"

fun vtmult_safe::"('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "vtmult_safe (Uint l) (Uint r) = (if unat l * unat r < 2^256 then Some (Uint (l * r)) else None)"
| "vtmult_safe _ _ = None"

declare vtmult_safe.simps[simp del]

definition vtmod where
  "vtmod = lift_int_binary (mod)"

section ‹Operations on bytes›

(* indexing bytes_n returns bytes_1 *)
fun vtbytes_index :: "('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "vtbytes_index (Bytes xs) (Uint i) = (if unat i < length xs then Some (Bytes [xs ! unat i]) else None)"
| "vtbytes_index _ _ = None"

definition zipWith :: "('a  'b  'c)  'a list  'b list  'c list" where
  "zipWith op xs ys = map (λ (x, y). op x y) (zip xs ys)"

fun lift_bytes_binary::"(char  char  char)  ('a::address) valtype  ('a::address) valtype  ('a::address) valtype option" where
  "lift_bytes_binary op (Bytes l) (Bytes r) = (if length l = length r then Some (Bytes (zipWith op l r)) else None)"
| "lift_bytes_binary _ _ _ = None"

fun lift_bytes_unary::"(char  char)  ('a::address) valtype  ('a::address) valtype option" where
  "lift_bytes_unary op (Bytes l) = Some (Bytes (map op l))"
| "lift_bytes_unary _ _ = None"

definition word8_to_char :: "8 word  char" where
  "word8_to_char w = char_of (unat w)"

definition char_to_word8 :: "char  8 word" where
  "char_to_word8 c = of_nat (of_char c)"

definition op_word8_to_char :: "(8 word  8 word  8 word)  (char  char  char)" where
  "op_word8_to_char op x y = word8_to_char (op (char_to_word8 x) (char_to_word8 y))"

context
  includes bit_operations_syntax
begin

definition vtbytes_and where
  "vtbytes_and = lift_bytes_binary (op_word8_to_char (AND))"

definition vtbytes_or where
  "vtbytes_or = lift_bytes_binary (op_word8_to_char (OR))"

definition vtbytes_xor where
  "vtbytes_xor = lift_bytes_binary (op_word8_to_char (XOR))"

definition vtbytes_not where
  "vtbytes_not = lift_bytes_unary (λ x. word8_to_char ((NOT) (char_to_word8 x)))"

end

fun resize_list :: "nat  'a  'a list  'a list" where
  "resize_list m pad xs =
     (if length xs < m
      then xs @ replicate (m - length xs) pad
      else take m xs)"

fun vtbytes_cast :: "nat  ('a::address) valtype  ('a::address) valtype option" where
  "vtbytes_cast m (Bytes xs) = Some (Bytes (resize_list m (CHR 0x00) xs))"
| "vtbytes_cast m _ = None"

section ‹State›

subsection ‹Definition›

type_synonym 'v stack = "(id, 'v kdata) fmap"
type_synonym 'a balances = "'a  nat"
type_synonym ('a, 'v) storage = "'a  id  'v storage_data"
type_synonym 'v calldata = "(id, 'v call_data) fmap"

record ('a::address) state =
  Memory:: "('a::address valtype) memory"
  Calldata:: "('a::address valtype) calldata"
  Storage:: "('a::address, 'a::address valtype) storage"
  Stack:: "('a::address valtype) stack"
  Balances::"('a::address) balances"

definition sameState where "sameState s s'  state.Stack s' = state.Stack s  state.Memory s' = state.Memory s  state.Calldata s' = state.Calldata s"

subsection ‹Update Function›

datatype ex = Err

definition balances_update:: "('a::address)  nat  ('a::address) state  ('a::address) state" where
  "balances_update i n s = sBalances := (Balances s)(i := n)"

definition calldata_update:: "id  ('a::address valtype) call_data  ('a::address) state  ('a::address) state" where
  "calldata_update i d = Calldata_update (fmupd i d)"

definition stack_update:: "id  ('a::address valtype) kdata  ('a::address) state  ('a::address) state" where
  "stack_update i d = Stack_update (fmupd i d)"

definition memory_update:: "location  ('a::address valtype) mdata  ('a::address) state  ('a::address) state" where
  "memory_update i d s = sMemory := (Memory s)[i := d]"

lemma balances_update_id[simp]: "balances_update x (Balances s x) s = s"
  unfolding balances_update_def by simp

end