Theory Product_PMF_Ext

section ‹Indexed Products of Probability Mass Functions›

theory Product_PMF_Ext
  imports Main Probability_Ext "HOL-Probability.Product_PMF" Universal_Hash_Families.Universal_Hash_Families_More_Independent_Families
begin

hide_const "Isolated.discrete"

text ‹This section introduces a restricted version of @{term "Pi_pmf"} where the default value is @{term "undefined"}
and contains some additional results about that case in addition to @{theory "HOL-Probability.Product_PMF"}

abbreviation prod_pmf where "prod_pmf I M  Pi_pmf I undefined M"

lemma pmf_prod_pmf: 
  assumes "finite I"
  shows "pmf (prod_pmf I M) x = (if x  extensional I then i  I. (pmf (M i)) (x i) else 0)"
  by (simp add:  pmf_Pi[OF assms(1)] extensional_def)

lemma PiE_defaut_undefined_eq: "PiE_dflt I undefined M = PiE I M" 
  by (simp add:PiE_dflt_def PiE_def extensional_def Pi_def set_eq_iff) blast

lemma set_prod_pmf:
  assumes "finite I"
  shows "set_pmf (prod_pmf I M) = PiE I (set_pmf  M)"
  by (simp add:set_Pi_pmf[OF assms] PiE_defaut_undefined_eq)

text ‹A more general version of @{thm [source] "measure_Pi_pmf_Pi"}.›
lemma prob_prod_pmf': 
  assumes "finite I"
  assumes "J  I"
  shows "measure (measure_pmf (Pi_pmf I d M)) (Pi J A) = ( i  J. measure (M i) (A i))"
proof -
  have a:"Pi J A = Pi I (λi. if i  J then A i else UNIV)"
    using assms by (simp add:Pi_def set_eq_iff, blast)
  show ?thesis
    using assms
    by (simp add:if_distrib  a measure_Pi_pmf_Pi[OF assms(1)] prod.If_cases[OF assms(1)] Int_absorb1)
qed

lemma prob_prod_pmf_slice:
  assumes "finite I"
  assumes "i  I"
  shows "measure (measure_pmf (prod_pmf I M)) {ω. P (ω i)} = measure (M i) {ω. P ω}"
  using prob_prod_pmf'[OF assms(1), where J="{i}" and M="M" and A="λ_. Collect P"]
  by (simp add:assms Pi_def)


definition restrict_dfl where "restrict_dfl f A d = (λx. if x  A then f x else d)"

lemma pi_pmf_decompose:
  assumes "finite I"
  shows "Pi_pmf I d M = map_pmf (λω. restrict_dfl (λi. ω (f i) i) I d) (Pi_pmf (f ` I) (λ_. d) (λj. Pi_pmf (f -` {j}  I) d M))"
proof -
  have fin_F_I:"finite (f ` I)" using assms by blast

  have "finite I  ?thesis"
    using fin_F_I
  proof (induction "f ` I" arbitrary: I rule:finite_induct)
    case empty
    then show ?case by (simp add:restrict_dfl_def)
  next
    case (insert x F)
    have a: "(f -` {x}  I)  (f -` F  I) = I"
      using insert(4) by blast
    have b: "F = f `  (f -` F  I) " using insert(2,4) 
      by (auto simp add:set_eq_iff image_def vimage_def) 
    have c: "finite (f -` F  I)" using insert by blast
    have d: "j. j  F  (f -` {j}  (f -` F  I)) = (f -` {j}  I)"
      using insert(4) by blast 

    have " Pi_pmf I d M = Pi_pmf ((f -` {x}  I)  (f -` F  I)) d M"
      by (simp add:a)
    also have "... = map_pmf (λ(g, h) i. if i  f -` {x}  I then g i else h i) 
      (pair_pmf (Pi_pmf (f -` {x}  I) d M) (Pi_pmf (f -` F  I) d M))"
      using insert by (subst Pi_pmf_union) auto
    also have "... = map_pmf (λ(g,h) i. if f i = x  i  I then g i else if f i  F  i  I then h (f i) i else d)
      (pair_pmf (Pi_pmf (f -` {x}  I) d M) (Pi_pmf F (λ_. d) (λj. Pi_pmf (f -` {j}  (f -` F  I)) d M)))"
      by (simp add:insert(3)[OF b c] map_pmf_comp case_prod_beta' apsnd_def map_prod_def 
          pair_map_pmf2 b[symmetric] restrict_dfl_def) (metis fst_conv snd_conv)
    also have "... = map_pmf (λ(g,h) i. if i  I then (h(x := g)) (f i) i else d) 
      (pair_pmf (Pi_pmf (f -` {x}  I) d M) (Pi_pmf F (λ_. d) (λj. Pi_pmf (f -` {j}  I) d M)))" 
      using insert(4) d
      by (intro arg_cong2[where f="map_pmf"] ext) (auto simp add:case_prod_beta' cong:Pi_pmf_cong) 
    also have "... = map_pmf (λω i. if i  I then ω (f i) i else d) (Pi_pmf (insert x F) (λ_. d) (λj. Pi_pmf (f -` {j}  I) d M))"
      by (simp add:Pi_pmf_insert[OF insert(1,2)] map_pmf_comp case_prod_beta')
    finally show ?case by (simp add:insert(4) restrict_dfl_def)
  qed
  thus ?thesis using assms by blast
qed

lemma restrict_dfl_iter: "restrict_dfl (restrict_dfl f I d) J d = restrict_dfl f (I  J) d"
  by (rule ext, simp add:restrict_dfl_def)

lemma indep_vars_restrict':
  assumes "finite I"
  shows "prob_space.indep_vars (Pi_pmf I d M) (λ_. discrete) (λi ω. restrict_dfl ω (f -` {i}  I) d) (f ` I)"
proof -
  let ?Q = "(Pi_pmf (f ` I) (λ_. d) (λi. Pi_pmf (I  f -` {i}) d M))"
  have a:"prob_space.indep_vars ?Q (λ_. discrete) (λi ω. ω i) (f ` I)"
    using assms by (intro indep_vars_Pi_pmf, blast)
  have b: "AE x in measure_pmf ?Q. if ` I. x i = restrict_dfl (λi. x (f i) i) (I  f -` {i}) d"
    using assms
    by (auto simp add:PiE_dflt_def restrict_dfl_def AE_measure_pmf_iff set_Pi_pmf comp_def Int_commute)
  have "prob_space.indep_vars ?Q (λ_. discrete) (λi x. restrict_dfl (λi. x (f i) i) (I  f -` {i}) d) (f ` I)"
    by (rule prob_space.indep_vars_cong_AE[OF prob_space_measure_pmf b a],  simp)
  thus ?thesis
    using prob_space_measure_pmf 
    by (auto intro!:prob_space.indep_vars_distr simp:pi_pmf_decompose[OF assms, where f="f"]  
        map_pmf_rep_eq comp_def restrict_dfl_iter Int_commute) 
qed

lemma indep_vars_restrict_intro':
  assumes "finite I"
  assumes "i ω. i  J  X' i ω = X' i (restrict_dfl ω (f -` {i}  I) d)"
  assumes "J = f ` I"
  assumes "ω i. i  J   X' i ω  space (M' i)"
  shows "prob_space.indep_vars (measure_pmf (Pi_pmf I d p)) M' (λi ω. X' i ω) J"
proof -
  define M where "M  measure_pmf (Pi_pmf I d p)"
  interpret prob_space "M"
    using M_def prob_space_measure_pmf by blast
  have "indep_vars (λ_. discrete) (λi x. restrict_dfl x (f -` {i}  I) d) (f ` I)" 
    unfolding M_def  by (rule indep_vars_restrict'[OF assms(1)])
  hence "indep_vars M' (λi ω. X' i (restrict_dfl ω ( f -` {i}  I) d)) (f ` I)"
    using assms(4)
    by (intro indep_vars_compose2[where Y="X'" and N="M'" and M'="λ_. discrete"])  (auto simp:assms(3))
  hence "indep_vars M' (λi ω. X' i ω) (f ` I)"
    using assms(2)[symmetric]
    by (simp add:assms(3) cong:indep_vars_cong)
  thus ?thesis
    unfolding M_def using assms(3) by simp 
qed

lemma  
  fixes f :: "'b  ('c :: {second_countable_topology,banach,real_normed_field})"
  assumes "finite I"
  assumes "i  I"
  assumes "integrable (measure_pmf (M i)) f"
  shows  integrable_Pi_pmf_slice: "integrable (Pi_pmf I d M) (λx. f (x i))"
  and expectation_Pi_pmf_slice: "integralL (Pi_pmf I d M) (λx. f (x i)) = integralL (M i) f"
proof -
  have a:"distr (Pi_pmf I d M) (M i) (λω. ω i) = distr (Pi_pmf I d M) discrete (λω. ω i)"
    by (rule distr_cong, auto)

  have b: "measure_pmf.random_variable (M i) borel f"
    using assms(3) by simp

  have c:"integrable (distr (Pi_pmf I d M) (M i) (λω. ω i)) f" 
    using assms(1,2,3)
    by (subst a, subst map_pmf_rep_eq[symmetric], subst Pi_pmf_component, auto)

  show "integrable (Pi_pmf I d M) (λx. f (x i))"
    by (rule integrable_distr[where f="f" and M'="M i"])  (auto intro: c)

  have "integralL (Pi_pmf I d M) (λx. f (x i)) = integralL (distr (Pi_pmf I d M) (M i) (λω. ω i)) f"
    using b by (intro integral_distr[symmetric], auto)
  also have "... =  integralL (map_pmf (λω. ω i) (Pi_pmf I d M)) f"
    by (subst a, subst map_pmf_rep_eq[symmetric], simp)
  also have "... =  integralL (M i) f"
    using assms(1,2) by (simp add: Pi_pmf_component)
  finally show "integralL (Pi_pmf I d M) (λx. f (x i)) = integralL (M i) f" by simp
qed

text ‹This is an improved version of @{thm [source] "expectation_prod_Pi_pmf"}.
It works for general normed fields instead of non-negative real functions .›

lemma expectation_prod_Pi_pmf: 
  fixes f :: "'a  'b  ('c :: {second_countable_topology,banach,real_normed_field})"
  assumes "finite I"
  assumes "i. i  I  integrable (measure_pmf (M i)) (f i)"
  shows   "integralL (Pi_pmf I d M) (λx. (i  I. f i (x i))) = ( i  I. integralL (M i) (f i))"
proof -
  have a: "prob_space.indep_vars (measure_pmf (Pi_pmf I d M)) (λ_. borel) (λi ω. f i (ω i)) I"
    by (intro prob_space.indep_vars_compose2[where Y="f" and M'="λ_. discrete"] 
        prob_space_measure_pmf indep_vars_Pi_pmf assms(1)) auto
  have "integralL (Pi_pmf I d M) (λx. (i  I. f i (x i))) = ( i  I. integralL (Pi_pmf I d M) (λx. f i (x i)))"
    by (intro prob_space.indep_vars_lebesgue_integral prob_space_measure_pmf assms(1,2) 
        a integrable_Pi_pmf_slice) auto
  also have "... = ( i  I. integralL (M i) (f i))"
    by (intro prod.cong expectation_Pi_pmf_slice assms(1,2)) auto
  finally show ?thesis by simp
qed

lemma variance_prod_pmf_slice:
  fixes f :: "'a  real"
  assumes "i  I" "finite I"
  assumes "integrable (measure_pmf (M i)) (λω. f ω^2)"
  shows "prob_space.variance (Pi_pmf I d M) (λω. f (ω i)) = prob_space.variance (M i) f"
proof -
  have a:"integrable (measure_pmf (M i)) f"
    using assms(3) measure_pmf.square_integrable_imp_integrable by auto
  have b:" integrable (measure_pmf (Pi_pmf I d M)) (λx. (f (x i))2)"
    by (rule integrable_Pi_pmf_slice[OF assms(2) assms(1)], metis assms(3))
  have c:" integrable (measure_pmf (Pi_pmf I d M)) (λx. (f (x i)))"
    by (rule integrable_Pi_pmf_slice[OF assms(2) assms(1)], metis a)

  have "measure_pmf.expectation (Pi_pmf I d M) (λx. (f (x i))2) - (measure_pmf.expectation (Pi_pmf I d M) (λx. f (x i)))2 =
      measure_pmf.expectation (M i) (λx. (f x)2) - (measure_pmf.expectation (M i) f)2"
    using assms a b c by ((subst expectation_Pi_pmf_slice[OF assms(2,1)])?, simp)+

  thus ?thesis
    using assms a b c by (simp add: measure_pmf.variance_eq)
qed

lemma Pi_pmf_bind_return:
  assumes "finite I"
  shows "Pi_pmf I d (λi. M i  (λx. return_pmf (f i x))) = Pi_pmf I d' M  (λx. return_pmf (λi. if i  I then f i (x i) else d))"
  using assms by (simp add: Pi_pmf_bind[where d'="d'"])

end