# Theory Prob_Theory

theory Prob_Theory
imports Probability
```(*  Title:       Definition of Expectation and Distribution of uniformly distributed bit vectors
Author:      Max Haslbeck
*)

section "Probability Theory"

theory Prob_Theory
imports "HOL-Probability.Probability"
begin

lemma integral_map_pmf[simp]:
fixes f::"real ⇒ real"
shows "(∫x. f x ∂(map_pmf g M)) = (∫x. f (g x) ∂M)"
unfolding map_pmf_rep_eq
using integral_distr[of g "(measure_pmf M)" "(count_space UNIV)" f] by auto

subsection "function ‹E›"

definition E :: "real pmf ⇒ real"  where
"E M = (∫x. x ∂ measure_pmf M)"

translations
"∫ x. f ∂M" <= "CONST lebesgue_integral M (λx. f)"

notation (latex output) E  ("E[_]"  100)

lemma E_const[simp]: "E (return_pmf a) = a"
unfolding E_def
unfolding return_pmf.rep_eq

lemma E_null[simp]: "E (return_pmf 0) = 0"
by auto

lemma E_finite_sum: "finite (set_pmf X) ⟹ E X = (∑x∈(set_pmf X). pmf X x * x)"
unfolding E_def by (subst integral_measure_pmf) simp_all

lemma E_of_const: "E(map_pmf (λx. y) (X::real pmf)) = y" by auto

lemma E_nonneg:
shows "(∀x∈set_pmf X. 0≤ x) ⟹ 0 ≤ E X"
unfolding E_def
using integral_nonneg by (simp add: AE_measure_pmf_iff integral_nonneg_AE)

lemma E_nonneg_fun: fixes f::"'a⇒real"
shows "(∀x∈set_pmf X. 0≤f x) ⟹ 0 ≤ E (map_pmf f X)"
using E_nonneg by auto

lemma E_cong:
fixes f::"'a ⇒ real"
shows "finite (set_pmf X) ⟹ (∀x∈ set_pmf X. (f x) = (u x)) ⟹ E (map_pmf f X) = E (map_pmf u X)"
unfolding E_def integral_map_pmf apply(rule integral_cong_AE)

lemma E_mono3:
fixes f::"'a ⇒ real"
shows " integrable (measure_pmf X) f ⟹  integrable (measure_pmf X) u ⟹ (∀x∈ set_pmf X. (f x) ≤ (u x)) ⟹ E (map_pmf f X) ≤ E (map_pmf u X)"
unfolding E_def integral_map_pmf apply(rule integral_mono_AE)

lemma E_mono2:
fixes f::"'a ⇒ real"
shows "finite (set_pmf X) ⟹ (∀x∈ set_pmf X. (f x) ≤ (u x)) ⟹ E (map_pmf f X) ≤ E (map_pmf u X)"
unfolding E_def integral_map_pmf apply(rule integral_mono_AE)

lemma E_linear_diff2: "finite (set_pmf A) ⟹ E (map_pmf f A) - E (map_pmf g A) = E (map_pmf (λx. (f x) - (g x)) A)"
unfolding E_def integral_map_pmf apply(rule Bochner_Integration.integral_diff[of "measure_pmf A" f g, symmetric])

lemma E_linear_plus2: "finite (set_pmf A) ⟹ E (map_pmf f A) + E (map_pmf g A) = E (map_pmf (λx. (f x) + (g x)) A)"
unfolding E_def integral_map_pmf apply(rule Bochner_Integration.integral_add[of "measure_pmf A" f g, symmetric])

lemma E_linear_sum2: "finite (set_pmf D) ⟹ E(map_pmf (λx. (∑i<up. f i x)) D)
= (∑i<(up::nat). E(map_pmf (f i) D))"
unfolding E_def integral_map_pmf apply(rule Bochner_Integration.integral_sum) by (simp add: integrable_measure_pmf_finite)

lemma E_linear_sum_allg: "finite (set_pmf D) ⟹ E(map_pmf (λx. (∑i∈ A. f i x)) D)
= (∑i∈ (A::'a set). E(map_pmf (f i) D))"
unfolding E_def integral_map_pmf apply(rule Bochner_Integration.integral_sum) by (simp add: integrable_measure_pmf_finite)

lemma E_finite_sum_fun: "finite (set_pmf X) ⟹
E (map_pmf f X) = (∑x∈set_pmf X. pmf X x * f x)"
proof -
assume finite: "finite (set_pmf X)"
have "E (map_pmf f X) = (∫x. f x ∂measure_pmf X)"
unfolding E_def by auto
also have "… = (∑x∈set_pmf X. pmf X x * f x)"
by (subst integral_measure_pmf) (auto simp add: finite)
finally show ?thesis .
qed

lemma E_bernoulli: "0≤p ⟹ p≤1 ⟹
E (map_pmf f (bernoulli_pmf p)) = p*(f True) + (1-p)*(f False)"
unfolding E_def by (auto)

subsection "function ‹bv›"

fun bv:: "nat ⇒ bool list pmf" where
"bv 0 = return_pmf []"
| "bv (Suc n) =  do {
(xs::bool list) ← bv n;
(x::bool) ← (bernoulli_pmf 0.5);
return_pmf (x#xs)
}"

lemma bv_finite: "finite (bv n)"
by (induct  n) auto

lemma len_bv_n: "∀xs ∈ set_pmf (bv n). length xs = n"
apply(induct n) by auto

lemma bv_set: "set_pmf (bv n) = {x::bool list. length x = n}"
proof (induct n)
case (Suc n)
then have "set_pmf (bv (Suc n)) = (⋃x∈{x. length x = n}. {True # x, False # x})"
also have "… = {x#xs| x xs. length xs = n}" by auto
also have "… = {x. length x = Suc n} " using Suc_length_conv by fastforce
finally show ?case .
qed (simp)

lemma len_not_in_bv: "length xs  ≠ n ⟹ xs ∉ set_pmf (bv n)"
by(auto simp: len_bv_n)

lemma not_n_bv_0: "length xs ≠ n ⟹ pmf (bv n) xs = 0"

lemma bv_comp_bernoulli: "n < l
⟹ map_pmf (λy. y!n) (bv l) = bernoulli_pmf (5 / 10)"
proof (induct n arbitrary: l)
case 0
then obtain m where "l = Suc m" by (metis Suc_pred)
then show "map_pmf (λy. y!0) (bv l) =  bernoulli_pmf (5 / 10)" by (auto simp: map_pmf_def bind_return_pmf bind_assoc_pmf bind_return_pmf')
next
case (Suc n)
then have "0 < l" by auto
then obtain m where lsm: "l = Suc m" by (metis Suc_pred)
with Suc(2) have nltm: "n < m" by auto

from lsm have "map_pmf (λy. y ! Suc n) (bv l)
=  map_pmf (λx. x!n) (bind_pmf (bv m) (λt. (return_pmf t)))" by (auto simp: map_bind_pmf)
also
have "… =  map_pmf (λx. x!n) (bv m)" by (auto simp: bind_return_pmf')
also
have "… = bernoulli_pmf (5 / 10)" by (auto simp add: Suc(1)[of m, OF nltm])
finally
show ?case .
qed

lemma pmf_2elemlist: "pmf (bv (Suc 0)) ([x]) = pmf (bv 0) [] * pmf (bernoulli_pmf (5 / 10)) x"
unfolding bv.simps(2)[where n=0] pmf_bind pmf_return
apply (subst integral_measure_pmf[where A="{[]}"])
apply (auto) by (cases x) auto

lemma pmf_moreelemlist: "pmf (bv (Suc n)) (x#xs) = pmf (bv n) xs * pmf (bernoulli_pmf (5 / 10)) x"
unfolding bv.simps(2) pmf_bind pmf_return
apply (subst integral_measure_pmf[where A="{xs}"])
apply auto apply (cases x) apply(auto)
apply (meson indicator_simps(2) list.inject singletonD)
apply (meson indicator_simps(2) list.inject singletonD)
apply (cases x) by(auto)

lemma list_pmf: "length xs = n ⟹ pmf (bv n) xs = (1 / 2)^n"
proof(induct n arbitrary: xs)
case 0
then have "xs = []" by auto
then show "pmf (bv 0) xs = (1 / 2) ^ 0" by(auto)
next
case (Suc n xs)
then obtain a as where split: "xs = a#as" by (metis Suc_length_conv)
have "length as = n" using Suc(2) split by auto
with Suc(1) have 1: "pmf (bv n) as = (1 / 2) ^ n" by auto

from split pmf_moreelemlist[where n=n and x=a and xs=as] have
"pmf (bv (Suc n)) xs = pmf (bv n) as * pmf (bernoulli_pmf (5 / 10)) a" by auto
then have "pmf (bv (Suc n)) xs = (1 / 2) ^ n * 1 / 2" using 1 by auto
then show "pmf (bv (Suc n)) xs = (1 / 2) ^ Suc n" by auto
qed

lemma bv_0_notlen: "pmf (bv n) xs = 0 ⟹ length xs ≠ n "
by(auto simp: list_pmf)

lemma "length xs > n ⟹ pmf (bv n) xs = 0"
proof (induct n arbitrary: xs)
case (Suc n xs)
then obtain a as where split: "xs = a#as" by (metis Suc_length_conv Suc_lessE)
have "length as > n" using Suc(2) split by auto
with Suc(1) have 1: "pmf (bv n) as = 0" by auto
from split pmf_moreelemlist[where n=n and x=a and xs=as] have
"pmf (bv (Suc n)) xs = pmf (bv n) as * pmf (bernoulli_pmf (5 / 10)) a" by auto
then have "pmf (bv (Suc n)) xs = 0 * 1 / 2" using 1 by auto
then show "pmf (bv (Suc n)) xs = 0" by auto
qed simp

lemma map_hd_list_pmf: "map_pmf hd (bv (Suc n)) = bernoulli_pmf (5 / 10)"
by (simp add: map_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf')

lemma map_tl_list_pmf: "map_pmf tl (bv (Suc n)) = bv n"
by (simp add: map_pmf_def bind_assoc_pmf bind_return_pmf bind_return_pmf' )

subsection "function ‹flip›"

fun flip :: "nat ⇒ bool list ⇒ bool list" where
"flip _ [] = []"
| "flip 0 (x#xs) = (¬x)#xs"
| "flip (Suc n) (x#xs) = x#(flip n xs)"

lemma flip_length[simp]: "length (flip i xs) = length xs"
apply(induct xs arbitrary: i) apply(simp) apply(case_tac i) by(simp_all)

lemma flip_out_of_bounds: "y ≥ length X ⟹ flip y X = X"
apply(induct X arbitrary: y)
proof -
case (Cons X Xs)
hence "y > 0" by auto
with Cons obtain y' where y1: "y = Suc y'" and y2: "y' ≥ length Xs" by (metis Suc_pred' length_Cons not_less_eq_eq)
then have "flip y (X # Xs) = X#(flip y' Xs)" by auto
moreover from Cons y2 have "flip y' Xs = Xs" by auto
ultimately show ?case by auto
qed simp

lemma flip_other: "y < length X ⟹ z < length X ⟹ z ≠ y ⟹ flip z X ! y = X ! y"
apply(induct y arbitrary: X z)
apply(simp) apply (metis flip.elims neq0_conv nth_Cons_0)
proof (case_tac z, goal_cases)
case (1 y X z)
then obtain a as where "X=a#as" using length_greater_0_conv by (metis (full_types) flip.elims)
with 1(5) show ?case by(simp)
next
case (2 y X z z')
from 2 have 3: "z' ≠ y" by auto
from 2(2) have "length X > 0" by auto
then obtain a as where aas: "X = a#as" by (metis (full_types) flip.elims length_greater_0_conv)
then have a: "flip (Suc z') X ! Suc y = flip z' as ! y"
and b : "(X ! Suc y) = (as !  y)" by auto
from 2(2) aas have 1: "y < length as" by auto
from 2(3,5) aas have f2: "z' < length as" by auto
note c=2(1)[OF 1 f2 3]

have "flip z X ! Suc y = flip (Suc z') X ! Suc y" using 2 by auto
also have "… = flip z' as ! y" by (rule a)
also have "… = as ! y" by (rule c)
also have "… = (X ! Suc y)" by (rule b[symmetric])
finally show "flip z X ! Suc y = (X ! Suc y)" .
qed

lemma flip_itself: "y < length X ⟹ flip y X ! y = (¬ X ! y)"
apply(induct y arbitrary: X)
apply(simp) apply (metis flip.elims nth_Cons_0 old.nat.distinct(2))
proof -
fix y
fix X::"bool list"
assume iH: "(⋀X. y < length X ⟹ flip y X ! y = (¬ X ! y))"
assume len: "Suc y < length X"
from len have "y < length X" by auto
from len have "length X > 0" by auto
then obtain z zs where zzs: "X = z#zs" by (metis (full_types) flip.elims length_greater_0_conv)
then have a: "flip (Suc y) X ! Suc y = flip y zs ! y"
and b : "(¬ X ! Suc y) = (¬ zs !  y)" by auto
from len zzs have "y < length zs" by auto
note c=iH[OF this]
from a b c show "flip (Suc y) X ! Suc y = (¬ X ! Suc y)" by auto
qed

lemma flip_twice: "flip i (flip i b) = b"
proof (cases "i < length b")
case True
then have A: "i < length (flip i b)" by simp
show ?thesis apply(simp add: list_eq_iff_nth_eq) apply(clarify)
proof (goal_cases)
case (1 j)
then show ?case
apply(cases "i=j")
using flip_itself[OF A] flip_itself[OF True] apply(simp)
using flip_other True 1 by auto
qed

lemma flipidiflip: "y < length X ⟹ e < length X  ⟹ flip e X ! y = (if e=y then ~ (X ! y) else X ! y)"
apply(cases "e=y")

lemma bernoulli_Not: "map_pmf Not (bernoulli_pmf (1 / 2)) = (bernoulli_pmf (1 / 2))"
apply(rule pmf_eqI)
proof (case_tac i, goal_cases)
case (1 i)
then have "pmf (map_pmf Not (bernoulli_pmf (1 / 2))) i =
pmf (map_pmf Not (bernoulli_pmf (1 / 2))) (Not False)" by auto
also have "… = pmf (bernoulli_pmf (1 / 2)) False" apply (rule pmf_map_inj') apply(rule injI) by auto
also have "… = pmf (bernoulli_pmf (1 / 2)) i" by auto
finally show ?case .
next
case (2 i)
then have "pmf (map_pmf Not (bernoulli_pmf (1 / 2))) i =
pmf (map_pmf Not (bernoulli_pmf (1 / 2))) (Not True)" by auto
also have "… = pmf (bernoulli_pmf (1 / 2)) True" apply (rule pmf_map_inj') apply(rule injI) by auto
also have "… = pmf (bernoulli_pmf (1 / 2)) i" by auto
finally show ?case .
qed

lemma inv_flip_bv: "map_pmf (flip i) (bv n) = (bv n)"
proof(induct n arbitrary: i)
case (Suc n i)
note iH=this
have "bind_pmf (bv n) (λx. bind_pmf (bernoulli_pmf (1 / 2)) (λxa. map_pmf (flip i) (return_pmf (xa # x))))
= bind_pmf (bernoulli_pmf (1 / 2)) (λxa .bind_pmf (bv n) (λx. map_pmf (flip i) (return_pmf (xa # x))))"
by(rule bind_commute_pmf)
also have "… = bind_pmf (bernoulli_pmf (1 / 2)) (λxa . bind_pmf (bv n) (λx. return_pmf (xa # x)))"
proof (cases i)
case 0
then have "bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. map_pmf (flip i) (return_pmf (xa # x))))
= bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. return_pmf ((¬ xa) # x)))" by auto
also have "…  = bind_pmf (bv n) (λx. bind_pmf (bernoulli_pmf (1 / 2)) (λxa. return_pmf ((¬ xa) # x)))"
by(rule bind_commute_pmf)
also have "…
= bind_pmf (bv n) (λx. bind_pmf (map_pmf Not (bernoulli_pmf (1 / 2))) (λxa. return_pmf (xa # x)))"
also have "… = bind_pmf (bv n) (λx. bind_pmf (bernoulli_pmf (1 / 2)) (λxa. return_pmf (xa # x)))" by (simp only: bernoulli_Not)
also have "… = bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. return_pmf (xa # x)))"
by(rule bind_commute_pmf)
finally show ?thesis .
next
case (Suc i')
have "bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. map_pmf (flip i) (return_pmf (xa # x))))
= bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. return_pmf (xa # flip i' x)))" unfolding Suc by(simp)
also have "… = bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (map_pmf (flip i') (bv n)) (λx. return_pmf (xa # x)))"
also have "… =  bind_pmf (bernoulli_pmf (1 / 2)) (λxa. bind_pmf (bv n) (λx. return_pmf (xa # x)))"
using iH[of "i'"] by simp
finally show ?thesis .
qed
also have "… = bind_pmf (bv n) (λx. bind_pmf (bernoulli_pmf (1 / 2)) (λxa. return_pmf (xa # x)))"
by(rule bind_commute_pmf)
finally show ?case by(simp add: map_pmf_def bind_assoc_pmf)
qed simp

subsection "Example for pmf"

definition "twocoins =
do {
x ← (bernoulli_pmf 0.4);
y ← (bernoulli_pmf 0.5);
return_pmf (x ∨ y)
}"

lemma experiment0_7: "pmf twocoins True = 0.7"
unfolding twocoins_def
unfolding pmf_bind pmf_return
apply (subst integral_measure_pmf[where A="{True, False}"])
by auto

subsection "Sum Distribution"

definition "Sum_pmf p Da Db = (bernoulli_pmf p) ⤜ (%b. if b then map_pmf Inl Da else map_pmf Inr Db )"

lemma b0: "bernoulli_pmf 0 = return_pmf False"
apply(rule pmf_eqI) apply(case_tac i)
by(simp_all)
lemma b1: "bernoulli_pmf 1 = return_pmf True"
apply(rule pmf_eqI) apply(case_tac i)
by(simp_all)

lemma Sum_pmf_0: "Sum_pmf 0 Da Db = map_pmf Inr Db"
unfolding Sum_pmf_def
apply(rule pmf_eqI)

lemma Sum_pmf_1: "Sum_pmf 1 Da Db = map_pmf Inl Da"
unfolding Sum_pmf_def
apply(rule pmf_eqI)

definition "Proj1_pmf D = map_pmf (%a. case a of Inl e ⇒ e) (cond_pmf D {f. (∃e. Inl e = f)})"

lemma A: "(case_sum (λe. e) (λa. undefined)) (Inl e) = e"
by(simp)

lemma B: "inj (case_sum (λe. e) (λa. undefined))"
oops

lemma none: "p >0 ⟹ p < 1 ⟹ (set_pmf (bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
∩ {f. (∃e. Inl e = f)}) ≠ {}"
using set_pmf_not_empty by fast
lemma none2: "p >0 ⟹ p < 1 ⟹  (set_pmf (bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
∩ {f. (∃e. Inr e = f)}) ≠ {}"
using set_pmf_not_empty by fast

lemma C: "set_pmf (Proj1_pmf (Sum_pmf 0.5 Da Db)) = set_pmf Da"
proof -
show ?thesis
unfolding Sum_pmf_def Proj1_pmf_def
using none[of "0.5" Da Db] apply(simp add: set_cond_pmf UNIV_bool)
by force
qed

thm integral_measure_pmf

thm pmf_cond pmf_cond[OF none]

lemma proj1_pmf: assumes "p>0" "p<1" shows "Proj1_pmf (Sum_pmf p Da Db) =  Da"
proof -

have kl: "⋀e. pmf (map_pmf Inr Db) (Inl e) = 0"
apply(simp only: pmf_eq_0_set_pmf)
apply(simp) by blast

have ll: "measure_pmf.prob
(bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
{f. ∃e. Inl e = f} = p"
using assms
using integrable_pmf apply fast
using integrable_pmf apply fast

have E: "(cond_pmf
(bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
{f. ∃e. Inl e = f}) =
map_pmf Inl Da"
apply(rule pmf_eqI)
apply(subst pmf_cond)
using none[of p Da Db] assms apply (simp)
using assms apply(auto)
apply(subst pmf_bind)
apply(simp only: pmf_eq_0_set_pmf) by auto

have ID: "case_sum (λe. e) (λa. undefined) ∘ Inl = id"
by fastforce
show ?thesis
unfolding Sum_pmf_def Proj1_pmf_def
apply(simp only: E)
done

qed

definition "Proj2_pmf D = map_pmf (%a. case a of Inr e ⇒ e) (cond_pmf D {f. (∃e. Inr e = f)})"

lemma proj2_pmf: assumes "p>0" "p<1" shows "Proj2_pmf (Sum_pmf p Da Db) =  Db"
proof -

have kl: "⋀e. pmf (map_pmf Inl Da) (Inr e) = 0"
apply(simp only: pmf_eq_0_set_pmf)
apply(simp) by blast

have ll: "measure_pmf.prob
(bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
{f. ∃e. Inr e = f} = 1-p"
using assms
using integrable_pmf apply fast
using integrable_pmf apply fast

have E: "(cond_pmf
(bernoulli_pmf p ⤜
(λb. if b then map_pmf Inl Da else map_pmf Inr Db))
{f. ∃e. Inr e = f}) =
map_pmf Inr Db"
apply(rule pmf_eqI)
apply(subst pmf_cond)
using none2[of p Da Db] assms apply (simp)
using assms apply(auto)
apply(subst pmf_bind)
apply(simp only: pmf_eq_0_set_pmf) by auto

have ID: "case_sum (λe. undefined) (λa. a) ∘ Inr = id"
by fastforce
show ?thesis
unfolding Sum_pmf_def Proj2_pmf_def
apply(simp only: E)
done

qed

definition "invSum invA invB D x i == invA (Proj1_pmf D) x i ∧ invB (Proj2_pmf D) x i"

lemma invSum_split: "p>0 ⟹ p<1 ⟹ invA Da x i ⟹ invB Db x i ⟹ invSum invA invB (Sum_pmf p Da Db) x i"

term "(%a. case a of Inl e ⇒ Inl (fa e) | Inr e ⇒ Inr (fb e))"
definition "f_on2 fa fb = (%a. case a of Inl e ⇒ map_pmf Inl (fa e) | Inr e ⇒ map_pmf Inr (fb e))"

term "bind_pmf"

lemma Sum_bind_pmf: assumes a: "bind_pmf Da fa = Da'" and b: "bind_pmf Db fb = Db'"
shows "bind_pmf (Sum_pmf p Da Db) (f_on2 fa fb)
= Sum_pmf p Da' Db'"
proof -
{ fix x
have "(if x then map_pmf Inl Da else map_pmf Inr Db) ⤜
case_sum (λe. map_pmf Inl (fa e))
(λe. map_pmf Inr (fb e))
=
(if x then map_pmf Inl Da ⤜ case_sum (λe. map_pmf Inl (fa e))
(λe. map_pmf Inr (fb e))
else map_pmf Inr Db ⤜ case_sum (λe. map_pmf Inl (fa e))
(λe. map_pmf Inr (fb e)))"
apply(simp) done
also
have "… = (if x then map_pmf Inl (bind_pmf Da fa) else map_pmf Inr (bind_pmf Db fb))"
by(auto simp add: map_pmf_def bind_assoc_pmf bind_return_pmf)
also
have "… = (if x then map_pmf Inl Da' else map_pmf Inr Db')"
using a b by simp
finally
have "(if x then map_pmf Inl Da else map_pmf Inr Db) ⤜
case_sum (λe. map_pmf Inl (fa e))
(λe. map_pmf Inr (fb e)) = (if x then map_pmf Inl Da' else map_pmf Inr Db')" .
} note gr=this

show ?thesis
unfolding Sum_pmf_def f_on2_def
apply(rule pmf_eqI)
apply(case_tac i)
qed

definition "sum_map_pmf fa fb = (%a. case a of Inl e ⇒ Inl (fa e) | Inr e ⇒ Inr (fb e))"

lemma Sum_map_pmf: assumes a: "map_pmf fa Da = Da'" and b: "map_pmf fb Db = Db'"
shows "map_pmf (sum_map_pmf fa fb) (Sum_pmf p Da Db)
= Sum_pmf p Da' Db'"
proof -
have "map_pmf (sum_map_pmf fa fb) (Sum_pmf p Da Db)
= bind_pmf (Sum_pmf p Da Db) (f_on2 (λx. return_pmf (fa x)) (λx. return_pmf (fb x)))"
using a b
unfolding map_pmf_def sum_map_pmf_def f_on2_def