Theory VI_Code

theory VI_Code
  imports
    Code_Setup
    "../Value_Iteration" 
    "HOL-Library.Code_Target_Numeral"
begin

context MDP_Code begin

partial_function (tailrec) VI_code_aux where
"VI_code_aux v eps = (
  let v' = ℒ_code v in
    if check_dist v v' eps 
    then v'  
    else VI_code_aux v' eps)"

lemmas VI_code_aux.simps[code]

definition "VI_code v eps = (if l = 0  eps  0 then ℒ_code v else VI_code_aux v eps)"


lemma VI_code_aux_correct_aux:
  assumes "eps > 0" "v_invar v" "v_len v = states" "l  0"
  shows "V_Map.map_to_fun (VI_code_aux v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v) 
   v_len (VI_code_aux v eps) = states 
   v_invar (VI_code_aux v eps)"
  using assms
proof (induction eps "V_Map.map_to_bfun v" arbitrary: v rule: MDP.value_iteration.induct)
  case (1 eps)
  have *: "(check_dist v (ℒ_code v) eps)  2 * l * dist (V_Map.map_to_bfun v) (MDP.ℒb (V_Map.map_to_bfun v)) < eps * (1 - l)"
  proof (subst check_dist_correct)
    have " 0 < l" using 1 MDP.zero_le_disc by linarith
    thus "(dist (V_Map.map_to_bfun v) (V_Map.map_to_bfun (ℒ_code v)) < eps * (1 - l) / (2 * l)) =
    (2 * l * dist (V_Map.map_to_bfun v) (MDP.ℒb (V_Map.map_to_bfun v)) < eps * (1 - l))"
      by (subst pos_less_divide_eq) (fastforce simp: ℒ_code_correct' 1 algebra_simps)+
  qed (auto simp: 1 intro: invar_ℒ_code)
  hence **: "V_Map.map_to_fun (VI_code_aux v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun (ℒ_code v)))" if "¬ (check_dist v (ℒ_code v) eps)"
    using invar_ℒ_code 1 that by (auto simp: VI_code_aux.simps ℒ_code_correct')
  have "V_Map.map_to_fun (VI_code_aux v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun v))"
  proof (cases "(check_dist v (ℒ_code v) eps)")
    case True
    thus ?thesis
    using 1 invar_ℒ_code
    by (auto simp: MDP.value_iteration.simps VI_code_aux.simps[of v] * map_to_bfun_eq_fun[symmetric] ℒ_code_correct')
next
  case False
  thus ?thesis 
    using 1 ℒ_code_correct' ** * MDP.value_iteration.simps by auto
qed
  thus ?case
    using 1 VI_code_aux.simps ℒ_code_correct' * invar_ℒ_code by auto
qed

lemma VI_code_aux_correct:
  assumes "eps > 0" "v_invar v" "v_len v = states" "l  0"
  shows "V_Map.map_to_fun (VI_code_aux v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v)"
  using assms VI_code_aux_correct_aux by auto

lemma VI_code_aux_keys:
  assumes "eps > 0" "v_invar v" "v_len v = states" "l  0"
  shows "v_len (VI_code_aux v eps) = states"
  using assms VI_code_aux_correct_aux by auto

lemma VI_code_aux_invar:
  assumes "eps > 0" "v_invar v" "v_len v = states" "l  0"
  shows "v_invar (VI_code_aux v eps)"
  using assms VI_code_aux_correct_aux by auto

lemma VI_code_correct:
  assumes "eps > 0" "v_invar v" "v_len v = states"
  shows "V_Map.map_to_fun (VI_code v eps) = MDP.value_iteration eps (V_Map.map_to_bfun v)"
proof (cases "l = 0")
  case True
  then show ?thesis  
    using assms invar_ℒ_code ℒ_code_correct'
    unfolding VI_code_def MDP.value_iteration.simps[of _ "V_Map.map_to_bfun v"]
    by (fastforce simp: map_to_bfun_eq_fun)
next
  case False
  then show ?thesis
    using assms
    by (auto simp add: VI_code_def VI_code_aux_correct)
qed

definition "VI_policy_code v eps = vi_find_policy_code (VI_code v eps)"

lemma VI_policy_code_correct:
  assumes "eps > 0" "v_invar v" "v_len v = states"
  shows "D_Map.map_to_fun (VI_policy_code v eps) = MDP.vi_policy' eps (V_Map.map_to_bfun v)"
proof -
  have "V_Map.map_to_bfun (VI_code v eps) = (MDP.value_iteration eps (V_Map.map_to_bfun v))"
    using  assms VI_code_correct
    by (auto simp:  VI_code_aux_invar map_to_bfun_eq_fun)
  moreover have "D_Map.map_to_fun (VI_policy_code v eps) = MDP.find_policy' (V_Map.map_to_bfun (VI_code v eps))"
    unfolding VI_code_def VI_policy_code_def
    using assms invar_ℒ_code keys_ℒ_code vi_find_policy_correct vi_find_policy_correct VI_code_aux_correct_aux assms by (cases "l = 0") auto
  ultimately show ?thesis 
    unfolding MDP.vi_policy'_def
    by presburger
qed

end

context MDP_nat_disc
begin

lemma dist_opt_bound_ℒb: "dist v νb_opt  dist v (b v) / (1 - l)"
  using contraction_ℒ_dist
  by (simp add: mult.commute mult_imp_le_div_pos)

lemma cert_ℒb: 
  assumes "ε  0" "dist v (b v) / (1 - l)  ε"
  shows "dist v νb_opt  ε"
  using assms dist_opt_bound_ℒb order_trans by auto

definition "check_value_ℒb eps v  dist v (b v) / (1 - l)  eps"

definition "vi_policy_bound_error v = (
  let v' = (b v); err = (2 * l) * dist v v' / (1 - l) in
  (err, find_policy' v'))"

lemma 
  assumes "vi_policy_bound_error v = (err, d)"
  shows "dist (νb (mk_stationary_det d)) νb_opt  err"
proof (cases "l = 0")
  case True
  hence "vi_policy_bound_error v = (0, find_policy' (b v))"
    unfolding vi_policy_bound_error_def by auto
  have "b v = b νb_opt"
    by (auto simp: b.rep_eq L_def simp del: b_opt intro!: bfun_eqI simp: ℒ_def) (simp add: True)
  hence "b v = νb_opt"
    by auto
  hence "νb (mk_stationary_det (find_policy' (b v))) = νb_opt"
    using L_ν_fix ν_improving_opt_acts conserving_imp_opt
    unfolding find_policy'_def ν_conserving_def
    by auto
  then show ?thesis
    using assms unfolding vi_policy_bound_error_def 
    by (auto simp: True)
next
  case False
  then show ?thesis
  proof (cases "b v = v")
    case True
  hence "νb (mk_stationary_det (find_policy' (b v))) = νb_opt"
    using L_ν_fix ν_improving_opt_acts conserving_imp_opt
    unfolding find_policy'_def ν_conserving_def
    by auto
  then show ?thesis
    using assms unfolding vi_policy_bound_error_def 
    by (auto simp: True)
  next
    case False
    hence 1: "dist v (b v) > 0"
      by fastforce
    hence "2 * l * dist v (b v) > 0"
      using l  0 zero_le_disc by (simp add: less_le)
    hence "err > 0"
      using assms unfolding vi_policy_bound_error_def by auto
    hence "dist (νb (mk_stationary_det (find_policy' (b v)))) νb_opt < err'" if "err < err'" for err'
      using that assms
      unfolding vi_policy_bound_error_def
      by (auto simp:  pos_divide_less_eq[symmetric] intro: find_policy'_error_bound)
    then show ?thesis
      using assms unfolding vi_policy_bound_error_def Let_def
      by force
  qed
qed

end

context MDP_Code
begin
definition "vi_policy_bound_error_code v = (
  let v' = (ℒ_code v);
    d = if states = 0 then 0 else (MAX s  {..< states}. dist (v_lookup v s) (v_lookup v' s));
   err = (2 * l) * d / (1 - l) in
  (err, vi_find_policy_code v'))"

lemma 
  assumes "v_len v = states" "v_invar v"
  shows "D_Map.map_to_fun (snd (vi_policy_bound_error_code v)) = snd (MDP.vi_policy_bound_error (V_Map.map_to_bfun v))"
  using assms ℒ_code_correct' invar_ℒ_code vi_find_policy_correct
  by (auto simp: vi_policy_bound_error_code_def MDP.vi_policy_bound_error_def)

lemma MAX_cong:
  assumes "x. x  X  f x = g x"
  shows "(MAX x  X. f x) = (MAX x  X. g x)"
  using assms by auto

lemma 
  assumes "v_len v = states" "v_invar v"
  shows "(fst (vi_policy_bound_error_code v)) = fst (MDP.vi_policy_bound_error (V_Map.map_to_bfun v))"
proof-
  have dist_zero_ge: "dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun (ℒ_code v)) x) = 0" if "x  states" for x
    using assms that 
    by (auto simp: V_Map.map_to_bfun.rep_eq)
  have univ: "UNIV = {0..<states}  {states..}" by auto
  let ?d = "λx. dist (apply_bfun (V_Map.map_to_bfun v) x) (apply_bfun (V_Map.map_to_bfun (ℒ_code v)) x)"

 have fin: "finite (range (λx. ?d x))"
   by (auto simp: dist_zero_ge univ Set.image_Un Set.image_constant[of states])
  
  have r: "range (λx. ?d x) = ?d ` {..<states}  ?d ` {states..}"
    by force
  hence "Sup (range ?d) = Max (range ?d)"
    using fin cSup_eq_Max by blast
  also have " = (if states = 0 then (Max (?d ` {states..})) else max (Max (?d ` {..<states})) (Max (?d ` {states..})))"
    using r fin by (auto intro: Max_Un)
  also have " = (if states = 0 then 0 else max (Max (?d ` {..<states})) 0)"
    using dist_zero_ge
    by (auto simp: Set.image_constant[of states] cSup_eq_Max[symmetric, of "(λ_. 0) ` {states..}"])
  also have " = (if states = 0 then 0 else (Max (?d ` {..<states})))"
    by (auto intro!: max_absorb1 max_geI)
  finally have 1: "Sup (range ?d) = (if states = 0 then 0 else (Max (?d ` {..<states})))".
  thus ?thesis
    unfolding  MDP.vi_policy_bound_error_def vi_policy_bound_error_code_def dist_bfun_def
    using assms v_lookup_map_to_bfun ℒ_code_correct' ℒ_code_correct 
    by fastforce
qed

end

global_interpretation VI_Code:
  MDP_Code
  (* state map (transition system) *)
  "IArray.sub" "λn x arr. IArray ((IArray.list_of arr)[n:= x])" "IArray.length" "IArray" "IArray.list_of" "λ_. True"
  
  (* action map *)
  RBT_Set.empty RBT_Map.update RBT_Map.delete Lookup2.lookup Tree2.inorder rbt

  "MDP.transitions (Rep_Valid_MDP mdp)" "MDP.states (Rep_Valid_MDP mdp)" 

  (* value map *)
  starray_get "λi x arr. starray_set arr i x" starray_length starray_of_list "λarr. starray_foldr (λx xs. x # xs) arr []" "λ_. True"

  (* decision rule map *)
  RBT_Set.empty RBT_Map.update RBT_Map.delete Lookup2.lookup Tree2.inorder rbt

  "MDP.disc (Rep_Valid_MDP mdp)"

  for mdp
  defines VI_code = VI_Code.VI_code
    and vi_policy_bound_error_code = VI_Code.vi_policy_bound_error_code
    and VI_code_aux = VI_Code.VI_code_aux
    and La_code = VI_Code.La_code
    and a_lookup' = VI_Code.a_lookup'
    and d_lookup' = VI_Code.d_lookup'
    and find_policy_state_code_aux' = VI_Code.find_policy_state_code_aux'
    and find_policy_state_code_aux = VI_Code.find_policy_state_code_aux
    and check_dist = VI_Code.check_dist
    and ℒ_code = VI_Code.ℒ_code
    and VI_policy_code = VI_Code.VI_policy_code
    and ℒ_GS_code = VI_Code.ℒ_GS_code
    and v0 = VI_Code.v0
    and entries = M.entries
    and from_list' = M.from_list'
    and from_list = M.from_list
    and vi_find_policy_code = VI_Code.vi_find_policy_code
    and v_map_from_list = VI_Code.v_map_from_list
    and arr_tabulate = starray_Array.arr_tabulate
  using Rep_Valid_MDP  
  by unfold_locales 
   (fastforce simp: Ball_set_list_all[symmetric] case_prod_beta pmf_of_list_wf_def is_MDP_def RBT_Set.empty_def M.invar_def empty_def M.entries_def M.is_empty_def length_0_conv[symmetric])+

lemmas arr_tabulate_def[unfolded starray_Array.arr_tabulate_def, code]
lemmas entries_def[unfolded M.entries_def, code]
lemmas from_list'_def[unfolded M.from_list'_def, code]
lemmas from_list_def[unfolded M.from_list_def, code]

function tabulate where
"tabulate f acc upper n = (
  if n < upper then tabulate f (update n (f n) acc) upper (Suc n) else acc)"
  by auto
termination
  by (relation "Wellfounded.measure (λ(_, _, i,N). i - N)") auto

lemma tabulate_Suc: "j  n'  update n' (f n') (tabulate f m n' j) = tabulate f m (Suc n') j"
proof (induction "n' - j" arbitrary: m n' j)
  case 0
  then show ?case by auto
next
  case (Suc j)
  then show ?case
    by auto
qed

lemma from_list'_upt [code_unfold]: "from_list' f [0..<n] = tabulate f empty n 0"
proof -
  have "j  n  foldl (λacc s. update s (f s) acc) m [j..<n] = tabulate f m n j" for m j
  proof (induction "n - j" arbitrary: m n j)
    case 0
    then show ?case by auto
  next
    case (Suc x)
    then obtain n' where n': "n = Suc n'"
      using diff_le_self Suc_le_D by metis
    then show ?case
      using Suc
      by (auto simp del: tabulate.simps simp: n' tabulate_Suc)
  qed
  thus ?thesis
  unfolding from_list'_def M.from_list'_def
  by auto
qed

end