Theory Markov_Decision_Process

(* Author: Johannes Hölzl <hoelzl@in.tum.de> *)

section ‹Markov Decision Processes›

theory Markov_Decision_Process
  imports Discrete_Time_Markov_Chain
begin

definition "some_elem s = (SOME x. x  s)"

lemma some_elem_ne: "s  {}  some_elem s  s"
  unfolding some_elem_def by (auto intro: someI)

subsection ‹Configurations›

text ‹

We want to construct a \emph{non-free} codatatype
  's cfg = Cfg (state: 's) (action: 's pmf) (cont: 's ⇒ 's cfg)›.
with the restriction
  @{term "state (cont cfg s) = s"}

hide_const cont

codatatype 's scheduler = Scheduler (action_sch: "'s pmf") (cont_sch: "'s  's scheduler")

lemma equivp_rel_prod: "equivp R  equivp Q  equivp (rel_prod R Q)"
  by (auto intro!: equivpI prod.rel_symp prod.rel_transp prod.rel_reflp elim: equivpE)

coinductive eq_scheduler :: "'s scheduler  's scheduler  bool"
where
  "D. action_sch sc1 = D  action_sch sc2 = D 
    (sD. eq_scheduler (cont_sch sc1 s) (cont_sch sc2 s))  eq_scheduler sc1 sc2"

lemma eq_scheduler_refl[intro]: "eq_scheduler sc sc"
  by (coinduction arbitrary: sc) auto

quotient_type 's cfg = "'s × 's scheduler" / "rel_prod (=) eq_scheduler"
proof (intro equivp_rel_prod equivpI reflpI sympI transpI)
  show "eq_scheduler sc1 sc2  eq_scheduler sc2 sc1" for sc1 sc2 :: "'s scheduler"
    by (coinduction arbitrary: sc1 sc2) (auto elim: eq_scheduler.cases)
  show "eq_scheduler sc1 sc2  eq_scheduler sc2 sc3  eq_scheduler sc1 sc3"
    for sc1 sc2 sc3 :: "'s scheduler"
    by (coinduction arbitrary: sc1 sc2 sc3)
       (subst (asm) (1 2) eq_scheduler.simps, auto)
qed auto

lift_definition state :: "'s cfg  's" is "fst"
  by auto

lift_definition action :: "'s cfg  's pmf" is "λ(s, sc). action_sch sc"
  by (force elim: eq_scheduler.cases)

lift_definition cont :: "'s cfg  's  's cfg" is
  "λ(s, sc) t. if t  action_sch sc then (t, cont_sch sc t) else
    (t, cont_sch sc (some_elem (action_sch sc)))"
  apply (simp add: rel_prod_conv split: prod.splits)
  apply (subst (asm) eq_scheduler.simps)
  apply (auto simp: Let_def set_pmf_not_empty[THEN some_elem_ne])
  done

lift_definition Cfg :: "'s  's pmf  ('s  's cfg)  's cfg" is
  "λs D c. (s, Scheduler D (λt. snd (c t)))"
  by (auto simp: rel_prod_conv split_beta' eq_scheduler.simps[of "Scheduler _  _"])

lift_definition cfg_corec :: "'s  ('a  's pmf)  ('a  's  'a)   'a  's cfg" is
  "λs D C x. (s, corec_scheduler D (λx s. Inr (C x s)) x)"  .

lemma state_cont[simp]: "state (cont cfg s) = s"
  by transfer (simp split: prod.split)

lemma state_Cfg[simp]: "state (Cfg s d' c') = s"
  by transfer simp

lemma action_Cfg[simp]: "action (Cfg s d' c') = d'"
  by transfer simp

lemma cont_Cfg[simp]: "t  set_pmf d'  state (c' t) = t  cont (Cfg s d' c') t = c' t"
  by transfer (auto simp add: rel_prod_conv split: prod.split)

lemma state_cfg_corec[simp]: "state (cfg_corec s d c x) = s"
  by transfer auto

lemma action_cfg_corec[simp]: "action (cfg_corec s d c x) = d x"
  by transfer auto

lemma cont_cfg_corec[simp]: "t  set_pmf (d x)  cont (cfg_corec s d c x) t = cfg_corec t d c (c x t)"
  by transfer auto

lemma cfg_coinduct[consumes 1, case_names state action cont, coinduct pred]:
  "X c d  (c d. X c d  state c = state d)  (c d. X c d  action c = action d) 
    (c d t. X c d  t  set_pmf (action c)  X (cont c t) (cont d t))  c = d"
proof (transfer, clarsimp)
  fix X :: "('a × 'a scheduler)  ('a × 'a scheduler)  bool" and B s1 s2 sc1 sc2
  assume X: "X (s1, sc1) (s2, sc2)" and "rel_fun cr_cfg (rel_fun cr_cfg (=)) X B"
    and 1: "s1 sc1 s2 sc2. X (s1, sc1) (s2, sc2)  s1 = s2"
    and 2: "s1 sc1 s2 sc2. X (s1, sc1) (s2, sc2)  action_sch sc1 = action_sch sc2"
    and 3: "s1 sc1 s2 sc2 t. X (s1, sc1) (s2, sc2)  t  set_pmf (action_sch sc2) 
      X (t, cont_sch sc1 t) (t, cont_sch sc2 t)"
  from X show "eq_scheduler sc1 sc2"
    by (coinduction arbitrary: s1 s2 sc1 sc2)
       (blast dest: 2 3)
qed

coinductive rel_cfg :: "('a  'b  bool)  'a cfg  'b cfg  bool" for P :: "'a  'b  bool"
where
  "P (state cfg1) (state cfg2) 
    rel_pmf (λs t. rel_cfg P (cont cfg1 s) (cont cfg2 t)) (action cfg1) (action cfg2) 
    rel_cfg P cfg1 cfg2"

lemma rel_cfg_state: "rel_cfg P cfg1 cfg2  P (state cfg1) (state cfg2)"
  by (auto elim: rel_cfg.cases)

lemma rel_cfg_cont:
  "rel_cfg P cfg1 cfg2 
    rel_pmf (λs t. rel_cfg P (cont cfg1 s) (cont cfg2 t)) (action cfg1) (action cfg2)"
  by (auto elim: rel_cfg.cases)

lemma rel_cfg_action:
  assumes P: "rel_cfg P cfg1 cfg2" shows "rel_pmf P (action cfg1) (action cfg2)"
proof (rule pmf.rel_mono_strong)
  show "rel_pmf (λs t. rel_cfg P (cont cfg1 s) (cont cfg2 t)) (action cfg1) (action cfg2)"
    using P by (rule rel_cfg_cont)
qed (auto dest: rel_cfg_state)

lemma rel_cfg_eq: "rel_cfg (=) cfg1 cfg2  cfg1 = cfg2"
proof safe
  show "rel_cfg (=) cfg1 cfg2  cfg1 = cfg2"
  proof (coinduction arbitrary: cfg1 cfg2)
    case cont
    have "action cfg1 = action cfg2"
      using rel_cfg (=) cfg1 cfg2 by (auto dest: rel_cfg_action simp: pmf.rel_eq)
    then have "rel_pmf (λs t. rel_cfg (=) (cont cfg1 s) (cont cfg2 t)) (action cfg1) (action cfg1)"
      using cont by (auto dest: rel_cfg_cont)
    then have "rel_pmf (λs t. rel_cfg (=) (cont cfg1 s) (cont cfg2 t)  s = t) (action cfg1) (action cfg1)"
      by (rule pmf.rel_mono_strong) (auto dest: rel_cfg_state)
    then have "pred_pmf (λs. rel_cfg (=) (cont cfg1 s) (cont cfg2 s)) (action cfg1)"
      unfolding pmf.pred_rel by (rule pmf.rel_mono_strong) (auto simp: eq_onp_def)
    with t  action cfg1 show ?case
      by (auto simp: pmf.pred_set)
  qed (auto dest: rel_cfg_state rel_cfg_action simp: pmf.rel_eq)
  show "rel_cfg (=) cfg2 cfg2"
    by (coinduction arbitrary: cfg2) (auto intro!: rel_pmf_reflI)
qed

subsection ‹Configuration with Memoryless Scheduler›

definition "memoryless_on f s = cfg_corec s f (λ_ t. t) s"

lemma
  shows state_memoryless_on[simp]: "state (memoryless_on f s) = s"
    and action_memoryless_on[simp]: "action (memoryless_on f s) = f s"
    and cont_memoryless_on[simp]: "t  (f s)  cont (memoryless_on f s) t = memoryless_on f t"
  by (simp_all add: memoryless_on_def)

definition K_cfg :: "'s cfg  's cfg pmf" where
  "K_cfg cfg = map_pmf (cont cfg) (action cfg)"

lemma set_K_cfg: "set_pmf (K_cfg cfg) = cont cfg ` set_pmf (action cfg)"
  by (simp add: K_cfg_def)

lemma nn_integral_K_cfg: "(+cfg. f cfg K_cfg cfg) = (+s. f (cont cfg s) action cfg)"
  by (simp add: K_cfg_def map_pmf_rep_eq nn_integral_distr)

subsection ‹MDP Kernel and Induced Configurations›

locale Markov_Decision_Process =
  fixes K :: "'s  's pmf set"
  assumes K_wf: "s. K s  {}"
begin

definition "E = (SIGMA s:UNIV. DK s. set_pmf D)"

coinductive cfg_onp :: "'s  's cfg  bool" where
  "s. state cfg = s  action cfg  K s  (t. t  action cfg  cfg_onp t (cont cfg t)) 
    cfg_onp s cfg"

definition "cfg_on s = {cfg. cfg_onp s cfg}"

lemma
  shows cfg_onD_action[intro, simp]: "cfg  cfg_on s  action cfg  K s"
    and cfg_onD_cont[intro, simp]: "cfg  cfg_on s  t  action cfg  cont cfg t  cfg_on t"
    and cfg_onD_state[simp]: "cfg  cfg_on s  state cfg = s"
    and cfg_onI: "state cfg = s  action cfg  K s  (t. t  action cfg  cont cfg t  cfg_on t)  cfg  cfg_on s"
  by (auto simp: cfg_on_def intro: cfg_onp.intros elim: cfg_onp.cases)

lemma cfg_on_coinduct[coinduct set: cfg_on]:
  assumes "P s cfg"
  assumes "cfg s. P s cfg  state cfg = s"
  assumes "cfg s. P s cfg  action cfg  K s"
  assumes "cfg s t. P s cfg  t  action cfg  P t (cont cfg t)"
  shows "cfg  cfg_on s"
  using assms cfg_onp.coinduct[of P s cfg] by (simp add: cfg_on_def)

lemma memoryless_on_cfg_onI:
  assumes "s. f s  K s"
  shows "memoryless_on f s  cfg_on s"
  by (coinduction arbitrary: s) (auto intro: assms)

lemma cfg_of_cfg_onI:
  "D  K s  (t. t  D  c t  cfg_on t)  Cfg s D c  cfg_on s"
  by (rule cfg_onI) auto

definition "arb_act s = (SOME D. D  K s)"

lemma arb_actI[simp]: "arb_act s  K s"
  by (simp add: arb_act_def some_in_eq K_wf)

lemma cfg_on_not_empty[intro, simp]: "cfg_on s  {}"
  by (auto intro: memoryless_on_cfg_onI arb_actI)

sublocale MC: MC_syntax K_cfg .

abbreviation St :: "'s stream measure" where
  "St  stream_space (count_space UNIV)"

subsection ‹Trace Space›

definition "T cfg = distr (MC.T cfg) St (smap state)"

sublocale T: prob_space "T cfg" for cfg
  by (simp add: T_def MC.T.prob_space_distr)

lemma space_T[simp]: "space (T cfg) = space St"
  by (simp add: T_def)

lemma sets_T[simp]: "sets (T cfg) = sets St"
  by (simp add: T_def)

lemma measurable_T1[simp]: "measurable (T cfg) N = measurable St N"
  by (simp add: T_def)

lemma measurable_T2[simp]: "measurable N (T cfg) = measurable N St"
  by (simp add: T_def)

lemma nn_integral_T:
  assumes [measurable]: "f  borel_measurable St"
  shows "(+X. f X T cfg) = (+cfg'. (+x. f (state cfg' ## x) T cfg') K_cfg cfg)"
  by (simp add: T_def MC.nn_integral_T[of _ cfg] nn_integral_distr)

lemma T_eq:
  "T cfg = (measure_pmf (K_cfg cfg)  (λcfg'. distr (T cfg') St (λω. state cfg' ## ω)))"
proof (rule measure_eqI)
  fix A assume "A  sets (T cfg)"
  then show "emeasure (T cfg) A =
    emeasure (measure_pmf (K_cfg cfg)  (λcfg'. distr (T cfg') St (λω. state cfg' ## ω))) A"
    by (subst emeasure_bind[where N=St])
       (auto simp: space_subprob_algebra nn_integral_distr nn_integral_indicator[symmetric] nn_integral_T[of _ cfg]
             simp del: nn_integral_indicator intro!: prob_space_imp_subprob_space T.prob_space_distr)
qed simp

lemma T_memoryless_on: "T (memoryless_on ct s) = MC_syntax.T ct s"
proof -
  interpret ct: MC_syntax ct .
  have "T  (memoryless_on ct) = MC_syntax.T ct"
  proof (rule ct.T_bisim[symmetric])
    fix s show "(T  memoryless_on ct) s =
        measure_pmf (ct s)  (λs. distr ((T  memoryless_on ct) s) St ((##) s))"
      by (auto simp add: T_eq[of "memoryless_on ct s"] K_cfg_def map_pmf_rep_eq bind_distr[where K=St]
                         space_subprob_algebra T.prob_space_distr prob_space_imp_subprob_space
               intro!: bind_measure_pmf_cong)
  qed (simp_all, intro_locales)
  then show ?thesis by (simp add: fun_eq_iff)
qed

lemma nn_integral_T_lfp:
  assumes [measurable]: "case_prod g  borel_measurable (count_space UNIV M borel)"
  assumes cont_g: "s. sup_continuous (g s)"
  assumes int_g: "f cfg. f  borel_measurable (stream_space (count_space UNIV)) 
    (+ω. g (state cfg) (f ω) T cfg) = g (state cfg) (+ω. f ω T cfg)"
  shows "(+ω. lfp (λf ω. g (shd ω) (f (stl ω))) ω T cfg) =
    lfp (λf cfg. +t. g (state t) (f t) K_cfg cfg) cfg"
proof (rule nn_integral_lfp)
  show "s. sets (T s) = sets St"
      "F. F  borel_measurable St  (λa. g (shd a) (F (stl a)))  borel_measurable St"
    by auto
next
  fix s and F :: "'s stream  ennreal" assume "F  borel_measurable St"
  then show "(+ a. g (shd a) (F (stl a)) T s) =
           (+ cfg. g (state cfg) (integralN (T cfg) F) K_cfg s)"
    by (rewrite nn_integral_T) (simp_all add: int_g)
qed (auto intro!: order_continuous_intros cont_g[THEN sup_continuous_compose])

lemma emeasure_Collect_T:
  assumes [measurable]: "Measurable.pred St P"
  shows "emeasure (T cfg) {xspace St. P x} =
    (+cfg'. emeasure (T cfg') {xspace St. P (state cfg' ## x)} K_cfg cfg)"
  using MC.emeasure_Collect_T[of "λx. P (smap state x)" cfg]
  by (simp add: nn_integral_distr emeasure_Collect_distr T_def)

definition E_sup :: "'s  ('s stream  ennreal)  ennreal"
where
  "E_sup s f = (cfgcfg_on s. +x. f x T cfg)"

lemma E_sup_const: "0  c  E_sup s (λ_. c) = c"
  using T.emeasure_space_1 by (simp add: E_sup_def)

lemma E_sup_mult_right:
  assumes [measurable]: "f  borel_measurable St" and [simp]: "0  c"
  shows "E_sup s (λx. c * f x) = c * E_sup s f"
  by (simp add: nn_integral_cmult E_sup_def SUP_mult_left_ennreal)

lemma E_sup_mono:
  "(ω. f ω  g ω)  E_sup s f  E_sup s g"
  unfolding E_sup_def by (intro SUP_subset_mono order_refl nn_integral_mono)

lemma E_sup_add:
  assumes [measurable]: "f  borel_measurable St" "g  borel_measurable St"
  shows "E_sup s (λx. f x + g x)  E_sup s f + E_sup s g"
proof -
  have "E_sup s (λx. f x + g x) = (cfgcfg_on s. (+x. f x T cfg) + (+x. g x T cfg))"
    by (simp add: E_sup_def nn_integral_add)
  also have "  (cfgcfg_on s. +x. f x T cfg) + (cfgcfg_on s. (+x. g x T cfg))"
    by (auto simp: SUP_le_iff intro!: add_mono SUP_upper)
  finally show ?thesis
    by (simp add: E_sup_def)
qed

lemma E_sup_add_left:
  assumes [measurable]: "f  borel_measurable St"
  shows "E_sup s (λx. f x + c) = E_sup s f + c"
  by (simp add: nn_integral_add E_sup_def T.emeasure_space_1[simplified] ennreal_SUP_add_left)

lemma E_sup_add_right:
  "f  borel_measurable St  E_sup s (λx. c + f x) = c + E_sup s f"
  using E_sup_add_left[of f s c] by (simp add: add.commute)

lemma E_sup_SUP:
  assumes [measurable]: "i. f i  borel_measurable St" and [simp]: "incseq f"
  shows "E_sup s (λx. i. f i x) = (i. E_sup s (f i))"
  by (auto simp add: E_sup_def nn_integral_monotone_convergence_SUP intro: SUP_commute)

lemma E_sup_iterate:
  assumes [measurable]: "f  borel_measurable St"
  shows "E_sup s f = (DK s. + t. E_sup t (λω. f (t ## ω)) measure_pmf D)"
proof -
  let ?v = "λt. +x. f (state t ## x) T t"
  let ?p = "λt. E_sup t (λω. f (t ## ω))"
  have "E_sup s f = (cfgcfg_on s. +t. ?v t K_cfg cfg)"
    unfolding E_sup_def by (intro SUP_cong refl) (subst nn_integral_T, simp_all add: cfg_on_def)
  also have " = (DK s. +t. ?p t measure_pmf D)"
  proof (intro antisym SUP_least)
    fix cfg :: "'s cfg" assume cfg: "cfg  cfg_on s"
    then show "(+ t. ?v t K_cfg cfg)  (SUP DK s. +t. ?p t measure_pmf D)"
      by (auto simp: E_sup_def nn_integral_K_cfg AE_measure_pmf_iff
               intro!: nn_integral_mono_AE SUP_upper2)
  next
    fix D assume D: "D  K s" show "(+t. ?p t D)  (SUP cfg  cfg_on s. + t. ?v t K_cfg cfg)"
    proof cases
      assume p_finite: "tD. ?p t < "
      show ?thesis
      proof (rule ennreal_le_epsilon)
        fix e :: real assume "0 < e"
        have "tD. cfgcfg_on t. ?p t  ?v cfg + e"
        proof
          fix t assume "t  D"
          moreover have "(SUP cfg  cfg_on t. ?v cfg) = ?p t"
            unfolding E_sup_def by (simp add: cfg_on_def)
          ultimately have "(SUP cfg  cfg_on t. ?v cfg)  "
            using p_finite by auto
          from SUP_approx_ennreal[OF 0<e _ refl this]
          show "cfgcfg_on t. ?p t  ?v cfg + e"
            by (auto simp add: E_sup_def intro: less_imp_le)
        qed
        then obtain cfg' where v_cfg': "t. t  D  ?p t  ?v (cfg' t) + e" and
          cfg_on_cfg': "t. t  D  cfg' t  cfg_on t"
          unfolding Bex_def bchoice_iff by blast

        let ?cfg = "Cfg s D cfg'"
        have cfg: "K_cfg ?cfg = map_pmf cfg' D"
          by (auto simp add: K_cfg_def fun_eq_iff cfg_on_cfg' intro!: map_pmf_cong)

        have "(+ t. ?p t D)  (+t. ?v (cfg' t) + e D)"
          by (intro nn_integral_mono_AE) (simp add: v_cfg' AE_measure_pmf_iff)
        also have " = (+t. ?v (cfg' t) D) + e"
          using 0 < e measure_pmf.emeasure_space_1[of D]
          by (subst nn_integral_add) (auto intro: cfg_on_cfg' )
        also have "(+t. ?v (cfg' t) D) = (+t. ?v t K_cfg ?cfg)"
          by (simp add: cfg map_pmf_rep_eq nn_integral_distr)
        also have "  (SUP cfgcfg_on s. (+t. ?v t K_cfg cfg))"
          by (auto intro!: SUP_upper intro!: cfg_of_cfg_onI D cfg_on_cfg')
        finally show "(+ t. ?p t D)  (SUP cfg  cfg_on s. + t. ?v t K_cfg cfg) + e"
          by (blast intro: add_mono)
      qed
    next
      assume "¬ (tD. ?p t < )"
      then obtain t where "t  D" "?p t = "
        by (auto simp: not_less top_unique)
      then have " = pmf (D) t * ?p t"
        by (auto simp: ennreal_mult_top set_pmf_iff)
      also have " = (SUP cfg  cfg_on t. pmf (D) t * ?v cfg)"
        unfolding E_sup_def
        by (auto simp: SUP_mult_left_ennreal[symmetric])
      also have "  (SUP cfg  cfg_on s. + t. ?v t K_cfg cfg)"
        unfolding E_sup_def
      proof (intro SUP_least SUP_upper2)
        fix cfg :: "'s cfg" assume cfg: "cfg  cfg_on t"

        let ?cfg = "Cfg s D ((memoryless_on arb_act) (t := cfg))"
        have C: "K_cfg ?cfg = map_pmf ((memoryless_on arb_act) (t := cfg)) D"
          by (auto simp add: K_cfg_def fun_eq_iff intro!: map_pmf_cong simp: cfg)

        show "?cfg  cfg_on s"
          by (auto intro!: cfg_of_cfg_onI D cfg memoryless_on_cfg_onI)
        have "ennreal (pmf (D) t) * (+ x. f (state cfg ## x) T cfg) =
          (+t'. (+ x. f (state cfg ## x) T cfg) * indicator {t} t' D)"
          by (auto simp add:  max_def emeasure_pmf_single intro: mult_ac)
        also have " = (+cfg. ?v cfg * indicator {t} (state cfg) K_cfg ?cfg)"
          unfolding C using cfg
          by (auto simp add: nn_integral_distr map_pmf_rep_eq split: split_indicator
                   simp del: nn_integral_indicator_singleton
                   intro!: nn_integral_cong)
        also have "  (+cfg. ?v cfg K_cfg ?cfg)"
          by (auto intro!: nn_integral_mono  split: split_indicator)
        finally show "ennreal (pmf (D) t) * (+ x. f (state cfg ## x) T cfg)
            (+ t. + x. f (state t ## x) T t K_cfg ?cfg)" .
      qed
      finally show ?thesis
        by (simp add: top_unique del: Sup_eq_top_iff SUP_eq_top_iff)
    qed
  qed
  finally show ?thesis .
qed

lemma E_sup_bot: "E_sup s  = 0"
  by (auto simp add: E_sup_def bot_ennreal)

lemma E_sup_lfp:
  fixes g
  defines "l  λf ω. g (shd ω) (f (stl ω))"
  assumes measurable_g[measurable]: "case_prod g  borel_measurable (count_space UNIV M borel)"
  assumes cont_g: "s. sup_continuous (g s)"
  assumes int_g: "f cfg. f  borel_measurable St 
     (+ ω. g (state cfg) (f ω) T cfg) = g (state cfg) (integralN (T cfg) f)"
  shows "(λs. E_sup s (lfp l)) = lfp (λf s. DK s. +t. g t (f t) measure_pmf D)"
proof (rule lfp_transfer_bounded[where α="λF s. E_sup s F" and f=l and P="λf. f  borel_measurable St"])
  show "sup_continuous (λf s. xK s. + t. g t (f t) measure_pmf x)"
    using cont_g[THEN sup_continuous_compose] by (auto intro!: order_continuous_intros)
  show "sup_continuous l"
    using cont_g[THEN sup_continuous_compose] by (auto intro!: order_continuous_intros simp: l_def)
  show "F. (λs. E_sup s )  (λs. DK s. + t. g t (F t) measure_pmf D)"
    using K_wf by (auto simp: E_sup_bot le_fun_def intro: SUP_upper2 )
next
  fix f :: "'s stream  ennreal" assume f: "f  borel_measurable St"
  moreover
  have "E_sup s (λω. g s (f ω)) = g s (E_sup s f)" for s
    unfolding E_sup_def using int_g[OF f]
    by (subst SUP_sup_continuous_ennreal[OF cont_g, symmetric])
       (auto intro!: SUP_cong simp del: cfg_onD_state dest: cfg_onD_state[symmetric])
  ultimately show "(λs. E_sup s (l f)) = (λs. DK s. + t. g t (E_sup t f) measure_pmf D)"
    by (subst E_sup_iterate) (auto simp: l_def int_g fun_eq_iff intro!: SUP_cong nn_integral_cong)
qed (auto simp: bot_fun_def l_def SUP_apply[abs_def] E_sup_SUP)

definition "P_sup s P = (cfgcfg_on s. emeasure (T cfg) {xspace St. P x})"

lemma P_sup_eq_E_sup:
  assumes [measurable]: "Measurable.pred St P"
  shows "P_sup s P = E_sup s (indicator {xspace St. P x})"
  by (auto simp add: P_sup_def E_sup_def intro!: SUP_cong nn_integral_cong)

lemma P_sup_True[simp]: "P_sup t (λω. True) = 1"
  using T.emeasure_space_1
  by (auto simp add: P_sup_def SUP_constant)

lemma P_sup_False[simp]: "P_sup t (λω. False) = 0"
  by (auto simp add: P_sup_def SUP_constant)

lemma P_sup_SUP:
  fixes P :: "nat  's stream  bool"
  assumes "mono P" and P[measurable]: "i. Measurable.pred St (P i)"
  shows "P_sup s (λx. i. P i x) = (i. P_sup s (P i))"
proof -
  have "P_sup s (λx. i. P i x) = (cfgcfg_on s. emeasure (T cfg) (i. {xspace St. P i x}))"
    by (auto simp: P_sup_def intro!: SUP_cong arg_cong2[where f=emeasure])
  also have " = (cfgcfg_on s. i. emeasure (T cfg) {xspace St. P i x})"
    using mono P by (auto intro!: SUP_cong SUP_emeasure_incseq[symmetric] simp: mono_def le_fun_def)
  also have " = (i. P_sup s (P i))"
    by (subst SUP_commute) (simp add: P_sup_def)
  finally show ?thesis
    by simp
qed

lemma P_sup_lfp:
  assumes Q: "sup_continuous Q"
  assumes f: "f  measurable St M"
  assumes Q_m: "P. Measurable.pred M P  Measurable.pred M (Q P)"
  shows "P_sup s (λx. lfp Q (f x)) = (i. P_sup s (λx. (Q ^^ i)  (f x)))"
  unfolding sup_continuous_lfp[OF Q]
  apply simp
proof (rule P_sup_SUP)
  fix i show "Measurable.pred St (λx. (Q ^^ i)  (f x))"
    apply (intro measurable_compose[OF f])
    by (induct i) (auto intro!: Q_m)
qed (intro mono_funpow sup_continuous_mono[OF Q] mono_compose[where f=f])

lemma P_sup_iterate:
  assumes [measurable]: "Measurable.pred St P"
  shows "P_sup s P = (DK s. + t. P_sup t (λω. P (t ## ω)) measure_pmf D)"
proof -
  have [simp]: "x s. indicator {x  space St. P x} (x ## s) = indicator {s  space St. P (x ## s)} s"
    by (auto simp: space_stream_space split: split_indicator)
  show ?thesis
    using E_sup_iterate[of "indicator {xspace St. P x}" s] by (auto simp: P_sup_eq_E_sup)
qed

definition "E_inf s f = (cfgcfg_on s. +x. f x T cfg)"

lemma E_inf_const: "0  c  E_inf s (λ_. c) = c"
  using T.emeasure_space_1 by (simp add: E_inf_def)

lemma E_inf_mono:
  "(ω. f ω  g ω)  E_inf s f  E_inf s g"
  unfolding E_inf_def by (intro INF_superset_mono order_refl nn_integral_mono)

lemma E_inf_iterate:
  assumes [measurable]: "f  borel_measurable St"
  shows "E_inf s f = (DK s. + t. E_inf t (λω. f (t ## ω)) measure_pmf D)"
proof -
  let ?v = "λt. +x. f (state t ## x) T t"
  let ?p = "λt. E_inf t (λω. f (t ## ω))"
  have "E_inf s f = (cfgcfg_on s. +t. ?v t K_cfg cfg)"
    unfolding E_inf_def by (intro INF_cong refl) (subst nn_integral_T, simp_all add: cfg_on_def)
  also have " = (DK s. +t. ?p t measure_pmf D)"
  proof (intro antisym INF_greatest)
    fix cfg :: "'s cfg" assume cfg: "cfg  cfg_on s"
    then show "(INF DK s. +t. ?p t measure_pmf D)  (+ t. ?v t K_cfg cfg)"
      by (auto simp add: E_inf_def nn_integral_K_cfg AE_measure_pmf_iff intro!: nn_integral_mono_AE INF_lower2)
  next
    fix D assume D: "D  K s" show "(INF cfg  cfg_on s. + t. ?v t K_cfg cfg)  (+t. ?p t D)"
    proof (rule ennreal_le_epsilon)
      fix e :: real assume "0 < e"
      have "tD. cfgcfg_on t. ?v cfg  ?p t + e"
      proof
        fix t assume "t  D"
        show "cfgcfg_on t. ?v cfg  ?p t + e"
        proof cases
          assume "?p t = " with cfg_on_not_empty[of t] show ?thesis
            by (auto simp: top_add simp del: cfg_on_not_empty)
        next
          assume p_finite: "?p t  "
          note t  D
          moreover have "(INF cfg  cfg_on t. ?v cfg) = ?p t"
            unfolding E_inf_def by (simp add: cfg_on_def)
          ultimately have "(INF cfg  cfg_on t. ?v cfg)  "
            using p_finite by auto
          from INF_approx_ennreal[OF 0 < e refl this]
          show "cfgcfg_on t. ?v cfg  ?p t + e"
            by (auto simp: E_inf_def intro: less_imp_le)
        qed
      qed
      then obtain cfg' where v_cfg': "t. t  D  ?v (cfg' t)  ?p t + e" and
        cfg_on_cfg': "t. t  D  cfg' t  cfg_on t"
        unfolding Bex_def bchoice_iff by blast

      let ?cfg = "Cfg s D cfg'"

      have cfg: "K_cfg ?cfg = map_pmf cfg' D"
        by (auto simp add: K_cfg_def cfg_on_cfg' intro!: map_pmf_cong)

      have "?cfg  cfg_on s"
        by (auto intro: D cfg_on_cfg' cfg_of_cfg_onI)
      then have "(INF cfg  cfg_on s. + t. ?v t K_cfg cfg)  (+ t. ?p t + e D)"
        by (rule INF_lower2) (auto simp: cfg map_pmf_rep_eq nn_integral_distr v_cfg' AE_measure_pmf_iff intro!: nn_integral_mono_AE)
      also have " = (+ t. ?p t D) + e"
        using 0 < e by (simp add: nn_integral_add measure_pmf.emeasure_space_1[simplified])
      finally show "(INF cfg  cfg_on s. + t. ?v t K_cfg cfg)  (+ t. ?p t D) + e" .
    qed
  qed
  finally show ?thesis .
qed

lemma emeasure_T_const[simp]: "emeasure (T s) (space St) = 1"
  using T.emeasure_space_1[of s] by simp

lemma E_inf_greatest:
  "(cfg. cfg  cfg_on s  x  (+x. f x T cfg))  x  E_inf s f"
  unfolding E_inf_def by (rule INF_greatest)

lemma E_inf_lower2:
  "cfg  cfg_on s  (+x. f x T cfg)  x  E_inf s f  x"
  unfolding E_inf_def by (rule INF_lower2)

text ‹
  Maybe the following statement can be generalized to infinite @{term "K s"}.
›

lemma E_inf_lfp:
  fixes g
  defines "l  λf ω. g (shd ω) (f (stl ω))"
  assumes measurable_g[measurable]: "case_prod g  borel_measurable (count_space UNIV M borel)"
  assumes cont_g: "s. sup_continuous (g s)"
  assumes int_g: "f cfg. f  borel_measurable St 
     (+ ω. g (state cfg) (f ω) T cfg) = g (state cfg) (integralN (T cfg) f)"
  assumes K_finite: "s. finite (K s)"
  shows "(λs. E_inf s (lfp l)) = lfp (λf s. DK s. +t. g t (f t) measure_pmf D)"
proof (rule antisym)
  let ?F = "λF s. DK s. + t. g t (F t) measure_pmf D"
  let ?I = "λD. (+t. g t (lfp ?F t) measure_pmf D)"
  have mono_F: "mono ?F"
    using sup_continuous_mono[OF cont_g]
    by (force intro!: INF_mono nn_integral_mono monoI simp: mono_def le_fun_def)
  define ct where "ct s = (SOME D. D  K s  (lfp ?F s = ?I D))" for s
  { fix s
    have "finite (?I ` K s)"
      by (auto intro: K_finite)
    then obtain D where "D  K s" "?I D = Min (?I ` K s)"
      by (auto simp: K_wf dest!: Min_in)
    note this(2)
    also have " = (INF D  K s. ?I D)"
      using K_wf by (subst Min_Inf) (auto intro: K_finite)
    also have " = lfp ?F s"
      by (rewrite in "_ = " lfp_unfold[OF mono_F]) auto
    finally have "D. D  K s  (lfp ?F s = ?I D)"
      using D  K s by auto
    then have "ct s  K s  (lfp ?F s = ?I (ct s))"
      unfolding ct_def by (rule someI_ex)
    then have "ct s  K s" "lfp ?F s = ?I (ct s)"
      by auto }
  note ct = this
  then have ct_cfg_on[simp]: "s. memoryless_on ct s  cfg_on s"
    by (intro memoryless_on_cfg_onI) simp
  then show "(λs. E_inf s (lfp l))  lfp ?F"
  proof (intro le_funI, rule E_inf_lower2)
    fix s
    define P where "P f cfg = + t. g (state t) (f t) K_cfg cfg" for f cfg
    have "integralN (T (memoryless_on ct s)) (lfp l) = lfp P (memoryless_on ct s)"
      unfolding P_def l_def using measurable_g cont_g int_g by (rule nn_integral_T_lfp)
    also have " = (SUP i. (P ^^ i) ) (memoryless_on ct s)"
      by (rewrite sup_continuous_lfp)
         (auto intro!: order_continuous_intros cont_g[THEN sup_continuous_compose] simp: P_def)
    also have " = (SUP i. (P ^^ i)  (memoryless_on ct s))"
      by (simp add: image_comp)
    also have "  lfp ?F s"
    proof (rule SUP_least)
      fix i show "(P ^^ i)  (memoryless_on ct s)  lfp ?F s"
      proof (induction i arbitrary: s)
        case 0 then show ?case
          by simp
      next
        case (Suc n)
        have "(P ^^ Suc n)  (memoryless_on ct s) =
          (+ t. g t ((P ^^ n)  (memoryless_on ct t)) ct s)"
          by (auto simp add: P_def K_cfg_def AE_measure_pmf_iff intro!: nn_integral_cong_AE)
        also have "  (+ t. g t (lfp ?F t) ct s)"
          by (intro nn_integral_mono sup_continuous_mono[OF cont_g, THEN monoD] Suc)
        also have " = lfp ?F s"
          by (rule  ct(2) [symmetric])
        finally show ?case .
      qed
    qed
    finally show "integralN (T (memoryless_on ct s)) (lfp l)  lfp ?F s" .
  qed

  have cont_l: "sup_continuous l"
    by (auto simp: l_def intro!: order_continuous_intros cont_g[THEN sup_continuous_compose])

  show "lfp ?F  (λs. E_inf s (lfp l))"
  proof (intro lfp_lowerbound le_funI)
    fix s show "(xK s. + t. g t (E_inf t (lfp l)) measure_pmf x)  E_inf s (lfp l)"
    proof (rewrite in "_  " E_inf_iterate)
      show l: "lfp l  borel_measurable St"
        using cont_l by (rule borel_measurable_lfp) (simp add: l_def)
      show "(DK s. + t. g t (E_inf t (lfp l)) measure_pmf D) 
        (DK s. + t. E_inf t (λω. lfp l (t ## ω)) measure_pmf D)"
      proof (rule INF_mono nn_integral_mono bexI)+
        fix t D assume "D  K s"
        { fix cfg assume "cfg  cfg_on t"
          have "(+ ω. g (state cfg) (lfp l ω) T cfg) = g (state cfg) (+ ω. (lfp l ω) T cfg)"
            using l by (rule int_g)
          with cfg  cfg_on t have *: "(+ ω. g t (lfp l ω) T cfg) = g t (+ ω. (lfp l ω) T cfg)"
            by simp }
        then
        have *: "g t (cfgcfg_on t. integralN (T cfg) (lfp l))  (cfgcfg_on t. + ω. g t (lfp l ω) T cfg)"
          apply simp
          apply (rule INF_greatest)
          apply (rule sup_continuous_mono[OF cont_g, THEN monoD])
          apply (rule INF_lower)
          apply assumption
          done
        show "g t (E_inf t (lfp l))  E_inf t (λω. lfp l (t ## ω))"
          apply (rewrite in "_  " lfp_unfold[OF sup_continuous_mono[OF cont_l]])
          apply (rewrite in "_  " l_def)
          apply (simp add: E_inf_def *)
          done
      qed
    qed
  qed
qed

definition "P_inf s P = (cfgcfg_on s. emeasure (T cfg) {xspace St. P x})"

lemma P_inf_eq_E_inf:
  assumes [measurable]: "Measurable.pred St P"
  shows "P_inf s P = E_inf s (indicator {xspace St. P x})"
  by (auto simp add: P_inf_def E_inf_def intro!: SUP_cong nn_integral_cong)

lemma P_inf_True[simp]: "P_inf t (λω. True) = 1"
  using T.emeasure_space_1
  by (auto simp add: P_inf_def SUP_constant)

lemma P_inf_False[simp]: "P_inf t (λω. False) = 0"
  by (auto simp add: P_inf_def SUP_constant)

lemma P_inf_INF:
  fixes P :: "nat  's stream  bool"
  assumes "decseq P" and P[measurable]: "i. Measurable.pred St (P i)"
  shows "P_inf s (λx. i. P i x) = (i. P_inf s (P i))"
proof -
  have "P_inf s (λx. i. P i x) = (cfgcfg_on s. emeasure (T cfg) (i. {xspace St. P i x}))"
    by (auto simp: P_inf_def intro!: INF_cong arg_cong2[where f=emeasure])
  also have " = (cfgcfg_on s. i. emeasure (T cfg) {xspace St. P i x})"
    using decseq P
    by (auto intro!: INF_cong INF_emeasure_decseq[symmetric]
        simp: decseq_def monotone_def le_fun_def)
  also have " = (i. P_inf s (P i))"
    by (subst INF_commute) (simp add: P_inf_def)
  finally show ?thesis
    by simp
qed

lemma P_inf_gfp:
  assumes Q: "inf_continuous Q"
  assumes f: "f  measurable St M"
  assumes Q_m: "P. Measurable.pred M P  Measurable.pred M (Q P)"
  shows "P_inf s (λx. gfp Q (f x)) = (i. P_inf s (λx. (Q ^^ i)  (f x)))"
  unfolding inf_continuous_gfp[OF Q]
  apply simp
proof (rule P_inf_INF)
  fix i show "Measurable.pred St (λx. (Q ^^ i)  (f x))"
    apply (intro measurable_compose[OF f])
    by (induct i) (auto intro!: Q_m)
next
  show "decseq (λi x. (Q ^^ i)  (f x))"
    using inf_continuous_mono[OF Q, THEN funpow_increasing[rotated]]
    unfolding decseq_def monotone_def le_fun_def by auto
qed

lemma P_inf_iterate:
  assumes [measurable]: "Measurable.pred St P"
  shows "P_inf s P = (DK s. + t. P_inf t (λω. P (t ## ω)) measure_pmf D)"
proof -
  have [simp]: "x s. indicator {x  space St. P x} (x ## s) = indicator {s  space St. P (x ## s)} s"
    by (auto simp: space_stream_space split: split_indicator)
  show ?thesis
    using E_inf_iterate[of "indicator {xspace St. P x}" s] by (auto simp: P_inf_eq_E_inf)
qed

end

subsection ‹Finite MDPs›

locale Finite_Markov_Decision_Process = Markov_Decision_Process K for K :: "'s  's pmf set" +
  fixes S :: "'s set"
  assumes S_not_empty: "S  {}"
  assumes S_finite: "finite S"
  assumes K_closed: "s. s  S  (DK s. set_pmf D)  S"
  assumes K_finite: "s. s  S  finite (K s)"
begin

lemma action_closed: "s  S  cfg  cfg_on s  t  action cfg  t  S"
  using cfg_onD_action[of cfg s] K_closed[of s] by auto

lemma set_pmf_closed: "s  S  D  K s  t  D  t  S"
  using K_closed by auto

lemma Pi_closed: "ct  Pi S K  s  S  t  ct s  t  S"
  using set_pmf_closed by auto

lemma E_closed: "s  S  (s, t)  E  t  S"
  using K_closed by (auto simp: E_def)

lemma set_pmf_finite: "s  S  D  K s  finite D"
  using K_closed by (intro finite_subset[OF _ S_finite]) auto

definition "valid_cfg = (sS. cfg_on s)"

lemma valid_cfgI: "s  S  cfg  cfg_on s  cfg  valid_cfg"
  by (auto simp: valid_cfg_def)

lemma valid_cfgD: "cfg  valid_cfg  cfg  cfg_on (state cfg)"
  by (auto simp: valid_cfg_def)

lemma
  shows valid_cfg_state_in_S: "cfg  valid_cfg  state cfg  S"
    and valid_cfg_action: "cfg  valid_cfg  s  action cfg  s  S"
    and valid_cfg_cont: "cfg  valid_cfg  s  action cfg  cont cfg s  valid_cfg"
  by (auto simp: valid_cfg_def intro!: bexI[of _ s] intro: action_closed)

lemma valid_K_cfg[intro]: "cfg  valid_cfg  cfg'  K_cfg cfg  cfg'  valid_cfg"
  by (auto simp add: K_cfg_def valid_cfg_cont)

definition "simple ct = memoryless_on (λs. if s  S then ct s else arb_act s)"

lemma simple_cfg_on[simp]: "ct  Pi S K  simple ct s  cfg_on s"
  by (auto simp: simple_def intro!: memoryless_on_cfg_onI)

lemma simple_valid_cfg[simp]: "ct  Pi S K  s  S  simple ct s  valid_cfg"
  by (auto intro: valid_cfgI)

lemma cont_simple[simp]: "s  S  t  set_pmf (ct s)  cont (simple ct s) t = simple ct t"
  by (simp add: simple_def)

lemma state_simple[simp]: "state (simple ct s) = s"
  by (simp add: simple_def)

lemma action_simple[simp]: "s  S  action (simple ct s) = ct s"
  by (simp add: simple_def)

lemma simple_valid_cfg_iff: "ct  Pi S K  simple ct s  valid_cfg  s  S"
  using cfg_onD_state[of "simple ct s"] by (auto simp add: valid_cfg_def intro!: bexI[of _ s])

end

end