Theory Policy_Iteration_Fin

theory Policy_Iteration_Fin
  imports 
    Policy_Iteration 
    MDP_fin 
    Blinfun_To_Matrix
begin

context MDP_nat_disc begin

lemma finite_DD[simp]: "finite DD"
proof -
  let ?set = "{d. x :: nat. (x  {0..<states}  d x  (s{0..<states}. A s))  (x  {0..<states}  d x = 0)}"
  have "finite (s<states. A s)"
    using A_fin by auto
  hence "finite ?set"
    by (intro finite_set_of_finite_funs) auto
  moreover have "DD  ?set"
    unfolding is_dec_det_def
    using A_outside
    by (auto simp: not_less)
  ultimately show ?thesis
    using finite_subset 
    by auto
qed

lemma finite_rel: "finite {(u, v). is_dec_det u  is_dec_det v  νb (mk_stationary_det u) > νb (mk_stationary_det v)}"
proof-
  have aux: "finite {(u, v). is_dec_det u  is_dec_det v}"
    by auto
  show ?thesis
    by (auto intro: finite_subset[OF _ aux])
qed

lemma eval_eq_imp_policy_eq: 
  assumes "policy_eval d = policy_eval (policy_step d)" "is_dec_det d"
  shows "d = policy_step d"
proof -
  have "policy_eval d s = policy_eval (policy_step d) s" for s
    using assms 
    by auto
  have "policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))"
    unfolding policy_eval_def
    using L_ν_fix 
    by (auto simp: assms(1)[symmetric, unfolded policy_eval_def])
  hence "policy_eval d = b (policy_eval d)"
    by (metis L_ν_fix policy_eval_def assms eval_policy_step_L)
  hence "L (mk_dec_det d) (policy_eval d) s = b (policy_eval d) s" for s
    using policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d)) assms(1) by auto
  hence "is_arg_max (λa. La a (νb (mk_stationary (mk_dec_det d))) s) (λa. a  A s) (d s)" for s
    unfolding L_eq_La_det
    unfolding policy_eval_def b.rep_eq ℒ_eq_SUP_det SUP_step_det_eq
    using assms(2) is_dec_det_def La_le
    by (auto simp del: νb.rep_eq simp: νb.rep_eq[symmetric] 
        intro!: SUP_is_arg_max boundedI[of _ "rM + l * norm _"] bounded_imp_bdd_above)
  thus ?thesis
    unfolding policy_eval_def policy_step_def policy_improvement_def
    by auto
qed


termination policy_iteration
proof (relation "{(u, v). u  DD  v  DD  νb (mk_stationary_det u) > νb (mk_stationary_det v)}")
  show "wf {(u, v). u  DD  v  DD  νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using finite_rel 
    by (auto intro!: finite_acyclic_wf acyclicI_order)
next
  fix d x
  assume h: "x = policy_step d" "¬ (d = x  ¬ is_dec_det d)"
  have "is_dec_det d  νb (mk_stationary_det d)  νb (mk_stationary_det (policy_step d))"
    using policy_eval_mon  
    by (simp add: policy_eval_def)
  hence "is_dec_det d  d  policy_step d 
    νb (mk_stationary_det d) < νb (mk_stationary_det (policy_step d))"
    using eval_eq_imp_policy_eq policy_eval_def
    by (force intro!: order.not_eq_order_implies_strict)
  thus "(x, d)  {(u, v). u  DD  v  DD  νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using is_dec_det_pi policy_step_def h 
    by auto
qed

lemma is_dec_det_pi': "d  DD  is_dec_det (policy_iteration d)"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: Let_def policy_step_def)

lemma pi_pi[simp]: "d  DD  policy_step (policy_iteration d) = policy_iteration d"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: policy_step_def Let_def)

lemma policy_iteration_correct: 
  "d  DD  νb (mk_stationary_det (policy_iteration d)) = νb_opt" 
  by (induction d rule: policy_iteration.induct)
    (fastforce intro!: policy_step_eq_imp_opt is_dec_det_pi' simp del: policy_iteration.simps)

lemma νb_zero_notin:  "s  states  νb p s = 0"
  using  ν_zero_notin unfolding νb.rep_eq by auto  

lemma r_decb_zero_notin:  "s  states  r_decb d s = 0"
  using reward_zero_outside
  by auto

lemma νb_eq_inv: "νb (mk_stationary d) = invL (id_blinfun - l *R 𝒫1 d) (r_decb d)"
  using ν_stationary_inv.

lemma νb_eq_bfun_if: "νb (mk_stationary d) = bfun_if (λi. i < states) (νb (mk_stationary d)) 0"
  using νb_zero_notin by (auto simp: bfun_if.rep_eq)

lemma νb_vec_aux: "((1m states) - l m (blinfun_to_mat states states (𝒫1 d))) *v bfun_to_vec states (νb (mk_stationary d)) = bfun_to_vec states (r_decb d)"
proof -
  let ?to_mat = "blinfun_to_mat states states"
  let ?to_vec = "bfun_to_vec states"

  have "((1m states) - l m (?to_mat (𝒫1 d))) *v ?to_vec (νb (mk_stationary d)) =
      ((1m states) - ?to_mat (l *R (𝒫1 d))) *v ?to_vec (νb (mk_stationary d))"
    using blinfun_to_mat_scale by fastforce
  also have " = (?to_mat id_blinfun - ?to_mat (l *R (𝒫1 d))) *v ?to_vec (νb (mk_stationary d))"
    using blinfun_to_mat_id by presburger
  also have " = ?to_mat (id_blinfun - l *R 𝒫1 d) *v ?to_vec (νb (mk_stationary d))"
    using blinfun_to_mat_sub by presburger
  also have " = ?to_vec ((id_blinfun - l *R 𝒫1 d) ((νb (mk_stationary d))))"
    unfolding blinfun_to_mat_mult using νb_eq_bfun_if by auto
  also have " = ?to_vec (r_decb d)"
    by (metis L_ν_fix_iff L_def blinfun.diff_left blinfun.scaleR_left diff_eq_eq id_blinfun.rep_eq)
  finally show ?thesis.
qed

lemma summable_geom_𝒫1: "summable (λk. ((l *R 𝒫1 d)^^k))"
  using summable_inv_Q norm_𝒫1
  by (metis add_diff_cancel_left' diff_add_cancel norm_𝒫1_l_less)

lemma summable_geom_𝒫1': "summable (λk. ((l *R 𝒫1 d)^^k) v)" for v
  using  summable_geom_𝒫1 tendsto_blinfun_apply
  unfolding summable_def sums_def
  by (fastforce simp: blinfun.sum_left)
  
lemma summable_geom_𝒫1'': "summable (λk. ((l *R 𝒫1 d)^^k) v s)" for v s
  using summable_geom_𝒫1' bfun_tendsto_apply_bfun
    unfolding summable_def sums_def
    by (fastforce simp: sum_apply_bfun)

lemma K_closed': "s<states  j  set_pmf (K (s, a))  j < states"
  by (meson K_closed atLeastLessThan_iff basic_trans_rules(31))

lemma 𝒫1_indep:
  assumes "(i. i < states  apply_bfun v i = apply_bfun v' i)" "j < states"
  shows "(l *R 𝒫1 d) v j = (l *R 𝒫1 d) v' j"
  using assms K_closed'[OF assms(2)]
  by (auto simp: blinfun.scaleR_left 𝒫1.rep_eq K_st_def intro!: integral_cong_AE AE_pmfI)

lemma invL_indep: 
  assumes "i. i < states  apply_bfun v i = apply_bfun v' i" "j < states" 
  shows "((invL (id_blinfun - l *R 𝒫1 d)) v) j = ((invL (id_blinfun - l *R 𝒫1 d)) v') j"
proof -
  have "((l *R 𝒫1 d) ^^ n) v j = ((l *R 𝒫1 d) ^^ n) v' j" for n
    using assms 𝒫1_indep by (induction n arbitrary: j v v') fastforce+
  thus ?thesis
    using summable_geom_𝒫1 summable_geom_𝒫1'
    by (auto simp: invL_inf_sum blinfun_apply_suminf[symmetric] suminf_apply_bfun)
qed

lemma vec_νb: "bfun_to_vec states (νb (mk_stationary d)) = 
    inverse_mat ((1m states) - l m (blinfun_to_mat states states (𝒫1 d))) *v (bfun_to_vec states (r_decb d))"
proof -
  have "bfun_to_vec states (νb (mk_stationary d)) = bfun_to_vec states (invL (id_blinfun - l *R 𝒫1 d) (r_decb d))"
    using νb_eq_inv by force
  also have " = bfun_to_vec states (invL (id_blinfun - l *R 𝒫1 d) (bfun_if (λi. i < states) (r_decb d) 0))"
    using r_decb_zero_notin
    by (subst bfun_if_eq) auto    
  also have " = blinfun_to_mat states states (invL (id_blinfun - l *R 𝒫1 d)) *v (bfun_to_vec states (r_decb d))"
    using blinfun_to_mat_mult..
  also have " = inverse_mat (blinfun_to_mat states states (id_blinfun - l *R 𝒫1 d)) *v (bfun_to_vec states (r_decb d))"
       using invL_indep 𝒫1_indep
       by (fastforce simp add: inverse_blinfun_to_mat  invertibleL_inf_sum blinfun.diff_left)+
  finally show ?thesis
    using blinfun_to_mat_id blinfun_to_mat_scale blinfun_to_mat_sub by presburger
qed


lemma invertible_νb_mat: "invertible_mat ((1m states) - l m (blinfun_to_mat states states (𝒫1 d)))"
proof -
  have "invertible_mat (blinfun_to_mat states states ((id_blinfun - l *R 𝒫1 d)))"
    using 𝒫1_indep invL_indep 
    by (fastforce simp: invertibleL_inf_sum blinfun.diff_left intro!: inverse_blinfun_to_mat(2))+
  thus ?thesis
    by (auto simp: blinfun_to_mat_id blinfun_to_mat_sub blinfun_to_mat_scale)
qed

lemma mat_cong:
  assumes "(i j. i < n  j < m  f i j = g i j)"
  shows "Matrix.mat n m (λ(i, j). f i j) = Matrix.mat n m (λ(i,j). g i j)"
  using assms by auto

lemma 𝒫1_mat: "(Matrix.mat states states (λ(s, s'). pmf (K (s, d s)) s')) = blinfun_to_mat states states (𝒫1 (mk_dec_det d))"
proof -
  have "pmf (K (s, d s)) s' = measure_pmf.expectation (K (s, d s)) (λk. if s' = k then 1 else 0)" 
      if "s < states" "s' < states" for s s'
    by (auto simp: integral_measure_pmf_real[of "{s'}"] split: if_splits)
  thus ?thesis
    by (auto simp: blinfun_to_mat_def 𝒫1.rep_eq K_st_def mk_dec_det_def bind_return_pmf)
qed

lemma vec_νb': "bfun_to_vec states (νb (mk_stationary_det d)) = 
  inverse_mat ((1m states) - l m (Matrix.mat states states (λ(s, s'). pmf (K (s, d s)) s'))) *v 
  (vec states (λi. r (i, d i)))"
  unfolding vec_νb using 𝒫1_mat by (auto simp: bfun_to_vec_def)

lemma vec_νb'': 
  assumes "s < states"
  shows "(νb (mk_stationary_det d)) s = 
  (inverse_mat ((1m states) - l m (Matrix.mat states states (λ(s, s'). pmf (K (s, d s)) s'))) *v 
  (vec states (λi. r (i, d i)))) $ s"
  using vec_νb' assms unfolding bfun_to_vec_def by (metis index_vec)

lemma invertible_νb_mat': 
  "invertible_mat (1m states - l m Matrix.mat states states (λ(s, y). pmf (K (s, d s)) y))"
  using invertible_νb_mat 𝒫1_mat by presburger

lemma dim_vec_νb: "dim_vec (inverse_mat ((1m states) - 
  l m (Matrix.mat states states (λ(s, s'). pmf (K (s, d s)) s'))) *v 
  (vec states (λi. r (i, d i)))) = states"
  by (simp add: inverse_mat_dims(2) invertible_νb_mat')

end
end