Theory Probability_Tools

(*******************************************************************************

  Project: Sumcheck Protocol

  Authors: Azucena Garvia Bosshard <zucegb@gmail.com>
           Christoph Sprenger, ETH Zurich <sprenger@inf.ethz.ch>
           Jonathan Bootle, IBM Research Europe <jbt@zurich.ibm.com>

*******************************************************************************)

section ‹Auxiliary Lemmas Related to Probability Theory›

theory Probability_Tools
  imports "HOL-Probability.Probability"
begin

subsection ‹Tuples›

definition tuples :: 'a set  nat  'a list set where
  tuples S n = {xs. set xs  S  length xs = n}

lemma tuplesI:  set xs  S; length xs = n   xs  tuples S n 
  by (simp add: tuples_def)

lemma tuplesE [elim]:  xs  tuples S n;  set xs  S; length xs = n   P   P 
  by (simp add: tuples_def)

lemma tuples_Zero: tuples S 0 = {[]}
  by (auto simp add: tuples_def)

lemma tuples_Suc: tuples S (Suc n) = (λ(x, xs). x # xs) ` (S × tuples S n)
  by (fastforce simp add: tuples_def image_def Suc_length_conv dest: sym)

lemma tuples_non_empty [simp]: S  {}  tuples S n  {}
  by (induction n) (auto simp add: tuples_Zero tuples_Suc)

lemma tuples_finite [simp]:  finite (S::'a set); S  {}   finite (tuples S n :: 'a list set)
  by (auto simp add: tuples_def dest: finite_lists_length_eq)


subsection ‹Congruence and monotonicity›

lemma prob_cong:                        ― ‹adapted from Joshua›
  assumes x. x  set_pmf M  x  A  x  B 
  shows measure_pmf.prob M A = measure_pmf.prob M B
  using assms
  by (simp add: measure_pmf.finite_measure_eq_AE AE_measure_pmf_iff)

lemma prob_mono: 
  assumes x. x  set_pmf M  x  A  x  B 
  shows measure_pmf.prob M A  measure_pmf.prob M B
  using assms
  by (simp add: measure_pmf.finite_measure_mono_AE AE_measure_pmf_iff)


subsection ‹Some simple derived lemmas›

lemma prob_empty: 
  assumes A = {}
  shows measure_pmf.prob M A = 0
  using assms
  by (simp)   ― ‹uses @{thm [source] "measure_empty"}: @{thm "measure_empty"}

lemma prob_pmf_of_set_geq_1:
  assumes "finite S" and "S  {}"
  shows "measure_pmf.prob (pmf_of_set S) A  1  S  A" using assms
  by (auto simp add: measure_pmf.measure_ge_1_iff measure_pmf.prob_eq_1 AE_measure_pmf_iff)


subsection ‹Intersection and union lemmas›

lemma prob_disjoint_union:
  assumes A  B = {} 
  shows measure_pmf.prob M (A  B) = measure_pmf.prob M A + measure_pmf.prob M B
  using assms
  by (fact measure_pmf.finite_measure_Union[simplified])

lemma prob_finite_Union:
  assumes disjoint_family_on A I finite I
  shows measure_pmf.prob M (iI. A i) = (iI. measure_pmf.prob M (A i))
  using assms
  by (intro measure_pmf.finite_measure_finite_Union) (simp_all)

lemma prob_disjoint_cases:
  assumes B  C = A B  C = {}
  shows measure_pmf.prob M A = measure_pmf.prob M B + measure_pmf.prob M C
proof - 
  have measure_pmf.prob M A = measure_pmf.prob M (B  C) using B  C = A 
    by (auto intro: prob_cong)
  also have ... = measure_pmf.prob M B + measure_pmf.prob M C using B  C = {}
    by (simp add: prob_disjoint_union)
  finally show ?thesis .
qed

lemma prob_finite_disjoint_cases:
  assumes (iI. B i) = A disjoint_family_on B I finite I
  shows measure_pmf.prob M A = (iI. measure_pmf.prob M (B i))
proof - 
  have measure_pmf.prob M A = measure_pmf.prob M (iI. B i) using assms(1)
    by (auto intro: prob_cong) 
  also have ... = (iI. measure_pmf.prob M (B i)) using assms(2,3) 
    by (intro prob_finite_Union) 
  finally show ?thesis .
qed


subsection ‹Independent probabilities for head and tail of a tuple›

lemma pmf_of_set_Times:   ― ‹by Andreas Lochbihler›
  "pmf_of_set (A × B) = pair_pmf (pmf_of_set A) (pmf_of_set B)"
  if "finite A" "finite B" "A  {}" "B  {}"   
  by(rule pmf_eqI)(auto simp add: that pmf_pair indicator_def)


lemma prob_tuples_hd_tl_indep:
  assumes S  {} 
  shows
    measure_pmf.prob (pmf_of_set (tuples S (Suc n))) {(r::'a::finite) # rs | r rs. P r  Q rs}
   = measure_pmf.prob (pmf_of_set (S::'a set)) {r. P r} * 
     measure_pmf.prob (pmf_of_set (tuples S n)) {rs. Q rs} 
    (is "?lhs = ?rhs")
proof -                  ― ‹mostly by Andreas Lochbihler›
  text ‹
    Step 1: Split the random variable @{term "pmf_of_set (tuples S (Suc n))"} into
    two independent (@{term "pair_pmf"}) random variables, one producing the head and one 
    producing the tail of the list, and then @{term Cons} the two random variables using 
    @{term "map_pmf"}.
  ›
  have *: "pmf_of_set (tuples S (Suc n)) 
         = map_pmf (λ(x :: 'a, xs). x # xs) (pair_pmf (pmf_of_set S) (pmf_of_set (tuples S n)))"
    unfolding tuples_Suc using S  {} 
    by (auto simp add: map_pmf_of_set_inj[symmetric] inj_on_def pmf_of_set_Times) 
  text ‹
    Step 2: Transform the event by move the @{term Cons} from the random variable into the event.
    This corresponds to using @{term distr} on measures.
  ›
  have "?lhs = measure_pmf.prob (pair_pmf (pmf_of_set S) (pmf_of_set (tuples S n))) 
                                ((λ(x :: 'a, xs). x # xs) -` {r # rs | r rs. P r  Q rs})"
    unfolding * measure_map_pmf by (rule refl)

  text ‹
    Step 3: Rewrite the event as a pair of events. Then we apply independence of the head 
    from the tail.
  ›
  also have "((λ(x, xs). x # xs) -` {r # rs | r rs. P r  Q rs}) = {r. P r} × {rs. Q rs}" by auto
  also have "measure_pmf.prob (pair_pmf (pmf_of_set S) (pmf_of_set (tuples S n)))  =
               measure_pmf.prob (pmf_of_set S) {r. P r} 
             * measure_pmf.prob (pmf_of_set (tuples S n)) {rs. Q rs}"
   by(rule measure_pmf_prob_product) simp_all

  finally show ?thesis .
qed

lemma prob_tuples_fixed_hd:
  measure_pmf.prob (pmf_of_set (tuples UNIV (Suc n))) {rs::'a list. P rs} 
   = (a  UNIV. measure_pmf.prob (pmf_of_set (tuples UNIV n)) {rs. P (a # rs)}) / real(CARD('a::finite))
  (is "?lhs = ?rhs")
proof -
  {
    fix a
    have measure_pmf.prob (pmf_of_set (tuples UNIV (Suc n))) ({rs. P rs}  {rs. hd rs = a})
        = measure_pmf.prob (pmf_of_set (tuples UNIV (Suc n))) ({r#rs | r rs. r = a  P (a#rs)})
      by (intro prob_cong) (auto simp add: tuples_Suc)
    also have ... = measure_pmf.prob (pmf_of_set (UNIV::'a set)) {r. r = a} * 
                     measure_pmf.prob (pmf_of_set (tuples UNIV n)) {rs. P (a#rs)} 
      by (intro prob_tuples_hd_tl_indep) simp
    also have ... = measure_pmf.prob (pmf_of_set (tuples UNIV n)) {rs. P (a#rs)} / real (CARD ('a))
      by (simp add: measure_pmf_single)
    finally 
    have measure_pmf.prob (pmf_of_set (tuples UNIV (Suc n))) ({rs. P rs}  {rs. hd rs = a}) 
        = measure_pmf.prob (pmf_of_set (tuples UNIV n)) {rs. P (a#rs)} / real (CARD ('a)) .
  }
  note A1 = this

  have ?lhs = (a  UNIV. measure_pmf.prob (pmf_of_set (tuples UNIV (Suc n))) ({rs. P rs}  {rs. hd rs = a}))
    by (intro prob_finite_disjoint_cases) (auto simp add: disjoint_family_on_def)
  also have ... = ?rhs using A1 
    by (simp add: sum_divide_distrib)
  finally show ?thesis .
qed


end