Theory Code_Setup

theory Code_Setup
  imports 
    "HOL-Library.IArray"
    "HOL-Data_Structures.Array_Braun"
    "HOL-Data_Structures.RBT_Map"

    "../MDP_fin" 
    "../Value_Iteration" 
    
    "./lib/DiffArray_ST"

begin

context MDP_nat_disc begin
lemma L_zero:
  assumes "s. s  states  apply_bfun v s = 0" "s  states"
  shows "L d v s = 0"
  using assms
  proof (induction s rule: less_induct)
    case (less x)
    moreover have "r (x, a) = 0" if "a  A x" for a
      by (simp add: less.prems reward_zero_outside)
    moreover have "measure_pmf.expectation (K (x, a)) v = 0" for a
      using K_closed_compl assms less 
      by (blast intro: integral_eq_zero_AE AE_pmfI)
    ultimately show ?case
      by (auto simp: A_ne A_fin L_eq_La reward_zero_outside)
    qed

lemma b_zero:
  assumes "s. s  states  apply_bfun v s = 0" "s  states"
  shows "b v s = 0"
  using assms
  proof (induction s rule: less_induct)
    case (less x)
    have "r (x, a) = 0" if "a  A x" for a
      by (simp add: less.prems reward_zero_outside)
    moreover have "measure_pmf.expectation (K (x, a)) v = 0" for a
      using K_closed_compl assms less 
      by (blast intro: integral_eq_zero_AE AE_pmfI)
      ultimately show ?case
        by (auto simp: A_ne A_fin b_eq_La_max')
    qed
end

lemma max_geI: "finite A  A  {}  (aA. x  a)  (x  Max A)" for x A
  by (simp add: Max_ge_iff)

section ‹Least argmax›

fun "least_arg_max_max_ne" where
  "least_arg_max_max_ne f (x#xs) = 
  (fold (λy (am, m). let fy = f y in 
   if m < fy then (y, fy) else (am, m)) xs (x, f x))" |
  "least_arg_max_max_ne a [] = undefined"

fun "least_arg_max_ne" where
"least_arg_max_ne f (x#xs) = fst (least_arg_max_max_ne f (x#xs))" |
"least_arg_max_ne a [] = undefined"

lemmas 
  least_arg_max_ne.simps[simp del]
  least_arg_max_max_ne.simps[simp del]

lemma least_arg_max_max_ne_Cons: "least_arg_max_max_ne f (x#y#xs) = 
  (if f x < f y then least_arg_max_max_ne f (y#xs) else least_arg_max_max_ne f (x#xs))"
  by (auto simp: least_arg_max_max_ne.simps)

lemma least_arg_max_max_ne_Cons1: "f x < f y  least_arg_max_max_ne f (x#y#xs) = least_arg_max_max_ne f (y#xs)"
  by (auto simp: least_arg_max_max_ne.simps)

lemma least_arg_max_max_ne_Cons2: "¬ f x < f y  least_arg_max_max_ne f (x#y#xs) = least_arg_max_max_ne f (x#xs)"
  by (auto simp: least_arg_max_max_ne.simps)

lemma Max_insert_absorb: "finite X  (y  X. x  y)  Max (Set.insert x X) = (if X = {} then x else Max X)"
  by (simp add: Max_ge_iff)

lemma Max_insert_absorb': "finite X  yX  x  y  Max (Set.insert x X) = (if X = {} then x else Max X)"
  using Max_insert_absorb
  by blast

lemma fold_max_eq_arg_max:
  assumes "sorted (x#xs)"
  shows "least_arg_max_max_ne f (x#xs) = (least_arg_max f (List.member (x#xs)), Max (f ` set (x#xs)))"
  using assms
proof (induction xs arbitrary: x)
  case Nil
  then show ?case
    by (auto simp:  List.member_def least_arg_max_def least_arg_max_max_ne.simps is_arg_max_def intro!: Least_equality[symmetric])
next
  case (Cons a xs)
  then show ?case
  proof (cases "is_arg_max f (List.member (x#a#xs)) x")
    case True
    have 1: "least_arg_max f (List.member (x#a#xs)) = x" 
      using True Cons
      unfolding least_arg_max_def 
      by (fastforce intro!: Least_equality simp: in_set_member[symmetric])
    have 2: "Max (f ` set (x#a#xs)) = f x"
      using True unfolding is_arg_max_def 
      by (subst Max_eq_iff) (auto simp add: not_less in_set_member member_rec(1))
    show ?thesis
      unfolding 1 2
      using True
      by (induction xs) (auto simp: least_arg_max_max_ne.simps simp: is_arg_max_linorder member_rec)+
  next
    case False
    have "is_arg_max f (List.member (x#a#xs)) = is_arg_max f (List.member (a#xs))"
      using False by (fastforce simp: least_arg_max_max_ne.simps is_arg_max_linorder member_rec)
    hence 1: "least_arg_max f (List.member (x#a#xs)) = least_arg_max f (List.member (a#xs))"
      using Cons False unfolding least_arg_max_def by auto
    have "f a  f x  is_arg_max f (List.member (x#xs)) = is_arg_max f (List.member xs)"
      using False by (fastforce simp: is_arg_max_linorder member_rec)   
    hence 4: "f a  f x  least_arg_max f (List.member (x#xs)) = least_arg_max f (List.member xs)"
      using Cons False unfolding least_arg_max_def by auto
    have "f a  f x  is_arg_max f (List.member (a#xs)) = is_arg_max f (List.member xs)"
      using False by (fastforce simp: is_arg_max_linorder  member_rec(1))
    hence 3: "f a  f x  least_arg_max f (List.member (a#xs)) = least_arg_max f (List.member xs)"
      using Cons False unfolding least_arg_max_def by auto
    have 2: "Max (f`set (x#a#xs)) = Max (f`set (a#xs))"
      using False 
      by (fastforce simp: nle_le in_set_member is_arg_max_linorder Max_ge_iff simp: member_rec intro!:  max_absorb2 )  
    have 5: "Max (f`set (a#xs)) = Max (f`set (xs))  Max (f`set (x#xs)) = Max (f`set (xs))" if "f a  f x"
      using False that
      by (cases "xs = []") (auto simp: nle_le is_arg_max_linorder in_set_member[symmetric] intro: order.trans intro!:max_absorb2)
    show ?thesis
      unfolding least_arg_max_max_ne_Cons 1 2 using Cons 5 3 4 by auto
    qed
  qed

lemma least_arg_max_ne_correct:
  assumes "sorted (x#xs)"
  shows "least_arg_max_ne (f :: _  'b ::linorder) (x#xs) = least_arg_max f (List.member (x#xs))"
  using assms
  by (auto simp: fold_max_eq_arg_max least_arg_max_ne.simps)

lemma least_arg_max_ne_cong: 
  assumes "x. x  set xs  g x = f x"
  shows "least_arg_max_max_ne f xs = least_arg_max_max_ne g xs"
proof (cases xs)
  case Nil
  then show ?thesis 
    by (metis least_arg_max_max_ne.elims list.discI)
next
  case (Cons a list)
  then show ?thesis 
  using assms
  by (auto simp: least_arg_max_max_ne.simps intro!: List.fold_cong)
qed

lemma least_arg_max_max_ne_app:
  assumes "y. y  set (x#xs)  f' (g y) = (f y)"
  shows "(case (least_arg_max_max_ne f (x#xs)) of (a, m)  (g a, m)) = least_arg_max_max_ne f' (map g (x#xs))"
  using assms
proof (induction xs arbitrary: x)
  case Nil
  then show ?case
    by (auto simp: least_arg_max_max_ne.simps)
next
  case (Cons a xs)
    thus ?case
      by (cases "f x <  f a") (auto simp: least_arg_max_max_ne_Cons1 least_arg_max_max_ne_Cons2)
  qed

lemma least_arg_max_max_ne_app':
  assumes "y. y  set xs  f' (g y) = (f y)" "xs  []"
  shows "(case (least_arg_max_max_ne f xs) of (a, m)  (g a, m)) = least_arg_max_max_ne f' (map g xs)"
  using assms
  by (cases xs) (auto intro!: least_arg_max_max_ne_app[simplified])

lemma fold_max_eq_arg_max': "xs  []  sorted xs  least_arg_max_max_ne f xs = (least_arg_max f (List.member xs), Max (f ` set xs))"
  using fold_max_eq_arg_max by (metis list.exhaust)

lemma least_arg_max_cong: "(x. P x  f x = g x)  least_arg_max f P = least_arg_max g P"
  unfolding least_arg_max_def using is_arg_max_cong' by metis

lemma least_arg_max_cong': "P = Q   (x. P x  f x = g x)  least_arg_max f P = least_arg_max g Q"
  unfolding least_arg_max_def using is_arg_max_cong' by metis

section ‹Congruence rule for fold›

lemma fold_cong':
  assumes "(x acc. P acc  x  set xs  f x acc = g x acc  P (f x acc))" "P a"
  shows "fold f xs a = fold g xs a"
  using assms
proof (induction xs arbitrary: a)
  case (Cons a xs y)
  show ?case
    using Cons(2)[OF Cons(3), of a] 
    by (auto intro!: Cons(2) intro: Cons.IH)
qed auto


section ‹MDP type›

datatype MDP = MDP (disc: real) (states: nat) 
  (transitions: "(((nat × (real × ((nat × real) list))) RBT.rbt)) iarray")

abbreviation "is_MDP_states mdp  
  IArray.length (transitions mdp) = states mdp"

abbreviation "is_MDP_actions mdp  IArray.all (λt. 
  rbt t  
  sorted1 (Tree2.inorder t)  
  t  empty  
  ((_, _, probs)  set (inorder t). sum_list (map snd probs) = 1
     (list_all (λ(s, p). p  0  s<states mdp) probs))) (transitions mdp)"

abbreviation "is_MDP_disc mdp  (0  disc mdp  disc mdp < 1)"

definition is_MDP :: "MDP  bool"
  where "is_MDP mdp  is_MDP_states mdp  is_MDP_disc mdp  is_MDP_actions mdp"

definition "trivial_MDP = MDP 0 0 (IArray [])"

lemma trivial_MDP: "is_MDP trivial_MDP"
  unfolding trivial_MDP_def is_MDP_def by auto

typedef Valid_MDP = "{mdp. is_MDP mdp}"
  using trivial_MDP by auto

setup_lifting type_definition_Valid_MDP

definition "error_mdp = trivial_MDP"

declare [[code abort: error_mdp]]

lift_definition to_valid_MDP :: "MDP  Valid_MDP" is 
  "λmdp. if is_MDP mdp then mdp else Code.abort (STR ''not an MDP'') (λ_. trivial_MDP)" 
  by (simp add: trivial_MDP_def is_MDP_def)

context Map_by_Ordered begin
lemmas map_specs(5)[intro]

lemma map_of_Some_in_set: "AList_Upd_Del.map_of xs k = Some v  (k, v)  set xs"
  by (induction xs) (auto split: if_splits)

lemma map_of_None_notin_set: "AList_Upd_Del.map_of xs k = None  k  fst ` set xs"
  by (induction xs) (fastforce split: if_splits)+

definition "entries m = set (inorder m)"
definition "keys m = fst ` set (inorder m)"

lemma lookup_some_set_a_inorder: 
  assumes "invar m" "lookup m x = Some y" 
  shows "(x, y)  entries m"
  using inorder_lookup assms map_of_Some_in_set invar_def entries_def by metis

lemma lookup_None_set_inorder: 
  assumes "invar m" "lookup m x = None" 
  shows "x  keys m"
  using assms inorder_lookup map_of_None_notin_set keys_def invar_def by metis

lemma entries_imp_keys[intro]: "(x,y)  entries m  x  keys m"
  unfolding keys_def entries_def by force

lemma lookup_some_set_key: "invar m  lookup m x = Some y  x  keys m"
  using lookup_some_set_a_inorder by force

lemma lookup_in_keys: "invar m  x  keys m  y. lookup m x = Some y"
  using lookup_None_set_inorder by auto

lemma lookup_notin_keys: "invar m  x  keys m  lookup m x = None"
  by (meson lookup_some_set_key not_Some_eq)

lemma inorder_delete: "invar m  inorder m = kv#xs  inorder ((delete (fst kv) m)) = xs"
  unfolding invar_def
  using AList_Upd_Del.del_list.simps(2)[of _ "fst kv" "snd kv"]
  by (simp add: local.inorder_delete)

lemma inorder_lookup_Some: "invar m  (k, v)  entries m  lookup m k = Some v"
  unfolding entries_def
proof (induction "inorder m" arbitrary: m)
  case Nil thus ?case by auto 
next
  case (Cons a x)
  show ?case
  proof (cases "a = (k,v)")
    case True
    then show ?thesis 
    using inorder_lookup Cons AList_Upd_Del.map_of.simps(2) invar_def by metis
  next
    case False
    have "lookup (delete (fst a) m) k = Some v"
      using False Cons(2)[symmetric] Cons(3-4)
      by (fastforce simp: inorder_delete map_specs intro!: Cons(1))
    then show ?thesis
      by (metis map_delete fun_upd_other fun_upd_same Cons(3) option.distinct(1))
  qed
qed

lemma keys_eq_lookup_Some: "invar m  keys m = {k. v. lookup m k = Some v}"
  using lookup_some_set_key lookup_in_keys by auto

lemma keys_eq_fst_entries: "invar m  keys m = fst ` entries m"
  unfolding entries_def keys_def by auto

lemma keys_update[simp]: "invar m  keys (update k v m) = Set.insert k (keys m)"
  by (subst keys_eq_lookup_Some) (auto simp add: lookup_notin_keys lookup_in_keys map_specs split: if_splits)

definition "is_empty t  inorder t = []"

lemma is_empty_iff_entries_empty: "is_empty t  entries t = {}"
  by (simp add: entries_def is_empty_def)

lemma is_empty_iff_keys_empty: "is_empty t  keys t = {}"
  by (simp add: keys_def is_empty_def)

lemma finite_keys: "finite (keys t)"
  by (simp add: keys_def)

lemma finite_entries: "finite (entries t)"
  by (simp add: entries_def)

lemma keys_empty[simp]: "keys empty = {}"
  by (auto simp: keys_def inorder_empty)

definition "lookup' m k = the (lookup m k)"

section ‹Converting Lists to Maps›

definition "from_list' f xs = foldl (λacc s. update s (f s) acc) empty xs"
definition "from_list xs = foldl (λacc (k,v). update k v acc) empty xs"

                                                                
lemmas invar_empty[simp, intro]

lemma from_list_invar[simp]: "invar (from_list' f xs)"
proof -
  have "invar t  invar (foldl (λacc s. update s (f s) acc) t xs)" for t
    by (induction xs arbitrary: t) auto
  thus ?thesis
    unfolding from_list'_def by auto
qed

lemma from_list_snoc[simp]: "(from_list' f (xs @[y])) = update y (f y) (from_list' f xs)"
  by (auto simp: from_list'_def)

lemma from_list_empty[simp]: "from_list' f [] = empty"
  unfolding from_list'_def by simp

lemma from_list_keys[simp]: "keys (from_list' f xs) = set xs"
  by (induction xs rule: List.rev_induct) (auto simp: map_update)

lemma from_list_lookup[simp]: "x  set xs  lookup (from_list' f xs)  x = Some (f x)"
  by (induction xs rule: List.rev_induct) (auto simp: map_update)

lemma from_list_lookup'[simp]: "x  set xs  lookup' (from_list' f xs)  x = f x"
  unfolding lookup'_def
  using from_list_lookup
  by auto

lemma from_list_snoc'[simp]: "(from_list (xs @[(k,v)])) = update k v (from_list xs)"
  by (auto simp: from_list_def)

lemma from_list_invar'[simp]: "invar (from_list xs)"
proof -
  have "invar t  invar (foldl (λacc (k,v). update k v acc) t xs)" for t
    by (induction xs arbitrary: t) auto
  thus ?thesis
    unfolding from_list_def by auto
qed

lemma lookup_from_list_distinct: "(x,y)  set xs  distinct (map fst xs)  lookup (from_list xs) x = Some y"
  by (induction xs  arbitrary: x y rule: List.rev_induct) (auto simp: rev_image_eqI map_update)

lemma lookup'_from_list_distinct: "(x,y)  set xs  distinct (map fst xs)  lookup' (from_list xs) x = y"
  using lookup_from_list_distinct unfolding lookup'_def
  by auto

lemma distinct_inorder: "invar m  distinct (map fst (inorder m))"
  using invar_def strict_sorted_iff by blast

lemmas map_empty[simp]

lemma from_list_lookup_notin[simp]: "x  set xs  lookup (from_list' f xs)  x = None"
  by (induction xs rule: List.rev_induct) (auto simp: map_update)
end

locale Map_by_Ordered_nat_zero = Map_by_Ordered empty update delete lookup inorder inv' for empty and update :: "nat  ('a::zero)  't  't" and delete lookup inorder inv'
begin

definition map_to_fun :: "'t  nat  'a" where 
  "map_to_fun m n = (if invar m then case lookup m n of None  0 | Some r  r else 0)"

lemma map_to_fun_update: "invar m  (map_to_fun (update k v m)) = (map_to_fun m)(k := v)"
  by (fastforce simp: map_to_fun_def map_update)
end

locale Map_by_Ordered_nat_real = Map_by_Ordered empty update delete lookup inorder inv' for empty and update :: "nat  real  't  't" and delete lookup inorder inv'
begin

lift_definition map_to_bfun :: "'t  nat b real" is 
  "λm n. if invar m then case lookup m n of None  0 | Some r  r else 0"
proof -
  fix t
  show "(λn. if invar t then case lookup t n of None  0 | Some r  r else 0)  bfun"
  proof (cases "is_empty t  ¬ invar t")
    case True
    then show ?thesis 
      by (auto simp add: is_empty_iff_keys_empty lookup_notin_keys)
  next
    case False
    have "norm (case lookup t x of None  0 | Some r  r)  (MAX a  entries t. abs (snd a))" for x
       using False is_empty_iff_entries_empty lookup_some_set_a_inorder[of t x]
       by (fastforce simp: Max_ge_iff finite_entries split: option.splits)
     thus ?thesis
       using False by (intro bfun_normI) auto
   qed
 qed

lemma map_to_bfun_update: "invar m  apply_bfun (map_to_bfun (update k v m)) = (map_to_bfun m)(k := v)"
  by (fastforce simp: map_to_bfun.rep_eq map_update)
  
end

locale Array' = Array +
  assumes lookup_array: "i < length xs  lookup (array xs) i = xs ! i"

locale Array_real = Array' lookup update len array list invar for lookup :: "'t  nat  real" and update len array list invar
begin

lift_definition map_to_bfun :: "'t  nat b real" is 
  "λm n. if invar m  n < len m then lookup m n else 0"
  using bounded_const by fastforce

lemma map_to_bfun_update: 
  assumes "invar m" "k < len m"
  shows "apply_bfun (map_to_bfun (update k v m)) = (map_to_bfun m)(k := v)"
  using assms
  by (auto simp: invar_update map_to_bfun.rep_eq len_array lookup update)
end

locale Array_zero = Array' lookup update len array list invar for lookup :: "'t  nat  'a::zero" and update len array list invar
begin

definition map_to_fun :: "'t  nat  'a" where 
  "map_to_fun m n = (if invar m  n < len m then lookup m n else 0)"

lemma map_to_fun_update: "invar m  k < len m  (map_to_fun (update k v m)) = (map_to_fun m)(k := v)"
  by (auto simp: invar_update map_to_fun_def len_array lookup update)

end

context Array' begin
lemma lookup_in_list: "invar m  x < len m  lookup m x  set (list m)"
  using lookup len_array
  by auto

definition "arr_tabulate f n = array (map f [0..<n])"

lemma invar_tabulate[simp]: "invar (arr_tabulate f n)"
  by (auto simp: arr_tabulate_def invar_array)


lemma len_tabulate[simp]: "len (arr_tabulate f n) = n"
  using arr_tabulate_def array invar_tabulate len_array by auto

lemma lookup_tabulate[simp]: "i < n  lookup (arr_tabulate f n) i = f i"
  by (simp add: arr_tabulate_def lookup_array)

lemmas invar_update[intro]
end

lemma foldr_Cons[simp]: "foldr (#) xs ys = xs@ys"
  by (induction xs) auto

interpretation starray_Array: 
  Array' starray_get "λi x arr. starray_set arr i x" starray_length starray_of_list 
    "λarr. starray_foldr (λx xs. x # xs) arr []" "λ_. True"
  by standard auto

definition "starray_to_list a = tabulate (starray_length a) (starray_get a)"


lemma set_pmf_of_list:
  assumes "pmf_of_list_wf ps" 
  shows "set_pmf (pmf_of_list ps) = {a | a b. (a,b)  set ps  b  0}"
proof safe
  fix x 
  assume "x  set_pmf (pmf_of_list ps)"
  hence "sum_list (map snd (filter (λz. fst z = x) ps))  0"
    using assms
    by (auto simp: set_pmf_eq pmf_pmf_of_list)
  hence "y  set (map snd (filter (λz. fst z = x) ps)). y  0"
    by (metis map_idI sum_list_0)
  then obtain sp where "snd sp  0" "fst sp = x" "sp  set ps"
    by auto
  thus "a b. x = a  (a, b)  set ps  b  0"
    by force
next
  fix x a b
  assume h: "(a, b)  set ps" "b  0"
  have "(Set.insert a X) = a +  (X-{a})" if "finite X" for X and a :: real
    using that 
    by (meson sum.insert_remove)
  hence *: "b  set ps. b  0  b  set ps   b  sum_list ps" for ps
    by (induction ps ) (auto intro!: sum_list_nonneg)
  have "pmf (pmf_of_list ps) a  b"
    using assms  (a, b)  set ps 
    by (fastforce simp add: image_iff pmf_pmf_of_list pmf_of_list_wf_def intro!: * )
  thus "a  set_pmf (pmf_of_list ps)"
    unfolding set_pmf_iff
    using h assms pmf_of_list_wf_def by fastforce
qed

lemma set_pmf_of_list':
  assumes "pmf_of_list_wf ps" 
  shows "set_pmf (pmf_of_list ps) = {a | a b. (a,b)  set ps  b > 0}"
  unfolding set_pmf_of_list[OF assms]
  using assms unfolding pmf_of_list_wf_def
  by fastforce

locale MDP_Code_raw =
  S_Map : Array' "s_lookup :: 'ts  nat  'ta"  s_update s_len s_array s_list s_invar +
  A_Map : Map_by_Ordered a_empty "a_update :: nat  (real × ((nat × real) list))  'ta  'ta" a_delete a_lookup a_inorder a_inv
  for s_lookup s_update s_len s_array s_list s_invar
  and a_empty a_update a_delete a_lookup a_inorder a_inv +
fixes
  mdp :: 'ts and
  states :: nat
assumes
  s_invar: "s_invar mdp" and
  s_len: "s_len mdp = states" and
  A_inv_locale: "am  set (s_list mdp). A_Map.invar am" and
  A_ne_locale: "am  set (s_list mdp). ¬ A_Map.is_empty am" and
  K_closed_locale: 
  "am  set (s_list mdp). (_, _, p)  A_Map.entries am. 
    list_all (λ(s', p). s' <states) p" and
  lists_are_pmfs: "am  set (s_list mdp). (_, _, p)  A_Map.entries am. pmf_of_list_wf p"
begin

definition "a_lookup' m x = (
  case (a_lookup m x) of 
   Some v  v
  | None  Code.abort (STR ''MDP is missing action information'') (λ_. undefined))" 

definition "MDP_A s = (if s < states then A_Map.keys (s_lookup mdp s) else {0})"

definition "MDP_r sa = (if fst sa  states then 0 else
  let a_map = s_lookup mdp (fst sa) in 
  (case a_lookup a_map (snd sa) of Some (r, _)  r | None  0)
)"

definition "MDP_K sa = (
  if fst sa  states then 
    return_pmf (fst sa) 
  else 
    let a_map = s_lookup mdp (fst sa) in (
      case a_lookup a_map (snd sa) of 
        Some (_, p)  pmf_of_list p 
      | None  return_pmf (fst sa))
)"

lemma MDP_r_zero_notin_states: "s  states   MDP_r (s, a) = 0" for s a
  unfolding MDP_r_def
  by auto


lemma a_lookup_some_in_A: "s < states  a_lookup (s_lookup mdp s) a = Some (aa, b)  a  MDP_A s"
  using A_Map.lookup_some_set_key A_inv_locale S_Map.lookup_in_list s_len s_invar
  by (simp add: A_Map.keys_def MDP_A_def)

lemma a_lookup_None_notin_A: "s < states  a_lookup (s_lookup mdp s) a = None  a  MDP_A s"
  unfolding MDP_A_def
  using A_Map.lookup_None_set_inorder A_inv_locale S_Map.lookup_in_list s_invar s_len
  by auto

lemma MDP_r_zero_notin_A: "s < states  a  MDP_A s   MDP_r (s, a) = 0" for s a
  using a_lookup_some_in_A
  by (auto split: option.splits simp: MDP_r_def)

lemma MDP_r_in_A_eq: "s < states  a  MDP_A s  MDP_r (s, a) = fst ((a_lookup' (s_lookup mdp s) a))"
  using a_lookup_None_notin_A by (auto split: option.splits simp: a_lookup'_def MDP_r_def)

lemma range_MDP_r_subs: "range (MDP_r)  {0}  {fst ((a_lookup' (s_lookup mdp s) a)) | s a. s < states  a  MDP_A s}"
  using MDP_r_in_A_eq MDP_r_zero_notin_A MDP_r_zero_notin_states 
  by (auto) (metis not_le)

lemma finite_MDP_A[simp]: "finite (MDP_A s)"
  unfolding MDP_A_def
  by (simp add: A_Map.finite_keys)

lemma finite_sa: "finite {(s,a). s < states  a  MDP_A s}"
proof -
  have "{(s,a). s < states  a  MDP_A s}  {(s,a). s < states  a  (s < states. MDP_A s)}"
    by auto
  moreover have "finite {(s,a). s < states  a  (s < states. MDP_A s)}"
    by auto
  ultimately show ?thesis
    using finite_subset by blast
qed

lemma finite_r_lookup: "finite {fst ((a_lookup' (s_lookup mdp s) a)) | s a. s < states  a  MDP_A s}"
proof -
  have aux: "{fst ((a_lookup' (s_lookup mdp s) a)) | s a. s < states  a  MDP_A s} = {fst ((a_lookup' (s_lookup mdp (fst sa)) (snd sa))) | sa. fst sa < states  snd sa  MDP_A (fst sa)}"
    by auto
  show ?thesis
    unfolding aux
    using finite_sa
    by (fastforce intro!: finite_image_set simp: case_prod_unfold)
qed

lemma bounded_MDP_r: "bounded (range MDP_r)"
  using finite_r_lookup range_MDP_r_subs
  by (simp add: finite_imp_bounded finite_subset)

lemma MDP_A_ne[simp]: "(MDP_A s)  {}"
  using A_ne_locale s_invar s_len
  by (auto simp: MDP_A_def A_Map.is_empty_iff_keys_empty S_Map.lookup_in_list)

lemma K_closed_locale': 
  "am  set (s_list mdp)  (x, y, p)  A_Map.entries am  (s', prob)  set p  s' <states"
  using K_closed_locale
  by (fastforce simp: list.pred_set case_prod_beta)
  
lemma MDP_K_closed: 
  assumes "s<states" 
  shows "set_pmf (MDP_K (s, a))  {0..<states}"
proof 
  fix s'
  assume h: "s'  set_pmf (MDP_K (s, a))"
  show "s'  {0..<states}"
  proof (cases "a  MDP_A s")
    case False
    thus ?thesis
    using assms h
    using a_lookup_some_in_A 
    by (auto simp: MDP_K_def split: option.splits)
  next
    case True
    from h obtain r ps where "a_lookup (s_lookup mdp s) a = Some (r, ps)" and **:"s'  set_pmf (pmf_of_list ps)"
      unfolding MDP_K_def using assms True a_lookup_None_notin_A
      by (auto split: option.splits)
    have "pmf_of_list_wf ps"
      using lists_are_pmfs
      by (metis A_Map.Map_by_Ordered_axioms A_inv_locale Map_by_Ordered.lookup_some_set_a_inorder S_Map.lookup_in_list a_lookup (s_lookup mdp s) a = Some (r, ps) assms case_prod_conv s_invar s_len)
    have ***:"(s'', p)  set ps  p > 0  s'' < states" for s'' p
      by (metis A_Map.Map_by_Ordered_axioms A_inv_locale K_closed_locale' Map_by_Ordered.lookup_some_set_a_inorder S_Map.lookup_in_list a_lookup (s_lookup mdp s) a = Some (r, ps) assms s_invar s_len)
    have "s' < states"
      using *** ** set_pmf_of_list'[OF pmf_of_list_wf ps]
      by blast
    then show ?thesis by auto
  qed
qed

lemma MDP_K_comp_closed: "s  states  set_pmf (MDP_K (s, a))  {states..}"
  unfolding MDP_K_def
  by auto

lemma MDP_A_outside: "states  s  MDP_A s = {0}"
  unfolding MDP_A_def
  by auto


lemma invar_s_lookup: "s < states  A_Map.invar (s_lookup mdp s)"
  by (simp add: A_inv_locale S_Map.lookup_in_list s_invar s_len)

lemma ne_s_lookup: "s < states  ¬ A_Map.is_empty (s_lookup mdp s)"
  using A_ne_locale S_Map.lookup_in_list s_invar s_len by blast

lemma sa_lookup_eq:
  assumes "s < states" "a  MDP_A s" "(a_lookup (s_lookup mdp s) a) = Some (r, ps)"
  shows "r = MDP_r (s,a)" "pmf_of_list ps = MDP_K (s, a)"
  unfolding MDP_K_def 
  using assms a_lookup_None_notin_A
  by (auto split: option.splits simp: MDP_r_in_A_eq a_lookup'_def)

lemma fst_sa_lookup'_eq:
  assumes "s < states" "a  MDP_A s"
  shows "fst (a_lookup' (s_lookup mdp s) a) = MDP_r (s, a)"
  by (simp add: MDP_r_in_A_eq assms)


lemma snd_sa_lookup'_eq:
  assumes "s < states" "a  MDP_A s"
  shows "pmf_of_list (snd (a_lookup' (s_lookup mdp s) a)) = MDP_K (s, a)"
  using assms a_lookup'_def sa_lookup_eq a_lookup_None_notin_A
  by (auto split: option.splits)

lemma entries_A_eq_r: "s < states  (a, r, succs)  A_Map.entries (s_lookup mdp s)  r = MDP_r (s, a)"
  using sa_lookup_eq[OF _ a_lookup_some_in_A] A_Map.inorder_lookup_Some[OF invar_s_lookup]  
  by simp

lemma entries_A_eq_K: "s < states  (a, r, succs)  A_Map.entries (s_lookup mdp s)  pmf_of_list succs = MDP_K (s, a)"
  using sa_lookup_eq[OF _ a_lookup_some_in_A] A_Map.inorder_lookup_Some[OF invar_s_lookup]  
  by simp

lemma a_inorderD:
  assumes "s < states" "(a, r, succs)  A_Map.entries (s_lookup mdp s)"
  shows "a  MDP_A s" "r = MDP_r (s, a)" "pmf_of_list succs = MDP_K (s, a)"
  using assms A_Map.inorder_lookup_Some a_lookup_some_in_A invar_s_lookup entries_A_eq_r entries_A_eq_K 
  by auto


lemma a_map_entries_lookup: "s < states  a  MDP_A s  (a, a_lookup' (s_lookup mdp s) a)  A_Map.entries (s_lookup mdp s)"
  by (metis A_Map.lookup_in_keys A_Map.lookup_some_set_a_inorder MDP_A_def a_lookup'_def invar_s_lookup option.simps(5))

lemma lists_are_pmfs': "amset (s_list mdp)   (a,r,p)A_Map.entries am  pmf_of_list_wf p"
  using lists_are_pmfs by fastforce

lemma lists_are_pmfs'': "amset (s_list mdp)   (a,rp)A_Map.entries am  pmf_of_list_wf (snd rp)"
  using lists_are_pmfs by fastforce

lemma lists_are_pmfs''': "s < states   (a,rp)A_Map.entries (s_lookup mdp s)  pmf_of_list_wf (snd rp)"
  using S_Map.lookup_in_list lists_are_pmfs'' s_invar s_len by blast

lemma pmf_of_list_wf_mdp:
  assumes "s < states" "a  MDP_A s"
  shows "pmf_of_list_wf (snd (a_lookup' (s_lookup mdp s) a))"
  using assms a_map_entries_lookup
  by (auto intro: lists_are_pmfs'''[of s a])


lemma  set_list_pmf_in_states:
   assumes "s < states" "a  MDP_A s" "(aa, b)  set (snd (a_lookup' (s_lookup mdp s) a)) "
shows 
  "aa < states"
proof -
  have"(s_lookup mdp s)  set( s_list mdp)"
    using S_Map.lookup_in_list assms(1) s_invar s_len by blast
  moreover have  "(a, (a_lookup' (s_lookup mdp s) a))  A_Map.entries (s_lookup mdp s)"
    by (metis A_Map.lookup_in_keys A_Map.lookup_some_set_a_inorder MDP_A_def a_lookup'_def assms(1) assms(2) invar_s_lookup option.case(2))
  ultimately show ?thesis
  using K_closed_locale assms
  by (fastforce simp: case_prod_beta list_all_def)
qed
end


lemma sum_list_partition_fst: "(spps. f sp) = (afst ` set ps. spfilter (λz. fst z = a) ps. f sp)"
proof (induction ps)
  case Nil
  then show ?case by auto
next
    have *:"(if b then x else y) + z =  (if b then x+z else y+z)" for b x y z
      by auto
  case (Cons a ps)
  show ?case
  proof (cases "fst a  fst ` set ps")
    case True
      have "sum_list (map f (a # ps)) = f a + (afst ` set ps. sum_list (map f (filter (λz. fst z = a) ps)))"      
      by (auto dest: simp: Cons if_distrib  sum.insert_remove cong: sum.cong if_cong)
    also have " = (aafst ` set ps. (if fst a =  aa then f a else 0)) + (aafst ` set ps. sum_list (map f (filter (λz. fst z = aa) ps)))"  
      using True by auto
    also have " = (aafst ` set ps. (if fst a =  aa then f a else 0) +  sum_list (map f (filter (λz. fst z = aa) ps)))" 
      by (auto simp:  sum.distrib)
    also have " = (aafst ` set (a # ps). sum_list (map f (filter (λz. fst z = aa) (a # ps))))"
      by (auto simp: * True if_distrib[of "map _"] if_distrib[of "sum_list"] insert_absorb cong: if_cong)
    finally show ?thesis.
  next
      case False     
      have "sum_list (map f (a # ps)) = f a + (afst ` set ps. sum_list (map f (filter (λz. fst z = a) ps)))"      
      by (auto dest: simp: Cons if_distrib  sum.insert_remove cong: sum.cong if_cong)
    also have " = (aafst ` set (a#ps). (if fst a =  aa then f a else 0)) + (aafst ` set ps. sum_list (map f (filter (λz. fst z = aa) ps)))"  
      using False
      by (auto simp: )
    also have " = (aafst ` set (a#ps). (if fst a =  aa then f a else 0)) + (aafst ` set (a#ps). sum_list (map f (filter (λz. fst z = aa) ps)))"  
    proof -
      have *: "(x. x  set xs  x = 0)  sum_list xs = 0" for xs
        by (induction xs) auto
      have "(sum_list (map f (filter (λz. fst z = fst a) ps))) = 0"
        using False
        by (intro *) (auto intro: fst_eqD)
      thus ?thesis
        by (auto simp: False)
    qed
    also have " = (aafst ` set (a#ps). (if fst a =  aa then f a else 0) +  sum_list (map f (filter (λz. fst z = aa) ps)))" 
      by (auto simp:  sum.distrib)
    also have " = (aafst ` set (a # ps). sum_list (map f (filter (λz. fst z = aa) (a # ps))))"
      by (auto simp: * False if_distrib[of "map _"] if_distrib[of "sum_list"] insert_absorb cong: if_cong)
    finally show ?thesis.
  qed
qed

lemma pmf_of_list_expectation:
  assumes "pmf_of_list_wf ps"
  shows "measure_pmf.expectation (pmf_of_list ps) f = ((s', p) ps. p * f s')"
proof -
  have sumlist_cong: "sum_list (map f xs) = sum_list (map g xs)" if "x. x  set xs  f x = g x" for f g xs
    using that
    by (induction xs) auto
  have "((s', p) ps. p * f s') = sum_list (map (λsp. snd sp * f (fst sp)) ps)"
    by (metis case_prod_conv fst_def old.prod.exhaust snd_def)
  also have " = (a  fst ` (set ps). sum_list (map (λsp. snd sp * f (fst sp)) (filter (λz. fst z = a) ps)))"
    using sum_list_partition_fst
    by auto
    also have " = (a  fst ` (set ps). sum_list (map snd (filter (λz. fst z = a) ps)) * f a)"
      by (auto simp: add.commute set_filter map_eq_conv sum_list_mult_const[symmetric] intro!: sumlist_cong  sum.cong)
    also have " = (a  {u. b. (u, b)  set ps  b  0}  {u. b. (u,b)  set ps  (b. (u,b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a)"
    proof -
      have "fst ` (set ps) = {u. b. (u, b)  set ps}" 
        by force
      also have " = {u. b. (u, b)  set ps  b  0}  {u. b. (u,b)  set ps  (b. (u,b)  set ps  b = 0)}"
        by auto
      finally show ?thesis by auto
    qed
          also have " = (a  {u. b. (u, b)  set ps  b  0} . sum_list (map snd (filter (λz. fst z = a) ps)) * f a) + (a  {u. b. (u,b)  set ps  (b. (u,b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a)"
          proof -
            have "{u. (b. (u, b)  set ps)  (b. (u, b)  set ps  b = 0)}  fst ` set ps" 
              by force
            hence "finite {u. (b. (u, b)  set ps)  (b. (u, b)  set ps  b = 0)}"
              using finite_surj by blast
            thus ?thesis
              using assms finite_set_pmf_of_list set_pmf_of_list 
              by (subst sum.union_disjoint) fastforce+
          qed
        also have " = (a  {u. b. (u, b)  set ps  b  0} . sum_list (map snd (filter (λz. fst z = a) ps)) * f a)"
          by (fastforce intro!: sum.neutral iffD2[OF sum_list_nonneg_eq_0_iff])
    also have " = (a{u. b. (u, b)  set ps  b  0}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a)" by blast
  finally have "measure_pmf.expectation (pmf_of_list ps) f = (a{u. b. (u, b)  set ps  b  0}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a)"   
    using finite_set_pmf_of_list[OF assms]
    by (subst integral_measure_pmf) (fastforce simp add:  pmf_pmf_of_list set_pmf_of_list assms)+
  thus ?thesis
    using ((s', p)ps. p * f s') = (spps. snd sp * f (fst sp)) (afst ` set ps. spfilter (λz. fst z = a) ps. snd sp * f (fst sp)) = (afst ` set ps. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) (afst ` set ps. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) = (a{u. b. (u, b)  set ps  b  0}  {u. b. (u, b)  set ps  (b. (u, b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) (a{u. b. (u, b)  set ps  b  0}  {u. b. (u, b)  set ps  (b. (u, b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) = (a{u. b. (u, b)  set ps  b  0}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) + (a{u. b. (u, b)  set ps  (b. (u, b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) (a{u. b. (u, b)  set ps  b  0}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) + (a{u. b. (u, b)  set ps  (b. (u, b)  set ps  b = 0)}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) = (a{u. b. (u, b)  set ps  b  0}. sum_list (map snd (filter (λz. fst z = a) ps)) * f a) (spps. snd sp * f (fst sp)) = (afst ` set ps. spfilter (λz. fst z = a) ps. snd sp * f (fst sp)) by presburger
qed


locale MDP_Code = MDP_Code_raw +
  V_Map : Array' "v_lookup :: 'tv  nat  real" v_update v_len v_array v_list v_invar +
  D_Map : Map_by_Ordered d_empty "d_update :: nat  nat  'td  'td" d_delete d_lookup d_inorder d_inv
  for v_lookup v_update v_len v_array v_list v_invar
  and d_empty d_update d_delete d_lookup d_inorder d_inv +
fixes
  l :: real
assumes
  zero_le_disc_locale: "0  l" and
  disc_lt_one_locale: "l < 1"
begin

sublocale V_Map: Array_real v_lookup v_update v_len v_array v_list v_invar
  by unfold_locales

sublocale V_Map: Array_zero v_lookup v_update v_len v_array v_list v_invar
  by unfold_locales
   
sublocale D_Map: Map_by_Ordered_nat_zero d_empty d_update d_delete d_lookup d_inorder d_inv
  by unfold_locales

definition "d_lookup' m x = the (d_lookup m x)"

lemma map_to_fun_lookup: "D_Map.invar f  s  D_Map.keys f  D_Map.map_to_fun f s = d_lookup' f s"
  unfolding D_Map.map_to_fun_def d_lookup'_def
 using D_Map.lookup_None_set_inorder
  by (auto split: option.splits)

sublocale MDP: MDP_reward "(MDP_A)" "(MDP_K)"  "(MDP_r)" l 
  using MDP_A_ne bounded_MDP_r zero_le_disc_locale disc_lt_one_locale
  by unfold_locales auto

sublocale MDP: MDP_nat_disc "(MDP_A)" "(MDP_K)" "(MDP_r)" l "λX. LEAST y. y  X"  states
proof -
  have [simp]: "MDP_reward_disc.max_L_ex MDP_A MDP_K MDP_r l s v" for s v
    by (simp add: MDP.MDP_reward_axioms MDP_reward_disc.intro MDP_reward_disc.max_L_ex_def MDP_reward_disc_axioms.intro disc_lt_one_locale finite_is_arg_max has_arg_max_def)
  have "X  {}  (LEAST (y::nat). y  X)  X" for X
    using Inf_nat_def Inf_nat_def1 by presburger
  thus "MDP_nat_disc MDP_A MDP_K MDP_r l (λX. LEAST y. y  X) states"
  using MDP_K_closed MDP_K_comp_closed MDP_r_zero_notin_states MDP_A_outside disc_lt_one_locale
  by unfold_locales auto
qed

section ‹Code for @{const MDP.La}

definition "La_code rp v = ( 
    let (r, ps) = rp in r + l * (foldl (λ acc (s', p). p * v_lookup v s' + acc)) 0 ps)"

lemma La_code_correct:
  assumes "s < states" "v_len v = states" "v_invar v"  "pmf_of_list (snd rps) = MDP_K (s, a)" 
    "pmf_of_list_wf (snd rps)" "fst ` set (snd rps)  {0..<states}" "fst rps = MDP_r (s, a)"
  shows "La_code rps v = MDP.La a (V_Map.map_to_bfun v) s"
proof -
  have "measure_pmf.expectation (MDP_K (s, a)) (v_lookup v) = measure_pmf.expectation (MDP_K (s, a)) (V_Map.map_to_bfun v)"
    using assms MDP.K_closed
    by (force simp: V_Map.map_to_bfun.rep_eq split: option.splits 
        intro!: Bochner_Integration.integral_cong_AE AE_pmfI)
  have "foldl (λacc x. f x + acc) x xs = (xxs. f x) + x" for f xs and x :: real
    by (induction xs arbitrary: x) (auto simp: algebra_simps)
  hence *: "(xxs. f x) = foldl (λacc x. f x + acc) (0::real) xs" for f xs
    by (metis add.right_neutral)
  have "foldl (λacc (s', p). p * v_lookup v s' + acc) 0 (snd rps) = measure_pmf.expectation (MDP_K (s, a)) (apply_bfun (V_Map.map_to_bfun v))"
    unfolding assms(4)[symmetric] 
    using assms(5,6,7)
    by (auto intro!: foldl_cong simp: pmf_of_list_expectation * V_Map.map_to_bfun.rep_eq assms(2,3))
  thus ?thesis
    unfolding La_code_def
    using assms
    by (simp add: case_prod_unfold)
qed

lemma L_GS_code_correct':
  assumes "s < states" "v_len v = states" "v_invar v" "a  MDP_A s"
  shows "La_code (a_lookup' (s_lookup mdp s) a) v = MDP.La a (V_Map.map_to_bfun v) s"
  using pmf_of_list_wf_mdp assms set_list_pmf_in_states
  by(intro La_code_correct) 
    (auto simp: fst_sa_lookup'_eq[symmetric] snd_sa_lookup'_eq)

lemma v_lookup_map_to_bfun: "v_invar m  k < v_len m  v_lookup m k = V_Map.map_to_bfun m k"
  unfolding V_Map.map_to_bfun.rep_eq 
  by (force split: option.splits)

lemma map_to_bfun_eq_fun: "v_invar m  apply_bfun (V_Map.map_to_bfun v) = V_Map.map_to_fun v" 
  by (auto simp: V_Map.map_to_bfun.rep_eq V_Map.map_to_fun_def)

lemma map_to_fun_notin: "D_Map.invar d  k  D_Map.keys d  D_Map.map_to_fun d k = 0"
  by (auto simp: D_Map.map_to_fun_def D_Map.lookup_notin_keys split: option.splits)

section ‹Folding lists to maps›
(* TODO: convert to function from_list *)

lemma v_lookup_update: "v_invar m  k < v_len m  j < v_len m  v_lookup (v_update j x m) k  = (if j = k then x else v_lookup m k)" 
  by (auto simp add: V_Map.invar_update V_Map.len_array V_Map.lookup V_Map.update)

lemma V_invar_fold: "v_invar m  set xs  {0..<v_len m}  v_invar (fold (λs v. v_update s (f s v) v) xs m)"
  by (induction xs arbitrary: m) (auto simp add: V_Map.invar_update V_Map.len_array V_Map.update)


lemma V_len_fold: "v_invar m  set xs  {0..<v_len m}  v_len (fold (λs v. v_update s (f s v) v) xs m) = v_len m"
  by (induction xs arbitrary: m) (auto simp add: V_Map.invar_update V_Map.len_array V_Map.update)

lemma v_len_update: "v_invar m  j < v_len m  v_len (v_update j x m) = v_len m"
  by (simp add: V_Map.invar_update V_Map.len_array V_Map.update)

lemma v_lookup_fold: "v_invar m  n  v_len m  k < n  v_lookup (fold (λs v. v_update s (f s) v) [0..<n] m) k = (f k)"
  using V_invar_fold
  by (induction n arbitrary: m k) (auto intro!: V_invar_fold simp: v_lookup_update V_len_fold)

lemma keys_fold_map: "D_Map.invar m  D_Map.keys (fold (λs. d_update s (f s)) xs m) = D_Map.keys m  set xs"
  using D_Map.map_specs
  by (induction xs arbitrary: m) auto
                
lemma invar_fold_update: "D_Map.invar m  D_Map.invar (fold (λs. d_update s (f s)) xs m)"
    using D_Map.map_specs by (induction xs arbitrary: m) auto

lemma d_lookup_fold: "k < n  D_Map.invar m  d_lookup (fold (λs v. d_update s (f s) v) [0..<n] m) k = Some (f k)"
  using D_Map.map_update invar_fold_update by (induction n) auto

section ‹Code for @{const MDP.ℒb}

definition "ℒ_GS_code acts v = 
  (MAX (a, rs)  A_Map.entries acts. La_code rs v)"


lemma ℒ_GS_code_correct:
  assumes "s < states"  "v_invar v" "v_len v = states"
  shows "ℒ_GS_code (s_lookup mdp s) v = (a  MDP_A s. MDP.La a (V_Map.map_to_bfun v) s)"
  unfolding ℒ_GS_code_def
proof (subst cSup_eq_Max[symmetric])
  show "finite ((λ(a, rs). La_code rs v) ` A_Map.entries (s_lookup mdp s))"
    using A_Map.finite_entries by blast
  show "(λ(a, rs). La_code rs v) ` A_Map.entries (s_lookup mdp s)  {}"
    using ne_s_lookup assms A_Map.is_empty_iff_entries_empty by blast
  
  
  have "La_code (r,s') v = MDP.La a (V_Map.map_to_bfun v) s" if "(a, r,s')  A_Map.entries (s_lookup mdp s)" for a r s'
  proof -
    have "r = MDP_r (s, a)"
      by (metis assms(1) entries_A_eq_r that)
    moreover have "fst ` set s'  MDP.state_space"
      using K_closed_locale' S_Map.lookup_in_list assms(1) s_invar s_len that by fastforce
    moreover have "s' = (snd (a_lookup' (s_lookup mdp s) a))"
      using A_Map.inorder_lookup_Some a_lookup'_def assms(1) invar_s_lookup that by auto
    ultimately show ?thesis 
        using assms that a_inorderD pmf_of_list_wf_mdp
        by (intro La_code_correct) auto
qed
  thus "((a, rs)A_Map.entries (s_lookup mdp s). La_code rs v) = (aMDP_A s. MDP.La a (V_Map.map_to_bfun v) s)"
    using invar_s_lookup
    by (auto simp: MDP_A_def assms SUP_image A_Map.keys_eq_fst_entries intro!: SUP_cong)
qed


definition "ℒ_code v = 
  V_Map.arr_tabulate (λs. ℒ_GS_code (s_lookup mdp s) v) states"


lemma ℒ_code_lookup:
  assumes "s < states" "v_len v = states" "v_invar v"
  shows "v_lookup (ℒ_code v) s = (ℒ_GS_code (s_lookup mdp s) v)"
  using assms unfolding ℒ_code_def by auto

lemma keys_ℒ_code[simp]: "v_invar v  v_len v = states  v_len (ℒ_code v) = v_len v"
  unfolding ℒ_code_def by auto


lemma ℒ_code_correct:
  assumes "s < states" "v_len v = states" "v_invar v"
  shows "v_lookup (ℒ_code v) s = MDP.ℒb (V_Map.map_to_bfun v) s"
  unfolding ℒ_code_lookup[OF assms] MDP.ℒb_eq_La_max'
  by (auto intro: cSup_eq_Max simp: assms ℒ_GS_code_correct)

lemma invar_ℒ_code: "v_invar v  v_invar (ℒ_code v)"
  using V_invar_fold unfolding ℒ_code_def
  using V_Map.arr_tabulate_def V_Map.invar_array by presburger


lemma ℒ_code_correct':
  assumes "v_len v = states" "v_invar v"
  shows "V_Map.map_to_bfun (ℒ_code v) = MDP.ℒb (V_Map.map_to_bfun v)"
  using MDP.ℒb_zero  assms
proof (intro bfun_eqI)
  fix x
  show "apply_bfun (V_Map.map_to_bfun (ℒ_code v)) x = apply_bfun (MDP.ℒb (V_Map.map_to_bfun v)) x"
  proof (cases "x < states")
    case True
    then show ?thesis
      using assms keys_ℒ_code ℒ_code_correct invar_ℒ_code v_lookup_map_to_bfun by force
  next
    case False
    then show ?thesis
      using assms keys_ℒ_code MDP.ℒb_zero
      by (fastforce simp: V_Map.map_to_bfun.rep_eq  dest: split: option.splits)+
  qed
qed

section ‹Code to check condition›

definition "check_dist v v' eps = (
  let m = eps * (1 - l) / (2 * l) in
    (s < states. abs (v_lookup v s - v_lookup v' s) < m))"

lemma check_dist_correct:
  assumes "v_invar v" "v_invar v'" "v_len v = states" "v_len v' = states" "eps > 0" "l  0"
  shows "check_dist v v' eps  dist (V_Map.map_to_bfun v) (V_Map.map_to_bfun v') < eps * (1 - l) / (2 * l)"
proof -
  have dist_zero_ge: "dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun v') x) = 0" if "x  states" for x
    using assms that 
    by (auto simp: V_Map.map_to_bfun.rep_eq split: option.splits)
 have univ: "UNIV = {0..<states}  {states..}" by auto
 have fin: "finite (range (λx. dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun v') x)))"
    by (auto simp: dist_zero_ge univ Set.image_Un Set.image_constant[of states])
  have zero_less_eps: "0 < eps * (1 - l) / (2 * l)"
    using MDP.zero_le_disc assms MDP.disc_lt_one
    by (auto intro!: mult_imp_less_div_pos simp: less_le)
  show ?thesis
  proof
    assume h: "check_dist v v' eps" 
    show "dist (V_Map.map_to_bfun v) (V_Map.map_to_bfun v') < eps * (1 - l) / (2 * l)"
      unfolding dist_bfun.rep_eq
    proof (rule finite_imp_Sup_less[OF fin])
      show "0  range (λx. dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun v') x))"
        using dist_zero_ge by fastforce
      have "dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun v') x) < eps * (1 - l) / (2 * l)" if "x < states" for x
        using assms h that
        unfolding check_dist_def V_Map.map_to_bfun.rep_eq  dist_real_def
        by (auto  split: option.splits)
      thus "x  range (λx. dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun v') x))  x < eps * (1 - l) / (2 * l)" for x
        using zero_less_eps dist_zero_ge imageE not_less 
        by (metis (no_types, lifting))
    qed
  next
  show "dist (V_Map.map_to_bfun v) (V_Map.map_to_bfun v') < eps * (1 - l) / (2 * l)  check_dist v v' eps"
    using assms fin
    by (auto simp: check_dist_def dist_bfun.rep_eq finite_Sup_less_iff dist_real_def v_lookup_map_to_bfun)
qed
qed


section ‹Find policy›
definition "find_policy_state_code_aux v s = 
  (least_arg_max_max_ne (λ(_, rsuccs). 
    La_code rsuccs v) ((a_inorder (s_lookup mdp s))))"

definition "find_policy_state_code_aux' v s = (
  case find_policy_state_code_aux v s of ((a, _, _), v)  (a, v))"


lemma find_policy_state_code_aux_eq:
  assumes "s < states"
  shows "find_policy_state_code_aux' v s = (least_arg_max_max_ne (λa.
   La_code (a_lookup' (s_lookup mdp s) a) v) ((map fst (a_inorder (s_lookup mdp s)))))"
  unfolding find_policy_state_code_aux'_def  find_policy_state_code_aux_def
  using assms ne_s_lookup invar_s_lookup A_Map.inorder_lookup_Some
  by (subst least_arg_max_max_ne_app'[symmetric]) 
    (auto simp: A_Map.entries_def a_lookup'_def case_prod_unfold A_Map.is_empty_def)

lemma find_policy_state_code_aux'_eq':
  assumes "s < states" "v_len v = states" "v_invar v"
  shows "find_policy_state_code_aux' v s = 
  (least_arg_max (λa. MDP.La a (V_Map.map_to_bfun v) s) (λa. a  MDP_A s), Max ((λa. MDP.La a (V_Map.map_to_bfun v) s) ` (MDP_A s)))"
proof -
  have "find_policy_state_code_aux' v s = least_arg_max_max_ne (λa. La_code (a_lookup' (s_lookup mdp s) a) v) (map fst (a_inorder (s_lookup mdp s)))"
    using find_policy_state_code_aux_eq assms by auto
  also have  = (least_arg_max (λa. La_code (a_lookup' (s_lookup mdp s) a) v) (List.member (map fst (a_inorder (s_lookup mdp s)))),
     MAX aset (map fst (a_inorder (s_lookup mdp s))). La_code (a_lookup' (s_lookup mdp s) a) v)
    using A_Map.is_empty_def assms(1) ne_s_lookup A_Map.invar_def A_inv_locale S_Map.lookup_in_list s_invar s_len 
    by (auto simp: fold_max_eq_arg_max')
  also have  = (least_arg_max (λa. MDP.La a (V_Map.map_to_bfun v) s) (List.member (map fst (a_inorder (s_lookup mdp s)))),
     MAX aset (map fst (a_inorder (s_lookup mdp s))). MDP.La a (V_Map.map_to_bfun v) s)
    using assms a_inorderD(1) A_Map.keys_def  MDP_A_def 
    by (auto intro!: least_arg_max_cong simp: L_GS_code_correct' in_set_member[symmetric])
  also have  = (least_arg_max (λa. MDP.La a (V_Map.map_to_bfun v) s) (λa. a  MDP_A s),
     MAX aMDP_A s. MDP.La a (V_Map.map_to_bfun v) s)
  proof -
    have *: "a  fst ` set (a_inorder (s_lookup mdp s))   List.member (map fst ((a_inorder (s_lookup mdp s)))) a" for a
      by (auto simp: List.member_def)
    show ?thesis
    using assms La_code_correct  A_Map.keys_def 
    by (auto intro!: least_arg_max_cong  simp: * MDP_A_def)
qed
  finally show ?thesis.
qed

definition "vi_find_policy_code (v::'tv) = D_Map.from_list' (λs. fst (find_policy_state_code_aux' v s)) [0..<states]"

lemma d_invar_vi_find_policy_code: "D_Map.invar (vi_find_policy_code v)"
  using D_Map.from_list_invar vi_find_policy_code_def by simp

lemma d_keys_vi_find_policy_code: "D_Map.keys (vi_find_policy_code v) = {0..<states}"
  using D_Map.from_list_keys vi_find_policy_code_def by simp

lemma vi_find_policy_code_notin: 
  assumes "s  states" shows "d_lookup (vi_find_policy_code v) s = None"
  using D_Map.lookup_notin_keys assms d_invar_vi_find_policy_code d_keys_vi_find_policy_code by force

lemma vi_find_policy_code_in: 
  assumes "s < states" shows "x. d_lookup (vi_find_policy_code v) s = Some x"
  by (simp add: D_Map.lookup_in_keys assms d_invar_vi_find_policy_code d_keys_vi_find_policy_code)
 
lemma vi_find_policy_code_ge: "s  states  D_Map.map_to_fun (vi_find_policy_code v) s = 0"
  using vi_find_policy_code_notin vi_find_policy_code_def 
  by (auto simp: D_Map.map_to_fun_def)


lemma vi_find_policy_code_correct:
  assumes "v_len v = states" "v_invar v" "s < states"
  shows "D_Map.map_to_fun ((vi_find_policy_code v)) s = least_arg_max (λa. MDP.La a (V_Map.map_to_bfun v) s) (λa. a  MDP_A s)"
  using assms
  by (simp add: find_policy_state_code_aux'_eq'  vi_find_policy_code_def D_Map.map_to_fun_def)


lemma vi_find_policy_correct: 
  assumes "v_len v = states" "v_invar v"
  shows "D_Map.map_to_fun (vi_find_policy_code v) = (MDP.find_policy' (V_Map.map_to_bfun v))"
proof -
  have "D_Map.map_to_fun (vi_find_policy_code v) s = (MDP.find_policy' (V_Map.map_to_bfun v)) s" if "s  states" for s
    using vi_find_policy_code_ge that
    by (auto simp:  MDP.find_policy'_def MDP_A_def  MDP.is_opt_act_def intro!: Least_equality)
  moreover have "D_Map.map_to_fun (vi_find_policy_code v) s = (MDP.find_policy' (V_Map.map_to_bfun v)) s" if "s < states" for s
    using that assms
    by (auto simp: MDP.find_policy'_def vi_find_policy_code_correct least_arg_max_def MDP.is_opt_act_def)
  ultimately show ?thesis
    using not_le by blast
qed

definition "v0 = V_Map.arr_tabulate (λ_. 0) states"

lemma v0_correct: "v_invar v0" "v_len v0 = states"
  unfolding v0_def by auto
 
definition "v_map_from_list xs = v_array xs" 

end

text ‹
hack:
@{const pmf_of_list_wf} is polymorphic, so equality to @{term "1"} is checked for the sum of all probabilities.
This fails for floats, so we reimplement the check monomorphically and change equality on floats to
@{term "a = b  dist a b < 1.0/10^8"}.
›
lemmas pmf_of_list_wf_code[code del]

definition
  "pmf_of_list_wf' xs  list_all (λz. snd z  0) xs  sum_list (map snd xs) = (1 :: real)"

lemma pmf_of_list_code [code abstract]:
  "mapping_of_pmf (pmf_of_list xs) = (
     if pmf_of_list_wf' xs then
       let xs' = filter (λz. snd z*(10^8)  0) xs
       in  Mapping.tabulate (remdups (map fst xs')) 
             (λx. sum_list (map snd (filter (λz. fst z = x) xs')))
     else
       Code.abort (STR ''Invalid list for pmf_of_list'') (λ_. mapping_of_pmf (pmf_of_list xs)))"
  using mapping_of_pmf_pmf_of_list'[of xs] pmf_of_list_wfI
  by (auto simp add: pmf_of_list_wf'_def list_all_def)


code_printing
 constant IArray.tabulate  (SML) "case _ of (n, f) => Vector.tabulate (IntInf.toInt n, fn i => f ((IntInf.fromInt i)))"
| constant IArray.sub'  (SML) "case _ of (arr, i) => Vector.sub (arr, IntInf.toInt i)"
| constant IArray.length'  (SML) "IntInf.fromInt (Vector.length _)"

definition "nat_map_from_list (xs :: (nat × _) list) = foldr (λ(k,v) m. RBT_Map.update k v m) xs RBT_Set.empty "
definition "nat_pmf_of_list (xs :: (nat × _) list) = pmf_of_list xs"

definition "assoc_list_to_MDP d xs = 
    to_valid_MDP (MDP d (length xs) (IArray (map (λas. foldr (λ(a,(r,p)) m. RBT_Map.update a (r, p) m) as RBT_Set.empty) xs)))"

lemma starray_of_list_tabulate [code_unfold]: "starray_of_list (map f [0..<n]) = starray_tabulate n f"
  by (simp add: starray_eq_iff tabulate_def)

end