Theory Trace_Space_Equals_Markov_Processes

(* Author: Johannes Hölzl <hoelzl@in.tum.de> *)

subsection ‹Trace Space equal to Markov Chains›

theory Trace_Space_Equals_Markov_Processes
  imports Discrete_Time_Markov_Chain
begin

text ‹
  We can construct for each time-homogeneous discrete-time Markov chain a corresponding
  probability space using @{theory Markov_Models.Discrete_Time_Markov_Chain}. The constructed probability space
  has the same probabilities.
›

locale Time_Homogeneous_Discrete_Markov_Process = M?: prob_space +
  fixes S :: "'s set" and X :: "nat  'a  's"
  assumes X [measurable]: "t. X t  measurable M (count_space UNIV)"
  assumes S: "countable S" "n. AE x in M. X n x  S"
  assumes MC: "n s s'.
    𝒫(ω in M. tn. X t ω = s t )  0 
    𝒫(ω in M. X (Suc n) ω = s' ¦ tn. X t ω = s t ) =
    𝒫(ω in M. X (Suc n) ω = s' ¦ X n ω = s n )"
  assumes TH: "n m s t.
    𝒫(ω in M. X n ω = t)  0  𝒫(ω in M. X m ω = t)  0 
    𝒫(ω in M. X (Suc n) ω = s ¦ X n ω = t) = 𝒫(ω in M. X (Suc m) ω = s ¦ X m ω = t)"
begin

context
begin

interpretation pmf_as_measure .

lift_definition I :: "'s pmf" is "distr M (count_space UNIV) (X 0)"
proof -
  let ?X = "distr M (count_space UNIV) (X 0)"
  interpret X: prob_space ?X
    by (auto simp: prob_space_distr)
  have "AE x in ?X. measure ?X {x}  0"
    using S by (subst X.AE_support_countable) (auto simp: AE_distr_iff intro!: exI[of _ S])
  then show "prob_space ?X  sets ?X = UNIV  (AE x in ?X. measure ?X {x}  0)"
    by (simp add: prob_space_distr AE_support_countable)
qed

lemma I_in_S:
  assumes "pmf I s  0" shows "s  S"
proof -
  from pmf I s  0 have "0  𝒫(x in M. X 0 x = s)"
    by transfer (auto simp: measure_distr vimage_def Int_def conj_commute)
  also have "𝒫(x in M. X 0 x = s) = 𝒫(x in M. X 0 x = s  s  S)"
    using S(2)[of 0] by (intro M.finite_measure_eq_AE) auto
  finally show ?thesis
    by (cases "s  S") auto
qed

lift_definition K :: "'s  's pmf" is
  "λs. with (λn. 𝒫(ω in M. X n ω = s)  0)
     (λn. distr (uniform_measure M {ωspace M. X n ω = s}) (count_space UNIV) (X (Suc n)))
     (uniform_measure (count_space UNIV) {s})"
proof (rule withI)
  fix s n assume *: "𝒫(ω in M. X n ω = s)  0"
  let ?D = "distr (uniform_measure M {ωspace M. X n ω = s}) (count_space UNIV) (X (Suc n))"
  have D: "prob_space ?D"
    by (intro prob_space.prob_space_distr prob_space_uniform_measure)
       (auto simp: M.emeasure_eq_measure *)
  then interpret D: prob_space ?D .
  have sets_D: "sets ?D = UNIV"
    by simp
  moreover have "AE x in ?D. measure ?D {x}  0"
    unfolding D.AE_support_countable[OF sets_D]
  proof (intro exI[of _ S] conjI)
    show "countable S" by (rule S)
    show "AE x in ?D. x  S"
      using * S(2)[of "Suc n"] by (auto simp add: AE_distr_iff AE_uniform_measure M.emeasure_eq_measure)
  qed
  ultimately show "prob_space ?D  sets ?D = UNIV  (AE x in ?D. measure ?D {x}  0)"
    using D by blast
qed (auto intro!: prob_space_uniform_measure AE_uniform_measureI)

lemma pmf_K:
  assumes n: "0 < 𝒫(ω in M. X n ω = s)"
  shows "pmf (K s) t = 𝒫(ω in M. X (Suc n) ω = t ¦ X n ω = s)"
proof (transfer fixing: n s t)
  let ?P = "λn. 𝒫(ω in M. X n ω = s)  0"
  let ?D = "λn. distr (uniform_measure M {ωspace M. X n ω = s}) (count_space UNIV) (X (Suc n))"
  let ?U = "uniform_measure (count_space UNIV) {s}"
  show "measure (with ?P ?D ?U) {t} = 𝒫(ω in M. X (Suc n) ω = t ¦ X n ω = s)"
  proof (rule withI)
    fix n' assume "?P n'"
    moreover have "X (Suc n') -` {t}  space M = {xspace M. X (Suc n') x = t}"
      by auto
    ultimately show "measure (?D n') {t} = 𝒫(ω in M. X (Suc n) ω = t ¦ X n ω = s)"
      using n M.measure_uniform_measure_eq_cond_prob[of "λx. X (Suc n') x = t" "λx. X n' x = s"]
      by (auto simp: measure_distr M.emeasure_eq_measure simp del: measure_uniform_measure intro!: TH)
  qed (insert n, simp)
qed

lemma pmf_K2:
  "(n. 𝒫(ω in M. X n ω = s) = 0)  pmf (K s) t = indicator {t} s"
  apply (transfer fixing: s t)
  apply (rule withI)
  apply (auto split: split_indicator)
  done

end

sublocale K: MC_syntax K .

lemma bind_I_K_eq_M: "K.T' I = distr M K.S (λω. to_stream (λn. X n ω))" (is "_ = ?D")
proof (rule stream_space_eq_sstart)
  note streams_sets[measurable]
  note measurable_abs_UNIV[measurable (raw)]
  note sstart_sets[measurable]

  { fix s assume "s  S"
    from K.AE_T_enabled[of s] have "AE ω in K.T s. ω  streams S"
    proof eventually_elim
      fix ω assume "K.enabled s ω" from this sS show "ω  streams S"
      proof (coinduction arbitrary: s ω)
        case streams
        then have 1: "pmf (K s) (shd ω)  0"
          by (simp add: K.enabled.simps[of s] set_pmf_iff)
        have "shd ω  S"
        proof cases
          assume "n. 0 < 𝒫(ω in M. X n ω = s)"
          then obtain n where "0 < 𝒫(ω in M. X n ω = s)" by auto
          with 1 have 2: "𝒫(ω' in M. X (Suc n) ω' = shd ω  X n ω' = s)  0"
            by (simp add: pmf_K cond_prob_def)
          show "shd ω  S"
          proof (rule ccontr)
            assume "shd ω  S"
            with S(2)[of "Suc n"] have "𝒫(ω' in M. X (Suc n) ω' = shd ω  X n ω' = s) = 0"
              by (intro M.prob_eq_0_AE) auto
            with 2 show False by contradiction
          qed
        next
          assume "¬ (n. 0 < 𝒫(ω in M. X n ω = s))"
          then have "pmf (K s) (shd ω) = indicator {shd ω} s"
            by (intro pmf_K2) (auto simp: not_less measure_le_0_iff)
          with 1 sS show ?thesis
            by (auto split: split_indicator_asm)
        qed
        with streams show ?case
          by (cases ω) (auto simp: K.enabled.simps[of s])
      qed
    qed }
  note AE_streams = this

  show "prob_space (K.T' I)"
    by (rule K.prob_space_T')
  show "prob_space ?D"
    by (rule M.prob_space_distr) simp

  show "AE x in K.T' I. x  streams S"
    by (auto simp add: K.AE_T' set_pmf_iff I_in_S AE_distr_iff streams_Stream intro!: AE_streams)
  show "AE x in ?D. x  streams S"
    by (simp add: AE_distr_iff to_stream_in_streams AE_all_countable S)
  show "sets (K.T' I) = sets (stream_space (count_space UNIV))"
    by (simp add: K.sets_T')
  show "sets ?D = sets (stream_space (count_space UNIV))"
    by simp

  fix xs' assume "xs'  []" "xs'  lists S"
  then obtain s xs where xs': "xs' = s # xs" and s: "s  S" and xs: "xs  lists S"
    by (auto simp: neq_Nil_conv del: in_listsD)

  have "emeasure (K.T' I) (sstart S xs') = (+s. emeasure (K.T s) {ωspace K.S. s ## ω  sstart S xs'} I)"
    by (rule K.emeasure_T') measurable
  also have " = (+s'. emeasure (K.T s) (sstart S xs) * indicator {s} s' I)"
    by (intro arg_cong2[where f=emeasure] nn_integral_cong)
       (auto split: split_indicator simp: emeasure_distr vimage_def space_stream_space neq_Nil_conv xs')
  also have " = pmf I s * emeasure (K.T s) (sstart S xs)"
    by (auto simp add: max_def emeasure_pmf_single intro: mult_ac)
  also have "emeasure (K.T s) (sstart S xs) = ennreal (i<length xs. pmf (K ((s#xs)!i)) (xs!i))"
    using xs s
  proof (induction arbitrary: s)
    case Nil then show ?case
      by (simp add: K.T.emeasure_eq_1_AE AE_streams)
  next
    case (Cons t xs)
    have "emeasure (K.T s) (sstart S (t # xs)) =
      emeasure (K.T s) {xspace (K.T s). shd x = t  stl x  sstart S xs}"
      by (intro arg_cong2[where f=emeasure]) (auto simp: space_stream_space)
    also have " = (+t'. emeasure (K.T t') {xspace K.S. t' = t  x  sstart S xs} K s)"
      by (subst K.emeasure_Collect_T) auto
    also have " = (+t'. emeasure (K.T t) (sstart S xs) * indicator {t} t' K s)"
      by (intro nn_integral_cong) (auto split: split_indicator simp: space_stream_space)
    also have " = emeasure (K.T t) (sstart S xs) * pmf (K s) t"
      by (simp add: emeasure_pmf_single max_def)
    finally show ?case
      by (simp add: lessThan_Suc_eq_insert_0 zero_notin_Suc_image prod.reindex Cons
        prod_nonneg ennreal_mult[symmetric])
  qed
  also have "pmf I s * ennreal (i<length xs. pmf (K ((s#xs)!i)) (xs!i)) =
    𝒫(x in M. ilength xs. X i x = (s # xs) ! i)"
    using xs s
  proof (induction xs rule: rev_induct)
    case Nil
    have "pmf I s = prob {x  space M. X 0 x = s}"
      by transfer (simp add: vimage_def Int_def measure_distr conj_commute)
    then show ?case
      by simp
  next
    case (snoc t xs)
    let ?l = "length xs" and ?lt = "length (xs @ [t])" and ?xs' = "s # xs @ [t]"
    have "ennreal (pmf I s) * (i<?lt. pmf (K ((?xs') ! i)) ((xs @ [t]) ! i)) =
      (ennreal (pmf I s) * (i<?l. pmf (K ((s # xs) ! i)) (xs ! i))) * pmf (K ((s # xs) ! ?l)) t"
      by (simp add: lessThan_Suc mult_ac nth_append append_Cons[symmetric] prod_nonneg ennreal_mult[symmetric]
               del: append_Cons)
    also have " = 𝒫(x in M. i?l. X i x = (s # xs) ! i) * pmf (K ((s # xs) ! ?l)) t"
      using snoc by (simp add: ennreal_mult[symmetric])
    also have " = 𝒫(x in M. i?lt. X i x = (?xs') ! i)"
    proof cases
      assume "𝒫(ω in M. i?l. X i ω = (s # xs) ! i) = 0"
      moreover have "𝒫(x in M. i?lt. X i x = (?xs') ! i)  𝒫(ω in M. i?l. X i ω = (s # xs) ! i)"
        by (intro M.finite_measure_mono) (auto simp: nth_append nth_Cons split: nat.split)
      moreover have "𝒫(x in M. i?l. X i x = (s # xs) ! i)  𝒫(ω in M. i?l. X i ω = (s # xs) ! i)"
        by (intro M.finite_measure_mono) (auto simp: nth_append nth_Cons split: nat.split)
      ultimately show ?thesis
        by (simp add: measure_le_0_iff)
    next
      assume "𝒫(ω in M. i?l. X i ω = (s # xs) ! i)  0"
      then have *: "0 < 𝒫(ω in M. i?l. X i ω = (s # xs) ! i)"
        unfolding less_le by simp
      moreover have "𝒫(ω in M. i?l. X i ω = (s # xs) ! i)  𝒫(ω in M. X ?l ω = (s # xs) ! ?l)"
        by (intro M.finite_measure_mono) (auto simp: nth_append nth_Cons split: nat.split)
      ultimately have "𝒫(ω in M. X ?l ω = (s # xs) ! ?l)  0"
        by auto
      then have "pmf (K ((s # xs) ! ?l)) t = 𝒫(ω in M. X ?lt ω = ?xs' ! ?lt ¦ X ?l ω = (s # xs) ! ?l)"
        by (subst pmf_K) (auto simp: less_le)
      also have " = 𝒫(ω in M. X ?lt ω = ?xs' ! ?lt ¦ i?l. X i ω = (s # xs) ! i)"
        using * MC[of ?l "λi. (s # xs) ! i" "?xs' ! ?lt"] by simp
      also have " = 𝒫(ω in M. i?lt. X i ω = ?xs' ! i) / 𝒫(ω in M. i?l. X i ω = (s # xs) ! i)"
        unfolding cond_prob_def
        by (intro arg_cong2[where f="(/)"] arg_cong2[where f=measure]) (auto simp: nth_Cons nth_append split: nat.splits)
      finally show ?thesis
        using * by simp
    qed
    finally show ?case .
  qed
  also have " = emeasure ?D (sstart S xs')"
  proof -
    have "AE x in M. i. X i x  S"
      using S(2) by (simp add: AE_all_countable)
    then have "AE x in M. (ilength xs. X i x = (s # xs) ! i) = (to_stream (λn. X n x)  sstart S xs')"
    proof eventually_elim
      fix x assume "i. X i x  S"
      then have "to_stream (λn. X n x)  streams S"
        by (auto simp: streams_iff_snth to_stream_def)
      then show "(ilength xs. X i x = (s # xs) ! i) = (to_stream (λn. X n x)  sstart S xs')"
        by (simp add: sstart_eq xs' to_stream_def less_Suc_eq_le del: sstart.simps(1) in_sstart)
    qed
    then show ?thesis
      by (auto simp: emeasure_distr M.emeasure_eq_measure intro!: M.finite_measure_eq_AE)
  qed
  finally show "emeasure (K.T' I) (sstart S xs') = emeasure ?D (sstart S xs')" .
qed (rule S)

end

lemma (in MC_syntax) is_THDTMC:
  fixes I :: "'s pmf"
  defines "U  (SIGMA s:UNIV. K s)* `` I"
  shows "Time_Homogeneous_Discrete_Markov_Process (T' I) U (λn ω. ω !! n)"
proof -
  have [measurable]: "U  sets (count_space UNIV)"
    by auto

  interpret prob_space "T' I"
    by (rule prob_space_T')

  { fix s t I
    have "t s. 𝒫(ω in T s. s = t) = indicator {t} s"
      using T.prob_space by (auto split: split_indicator)
    moreover then have "t t' s. 𝒫(ω in T s. shd ω = t'  s = t) = pmf (K t) t' * indicator {t} s"
      by (subst prob_T) (auto split: split_indicator simp: pmf.rep_eq)
    ultimately have "𝒫(ω in T' I. shd (stl ω) = t  shd ω = s) = 𝒫(ω in T' I. shd ω = s) * pmf (K s) t"
      by (simp add: prob_T' pmf.rep_eq) }
  note start_eq = this

  { fix n s t assume "𝒫(ω in T' I. ω !! n = s)  0"
    moreover have "𝒫(ω in T' I. ω !! (Suc n) = t  ω !! n = s) = 𝒫(ω in T' I. ω !! n = s) * pmf (K s) t"
    proof (induction n arbitrary: I)
      case (Suc n) then show ?case
        by (subst (1 2) prob_T') (simp_all del: space_T add: T_eq_T')
    qed (simp add: start_eq)
    ultimately have "𝒫(ω in T' I. stl ω !! n = t ¦ ω !! n = s) = pmf (K s) t"
      by (simp add: cond_prob_def field_simps) }
  note TH = this

  { fix n ω' t assume "𝒫(ω in T' I. in. ω !! i = ω' i)  0"
    moreover have "𝒫(ω in T' I. ω !! (Suc n) = t  (in. ω !! i = ω' i)) =
      𝒫(ω in T' I. in. ω !! i = ω' i) * pmf (K (ω' n)) t"
    proof (induction n arbitrary: I ω')
      case (Suc n)
      have *[simp]: "s P. measure (T' (K s)) {x. s = ω' 0  P x} =
        measure (T' (K (ω' 0))) {x. P x} * indicator {ω' 0} s"
        by (auto split: split_indicator)
      from Suc[of _ "λi. ω' (Suc i)"] show ?case
        by (subst (1 2) prob_T')
           (simp_all add: T_eq_T' all_Suc_split[where P="λi. i  Suc n  Q i" for n Q] conj_commute conj_left_commute sets_eq_imp_space_eq[OF sets_T'])
    qed (simp add: start_eq)
    ultimately have "𝒫(ω in T' I. stl ω !! n = t ¦ in. ω !! i = ω' i) = pmf (K (ω' n)) t"
      by (simp add: cond_prob_def field_simps) }
  note MC = this

  { fix n ω' assume "𝒫(ω in T' I. tn. ω !! t = ω' t)  0"
    moreover have "𝒫(ω in T' I. tn. ω !! t = ω' t)  𝒫(ω in T' I. ω !! n = ω' n)"
      by (auto intro!: finite_measure_mono_AE simp: sets_T' sets_eq_imp_space_eq[OF sets_T'])
    ultimately have "𝒫(ω in T' I. ω !! n = ω' n)  0"
      by (auto simp: neq_iff not_less measure_le_0_iff) }
  note MC' = this

  show ?thesis
  proof
    show "countable U"
      unfolding U_def by (rule countable_reachable countable_Image countable_set_pmf)+
    show "t. (λω. ω !! t)  measurable (T' I) (count_space UNIV)"
      by (subst measurable_cong_sets[OF sets_T' refl]) simp
  next
    fix n
    have "xI. AE y in T x. (x ## y) !! n  U"
      unfolding U_def
    proof (induction n arbitrary: I)
      case 0 then show ?case
        by auto
    next
      case (Suc n)
      { fix x assume "x  I"
        have "AE y in T x. y !! n  (SIGMA x:UNIV. K x)* `` K x"
          apply (subst AE_T_iff)
          apply (rule measurable_compose[OF measurable_snth], simp)
          apply (rule Suc)
          done
        moreover have "(SIGMA x:UNIV. K x)* `` K x  (SIGMA x:UNIV. K x)* `` I"
          using x  I by (auto intro: converse_rtrancl_into_rtrancl)
        ultimately have "AE y in T x. y !! n  (SIGMA x:UNIV. K x)* `` I"
          by (auto simp: subset_eq) }
      then show ?case
        by simp
    qed
    then show "AE x in T' I. x !! n  U"
      by (simp add: AE_T')
  qed (simp_all add: TH MC MC')
qed

end