Theory Policy_Iteration

(* Author: Maximilian Schäffeler *)

theory Policy_Iteration
  imports "MDP-Rewards.MDP_reward"

begin

section ‹Policy Iteration›
text ‹
The Policy Iteration algorithms provides another way to find optimal policies under the expected 
total reward criterion.
It differs from Value Iteration in that it continuously improves an initial guess for an optimal 
decision rule. Its execution can be subdivided into two alternating steps: policy evaluation and 
policy improvement.

Policy evaluation means the calculation of the value of the current decision rule.

During the improvement phase, we choose the decision rule with the maximum value for L, 
while we prefer to keep the old action selection in case of ties.
›

context MDP_att_ℒ begin
definition "policy_eval d = νb (mk_stationary_det d)"
end

context MDP_act_disc
begin

definition "policy_improvement d v s = (
  if is_arg_max (λa. La a (apply_bfun v) s) (λa. a  A s) (d s) 
  then d s
  else arb_act (opt_acts v s))"

definition "policy_step d = policy_improvement d (policy_eval d)"

(* todo: move check is_dec_det outside the recursion *)
function policy_iteration :: "('s  'a)  ('s  'a)" where
  "policy_iteration d = (
  let d' = policy_step d in
  if d = d'  ¬is_dec_det d then d else policy_iteration d')"
  by auto

text ‹
The policy iteration algorithm as stated above does require that the supremum in @{const b} is
always attained.
›

text ‹
Each policy improvement returns a valid decision rule.
›
lemma is_dec_det_pi: "is_dec_det (policy_improvement d v)"
  unfolding policy_improvement_def is_dec_det_def is_arg_max_def
  by (auto simp: some_opt_acts_in_A)

lemma policy_improvement_is_dec_det: "d  DD  policy_improvement d v  DD"
  unfolding policy_improvement_def is_dec_det_def
  using some_opt_acts_in_A
  by auto

lemma policy_improvement_improving: 
  assumes "d  DD" 
  shows "ν_improving v (mk_dec_det (policy_improvement d v))"
proof -
  have "b v x = L (mk_dec_det (policy_improvement d v)) v x" for x
    using is_opt_act_some
    by (fastforce simp: b_eq_argmax_La L_eq_La_det is_opt_act_def policy_improvement_def arg_max_SUP)
  thus ?thesis
    using policy_improvement_is_dec_det assms by (auto simp: ν_improving_alt)
qed

lemma eval_policy_step_L:
 "is_dec_det d  L (mk_dec_det (policy_step d)) (policy_eval d) = b (policy_eval d)"
  by (auto simp: policy_step_def ν_improving_imp_ℒb[OF policy_improvement_improving])

text ‹ The sequence of policies generated by policy iteration has monotonically increasing 
discounted reward.›
lemma policy_eval_mon:
  assumes "is_dec_det d"
  shows "policy_eval d  policy_eval (policy_step d)"
proof -
  let ?d' = "mk_dec_det (policy_step d)"
  let ?dp = "mk_stationary_det d"
  let ?P = "t. l ^ t *R 𝒫1 ?d' ^^ t"

  have "L (mk_dec_det d) (policy_eval d)  L ?d' (policy_eval d)"
    using assms by (auto simp: L_le_ℒb eval_policy_step_L)
  hence "policy_eval d  L ?d' (policy_eval d)"
    using L_ν_fix policy_eval_def by auto
  hence "νb ?dp  r_decb ?d' + l *R 𝒫1 ?d' (νb ?dp)"
    unfolding policy_eval_def L_def by auto
  hence "(id_blinfun - l *R 𝒫1 ?d') (νb ?dp)  r_decb ?d'"
    by (simp add: blinfun.diff_left diff_le_eq scaleR_blinfun.rep_eq)
  hence "?P ((id_blinfun - l *R 𝒫1 ?d') (νb ?dp))  ?P (r_decb ?d')"
    using lemma_6_1_2_b by auto
  hence "νb ?dp  ?P (r_decb ?d')"
    using inv_norm_le'(2)[OF norm_𝒫1_l_less] by (auto simp: blincomp_scaleR_right)
  thus ?thesis
    by (auto simp: policy_eval_def ν_stationary)
qed

text ‹
If policy iteration terminates, i.e. @{term "d = policy_step d"}, then it does so with optimal value.
›
lemma policy_step_eq_imp_opt:
  assumes "is_dec_det d" "d = policy_step d" 
  shows "νb (mk_stationary_det d) = νb_opt"
  using L_ν_fix assms eval_policy_step_L[unfolded policy_eval_def] 
  by (fastforce intro: ℒ_fix_imp_opt)

end

text ‹We prove termination of policy iteration only if both the state and action sets are finite.›
locale MDP_PI_finite = MDP_act_disc arb_act A K r l 
  for
    A and
    K :: "'s ::countable × 'a ::countable  's pmf" and r l arb_act +
  assumes fin_states: "finite (UNIV :: 's set)" and fin_actions: "s. finite (A s)"
begin

text ‹If the state and action sets are both finite, 
  then so is the set of deterministic decision rules @{const "DD"}
lemma finite_DD[simp]: "finite DD"
proof -
  let ?set = "{d. x :: 's. (x  UNIV  d x  (s. A s))  (x  UNIV  d x = undefined)}"
  have "finite (s. A s)"
    using fin_actions fin_states by blast
  hence "finite ?set"
    using fin_states by (fastforce intro: finite_set_of_finite_funs)
  moreover have "DD  ?set"
    unfolding is_dec_det_def by auto
  ultimately show ?thesis
    using finite_subset by auto
qed

lemma finite_rel: "finite {(u, v). is_dec_det u  is_dec_det v  νb (mk_stationary_det u) > 
  νb (mk_stationary_det v)}"
proof-
  have aux: "finite {(u, v). is_dec_det u  is_dec_det v}"
    by auto
  show ?thesis
    by (auto intro: finite_subset[OF _ aux])
qed

text ‹
This auxiliary lemma shows that policy iteration terminates if no improvement to the value of 
the policy could be made, as then the policy remains unchanged.
›
lemma eval_eq_imp_policy_eq: 
  assumes "policy_eval d = policy_eval (policy_step d)" "is_dec_det d"
  shows "d = policy_step d"
proof -
  have "policy_eval d s = policy_eval (policy_step d) s" for s
    using assms by auto
  have "policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d))"
    unfolding policy_eval_def
    using L_ν_fix 
    by (auto simp: assms(1)[symmetric, unfolded policy_eval_def])
  hence "policy_eval d = b (policy_eval d)"
    by (metis L_ν_fix policy_eval_def assms eval_policy_step_L)
  hence "L (mk_dec_det d) (policy_eval d) s = b (policy_eval d) s" for s
    using policy_eval d = L (mk_dec_det d) (policy_eval (policy_step d)) assms(1) by auto
  hence "is_arg_max (λa. La a (νb (mk_stationary (mk_dec_det d))) s) (λa. a  A s) (d s)" for s
    unfolding L_eq_La_det
    unfolding policy_eval_def b.rep_eq ℒ_eq_SUP_det SUP_step_det_eq
    using assms(2) is_dec_det_def La_le
    by (auto intro!: SUP_is_arg_max boundedI bounded_imp_bdd_above)
  thus ?thesis
    unfolding policy_eval_def policy_step_def policy_improvement_def
    by auto
qed

text ‹
We are now ready to prove termination in the context of finite state-action spaces.
Intuitively, the algorithm terminates as there are only finitely many decision rules,
and in each recursive call the value of the decision rule increases.
›
termination policy_iteration
proof (relation "{(u, v). u  DD  v  DD  νb (mk_stationary_det u) > νb (mk_stationary_det v)}")
  show "wf {(u, v). u  DD  v  DD  νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using finite_rel by (auto intro!: finite_acyclic_wf acyclicI_order)
next
  fix d x
  assume h: "x = policy_step d" "¬ (d = x  ¬ is_dec_det d)"
  have "is_dec_det d  νb (mk_stationary_det d)  νb (mk_stationary_det (policy_step d))"
    using policy_eval_mon by (simp add: policy_eval_def)
  hence "is_dec_det d  d  policy_step d 
    νb (mk_stationary_det d) < νb (mk_stationary_det (policy_step d))"
    using eval_eq_imp_policy_eq policy_eval_def
    by (force intro!: order.not_eq_order_implies_strict)
  thus "(x, d)  {(u, v). u  DD  v  DD  νb (mk_stationary_det v) < νb (mk_stationary_det u)}"
    using is_dec_det_pi policy_step_def h by auto
qed

text ‹
The termination proof gives us access to the induction rule/simplification lemmas associated 
with the @{const policy_iteration} definition.
Thus we can prove that the algorithm finds an optimal policy.
›

lemma is_dec_det_pi': "d  DD  is_dec_det (policy_iteration d)"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: Let_def policy_step_def)

lemma pi_pi[simp]: "d  DD  policy_step (policy_iteration d) = policy_iteration d"
  using is_dec_det_pi
  by (induction d rule: policy_iteration.induct) (auto simp: policy_step_def Let_def)

lemma policy_iteration_correct: 
  "d  DD  νb (mk_stationary_det (policy_iteration d)) = νb_opt" 
  by (induction d rule: policy_iteration.induct)
    (fastforce intro!: policy_step_eq_imp_opt is_dec_det_pi' simp del: policy_iteration.simps)
end

context MDP_finite_type begin
text ‹
The following proofs concern code generation, i.e. how to represent @{const 𝒫1} as a matrix.
›

sublocale MDP_att_ℒ
  by (auto simp: A_ne finite_is_arg_max MDP_att_ℒ_def MDP_att_ℒ_axioms_def max_L_ex_def 
      has_arg_max_def MDP_reward_disc_axioms) 
 
definition "fun_to_matrix f = matrix (λv. (χ j. f (vec_nth v) j))"
definition "Ek_mat d = fun_to_matrix (λv. ((𝒫1 d) (Bfun v)))"
definition "nu_inv_mat d = fun_to_matrix ((λv. ((id_blinfun - l *R 𝒫1 d) (Bfun v))))"
definition "nu_mat d = fun_to_matrix (λv. ((i. (l *R 𝒫1 d) ^^ i) (Bfun v)))"

lemma apply_nu_inv_mat: 
  "(id_blinfun - l *R 𝒫1 d) v = Bfun (λi. ((nu_inv_mat d) *v (vec_lambda v)) $ i)"
proof -
  have eq_onpI: "P x  eq_onp P x x" for P x
    by(simp add: eq_onp_def)

  have "Real_Vector_Spaces.linear (λv. vec_lambda (((id_blinfun - l *R 𝒫1 d) (bfun.Bfun (($) v)))))"
    by (auto simp del: real_scaleR_def intro: linearI
        simp: scaleR_vec_def eq_onpI plus_vec_def vec_lambda_inverse plus_bfun.abs_eq[symmetric] 
        scaleR_bfun.abs_eq[symmetric] blinfun.scaleR_right blinfun.add_right)
  thus ?thesis
    unfolding Ek_mat_def fun_to_matrix_def nu_inv_mat_def
    by (auto simp: apply_bfun_inverse vec_lambda_inverse)
qed

lemma bounded_linear_vec_lambda: "bounded_linear (λx. vec_lambda (x :: 's b real))"
proof (intro bounded_linear_intro)
  fix x :: "'s b real"
  have "sqrt ( i  UNIV . (apply_bfun x i)2)  ( i  UNIV . ¦(apply_bfun x i)¦)"
    using L2_set_le_sum_abs 
    unfolding L2_set_def
    by auto
  also have "( i  UNIV . ¦(apply_bfun x i)¦)  (card (UNIV :: 's set) * (xa. ¦apply_bfun x xa¦))"
    by (auto intro!: cSup_upper sum_bounded_above)
  finally show "norm (vec_lambda (apply_bfun x))  norm x * CARD('s)"
    unfolding norm_vec_def norm_bfun_def dist_bfun_def L2_set_def
    by (auto simp add: mult.commute)
qed (auto simp: plus_vec_def scaleR_vec_def)

lemma bounded_linear_vec_lambda_blinfun: 
  fixes f :: "('s b real) L ('s b real)"
  shows "bounded_linear (λv. vec_lambda (apply_bfun (blinfun_apply f (bfun.Bfun (($) v)))))" 
  using blinfun.bounded_linear_right
  by (fastforce intro: bounded_linear_compose[OF bounded_linear_vec_lambda] 
      bounded_linear_bfun_nth bounded_linear_compose[of f])

lemma invertible_nu_inv_max: "invertible (nu_inv_mat d)"
  unfolding nu_inv_mat_def fun_to_matrix_def
  by (auto simp: matrix_invertible inv_norm_le' vec_lambda_inverse apply_bfun_inverse 
      bounded_linear.linear[OF bounded_linear_vec_lambda_blinfun]
      intro!: exI[of _ "λv. (χ j. (λv. (i. (l *R 𝒫1 d) ^^ i) (Bfun v)) (vec_nth v) j)"])
end
      
locale MDP_ord = MDP_finite_type A K r l
  for A and                
    K :: "'s :: {finite, wellorder} × 'a :: {finite, wellorder}  's pmf"
    and r l
begin

lemma ℒ_fin_eq_det: " v s = (a  A s. La a v s)"
  by (simp add: SUP_step_det_eq ℒ_eq_SUP_det)

lemma b_fin_eq_det: "b v s = (a  A s. La a v s)"
  by (simp add: SUP_step_det_eq b.rep_eq ℒ_eq_SUP_det)

sublocale MDP_PI_finite A K r l "λX. Least (λx. x  X)"
  by unfold_locales (auto intro: LeastI)

end

end