Theory Value_Iteration

(* Author: Maximilian Schäffeler *)

theory Value_Iteration
  imports "MDP-Rewards.MDP_reward"
begin                    

context MDP_att_ℒ
begin

section ‹Value Iteration›
text ‹
In the previous sections we derived that repeated application of @{const "b"} to any bounded 
function from states to the reals converges to the optimal value of the MDP @{const "νb_opt"}.

We can turn this procedure into an algorithm that computes not only an approximation of 
@{const "νb_opt"} but also a policy that is arbitrarily close to optimal.

Most of the proofs rely on the assumption that the supremum in @{const "b"} can always be attained.
›

text ‹
The following lemma shows that the relation we use to prove termination of the value iteration 
algorithm decreases in each step.
In essence, the distance of the estimate to the optimal value decreases by a factor of at 
least @{term l} per iteration.›

abbreviation "term_measure  (λ(eps, v). LEAST n. (2 * l * dist ((b^^(Suc n)) v) ((b^^n) v) < eps * (1-l)))"

lemma Least_Suc_less:
  assumes "n. P n" "¬P 0"
  shows "Least (λn. P (Suc n)) < Least P"
  using assms by (auto simp: Least_Suc)

function value_iteration :: "real  ('s b real)  ('s b real)" where
  "value_iteration eps v =
  (if 2 * l * dist v (b v) < eps * (1-l)  eps  0 then b v else value_iteration eps (b v))"
  by auto
termination
proof (relation "Wellfounded.measure term_measure")
  fix eps v
  assume h: "¬ (2 * l * dist v (b v) < eps * (1 - l)  eps  0)"
  show "((eps, b v), eps, v)  Wellfounded.measure term_measure"
  proof -
    have "(λn. dist ((b ^^ Suc n) v) ((b ^^ n) v))  0" 
      using dist_ℒb_tendsto
      by (auto simp: dist_commute)
    hence *: "n. dist ((b ^^ Suc n) v) ((b ^^ n) v) < eps" if "eps > 0" for eps
      unfolding LIMSEQ_def using that by auto
    have **: "0 < l * 2" if "0  l"
      using zero_le_disc that by linarith
    hence "(LEAST n. (2 * l) * dist ((b ^^ (Suc (Suc n))) v) ((b ^^ (Suc n)) v) < eps * (1 - l)) < 
      (LEAST n. (2 * l) * dist ((b ^^ Suc n) v) ((b ^^ n) v) <  eps * (1 - l))" if "0  l"
        using h *[of "eps * (1-l) / (2 * l)"] that
        by (fastforce simp: ** algebra_simps dist_commute pos_less_divide_eq intro!: Least_Suc_less)
      thus ?thesis
      using h by (cases "l = 0") (auto simp: funpow_swap1)
  qed
qed auto

text ‹
The distance between an estimate for the value and the optimal value can be bounded with respect to 
the distance between the estimate and the result of applying it to @{const b}
(* 2nd to last inequality in the proof of Thm 6.3.1 *)
lemma contraction_ℒ_dist: "(1 - l) * dist v νb_opt  dist v (b v)"
  using contraction_dist contraction_ℒ disc_lt_one zero_le_disc by fastforce

lemma dist_ℒb_opt_eps:
  assumes "eps > 0" "2 * l * dist v (b v) < eps * (1-l)"
  shows "2 * dist (b v) νb_opt < eps"
proof -
  have "2 * l * dist v νb_opt * (1 - l)  2 * l * dist v (b v)"
    using contraction_ℒ_dist by (simp add: mult_left_mono mult.commute)
  hence "2 * l * dist v νb_opt * (1 - l) < eps * (1-l)"
    using assms(2) by linarith
  hence "2 * l * dist v νb_opt < eps"
    by force
  thus "2 * dist (b v) νb_opt < eps"
    using contraction_ℒ[of v νb_opt] by auto
qed

lemma dist_ℒb_lt_dist_opt: "dist v (b v)  2 * dist v νb_opt"
proof -
  have le1: "dist v (b v)  dist v νb_opt + dist (b v) νb_opt"
    by (simp add: dist_triangle dist_commute)
  have le2: "dist (b v) νb_opt  l * dist v νb_opt"
    using b_opt contraction_ℒ by metis
  show ?thesis
    using mult_right_mono[of l 1] disc_lt_one 
    by (fastforce intro!: order.trans[OF le2] order.trans[OF le1])
qed

text ‹
The estimates above allow to give a bound on the error of @{const value_iteration}.
›
declare value_iteration.simps[simp del]

lemma value_iteration_error: 
  assumes "eps > 0"
  shows "2 * dist (value_iteration eps v) νb_opt < eps"
  using assms dist_ℒb_opt_eps value_iteration.simps
  by (induction eps v rule: value_iteration.induct) auto

text ‹
After the value iteration terminates, one can easily obtain a stationary deterministic 
epsilon-optimal policy.

Such a policy does not exist in general, attainment of the supremum in @{const b} is required.
›
definition "find_policy (v :: 's b real) s = arg_max_on (λa. La a v s) (A s)"

definition "vi_policy eps v = find_policy (value_iteration eps v)"

abbreviation "vi u n  (b ^^ n) u"

lemma b_iter_mono:
  assumes "u  v" shows "vi u n  vi v n"
  using assms b_mono by (induction n) auto

lemma 
  assumes "vi v (Suc n)  vi v n" 
  shows "vi v (Suc n + m)  vi v (n + m)"
proof -
  have "vi v (Suc n + m) = vi (vi v (Suc n)) m"
    by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
  also have "...  vi (vi v n) m"
    using b_iter_mono[OF assms] by auto
  also have "... = vi v (n + m)"
    by (simp add: add.commute funpow_add)
  finally show ?thesis .
qed


lemma 
  assumes "vi v n  vi v (Suc n)" 
  shows "vi v (n + m)  vi v (Suc n + m)"
proof -
  have "vi v (n + m)  vi (vi v n) m"
    by (simp add: Groups.add_ac(2) funpow_add funpow_swap1)
  also have "  vi v (Suc n + m)"
    using b_iter_mono[OF assms] by (auto simp only: add.commute funpow_add o_apply)
  finally show ?thesis .
qed

lemma "(λn. dist (vi v (Suc n)) (vi v n))  0"
  using dist_ℒb_tendsto[of v] by (auto simp: dist_commute)

end

context MDP_att_ℒ 
begin

lemma is_arg_max_find_policy: "is_arg_max (λd. La d (apply_bfun v) s) (λd. d  A s)  (find_policy v s)"
  using Sup_att
  by (simp add: find_policy_def arg_max_on_def arg_max_def someI_ex max_L_ex_def has_arg_max_def)

text ‹
The error of the resulting policy is bounded by the distance from its value to the value computed 
by the value iteration plus the error in the value iteration itself.
We show that both are less than @{term "eps / 2"} when the algorithm terminates.
›
lemma find_policy_dist_ℒb:
  assumes "eps > 0" "2 * l * dist v (b v) < eps * (1-l)"
  shows "2 * dist (νb (mk_stationary_det (find_policy (b v)))) (b v)  eps"
proof -
  let ?d = "mk_dec_det (find_policy (b v))"
  let ?p = "mk_stationary ?d"
  have L_eq_ℒb: "L (mk_dec_det (find_policy v)) v = b v" for v
    by (auto simp: L_eq_La_det b_eq_argmax_La[OF is_arg_max_find_policy])
  have "dist (νb ?p) (b v) = dist (L ?d (νb ?p)) (b v)"
    using L_ν_fix by force
  also have "  dist (L ?d (νb ?p)) (b (b v)) + dist (b (b v)) (b v)"
    using dist_triangle by blast
  also have " = dist (L ?d (νb ?p)) (L ?d (b v)) + dist (b (b v)) (b v)"
    by (auto simp: L_eq_ℒb)
  also have "  l *  dist (νb ?p) (b v) + l * dist (b v) v"
    using contraction_ℒ contraction_L by (fastforce intro!: add_mono)
  finally have aux: "dist (νb ?p) (b v)  l * dist (νb ?p) (b v) + l * dist (b v) v" .
  hence "dist (νb ?p) (b v) * (1 - l)  l * dist (b v) v"
    by (auto simp: algebra_simps)
  hence *: "2 * dist (νb ?p) (b v) * (1 - l)  2 * l * dist (b v) v"
    using zero_le_disc mult_left_mono by auto
  hence "2 * dist (νb ?p) (b v) * (1 - l)  eps * (1 - l)"
     using assms by (fastforce simp: dist_commute intro!: order.trans[OF *])
  thus "2 * dist (νb ?p) (b v)  eps"
    by auto
qed  

lemma find_policy_error_bound:
  assumes "eps > 0" "2 * l * dist v (b v) < eps * (1-l)"
  shows "dist (νb (mk_stationary_det (find_policy (b v)))) νb_opt < eps"
proof -
  let ?p = "mk_stationary_det (find_policy (b v))"
  have "dist (νb ?p) νb_opt  dist (νb ?p) (b v) + dist (b v) νb_opt"
    using dist_triangle by blast  
  thus ?thesis
    using find_policy_dist_ℒb[OF assms] dist_ℒb_opt_eps[OF assms] by simp
qed

lemma vi_policy_opt:
  assumes "0 < eps"
  shows "dist (νb (mk_stationary_det (vi_policy eps v))) νb_opt < eps"
  unfolding vi_policy_def 
  using assms
proof (induction eps v rule: value_iteration.induct)
  case (1 v)
  then show ?case
    using find_policy_error_bound by (subst value_iteration.simps) auto
qed

lemma lemma_6_3_1_d:
  assumes "eps > 0" "2 * l * dist (vi v (Suc n)) (vi v n) < eps * (1-l)"
  shows "2 * dist (vi v (Suc n)) νb_opt < eps"
  using dist_ℒb_opt_eps assms by (simp add: dist_commute)
end

context MDP_act_disc begin
                 
definition "find_policy' (v :: 's b real) s = arb_act (opt_acts v s)"

definition "vi_policy' eps v = find_policy' (value_iteration eps v)"

lemma is_arg_max_find_policy': "is_arg_max (λd. La d (apply_bfun v) s) (λd. d  A s) (find_policy' v s)"
  using is_opt_act_some by (auto simp: find_policy'_def is_opt_act_def)

lemma find_policy'_dist_ℒb:
  assumes "eps > 0" "2 * l * dist v (b v) < eps * (1-l)"
  shows "2 * dist (νb (mk_stationary_det (find_policy' (b v)))) (b v)  eps"
proof -
  let ?d = "mk_dec_det (find_policy' (b v))"
  let ?p = "mk_stationary ?d"
  have L_eq_ℒb: "L (mk_dec_det (find_policy' v)) v = b v" for v
    by (auto simp: L_eq_La_det b_eq_argmax_La[OF is_arg_max_find_policy'])
  have "dist (νb ?p) (b v) = dist (L ?d (νb ?p)) (b v)"
    using L_ν_fix by force
  also have "  dist (L ?d (νb ?p)) (b (b v)) + dist (b (b v)) (b v)"
    using dist_triangle by blast
  also have " = dist (L ?d (νb ?p)) (L ?d (b v)) + dist (b (b v)) (b v)"
    by (auto simp: L_eq_ℒb)
  also have "  l *  dist (νb ?p) (b v) + l * dist (b v) v"
    using contraction_ℒ contraction_L by (fastforce intro!: add_mono)
  finally have aux: "dist (νb ?p) (b v)  l * dist (νb ?p) (b v) + l * dist (b v) v" .
  hence "dist (νb ?p) (b v) * (1 - l)  l * dist (b v) v"
    by (auto simp: algebra_simps)
  hence *: "2 * dist (νb ?p) (b v) * (1 - l)  2 * l * dist (b v) v"
    using zero_le_disc mult_left_mono by auto
  hence "2 * dist (νb ?p) (b v) * (1 - l)  eps * (1 - l)"
     using assms by (fastforce simp: dist_commute intro!: order.trans[OF *])
  thus "2 * dist (νb ?p) (b v)  eps"
    by auto
qed

lemma find_policy'_error_bound:
  assumes "eps > 0" "2 * l * dist v (b v) < eps * (1-l)"
  shows "dist (νb (mk_stationary_det (find_policy' (b v)))) νb_opt < eps"
proof -
  let ?p = "mk_stationary_det (find_policy' (b v))"
  have "dist (νb ?p) νb_opt  dist (νb ?p) (b v) + dist (b v) νb_opt"
    using dist_triangle by blast  
  thus ?thesis
    using find_policy'_dist_ℒb[OF assms] dist_ℒb_opt_eps[OF assms] by simp
qed

lemma vi_policy'_opt:
  assumes "eps > 0" "l > 0"
  shows "dist (νb (mk_stationary_det (vi_policy' eps v))) νb_opt < eps"
  unfolding vi_policy'_def 
  using assms
proof (induction eps v rule: value_iteration.induct)
  case (1 eps v)
  then show ?case
    using find_policy'_error_bound by (auto simp: value_iteration.simps[of _ v])
qed

end
end