Theory Backward_Induction

theory Backward_Induction
imports "MDP-Rewards.MDP_reward"
begin

locale MDP_reward_fin = discrete_MDP A K
  for  
    A and 
    K :: "'s ::countable × 'a ::countable  's pmf" +
  fixes
    r :: "('s × 'a)  real" and
    r_fin :: "'s  real" and
    N :: "nat"    
  assumes
    r_fin_bounded: "bounded (range r_fin)" and
    r_bounded_fin: "bounded (range r)"
begin

interpretation MDP_reward A K r 1
  rewrites "1 * (x::real) = x" and "x.(1::real)^(x::nat)=1"
  using r_bounded_fin
  by unfold_locales (auto simp: algebra_simps)

definition "νN p s = (t. (i<N. r (t !! i)) +  (r_fin (fst(t !! N))) 𝒯 p s)"

lemma measurable_r_fin_nth [measurable]: "(λt. r_fin ((t !! i)))  borel_measurable S"
  by measurable  

lemma integrable_r_fin_nth [simp]: "integrable (𝒯 p s) (λt. r_fin (fst(t !! i)))"
  using bounded_range_subset[OF r_fin_bounded]
  by (auto simp: range_composition[of r_fin])

lemma νN_eq: "νN p s = (i < N. measure_pmf.expectation (Pn' p s i) r) + measure_pmf.expectation (Xn' p s N) r_fin"
proof -
  have "νN p s = (t. (i<N. r (t !! i)) 𝒯 p s) + (t. (r_fin (fst(t !! N))) 𝒯 p s)"
    unfolding νN_def
    by (auto intro: Bochner_Integration.integral_add)
  moreover have " (t. (i<N. r (t !! i)) 𝒯 p s) = (i < N. measure_pmf.expectation (Pn' p s i) r)"
    using ν_fin_Suc ν_fin_eq_Pn by force
  moreover have "(t. (r_fin (fst(t !! N))) 𝒯 p s) = measure_pmf.expectation (Xn' p s N) r_fin"
    by (auto simp: Xn'_Pn' Pn'_eq_𝒯 integral_distr)
  ultimately show ?thesis by auto
qed

function νN_eval where
  "νN_eval p h s = (
    if length h = N then r_fin s else 
    if length h > N then 0 else
      measure_pmf.expectation (p h s) (λa. r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s'))) "
  by auto

termination
  by (relation "Wellfounded.measure (λ(_,h,s). N - length h)") auto

lemmas abs_disc_eq[simp del]
lemmas νN_eval.simps[simp del]

lemma pmf_bounded_integrable: "bounded (range (f::_  real))  integrable (measure_pmf p) f"
  using bounded_norm_le_SUP_norm[of f]
  by (intro measure_pmf.integrable_const_bound[of _ "x. ¦f x¦"]) auto

lemma abs_boundedD[dest]: "(x. ¦f x¦  (c::real))  bounded (range f)"
  using bounded_real by auto

lemma abs_integral_le[intro]: "(x. ¦f x¦  (c::real))  abs (measure_pmf.expectation p f)  c"
  by (fastforce intro!: pmf_bounded_integrable abs_boundedD measure_pmf.integral_le_const order.trans[OF integral_abs_bound])

lemma abs_νN_eval_le: "¦νN_eval p h s¦  (N - length h) * rM + (s. ¦r_fin s¦)"
proof (induction "(N - length h)" arbitrary: h s)
  case 0
  then show ?case 
    using r_fin_bounded
    by (auto simp: νN_eval.simps intro!: bounded_imp_bdd_above cSUP_upper2)
next
  case (Suc x)
  have "N > length h"
    using Suc(2) by linarith
  hence Suc_le: "Suc (length h)  N"
    by auto
  have *: "¦νN_eval p (h @ [(s, a)]) s'¦  real (N - length h - 1) * rM + (s. ¦r_fin s¦)" for a s'
    using Suc.hyps(1)[of "h @[(s,a)]"] Suc.hyps(2)
    by (auto simp: of_nat_diff[OF Suc_le] algebra_simps)
  hence **: "¦measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)])))¦
     real (N - length h - 1) * rM + (s. ¦r_fin s¦)"
    using Suc by auto
  have "¦measure_pmf.expectation (p h s) (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)])))¦
       ¦measure_pmf.expectation (p h s) (λa. r (s, a))  + measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)])))¦"
    using abs_r_le_rM
    by (subst Bochner_Integration.integral_add) (auto intro!: abs_boundedD * pmf_bounded_integrable)     
  also have "  ¦measure_pmf.expectation (p h s) (λa. r (s, a))¦  + ¦ measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)])))¦"
    by auto
  also have "  rM  + ¦ measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)])))¦"
    using abs_r_le_rM by auto
  also have "  rM  + (N - length h - 1) * rM + (s. ¦r_fin s¦)"
    using "**" by force
  also have "  (N - length h) * rM + (s. ¦r_fin s¦)"
    using Suc Suc_le by (auto simp: of_nat_diff algebra_simps)
  finally show ?case
    using νN_eval.simps length h < N by force
qed

lemma abs_νN_eval_le': "¦νN_eval p h s¦  N * rM + (s. ¦r_fin s¦)"
  by (simp add: mult_left_mono rM_nonneg algebra_simps order.trans[OF abs_νN_eval_le[of p h s]])

lemma measure_pmf_expectation_bind: 
  assumes "bounded (range f)" 
  shows "measure_pmf.expectation (p  q) (f::_  real) = measure_pmf.expectation p (λx. measure_pmf.expectation (q x) f)"
  unfolding measure_pmf_bind
  using assms measure_pmf_in_subprob_space
  by (fastforce intro!: Giry_Monad.integral_bind[of _ "count_space UNIV" "x. ¦f x¦"] bounded_imp_bdd_above cSUP_upper)+

lemma Pn'_shift: "bounded (range (f :: _  real))  measure_pmf.expectation (p h s)
     (λa. measure_pmf.expectation (K (s, a))
            (λs'. measure_pmf.expectation (Pn' (λh'. p ((h @ (s, a)# h'))) s' n) f))
    = measure_pmf.expectation (Pn' (λh'. p (h @ h')) s (Suc n)) f"
  unfolding PSuc' π_Suc_def K0'_def
  by (auto simp: measure_pmf_expectation_bind)

lemma bounded_r_snd': "bounded ((λa. r (s, a)) ` X)"
  using r_bounded' image_image
  by metis

lemma bounded_r_snd: "bounded (range (λa. r (s, a)))"
  using bounded_r_snd'.

lemma νN_eval_eq: "length h  N  νN_eval p h s =
  (i {length h..< N}.
  measure_pmf.expectation (Pn' (λh'. p (h@h')) s (i - length h)) r) + measure_pmf.expectation (Xn' (λh'. p (h@h')) s (N - length h)) r_fin"
proof (induction "N - length h" arbitrary: h s)
  case 0
  then show ?case
    using νN_eval.simps by auto
next
  case (Suc x)
  hence "length h < N"
    by auto
  hence
    "νN_eval p h s =
      measure_pmf.expectation (p h s) (λa. r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s'))"
    by (auto simp: νN_eval.simps[of p h] split: if_splits)
  also have " = 
      measure_pmf.expectation (p h s) (λa. r (s,a)) + measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s'))"
    using abs_νN_eval_le' bounded_r_snd
    by (fastforce simp: bounded_real intro!: Bochner_Integration.integral_add pmf_bounded_integrable abs_integral_le)
  also have " =
    (i = length h..<N. measure_pmf.expectation (Pn' (λh'. p (h @ h')) s (i - length h)) r) + measure_pmf.expectation (Xn' (λh'. p (h @ h')) s (N - length h)) r_fin"
  proof -
    have "measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s')) = 
    measure_pmf.expectation (p h s)
     (λa. measure_pmf.expectation (K (s, a))
            (λs'. (i = length (h @ [(s, a)])..<N. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length (h @ [(s, a)]))) r) +
                   measure_pmf.expectation (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length (h @ [(s, a)]))) r_fin))"
      using Suc length h < N
      by auto
    also have " = 
    measure_pmf.expectation (p h s)
     (λa. measure_pmf.expectation (K (s, a))
            (λs'. (i = length h + 1..<N. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length h - 1)) r) +
                   measure_pmf.expectation (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length h - 1)) r_fin))"
      using Suc length h < N K0'_def
      by auto
    also have " = 
    measure_pmf.expectation (p h s)
     (λa. measure_pmf.expectation (K (s, a))
            (λs'. (i = length h + 1..<N. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length h - 1)) r)) +
          measure_pmf.expectation (K (s, a))
            (λs'. measure_pmf.expectation (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length h - 1)) r_fin))"
      using abs_exp_r_le r_fin_bounded
      by (fastforce intro!: Bochner_Integration.integral_cong[OF refl] Bochner_Integration.integral_add 
          pmf_bounded_integrable Bochner_Integration.integrable_sum simp: bounded_real)+
    also have " = 
    measure_pmf.expectation (p h s)
     (λa. measure_pmf.expectation (K (s, a))
            (λs'. (i = length h + 1..<N. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length h - 1)) r))) +
      measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (λs'. measure_pmf.expectation 
        (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length h - 1)) r_fin))"
      using abs_r_le_rM r_fin_bounded
      by (fastforce intro!:
         Bochner_Integration.integral_add Bochner_Integration.integrable_sum pmf_bounded_integrable
         abs_integral_le order.trans[OF sum_abs] order.trans[OF sum_bounded_above[of _ _ "rM"]] simp: bounded_real)
    also have " = measure_pmf.expectation (p h s) (λa. (i = length h + 1..<N. measure_pmf.expectation (K (s, a))
            (λs'. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length h - 1)) r))) +
      measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (λs'. measure_pmf.expectation 
        (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length h - 1)) r_fin))"
      using abs_r_le_rM
      by (subst Bochner_Integration.integral_sum) (auto intro!: pmf_bounded_integrable boundedI[of _ "rM"] abs_integral_le)
    also have " = (i = length h + 1..<N. measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a))
            (λs'. measure_pmf.expectation (Pn' (λh'. p ((h @ [(s, a)]) @ h')) s' (i - length h - 1)) r))) +
      measure_pmf.expectation (p h s) (λa. measure_pmf.expectation (K (s, a)) (λs'. measure_pmf.expectation 
        (Xn' (λh'. p ((h @ [(s, a)]) @ h')) s' (N - length h - 1)) r_fin))"
      using abs_r_le_rM
      by (subst Bochner_Integration.integral_sum) (auto intro!: pmf_bounded_integrable boundedI[of _ "rM"] abs_integral_le)
    also have " =
      (i = length h + 1..<N. (measure_pmf.expectation (Pn' (λh'. p (h @ h')) s (i - length h))) r) + 
      measure_pmf.expectation (Xn' (λh'. p (h @ h')) s (N - length h)) r_fin"
      using r_bounded r_fin_bounded length h < N
      by (auto simp add: Pn'_shift Xn'_Pn' Suc_diff_Suc range_composition)
    finally show ?thesis
      unfolding sum.atLeast_Suc_lessThan[OF length h < N] r_dec_eq_r_K0
      by auto
  qed
  finally show ?case .
qed

lemma νN_eval_correct: "νN_eval p [] s = νN p s"
  using lessThan_atLeast0
  by (auto simp: νN_eq νN_eval_eq)

lift_definition νNb :: "('s, 'a) pol  's b real" is νN
  using r_fin_bounded
  by (intro bfun_normI[of _ "rM * N + (x. ¦r_fin x¦)"]) 
    (auto simp add: νN_eq rM_def r_bounded bounded_abs_range intro!: add_mono 
      order.trans[OF integral_abs_bound] pmf_bounded_integrable  lemma_4_3_1
      order.trans[OF sum_abs] order.trans[OF abs_triangle_ineq] order.trans[OF sum_bounded_above[of _ _ rM]])
 
definition "νN_opt s = (p  ΠHR. νN p s)"
definition "νN_eval_opt h s = (p  ΠHR. νN_eval p h s)"

function νN_opt_eqn where
  "νN_opt_eqn h s = (
    if length h = N then r_fin s else 
    if length h > N then 0 else
      a  A s. (r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_opt_eqn (h@[(s,a)]) s'))) "
  by auto

termination
  by (relation "Wellfounded.measure (λ(h,s). N - length h)") auto

lemmas νN_opt_eqn.simps[simp del]

lemma abs_νN_opt_eqn_le: "¦νN_opt_eqn h s¦  (N - length h) * rM + (s. ¦r_fin s¦)"
proof (induction "(N - length h)" arbitrary: h s)
  case 0
  then show ?case 
    using r_fin_bounded
    by (auto simp: νN_opt_eqn.simps intro!: bounded_imp_bdd_above cSUP_upper2)
next
  case (Suc x)
  have "N > length h"
    using Suc(2) by linarith
  have *: "¦νN_opt_eqn (h @ [(s, a)]) s'¦  real (N - length h - 1) * rM + (s. ¦r_fin s¦)" for a s'
    using Suc(1)[of "(h @ [(s, a)])"] Suc(2)
    by auto
  hence "¦measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))¦
     real (N - length h - 1) * rM + (s. ¦r_fin s¦)" for a
    using Suc by auto
  hence **: "rM + ¦measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))¦
     real (N - length h) * rM + (s. ¦r_fin s¦)" for a
    using Suc
    by (auto simp: of_nat_diff algebra_simps)
  hence *: "¦r (s, a)¦ + ¦measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))¦  real (N - length h) * rM + (s. ¦r_fin s¦)" for a
    using abs_r_le_rM
    by (meson add_le_cancel_right  order.trans)
  hence *: "¦r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))¦  real (N - length h) * rM + (s. ¦r_fin s¦)" for a
    using order.trans[OF abs_triangle_ineq] by auto
  have "¦νN_opt_eqn h s¦ = ¦aA s. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))¦"
    unfolding νN_opt_eqn.simps[of h] using length h < N
    by auto
  also have "  ¦aA s. measure_pmf.expectation (return_pmf a) (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)])))¦"
    by auto
  also have "  (aA s. ¦ measure_pmf.expectation (return_pmf a) (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)])))¦)"
    using length h < N A_ne  *
    by (auto intro!: boundedI abs_cSUP_le)
  also have "  real (N - length h) * rM + (s. ¦r_fin s¦)"
    using * A_ne
    by (auto intro!: cSUP_least)
  finally show ?case.
qed

lemma abs_νN_opt_eqn_le': "¦νN_opt_eqn h s¦  N * rM + (s. ¦r_fin s¦)"
  by (simp add: mult_left_mono rM_nonneg algebra_simps order.trans[OF abs_νN_opt_eqn_le[of h s]])

lemma abs_νN_eval_opt_le': "¦νN_eval_opt h s¦  N * rM + (s. ¦r_fin s¦)"
  unfolding νN_eval_opt_def
  using policies_ne abs_νN_eval_le'
  by (auto intro!: order.trans[OF abs_cSUP_le] boundedI cSUP_least)

lemma exp_νN_eval_opt_le: "¦measure_pmf.expectation (K (s, a)) (νN_eval_opt h)¦  N * rM + (s. ¦r_fin s¦)"
  by (metis abs_νN_eval_opt_le' abs_integral_le)

lemma bounded_exp_νN_eval_opt: "(bounded ((λa. measure_pmf.expectation (K (s, a)) (νN_eval_opt (h a))) ` X))"
  using exp_νN_eval_opt_le
  by (auto intro!: boundedI)

lemma bounded_r_exp_νN_eval_opt: "(bounded ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval_opt (h a))) ` X))"
  using bounded_exp_νN_eval_opt r_bounded abs_r_le_rM
  by (intro bounded_plus_comp) (auto intro!:  boundedI)

lemma integrable_r_exp_νN_eval_opt: "(integrable (measure_pmf q) ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval_opt (h a)))))"
  using bounded_r_exp_νN_eval_opt pmf_bounded_integrable by blast


lemma exp_νN_eval_le: "¦measure_pmf.expectation (K (s, a)) (νN_eval p h)¦  N * rM + (s. ¦r_fin s¦)"
  by (metis abs_νN_eval_le' abs_integral_le)

lemma bounded_exp_νN_eval: "(bounded ((λa. measure_pmf.expectation (K (s, a)) (νN_eval p (h a))) ` X))"
  using exp_νN_eval_le
  by (auto intro!: boundedI)

lemma bounded_r_exp_νN_eval: "(bounded ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval p (h a))) ` X))"
  using bounded_exp_νN_eval r_bounded abs_r_le_rM
  by (intro bounded_plus_comp) (auto intro!:  boundedI)

lemma integrable_r_exp_νN_eval: "(integrable (measure_pmf q) ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval p (h a)))))"
  using bounded_r_exp_νN_eval pmf_bounded_integrable by blast

lemma exp_νN_opt_eqn_le: "¦measure_pmf.expectation (K (s, a)) (νN_opt_eqn h)¦  N * rM + (s. ¦r_fin s¦)"
  by (metis abs_νN_opt_eqn_le' abs_integral_le)

lemma bounded_exp_νN_opt_eqn: "(bounded ((λa. measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h a))) ` X))"
  using exp_νN_opt_eqn_le
  by (auto intro!: boundedI)

lemma bounded_r_exp_νN_opt_eqn: "(bounded ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h a))) ` X))"
  using bounded_exp_νN_opt_eqn r_bounded abs_r_le_rM
  by (intro bounded_plus_comp) (auto intro!:  boundedI)

lemma integrable_r_exp_νN_opt_eqn: "(integrable (measure_pmf q) ((λa. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h a)))))"
  using bounded_r_exp_νN_opt_eqn pmf_bounded_integrable by blast

lemma νN_eval_le_opt_eqn: "p  ΠHR  νN_eval p h s  νN_opt_eqn h s"
proof (induction p h s rule: νN_eval.induct)
  case (1 p h s)
  have "νN_eval p (h @ [(s, a)]) s'  νN_opt_eqn (h @[(s,a)]) s'" if "length h < N" for a s'
    using that 1 by fastforce
  hence *: "r (s, a) + measure_pmf.expectation (K (s, a)) (νN_eval p (h @ [(s, a)]))  r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))" if "length h < N" for a
    using abs_νN_eval_le' abs_νN_opt_eqn_le' that
    by (fastforce intro!: integral_mono pmf_bounded_integrable simp: bounded_real)
  have **: "a  set_pmf (p h s)  a  A s" for a
    using 1 is_dec_def is_policy_def by blast
  then show ?case
    unfolding νN_eval.simps[of p h] νN_opt_eqn.simps[of h]
    using integrable_r_exp_νN_eval bounded_r_exp_νN_eval bounded_r_exp_νN_opt_eqn *
    by (auto simp: set_pmf_not_empty intro!: order.trans[OF lemma_4_3_1] cSUP_mono bexI bounded_imp_bdd_above)
  qed

lemma νN_eval_le_opt: "pΠHR  νN_eval_opt h s  νN_eval p h s"
  unfolding νN_eval_opt_def
  using bounded_subset_range[OF abs_boundedD[OF abs_νN_eval_le']]
  by (force intro!: cSUP_upper abs_boundedD bounded_imp_bdd_above)

lemma νN_opt_eqn_bounded[simp, intro]: "bounded ((νN_opt_eqn h) ` X)"
  by (meson Blinfun_Util.bounded_subset abs_νN_opt_eqn_le' abs_boundedD subset_UNIV)

lemma νN_eval_opt_bounded[simp, intro]: "bounded ((νN_eval_opt h) ` X)"
  by (meson Blinfun_Util.bounded_subset abs_νN_eval_opt_le' abs_boundedD subset_UNIV)

lemma νN_eval_bounded[simp, intro]: "bounded ((νN_eval p h) ` X)"
  by (meson Blinfun_Util.bounded_subset abs_νN_eval_le' abs_boundedD subset_UNIV)

lemma νN_opt_ge: "length h  N  νN_opt_eqn h s  νN_eval_opt h s"
proof (induction "N - length h" arbitrary: h s)
  case 0
  then show ?case
    unfolding νN_eval_opt_def νN_opt_eqn.simps[of h]
    using policies_ne
    by (subst νN_eval_eq) auto
next
  case (Suc x)
  hence "length h < N"
    by linarith
  {
    fix p assume "p  ΠHR"
    have "νN_eval p h s = measure_pmf.expectation (p h s) (λa. (r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s')))"
      unfolding νN_eval.simps[of p h]
      using length h < N
      by auto
    also have "  (a  A s. (r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_eval p (h@[(s,a)]) s')))"
      using p  ΠHR is_dec_def is_policy_def bounded_r_snd' bounded_exp_νN_eval
      by (auto intro!: lemma_4_3_1 bounded_plus_comp pmf_bounded_integrable simp: r_bounded')
    also have "  (a  A s. (r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. νN_opt_eqn (h@[(s,a)]) s')))"
    proof -
      have "a  A s  
        r (s,a) + measure_pmf.expectation (K (s,a)) (νN_eval p (h @ [(s,a)])) 
        r (s,a) + measure_pmf.expectation (K (s,a)) (νN_eval_opt (h@[(s,a)]))" for a
        using abs_boundedD[OF abs_νN_eval_opt_le'] abs_boundedD[OF abs_νN_eval_le']
        using νN_eval_le_opt p  ΠHR
        by (force intro!: integral_mono pmf_bounded_integrable)
      moreover have "a  A s  
        r (s,a) + measure_pmf.expectation (K (s,a)) (νN_eval_opt (h@[(s,a)])) 
        r (s,a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn (h@[(s,a)]))" for a
        using νN_eval_le_opt_eqn policies_ne Suc
        by (auto intro!: integral_mono pmf_bounded_integrable cSUP_least)
      ultimately show ?thesis
        using A_ne bounded_imp_bdd_above bounded_r_exp_νN_opt_eqn
        by (fastforce intro!: cSUP_mono)+
    qed
    also have " = νN_opt_eqn h s"
      unfolding νN_opt_eqn.simps[of h]  
      using length h < N
      by auto
    finally have "νN_opt_eqn h s  νN_eval p h s".
  }
  then show ?case    
    unfolding νN_eval_opt_def
    using policies_ne
    by (auto intro!: cSUP_least)
qed

lemma Sup_wit_ex:
  assumes "(d ::real)> 0"
  assumes "X  {}"
  assumes "bdd_above (f ` X)"
  shows "x  X. (x  X. f x) < f x + d"
proof -
  have "x X. (x  X. f x) - d < f x"
    using assms
    by (auto simp: less_cSUP_iff[symmetric])
  thus ?thesis
    by force
qed


lemma νN_opt_eqn_markov: "length h  N  length h = length h'  νN_opt_eqn h = νN_opt_eqn h'"
proof (induction "N - length h" arbitrary: h h')
  case 0
  then show ?case
    by (auto simp: νN_opt_eqn.simps)
next
  case (Suc x)
  {
    fix s
    have "νN_opt_eqn h s = (a  A s. r (s, a) + measure_pmf.expectation (K(s,a)) (νN_opt_eqn (h@[(s,a)])))"
      using  Suc by (fastforce simp: νN_opt_eqn.simps)
    also have " = (a  A s. r (s, a) + measure_pmf.expectation (K(s,a)) (νN_opt_eqn (h'@[(s,a)])))"
      using Suc
      by (auto intro!: SUP_cong Bochner_Integration.integral_cong Suc(1)[THEN cong])
    also have " = νN_opt_eqn h' s"
      using Suc νN_opt_eqn.simps by fastforce
    finally have "νN_opt_eqn h s = νN_opt_eqn h' s ".
  }
  thus ?case by auto
qed

lemma νN_opt_le:
  fixes eps :: real
  assumes "eps > 0"
  shows "p  ΠMD. h s. length h  N  νN_eval (mk_markovian_det p) h s + real (N - length h) * eps  νN_opt_eqn h s" 
proof -
  define p where "p = (λn s. if n  N then SOME a. a  A s else
      SOME a. a  A s 
        r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (replicate n (s, SOME a. a  A s) @ [(s,a)])) + eps > νN_opt_eqn (replicate n (s,SOME a. a  A s)) s)"
  have *: "a . a  A s 
        r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h@[(s,a)])) + eps > νN_opt_eqn h s" 
    if "length h < N"
    for h s
    using that Sup_wit_ex[OF assms A_ne, unfolded Bex_def] bounded_imp_bdd_above bounded_r_exp_νN_opt_eqn 
    by (auto simp: νN_opt_eqn.simps)
  hence **: "a . a  A s 
        r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn ((replicate n (s,SOME a. a  A s))@[(s,a)])) + eps > νN_opt_eqn (replicate n (s,SOME a. a  A s)) s" 
    if "n < N" for n s
    using that by simp
  have p_prop: "p n s  A s  r (s, p n s) + measure_pmf.expectation (K (s, p n s)) (νN_opt_eqn ((replicate n (s,SOME a. a  A s))@[(s,p n s)])) + eps > νN_opt_eqn ((replicate n (s,SOME a. a  A s))) s"
    if "n < N" for n s
    using someI_ex[OF **[OF that], of s] that
    by (auto simp: p_def)
  hence p_prop': "p (length h) s  A s  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s,p (length h) s)])) + eps > νN_opt_eqn h s"
    if "length h < N" for h s
    using that
    by (auto simp: 
        νN_opt_eqn_markov[of h "(replicate (length h) (s,SOME a. a  A s))"] 
        νN_opt_eqn_markov[of "(h@[(s,p (length h) s)])" "(replicate (length h) (s, SOME a. a  A s) @ [(s, p (length h) s)])"])
  have "p n s  A s" for n s
    using SOME_is_dec_det is_dec_det_def p_def p_prop by auto
  hence p:"p  ΠMD"
    using is_dec_det_def by force
  {
    fix h s p
    assume "p  ΠMD"
      and 
      p: "h s. length h < N  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s,p (length h) s)])) + eps > νN_opt_eqn h s"
    have "length h  N  νN_eval (mk_markovian_det p) h s + real (N - length h) * eps  νN_opt_eqn h s"
    proof (induction "N - length h" arbitrary: h s)
      case 0
      hence *: "length h = N"
        by auto
      thus ?case
        by (auto simp: νN_opt_eqn.simps νN_eval.simps)
    next
      case (Suc x)
      hence *: "length h < N"
        by auto
      have "νN_opt_eqn h s - real (N - length h) * eps < r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s,p (length h) s)])) - real (N - length h) * eps + eps"
        using p[OF *, of s] by auto        
      also have " = r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s,p (length h) s)])) - real (N - length h - 1) * eps"
      proof -
        have "real (N - length h - 1) = real (N - length h) - 1"
          using * by (auto simp: algebra_simps) 
        thus ?thesis 
          by algebra
      qed
      also have " = r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (λs'. νN_opt_eqn (h@[(s,p (length h) s)]) s' - real (N - length h - 1) * eps)"  
        by (subst Bochner_Integration.integral_diff) (auto intro: pmf_bounded_integrable) 
      also have "  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_eval (mk_markovian_det p) (h@[(s,p (length h) s)]))"
        using Suc(1)[of "h@[_]"] Suc *
        by (auto simp: algebra_simps intro!: integral_mono pmf_bounded_integrable bounded_minus_comp)
      also have " = νN_eval (mk_markovian_det p) h s"
        using Suc
        by (auto simp: mk_markovian_det_def νN_eval.simps)        
      finally show ?case 
        by auto
    qed
  }
  thus ?thesis
    using p p_prop' by blast
qed

lemma νN_opt_le':
  fixes eps :: real
  assumes "eps > 0"
  shows "p  ΠMD. h s. length h  N  νN_eval (mk_markovian_det p) h s + eps  νN_opt_eqn h s" 
proof -
  obtain p where "pΠMD" and "h s. length h  N  νN_opt_eqn h s  νN_eval (mk_markovian_det p) h s + real (N - length h) * (eps/N)"
    using νN_opt_le[of "eps / N"] νN_opt_le assms
    by (cases "N = 0") force+   
  hence **: "h s. length h  N  νN_opt_eqn h s  νN_eval (mk_markovian_det p) h s + eps - ((eps * length h) / N)"
    using assms
    by (cases "N = 0") (auto simp: algebra_simps of_nat_diff intro: add_increasing)
  moreover have *:"eps * real (length h) / N  0" for h 
    using assms by auto
  ultimately have "h s. length h  N  νN_opt_eqn h s  νN_eval (mk_markovian_det p) h s + eps"
    by (auto intro!: order.trans[OF **])
  thus ?thesis
    using p  ΠHD by blast
qed

lemma mk_det_preserves: "p  ΠHD  (mk_det p)  ΠHR"
  unfolding is_policy_def mk_det_def
  by (auto simp: is_dec_def is_dec_det_def)

lemma mk_markovian_det_preserves: "p  ΠMD  (mk_markovian_det p)  ΠHR"
  unfolding is_policy_def mk_markovian_det_def
  by (auto simp: is_dec_def is_dec_det_def)

lemma νN_opt_eq:
  assumes "length h  N"
  shows "νN_opt_eqn h s = νN_eval_opt h s"
proof -
  {
    fix eps :: real
    assume "0 < eps"
    hence "pΠHR. h s. length h  N  νN_opt_eqn h s  νN_eval p h s + eps"
    using mk_markovian_det_preserves νN_opt_le'[of eps]
    by auto
  then obtain p where "pΠHR" and **: "length h  N  νN_opt_eqn h s  νN_eval p h s + eps" for h s
    by auto
  hence "length h  N  νN_opt_eqn h s  νN_eval_opt h s + eps" for h s
    using νN_eval_le_opt[of p]
    by (auto intro:  order.trans[OF **])
  }
  hence "length h  N  νN_opt_eqn h s  νN_eval_opt h s"
    by (meson field_le_epsilon)
  thus ?thesis
    using νN_opt_ge assms antisym by auto
qed

lemma νN_opt_eqn_correct: "νN_opt s = νN_opt_eqn [] s"
  using νN_eval_correct νN_eval_opt_def νN_opt_def νN_opt_eq by force

lemma thm_4_3_4:
  assumes "eps  0" "p  ΠMD"
    and "h s. length h < N  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s, p (length h) s)])) + eps
     (a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn (h@[(s, a)])))"
  shows "h s.  length h  N  νN_eval (mk_markovian_det p) h s + (N - length h) * eps  νN_opt_eqn h s"
    "s. νN (mk_markovian_det p) s + N * eps  νN_opt s"
proof -
  show "νN_eval (mk_markovian_det p) h s + (N - length h) * eps  νN_opt_eqn h s" if "length h  N" for h s
    using assms that
  proof (induction "N - length h" arbitrary: h s)
    case 0
    then show ?case
      using νN_eval.simps νN_opt_eqn.simps by force
  next
    case (Suc x)
    have "νN_opt_eqn h s = (a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn (h@[(s, a)])))"
      using Suc.hyps(2) νN_opt_eqn.simps by fastforce
    also have "  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h@[(s, p (length h) s)])) + eps"
      using Suc.hyps(2) Suc.prems(3)
      by simp
    also have "  r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (λs'. νN_eval (mk_markovian_det p) (h@[(s, p (length h) s)]) s' + 
    (N - length (h@[(s,p (length h) s)])) * eps) + eps"
      using Suc(1)[of "(h@[(s,p (length h) s)])"] Suc.hyps(2) assms
      by (auto intro!: integral_mono pmf_bounded_integrable bounded_plus_comp)
    also have " = r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_eval (mk_markovian_det p) (h@[(s, p (length h) s)])) + (N - length h) * eps"
      using Suc
      by (subst Bochner_Integration.integral_add) (auto simp: of_nat_diff left_diff_distrib distrib_right intro!: pmf_bounded_integrable)
    also have " = νN_eval (mk_markovian_det p) h s + (N - length h) * eps"
      using Suc
      by (auto simp add: νN_eval.simps mk_markovian_det_def)
    finally show ?case.
  qed
  from this[of  "[]"] show "νN (mk_markovian_det p) s + N * eps  νN_opt s" for s
    using νN_eval_correct νN_opt_eqn_correct
    by auto
qed

lemma νN_has_eps_opt_pol:
  assumes "eps > 0"
  shows "p  ΠMD. s. νN (mk_markovian_det p) s + eps  νN_opt s"
proof -
  obtain p where "pΠMD" and 
    P: "h s. length h  N  νN_opt_eqn h s  νN_eval (mk_markovian_det p) h s + eps"
    using νN_opt_le'[of eps] assms by auto
  from P[of "[]"] have "νN_opt_eqn [] s  νN_eval (mk_markovian_det p) [] s + eps" for s
    by auto
  thus ?thesis
    unfolding νN_opt_eqn_correct
    using νN_eval_correct p  ΠHD by auto 
qed

lemma νN_le_opt: "p  ΠHR  νN p s  νN_opt s"
  by (metis νN_eval_correct νN_eval_le_opt_eqn νN_opt_eqn_correct)

lemma νN_has_opt_pol:
  assumes "h s. 
    length h < N 
     a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn (h@[(s,a)]))
    = (a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn (h@[(s,a)])))" 
  shows "p  ΠMD. s. νN (mk_markovian_det p) s = νN_opt s"
proof -
  define p where "p = (λn s. if n  N then SOME a. a  A s else
      SOME a. a  A s 
        r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (replicate n (s,SOME a. a  A s)@[(s,a)])) = νN_opt_eqn (replicate n (s,SOME a. a  A s)) s
)"
  have p_short: "p n s = (
      SOME a. a  A s 
        r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (replicate n (s,SOME a. a  A s)@[(s,a)])) = νN_opt_eqn (replicate n (s,SOME a. a  A s)) s)"
    if "n < N" for n s
    unfolding p_def using that by auto
  have *: "p n s  A s" 
    "(n < N   r (s, p n s) + measure_pmf.expectation (K (s,p n s)) (νN_opt_eqn ((replicate n (s,SOME a. a  A s))@[(s,p n s)]))
    = (a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn ((replicate n (s,SOME a. a  A s))@[(s,a)]))))" for n s
      using someI_ex[OF assms[unfolded Bex_def]] SOME_is_dec_det
      by (auto simp: νN_opt_eqn.simps is_dec_det_def p_def)
    have "νN (mk_markovian_det p) s  νN_opt s" for s
    proof (intro thm_4_3_4(2)[of 0 p, simplified])
      show "n. is_dec_det (p n)"
        using *
        by (auto simp: is_dec_det_def)
    next
      {
        fix h :: "('s × 'a) list" and s
        assume "length h < N"
        have "(aA s. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)]))) = 
          (a  A s. r (s, a) + measure_pmf.expectation (K (s,a)) (νN_opt_eqn ((replicate (length h) (s,SOME a. a  A s))@[(s,a)])))"
          using length h < N
          by (auto intro!: SUP_cong Bochner_Integration.integral_cong νN_opt_eqn_markov[THEN cong])
        also have " = r (s, p (length h) s) + measure_pmf.expectation (K (s,p (length h) s)) (νN_opt_eqn ((replicate (length h) (s,SOME a. a  A s))@[(s,p (length h) s)]))"
          using * length h < N by presburger
        also have " = r (s, p (length h) s) + measure_pmf.expectation (K (s,p (length h) s)) (νN_opt_eqn (h@[(s,p (length h) s)]))"
          using length h < N
          by (auto intro!: Bochner_Integration.integral_cong νN_opt_eqn_markov[THEN cong])
        finally show "(aA s. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)])))
            r (s, p (length h) s) + measure_pmf.expectation (K (s, p (length h) s)) (νN_opt_eqn (h @ [(s, p (length h) s)]))"
          by auto
      }
    qed
    hence "νN (mk_markovian_det p) s = νN_opt s" for s
      using νN_le_opt *(1) mk_markovian_det_preserves
      by (simp add: is_dec_det_def order_antisym)
    thus ?thesis
      using *(1)
      by (auto simp: is_dec_det_def)
qed

lemma ex_Max: "finite X  X  {}  x  X.  f x = Max (f ` X)"
  by (metis (mono_tags, opaque_lifting) Max_in empty_is_image finite_imageI imageE)

lemma fin_A_imp_opt_pol:
  assumes "s. finite (A s)"
  shows "pΠMD. s. νN (mk_markovian_det p) s = νN_opt s"
  using A_ne assms νN_has_opt_pol
  by (fastforce simp: cSup_eq_Max intro!: ex_Max)


section ‹Backward Induction›

function bw_ind_aux where
  "bw_ind_aux n s = (
    if n = N then r_fin s else 
    if n > N then 0 else
      a  A s. (r (s,a) +
        measure_pmf.expectation (K (s,a)) (λs'. bw_ind_aux (Suc n) s'))) "
  by auto

termination
  by (relation "Wellfounded.measure (λ(h,s). N - h)") auto

lemmas bw_ind_aux.simps[simp del]


lemma bw_ind_aux_eq: "bw_ind_aux (length h) s = νN_opt_eqn h s"
  by (induction h s rule: νN_opt_eqn.induct)
    (auto simp: bw_ind_aux.simps νN_opt_eqn.simps split: if_splits intro!: Bochner_Integration.integral_cong SUP_cong)

fun bw_ind_aux' where
  "bw_ind_aux' (Suc n) m = (
    let m' = (λi s. 
      if i = n 
      then (a  A s. (r (s,a) + measure_pmf.expectation (K (s,a)) (m (Suc n))))
      else m i s) in
    bw_ind_aux' n m')" |
  "bw_ind_aux' 0 m = m"

definition "bw_ind = bw_ind_aux' N (λi s. if i = N then r_fin s else 0)"


lemma bw_ind_aux'_const[simp]:
  assumes "i  n"
  shows "bw_ind_aux' n m i = m i"
  using assms
proof (induction n arbitrary: m i)
  case 0
  then show ?case by (auto simp: bw_ind_aux'.simps)
next
  case (Suc n)
  then show ?case
    by auto
qed

lemma bw_ind_aux'_indep:
  assumes "i < n" and
    "j. j > i  m j = m' j"
  shows "bw_ind_aux' n m i s = bw_ind_aux' n m' i s"
  using assms
proof (induction n arbitrary: m i m')
  case 0
  then show ?case
    by fastforce
next
  case (Suc n)
  show ?case
  proof (cases "i < n")
    case True
    then show ?thesis
      by (auto intro!: Suc(1) ext simp: Suc(2,3))
  next
    case False
    then show ?thesis
      using Suc.prems(1) less_Suc_eq
      by (auto simp: Suc)
  qed  
  qed

lemma bw_ind_aux'_simps': "i < n  bw_ind_aux' n m i s = (a  A s. (r (s,a) + measure_pmf.expectation (K (s,a)) (bw_ind_aux'  n m (Suc i))))"
proof (induction n arbitrary: m i s)
  case 0
  then show ?case by auto
next
  case (Suc n)
  have "bw_ind_aux' (Suc n) m i s = bw_ind_aux' n (λi s. if i = n then aA s. r (s, a) + measure_pmf.expectation (K (s, a)) (m (Suc n)) else m i s) i s"
    by auto
  also have " = (aA s. r (s, a) + measure_pmf.expectation (K (s, a)) ((bw_ind_aux' (Suc n) m (Suc i))))"
    using Suc.prems le_less_Suc_eq
    by (cases "n  i") (auto simp: Suc.IH bw_ind_aux'_const)
  finally show ?case.
qed

lemma bw_ind_correct: "n  N  bw_ind n = bw_ind_aux n"
  unfolding bw_ind_def
proof (induction "N - n" arbitrary: n)
  case 0
  show ?case
    using 0
    by (subst bw_ind_aux.simps) (auto)
next
  case (Suc x)
  thus ?case
    by (auto simp: bw_ind_aux'_simps' bw_ind_aux.simps intro!: ext)
qed

definition "bw_ind_pol_gen (d :: 'a set  'a) n s = (
  if n  N then d (A s)
  else
    d ({a . is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) a}))"

lemma bw_ind_pol_is_arg_max:
  assumes "X. X  {}  d X  X" "s. finite (A s)"
  shows "is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) (d ({a . is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) a}))"
proof -
  let ?s = "{a. is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) a}"
  have d ?s  ?s
    using assms(1)[of " {a. is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) a}"]
    using  finite_is_arg_max A_ne assms
    by (auto simp add: finite_is_arg_max)
  thus ?thesis
    by auto
qed

lemma bw_ind_pol_gen:
  assumes "X. X  {}  d X  X" "s. finite (A s)"
  shows "bw_ind_pol_gen d  ΠMD"
proof -
  have ***:"X  {}  X  Y  d X  Y" for X Y 
    using assms
    by auto
  have "a. is_arg_max (λa. r (s, a) + measure_pmf.expectation (K (s, a)) (bw_ind_aux (Suc n))) (λa. a  A s) a" for n s
    using finite_is_arg_max[OF assms(2) A_ne]
    by auto
  thus ?thesis
    unfolding bw_ind_pol_gen_def is_dec_det_def
    by (force intro!: ***)
qed

lemma
  assumes "X. X  {}  d X  X"  "s. finite (A s)" "length h  N"
  shows "νN_eval (mk_markovian_det (bw_ind_pol_gen d)) h s = νN_eval_opt h s"
proof -
  have "(h s. length h < N 
            (aA s. r (s, a) + measure_pmf.expectation (K (s, a)) (νN_opt_eqn (h @ [(s, a)])))
             r (s, bw_ind_pol_gen d (length h) s) +
               measure_pmf.expectation (K (s, bw_ind_pol_gen d (length h) s))
                (νN_opt_eqn (h @ [(s, bw_ind_pol_gen d (length h) s)])))"
    using A_ne bw_ind_pol_is_arg_max[OF assms(1,2)]
    unfolding bw_ind_aux_eq[symmetric]
    by (auto intro!: cSUP_least simp: bw_ind_pol_gen_def)
  hence "length h  N  νN_opt_eqn h s  νN_eval (mk_markovian_det (bw_ind_pol_gen d)) h s" for h s
    using assms bw_ind_pol_gen thm_4_3_4[of 0 "bw_ind_pol_gen d", simplified]
    by auto
  thus ?thesis
    using νN_opt_eq νN_eval_le_opt assms bw_ind_pol_gen mk_markovian_det_preserves
    by (auto intro!: antisym)
qed


lemma bw_ind_aux'_eq: "n  N  bw_ind_aux' N (λi s. if i = N then r_fin s else 0) n = bw_ind_aux n"
  using bw_ind_def bw_ind_correct by presburger
end

end