Theory Query

(*  Title:   Query.thy
    Author:  Michikazu Hirata, Tokyo Institute of Technology
*)

subsection ‹Query›

theory Query
  imports "Monad_QuasiBorel"
begin

declare [[coercion qbs_l]]
abbreviation qbs_real :: "real quasi_borel"       ("Q") where "Q  qbs_borel"
abbreviation qbs_ennreal :: "ennreal quasi_borel" ("Q0") where "Q0  qbs_borel"
abbreviation qbs_nat :: "nat quasi_borel"         ("Q") where "Q  qbs_count_space UNIV"
abbreviation qbs_bool :: "bool quasi_borel"       ("𝔹Q") where "𝔹Q  count_spaceQ UNIV"


definition query :: "['a qbs_measure, 'a  ennreal]  'a qbs_measure" where
"query  (λs f. normalize_qbs (density_qbs s f))"

lemma query_qbs_morphism[qbs]: "query  monadM_qbs X Q (X Q qbs_borel) Q monadM_qbs X"
  by(simp add: query_def)

definition "condition  (λs P. query s (λx. if P x then 1 else 0))"

lemma condition_qbs_morphism[qbs]: "condition  monadM_qbs X Q (X Q 𝔹Q) Q monadM_qbs X"
  by(simp add: condition_def)

lemma condition_morphismP:
  assumes "x. x  qbs_space X  𝒫(y in qbs_l (s x). P x y)  0"
      and [qbs]: "s  X Q monadP_qbs Y" "P  X Q Y Q qbs_count_space UNIV"
    shows "(λx. condition (s x) (P x))  X Q monadP_qbs Y"
proof(rule qbs_morphism_cong'[where f="λx. normalize_qbs (density_qbs (s x) (indicator {yqbs_space Y. P x y}))"])
  fix x
  assume x[qbs]:"x  qbs_space X"
  have "density_qbs (s x) (indicator {y  qbs_space Y. P x y}) = density_qbs (s x) (λy. if P x y then 1 else 0)"
    by(auto intro!: density_qbs_cong[OF qbs_space_monadPM[OF qbs_morphism_space[OF assms(2) x]]] indicator_qbs_morphism'')
  thus "normalize_qbs (density_qbs (s x) (indicator {y  qbs_space Y. P x y})) = condition (s x) (P x)"
    unfolding condition_def query_def by simp
next
  show "(λx. normalize_qbs (density_qbs (s x) (indicator {y  qbs_space Y. P x y})))  X Q monadP_qbs Y"
  proof(rule normalize_qbs_morphismP[of "λx. density_qbs (s x) (indicator {y  qbs_space Y. P x y})"])
    show "(λx. density_qbs (s x) (indicator {y  qbs_space Y. P x y}))  X Q monadM_qbs Y"
      using qbs_morphism_monadPD[OF assms(2)] by simp
  next
    fix x
    assume x:"x  qbs_space X"
    show "emeasure (qbs_l (density_qbs (s x) (indicator {y  qbs_space Y. P x y}))) (qbs_space Y)  0"
         "emeasure (qbs_l (density_qbs (s x) (indicator {y  qbs_space Y. P x y}))) (qbs_space Y)  "
      using assms(1)[OF x] qbs_l_monadP_le1[OF qbs_morphism_space[OF assms(2) x]]
      by(auto simp add: qbs_l_density_qbs_indicator[OF qbs_space_monadPM[OF qbs_morphism_space[OF assms(2) x]] qbs_morphism_space[OF assms(3) x]] measure_def space_qbs_l_in[OF qbs_space_monadPM[OF qbs_morphism_space[OF assms(2) x]]])
  qed
qed

lemma query_Bayes:
  assumes [qbs]: "s  qbs_space (monadP_qbs X)" "qbs_pred X P" "qbs_pred X Q"
  shows "𝒫(x in condition s P. Q x) = 𝒫(x in s. Q x ¦ P x)" (is "?lhs = ?pq")
proof -
  have X: "qbs_space X  {}"
    using assms(1) by(simp only: monadP_qbs_empty_iff[of X]) blast
  note s[qbs] = qbs_space_monadPM[OF assms(1)]
  have density_eq: "density_qbs s (λx. if P x then 1 else 0) = density_qbs s (indicator {xqbs_space X. P x})"
    by(auto intro!: density_qbs_cong[of _ X] indicator_qbs_morphism'')
  consider "emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X) = 0" | "emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)  0" by auto
  then show ?thesis
  proof cases
    case 1
    have 2:"normalize_qbs (density_qbs s (λx. if P x then 1 else 0)) = qbs_null_measure X"
      by(rule normalize_qbs0) (auto simp: 1)
    have "𝒫(ω in qbs_l s. P ω) = measure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)"
      by(simp add: space_qbs_l_in[OF s] measure_def density_eq qbs_l_density_qbs_indicator[OF s])
    also have "... = 0"
      by(simp add: measure_def 1)
    finally show ?thesis
      by(auto simp: condition_def query_def cond_prob_def 2 1 qbs_null_measure_null_measure[OF X])
  next
    case 1[simp]:2
    from rep_qbs_space_monadP[OF assms(1)]
    obtain α μ where hs: "s = X, α, μsfin" "qbs_prob X α μ" by auto
    then interpret qp: qbs_prob X α μ by simp
    have [measurable]:"Measurable.pred (qbs_to_measure X) P" "Measurable.pred (qbs_to_measure X) Q"
      using assms(2,3) by(simp_all add: lr_adjunction_correspondence)
    have 2[simp]: "emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)  "
      by(simp add: hs(1) qp.density_qbs qbs_s_finite.qbs_l[OF qp.density_qbs_s_finite] emeasure_distr emeasure_distr[where N="qbs_to_measure X",OF _ sets.top,simplified space_L] emeasure_density,rule order.strict_implies_not_eq[OF order.strict_trans1[OF qp.nn_integral_le_const[of 1] ennreal_one_less_top]]) auto
    have 3: "measure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X) > 0"
      using 2 emeasure_eq_ennreal_measure zero_less_measure_iff by fastforce
    have "query s (λx. if P x then 1 else 0) = density_qbs (density_qbs s (λx. if P x then 1 else 0)) (λx. 1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X))"
      unfolding query_def by(rule normalize_qbs) auto
    also have "... = density_qbs s (λx. (if P x then 1 else 0) * (1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)))"
      by(simp add: density_qbs_density_qbs_eq[OF qbs_space_monadPM[OF assms(1)]])
    finally have query:"query s (λx. if P x then 1 else 0) = ..." .
    have "?lhs = measure (density (qbs_l s) (λx. (if P x then 1 else 0) * (1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)))) {x  space (qbs_l s). Q x}"
      by(simp add: condition_def query qbs_l_density_qbs[OF qbs_space_monadPM[OF assms(1)]])
    also have "... = measure (density μ (λx. (if P (α x) then 1 else 0) * (1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)))) {y. α y  space (qbs_to_measure X)  Q (α y)}"
      by(simp add: hs(1) qp.qbs_l  density_distr measure_def emeasure_distr)
    also have "... = measure (density μ (λx. indicator {r. P (α r)} x * (1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)))) {y. Q (α y)}"
    proof -
      have [simp]:"(if P (α r) then 1 else 0) = indicator {r. P (α r)} r " for r
        by auto
      thus ?thesis by(simp add: space_L)
    qed
    also have "... = enn2real (1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)) * measure μ {r. P (α r)  Q (α r)}"
    proof -
      have n_inf: "1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)  "
        using 1 by(auto simp: ennreal_divide_eq_top_iff)
      show ?thesis
        by(simp add: measure_density_times[OF _ _ n_inf] Collect_conj_eq)
    qed
    also have "... = (1 / measure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)) * qp.prob {r. P (α r)  Q (α r)}"
    proof -
      have "1 / emeasure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X) = ennreal (1 / measure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X))"
        by(auto simp add: emeasure_eq_ennreal_measure[OF 2] ennreal_1[symmetric] simp del: ennreal_1 intro!: divide_ennreal) (simp_all add: 3)
      thus ?thesis by simp
    qed  also have "... = ?pq"
    proof -
      have qp:"𝒫(x in s. Q x  P x) = qp.prob {r. P (α r)  Q (α r)}"
        by(auto simp: hs(1) qp.qbs_l measure_def emeasure_distr, simp add: space_L) meson
      note sets = sets_qbs_l[OF qbs_space_monadPM[OF assms(1)],measurable_cong]
      have [simp]: "density (qbs_l s) (λx. if P x then 1 else 0) = density (qbs_l s) (indicator {xspace (qbs_to_measure X). P x})"
        by(auto intro!: density_cong) (auto simp: indicator_def space_L sets_eq_imp_space_eq[OF sets])
      have p: "𝒫(x in s. P x) = measure (qbs_l (density_qbs s (λx. if P x then 1 else 0))) (qbs_space X)"
        by(auto simp: qbs_l_density_qbs[OF qbs_space_monadPM[OF assms(1),qbs]]) (auto simp: measure_restricted[of "{x  space (qbs_to_measure X). P x}" "qbs_l s",simplified sets,OF _ sets.top,simplified,simplified space_L] space_L sets_eq_imp_space_eq[OF sets])
      thus ?thesis
        by(simp add: qp p cond_prob_def)
    qed
    finally show ?thesis .
  qed
qed

lemma qbs_pmf_cond_pmf:
  fixes p :: "'a :: countable pmf"
  assumes "set_pmf p  {x. P x}  {}"
  shows "condition (qbs_pmf p) P = qbs_pmf (cond_pmf p {x. P x})"
proof(rule inj_onD[OF qbs_l_inj[of "count_space UNIV"]])
  note count_space_count_space_qbs_morphism[of P,qbs]
  show g1:"condition (qbs_pmf p) P  qbs_space (monadM_qbs (count_spaceQ UNIV))" "qbs_pmf (cond_pmf p {x. P x})  qbs_space (monadM_qbs (count_spaceQ UNIV))"
    by auto
  show "qbs_l (condition (qbs_pmf p) P) = qbs_l (qbs_pmf (cond_pmf p {x. P x}))"
  proof(safe intro!: measure_eqI_countable)
    fix a
    have "condition (qbs_pmf p) P = normalize_qbs (density_qbs (qbs_pmf p) (λx. if P x then 1 else 0))"
      by(auto simp: condition_def query_def)
    also have "... = density_qbs (density_qbs (qbs_pmf p) (λx. if P x then 1 else 0)) (λx. 1 / emeasure (qbs_l (density_qbs (qbs_pmf p) (λx. if P x then 1 else 0))) (qbs_space (count_spaceQ UNIV)))"
    proof -
      have 1:"(+ x. ennreal (pmf p x) * (if P x then 1 else 0) count_space UNIV) = (+ x{x. P x}. ennreal (pmf p x) count_space UNIV)"
        by(auto intro!: nn_integral_cong)
      have "... > 0"
        using assms(1) by(force intro!: nn_integral_less[of "λx. 0",simplified] simp: AE_count_space set_pmf_eq' indicator_def)
      hence 2:"(+x{x. P x}. ennreal (pmf p x) count_space UNIV)  0"
        by auto
      have 3:"(+ x{x. P x}. ennreal (pmf p x) count_space UNIV)  "
      proof -
        have "(+ x{x. P x}. ennreal (pmf p x) count_space UNIV)  (+ x. ennreal (pmf p x) count_space UNIV)"
          by(auto intro!: nn_integral_mono simp: indicator_def)
        also have "... = 1"
          by (simp add: nn_integral_pmf_eq_1)
        finally show ?thesis
          using ennreal_one_neq_top neq_top_trans by fastforce
      qed
      show ?thesis
        by(rule normalize_qbs) (auto simp: qbs_l_density_qbs[of _ "count_space UNIV"] emeasure_density nn_integral_measure_pmf 1 2 3)
    qed
    also have "... = density_qbs (qbs_pmf p) (λx. (if P x then 1 else 0) * (1 / (+ x. ennreal (pmf p x) * (if P x then 1 else 0) count_space UNIV)))"
      by(simp add: density_qbs_density_qbs_eq[of _ "count_space UNIV"] qbs_l_density_qbs[of _ "count_space UNIV"] emeasure_density nn_integral_measure_pmf)
    also have "... = density_qbs (qbs_pmf p) (λx. (if P x then 1 else 0) * (1 / (emeasure (measure_pmf p) (Collect P))))"
    proof -
      have [simp]: "(+ x. ennreal (pmf p x) * (if P x then 1 else 0) count_space UNIV) = emeasure (measure_pmf p) (Collect P)" (is "?l = ?r")
      proof -
        have "?l = (+ x. ennreal (pmf p x) * (if P x then 1 else 0) count_space {x. P x})"
          by(rule  nn_integral_count_space_eq) auto
        also have "... = ?r"
          by(auto simp: nn_integral_pmf[symmetric] intro!: nn_integral_cong)
        finally show ?thesis .
      qed
      show ?thesis by simp
    qed 
    finally show "emeasure (qbs_l (condition (qbs_pmf p) P)) {a} = emeasure (qbs_l (qbs_pmf (cond_pmf p {x. P x}))) {a}"
      by(simp add: ennreal_divide_times qbs_l_density_qbs[of _ "count_space UNIV"] emeasure_density cond_pmf.rep_eq[OF assms(1)])
  qed(auto simp: sets_qbs_l[OF g1(1)])
qed

subsubsection ‹\texttt{twoUs}›
text ‹ Example from Section~2 in @{cite Sato_2019}.›
definition "Uniform  (λa b::real. uniform_qbs lborel_qbs {a<..<b})"

lemma Uniform_qbs[qbs]: "Uniform  Q Q Q Q monadM_qbs Q"
  unfolding Uniform_def by (rule interval_uniform_qbs)

definition twoUs :: "(real × real) qbs_measure" where
"twoUs  do {
              let u1 = Uniform 0 1;
              let u2 = Uniform 0 1;
              let y = u1 Qmes u2;
              condition y (λ(x,y). x < 0.5  y > 0.5)
             }"

lemma twoUs_qbs: "twoUs  monadM_qbs (Q Q Q)"
  by(simp add: twoUs_def)

interpretation rr: standard_borel_ne "borel M borel :: (real × real) measure"
  by(simp add: borel_prod)

lemma qbs_l_Uniform[simp]: "a < b  qbs_l (Uniform a b) = uniform_measure lborel {a<..<b}"
  by(simp add: standard_borel_ne.qbs_l_uniform_qbs[of borel lborel_qbs] Uniform_def)

lemma Uniform_qbsP:
  assumes [arith]: "a < b"
  shows "Uniform a b  monadP_qbs Q"
  by(auto simp: monadP_qbs_def sub_qbs_space intro!: prob_space_uniform_measure)

interpretation UniformP_pair: pair_prob_space "uniform_measure lborel {0<..<1::real}" "uniform_measure lborel {0<..<1::real}"
  by(auto simp: pair_prob_space_def pair_sigma_finite_def intro!: prob_space_imp_sigma_finite prob_space_uniform_measure)

lemma qbs_l_Uniform_pair: "a < b  qbs_l (Uniform a b Qmes Uniform a b) = uniform_measure lborel {a<..<b} M uniform_measure lborel {a<..<b}"
  by(auto intro!: qbs_l_qbs_pair_measure[of borel borel] standard_borel_ne.standard_borel simp: qbs_l_Uniform[symmetric] simp del: qbs_l_Uniform)

lemma Uniform_pair_qbs[qbs]:
  assumes "a < b"
  shows "Uniform a b Qmes Uniform a b  qbs_space (monadP_qbs (Q Q Q))"
proof -
  note [qbs] = qbs_pair_measure_morphismP Uniform_qbsP[OF assms]
  show ?thesis
    by simp
qed

lemma twoUs_prob1: "𝒫(z in Uniform 0 1 Qmes Uniform 0 1. fst z < 0.5  snd z > 0.5) = 3 / 4"
proof -
  have [simp]:"{z  space (uniform_measure lborel {0<..<1::real} M uniform_measure lborel {0<..<1::real}). fst z * 2 < 1  1 < snd z * 2} = UNIV × {1/2<..}  {..<1/2} × UNIV"
    by(auto simp: space_pair_measure)
  have 1:"UniformP_pair.prob (UNIV × {1 / 2<..}) = 1 / 2"
  proof -
    have [simp]:"{0<..<1}  {1 / 2<..} = {1/2<..<1::real}" by auto
    thus ?thesis
      by(auto simp: UniformP_pair.M1.measure_times)
  qed
  have 2:"UniformP_pair.prob ({..<1 / 2} × UNIV - UNIV × {1 / 2<..}) = 1 / 4"
  proof -
    have [simp]: "{..<1/2::real} × UNIV - UNIV × {1/2::real<..} = {..<1/2} × {..1/2}" "{0<..<1}  {..<1/2} = {0<..<1/2::real}" "{0<..<1}  {..1/2::real} = {0<..1/2}"
      by auto
    show ?thesis
      by(auto simp: UniformP_pair.M1.measure_times)
  qed
  show ?thesis
    by(auto simp: qbs_l_Uniform_pair UniformP_pair.P.finite_measure_Union' 1 2)
qed

lemma twoUs_prob2:"𝒫(z in Uniform 0 1 Qmes Uniform 0 1. 1/2 < fst z  (fst z < 1/2  snd z > 1/2)) = 1 / 4"
proof -
  have [simp]:"{z  space (uniform_measure lborel {0<..<1::real} M uniform_measure lborel {0<..<1::real}). 1 < fst z * 2  (fst z * 2 < 1  1 < snd z * 2)} = {1/2<..} × {1/2<..}"
    by(auto simp: space_pair_measure)
  have [simp]: "{0<..<1::real}  {1/2<..} = {1/2<..<1}" by auto
  show ?thesis
    by(auto simp: qbs_l_Uniform_pair UniformP_pair.M1.measure_times)
qed

lemma twoUs_qbs_prob: "twoUs  qbs_space (monadP_qbs (Q Q Q))" 
proof -
  have "𝒫(z in Uniform 0 1 Qmes Uniform 0 1. fst z < 0.5  snd z > 0.5)  0"
    unfolding twoUs_prob1 by simp
  note qbs_morphism_space[OF condition_morphismP[of qbs_borel "λx. Uniform 0 1 Qmes Uniform 0 1" "λx z. fst z < 0.5  snd z > 0.5" "Q Q Q",OF this],simplified,qbs]
  note Uniform_pair_qbs[of 0 1,simplified,qbs]
  show ?thesis
    by(simp add: twoUs_def split_beta')
qed

lemma "𝒫((x,y) in twoUs. 1/2 < x) = 1 / 3"
proof -
  have "𝒫((x,y) in twoUs. 1/2 < x) = 𝒫(z in twoUs. 1/2 < fst z)"
    by (simp add: split_beta')
  also have "... = 𝒫(z in Uniform 0 1 Qmes Uniform 0 1. 1/2 < fst z ¦ fst z < 0.5  snd z > 0.5)"
    by(simp add: twoUs_def split_beta',rule query_Bayes[OF Uniform_pair_qbs[of 0 1,simplified,qbs]]) auto
  also have "... = 𝒫(z in Uniform 0 1 Qmes Uniform 0 1. 1/2 < fst z  (fst z < 1/2  snd z > 1/2)) / 𝒫(z in Uniform 0 1 Qmes Uniform 0 1. fst z < 0.5  snd z > 0.5)"
    by(simp add: cond_prob_def)
  also have "... = 1 / 3"
    by(simp only: twoUs_prob2 twoUs_prob1) simp
  finally show ?thesis .
qed

subsubsection ‹ Two Dice ›
text ‹ Example from Adrian~\cite[Sect.~2.3]{Adrian_PL}.›
abbreviation "die  qbs_pmf (pmf_of_set {Suc 0..6})"

lemma die_qbs[qbs]: "die  monadM_qbs Q"
  by simp

definition two_dice :: "nat qbs_measure" where
 "two_dice  do {
                let die1 = die;
                let die2 = die;
                let twodice = die1 Qmes die2;
                (x,y)  condition twodice
                        (λ(x,y). x = 4  y = 4);
                return_qbs Q (x + y)
              }"

lemma two_dice_qbs: "two_dice  monadM_qbs Q"
  by(simp add: two_dice_def)

lemma prob_die2: "𝒫(x in qbs_l (die Qmes die). P x) = real (card ({x. P x}  ({1..6} × {1..6}))) / 36" (is "?P = ?rhs")
proof -
  have "?P = measure_pmf.prob (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {x. P x}"
    by(auto simp: qbs_pair_pmf)
  also have "... = measure_pmf.prob (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) ({x. P x}  set_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})))"
    by(rule measure_Int_set_pmf[symmetric])
  also have "... = measure_pmf.prob (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) ({x. P x}  ({Suc 0..6} × {Suc 0..6}))"
    by simp
  also have "... = (z{x. P x}  ({Suc 0..6} × {Suc 0..6}). pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) z)"
    by(simp add: measure_measure_pmf_finite)
  also have "... = (z{x. P x}  ({Suc 0..6} × {Suc 0..6}). 1 / 36)"
    by(rule Finite_Cartesian_Product.sum_cong_aux) (auto simp: pmf_pair)
  also have "... = ?rhs"
    by auto
  finally show ?thesis .
qed

lemma dice_prob1: "𝒫(z in qbs_l (die Qmes die). fst z = 4  snd z = 4) = 11 / 36"
proof -
  have 1:"Restr {z. fst z = 4  snd z = 4} {Suc 0..6::nat} = {Suc 0..Suc (Suc (Suc (Suc (Suc (Suc 0)))))} × {Suc (Suc (Suc (Suc 0)))}  {Suc (Suc (Suc (Suc 0)))} × {Suc 0..(Suc (Suc (Suc 0)))}  {Suc (Suc (Suc (Suc 0)))} × {Suc (Suc (Suc (Suc (Suc 0))))..Suc (Suc (Suc (Suc (Suc (Suc 0)))))}"
    by fastforce
  have "card ... = card ({Suc 0..Suc (Suc (Suc (Suc (Suc (Suc 0)))))} × {Suc (Suc (Suc (Suc 0)))}  {Suc (Suc (Suc (Suc 0)))} × {Suc 0..(Suc (Suc (Suc 0)))}) + card ({Suc (Suc (Suc (Suc 0)))} × {Suc (Suc (Suc (Suc (Suc 0))))..Suc (Suc (Suc (Suc (Suc (Suc 0)))))})"
    by(rule card_Un_disjnt) (auto simp: disjnt_def)
  also have "... = card ({Suc 0..Suc (Suc (Suc (Suc (Suc (Suc 0)))))} × {Suc (Suc (Suc (Suc 0)))}) + card ({Suc (Suc (Suc (Suc 0)))} × {Suc 0..(Suc (Suc (Suc 0)))}) + card ({Suc (Suc (Suc (Suc 0)))} × {Suc (Suc (Suc (Suc (Suc 0))))..Suc (Suc (Suc (Suc (Suc (Suc 0)))))})"
  proof -
    have "card ({Suc 0..Suc (Suc (Suc (Suc (Suc (Suc 0)))))} × {Suc (Suc (Suc (Suc 0)))}  {Suc (Suc (Suc (Suc 0)))} × {Suc 0..(Suc (Suc (Suc 0)))}) = card ({Suc 0..Suc (Suc (Suc (Suc (Suc (Suc 0)))))} × {Suc (Suc (Suc (Suc 0)))}) + card ({Suc (Suc (Suc (Suc 0)))} × {Suc 0..(Suc (Suc (Suc 0)))})"
      by(rule card_Un_disjnt) (auto simp: disjnt_def)
    thus ?thesis by simp
  qed
  also have "... = 11" by auto
  finally show ?thesis
    by(auto simp: prob_die2 1)
qed

lemma dice_program_prob:"𝒫(x in two_dice. P x) = 2 * (n{5,6,7,9,10}. of_bool (P n) / 11) + of_bool (P 8) / 11" (is "?P = ?rp")
proof -
  have 0: "(x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x}) = {5,6,7,8,9,10}"
  proof safe
    show " 5  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(1,4)"])
    show "6  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(2,4)"])
    show "7  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(3,4)"])
    show "8  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(4,4)"])
    show "9  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(5,4)"])
    show "10  (x{Suc 0..6} × {Suc 0..6}  {(x, y). x = 4  y = 4}. {fst x + snd x})"
      by(auto intro!: bexI[where x="(6,4)"])
  qed auto
   
  have 1:"{Suc 0..6} × {Suc 0..6}  {x. fst x = 4  snd x = 4}  {}"
  proof - 
    have "(1,4)  {Suc 0..6} × {Suc 0..6}  {x. fst x = 4  snd x = 4}"
      by auto
    thus ?thesis by blast
  qed
  hence 2: "set_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6}))  {(x, y). x = 4  y = 4}  {}"
    by(auto simp: split_beta')
  have ceq:"condition (die Qmes die) (λ(x,y). x = 4  y = 4) = qbs_pmf (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x,y). x = 4  y = 4})"
    by(auto simp: split_beta' qbs_pair_pmf 1 intro!: qbs_pmf_cond_pmf)
  have "two_dice = condition (die Qmes die) (λ(x,y). x = 4  y = 4)  (λ(x,y). return_qbs Q (x + y))"
    by(simp add: two_dice_def)
  also have "... = qbs_pmf (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x,y). x = 4  y = 4})  (λz. qbs_pmf (return_pmf (fst z + snd z)))"
    by(simp add: ceq) (simp add: qbs_pmf_return_pmf split_beta')
  also have "... = qbs_pmf (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x,y). x = 4  y = 4}  (λz. return_pmf (fst z + snd z)))"
    by(rule qbs_pmf_bind_pmf[symmetric])
  finally have two_dice_eq:"two_dice = qbs_pmf (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x,y). x = 4  y = 4}  (λz. return_pmf (fst z + snd z)))" .

  have 3:"measure_pmf.prob (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x, y). x = 4  y = 4} = 11 / 36"
    using dice_prob1 by(auto simp: split_beta' qbs_pair_pmf)

  have "?P = measure_pmf.prob (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x, y). x = 4  y = 4}  (λz. return_pmf (fst z + snd z))) {x. P x}" (is "_ = measure_pmf.prob ?bind _")
    by(simp add: two_dice_eq)
  also have "... = measure_pmf.prob ?bind ({x. P x}  set_pmf ?bind)"
    by(rule measure_Int_set_pmf[symmetric])
  also have "... = sum (pmf ?bind) ({x. P x}  set_pmf ?bind)"
    by(rule measure_measure_pmf_finite) (auto simp: set_cond_pmf[OF 2])
  also have "... = sum (pmf ?bind) ({x. P x}  {5, 6, 7, 8, 9, 10})"
    by(auto simp: set_cond_pmf[OF 2] 0)
  also have "... = (n{n. P n}{5, 6, 7, 8, 9, 10}. measure_pmf.expectation (cond_pmf (pair_pmf (pmf_of_set {Suc 0..6}) (pmf_of_set {Suc 0..6})) {(x, y). x = 4  y = 4}) (λx. indicat_real {n} (fst x + snd x)))" (is "_ = (__. measure_pmf.expectation ?cond _ )")
    by(simp add: pmf_bind)
  also have "... = (n{n. P n}{5, 6, 7, 8, 9, 10}. (m{(1,4),(2,4),(3,4),(4,4),(5,4),(6,4),(4,1),(4,2),(4,3),(4,5),(4,6)}. indicat_real {n} (fst m + snd m) * pmf ?cond m))"
  proof(intro Finite_Cartesian_Product.sum_cong_aux integral_measure_pmf_real)
    fix n m
    assume h:"n  {n. P n}{5, 6, 7, 8, 9, 10}" "m  set_pmf ?cond" "indicat_real {n} (fst m + snd m)  0"
    then have nm:"fst m + snd m = n"
      by(auto simp: indicator_def)
    have m: "fst m  0" "snd m  0" "fst m = 4  snd m = 4"
      using h(2) by(auto simp: set_cond_pmf[OF 2])
    show "m  {(1, 4), (2, 4), (3, 4), (4,4), (5, 4), (6, 4), (4, 1), (4, 2), (4, 3), (4, 5), (4, 6)}"
      using h(1) nm m by(auto, metis prod.collapse)+
  qed simp
  also have "... = (n{n. P n}{5, 6, 7, 8, 9, 10}. (m{(1,4),(2,4),(3,4),(4,4),(5,4),(6,4),(4,1),(4,2),(4,3),(4,5),(4,6)}. indicat_real {n} (fst m + snd m) * 1 / 11))"
  proof(rule Finite_Cartesian_Product.sum_cong_aux[OF Finite_Cartesian_Product.sum_cong_aux])
    fix n m
    assume h:"n  {n. P n}{5, 6, 7, 8, 9, 10}" "m  {(1,4),(2,4),(3,4),(4,4),(5,4),(6,4),(4,1),(4,2),(4,3),(4,5),(4::nat,6::nat)}"
    have "pmf ?cond m = 1 / 11"
      using h(2) by(auto simp add: pmf_cond[OF 2] 3 pmf_pair)
    thus " indicat_real {n} (fst m + snd m) * pmf ?cond m = indicat_real {n} (fst m + snd m) * 1 / 11"
      by simp
  qed
  also have "... = ?rp"
    by fastforce
  finally show ?thesis .
qed

corollary
 "𝒫(x in two_dice. x = 5)  = 2 / 11"
 "𝒫(x in two_dice. x = 6)  = 2 / 11"
 "𝒫(x in two_dice. x = 7)  = 2 / 11"
 "𝒫(x in two_dice. x = 8)  = 1 / 11"
 "𝒫(x in two_dice. x = 9)  = 2 / 11"
 "𝒫(x in two_dice. x = 10) = 2 / 11"

  unfolding dice_program_prob by simp_all

subsubsection ‹ Gaussian Mean Learning ›
text ‹ Example from Sato et al.~Section~8.~2 in @{cite Sato_2019}.›

definition "Gauss  (λμ σ. density_qbs lborelQ (normal_density μ σ))"

lemma Gauss_qbs[qbs]: "Gauss  Q Q Q Q monadM_qbs Q"
  by(simp add: Gauss_def)

primrec GaussLearn' :: "[real, real qbs_measure, real list]
                                      real qbs_measure" where
  "GaussLearn' _ p [] = p"
| "GaussLearn' σ p (y#ls) = query (GaussLearn' σ p ls)
                                  (normal_density y σ)"

lemma GaussLearn'_qbs[qbs]:"GaussLearn'  Q Q monadM_qbs Q Q list_qbs Q Q monadM_qbs Q"
  by(simp add: GaussLearn'_def)

context
  fixes σ :: real
  assumes [arith]: "σ > 0"
begin


abbreviation "GaussLearn  GaussLearn' σ"

lemma GaussLearn_qbs[qbs]: "GaussLearn  qbs_space (monadM_qbs Q Q list_qbs Q Q monadM_qbs Q)"
  by simp

definition Total :: "real list  real" where "Total = (λl. foldr (+) l 0)"

lemma Total_simp: "Total [] = 0" "Total (y#ls) = y + Total ls"
  by(simp_all add: Total_def)

lemma Total_qbs[qbs]: "Total  list_qbs Q Q Q"
  by(simp add: Total_def)

lemma GaussLearn_Total:
  assumes [arith]: "ξ > 0" "n = length L"
  shows "GaussLearn (Gauss δ ξ) L = Gauss ((Total L*ξ2+δ*σ2)/(n*ξ2+σ2)) (sqrt ((ξ2*σ2)/(n*ξ2+σ2)))"
  using assms(2)
proof(induction L arbitrary: n)
  case Nil
  then show ?case
    by(simp add: Total_def)
next
  case ih:(Cons a L)
  then obtain n' where n':"n = Suc n'" "n' = length L"
    by auto
  have 1:"ξ2 * σ2 / (real n' * ξ2 + σ2) > 0"
    by(auto intro!: divide_pos_pos add_nonneg_pos)
  have sigma:"(sqrt (ξ2 * σ2 / (real n' * ξ2 + σ2)) * σ / sqrt (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2)) = (sqrt (ξ2 * σ2 / (real n * ξ2 + σ2)))"
  proof(rule power2_eq_imp_eq)
    show "(sqrt (ξ2 * σ2 / (real n' * ξ2 + σ2)) * σ / sqrt (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2))2 = (sqrt (ξ2 * σ2 / (real n * ξ2 + σ2)))2" (is "?lhs = ?rhs")
    proof -
      have "?lhs = (ξ2 * σ2 / (real n' * ξ2 + σ2)) * (σ2 / (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2))"
        by (simp add: power_divide power_mult_distrib)
      also have "... = ξ2 * σ2 / (real n' * ξ2 + σ2) * (σ2 / ((ξ2 / (real n' * ξ2 + σ2) + 1) * σ2))"
        by (simp add: distrib_left mult.commute)
      also have "... = ξ2 * σ2 / (real n' * ξ2 + σ2) * (1 / (ξ2 / (real n' * ξ2 + σ2) + 1))"
        by simp
      also have "... = ξ2 * σ2 / (real n' * ξ2 + σ2) * (1 / ((ξ2 + (real n' * ξ2 + σ2)) / (real n' * ξ2 + σ2)))"
        by(simp only: add_divide_distrib[of "ξ2"]) auto
      also have "... = ξ2 * σ2 / (real n' * ξ2 + σ2) * ((real n' * ξ2 + σ2) / (ξ2 + (real n' * ξ2 + σ2)))"
        by simp
      also have "... = ξ2 * σ2 / (ξ2 + (real n' * ξ2 + σ2))"
        using "1" by force
      also have "... = ?rhs"
        by(simp add: n'(1) distrib_right)
      finally show ?thesis .
    qed
  qed simp_all
  have mu: "((Total L * ξ2 + δ * σ2) * σ2 / (real n' * ξ2 + σ2) + a * (ξ2 * σ2) / (real n' * ξ2 + σ2)) / (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2) = ((a + Total L) * ξ2 + δ * σ2) / (real n * ξ2 + σ2)" (is "?lhs = ?rhs")
  proof -
    have "?lhs = (((Total L * ξ2 + δ * σ2) * σ2 + a * (ξ2 * σ2))/ (real n' * ξ2 + σ2)) / (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2)"
      by(simp add: add_divide_distrib)
    also have "... = (((Total L * ξ2 + δ * σ2) + a * ξ2) * σ2 / (real n' * ξ2 + σ2)) / (ξ2 * σ2 / (real n' * ξ2 + σ2) + σ2)"
      by (simp add: distrib_left mult.commute)
    also have "... = (((Total L * ξ2 + δ * σ2) + a * ξ2) * σ2 / (real n' * ξ2 + σ2)) / ((ξ2 * σ2 + (real n' * ξ2 + σ2) * σ2) / (real n' * ξ2 + σ2))"
      by (simp add: add_divide_distrib)
    also have "... = (((Total L * ξ2 + δ * σ2) + a * ξ2) * σ2) / (ξ2 * σ2 + (real n' * ξ2 + σ2) * σ2)"
      using 1 by auto
    also have "... = (((Total L * ξ2 + δ * σ2) + a * ξ2) * σ2) / ((ξ2 + (real n' * ξ2 + σ2)) * σ2)"
      by(simp only: distrib_right)
    also have "... = ((Total L * ξ2 + δ * σ2) + a * ξ2) / (ξ2 + (real n' * ξ2 + σ2))"
      by simp
    also have "... = ((Total L * ξ2 + δ * σ2) + a * ξ2) / (real n * ξ2 + σ2)"
      by(simp add: n'(1) distrib_right)
    also have "... = ?rhs"
      by (simp add: distrib_right)
    finally show ?thesis .
  qed
  show ?case
    by(simp add: ih(1)[OF n'(2)]) (simp add: query_def qbs_normal_posterior[OF real_sqrt_gt_zero[OF 1]] Gauss_def Total_simp sigma mu)
qed

lemma GaussLearn_KL_divergence_lem1:
  fixes a :: real
  assumes [arith]: "a > 0" "b > 0" "c > 0" "d > 0"
  shows "(λn. ln ((b * (n * d + c)) / (d * (n * b + a))))  0"
proof -
  have "(λn::nat. ln ( (b * (Suc n * d + c)) / (d * (Suc n * b + a)))) = (λn. ln ( (b * (d + c / Suc n)) / (d * (b + a / Suc n))))"
  proof
    fix n
    show "ln (b * (real (Suc n) * d + c) / (d * (real (Suc n) * b + a))) = ln (b * (d + c / real (Suc n)) / (d * (b + a / real (Suc n))))" (is "ln ?l = ln ?r")
    proof -
      have "?l = b * (d + c / real (Suc n)) / (d * (b + a / real (Suc n))) * (Suc n / Suc n)"
        unfolding times_divide_times_eq distrib_left distrib_right by (simp add: mult.assoc mult.commute)
      also have "... = ?r" by simp
      finally show ?thesis by simp
    qed
  qed
  also have "...  0"
    apply(rule tendsto_eq_intros(33)[of _ 1])
      apply(rule Topological_Spaces.tendsto_eq_intros(25)[of _ "b * d" _ _ "b * d",OF LIMSEQ_Suc[OF Topological_Spaces.tendsto_eq_intros(18)[of _ b _ _ d]] LIMSEQ_Suc[OF Topological_Spaces.tendsto_eq_intros(18)[of _ d _ _ b]]])
             apply(intro Topological_Spaces.tendsto_eq_intros | auto)+  
    done
  finally show ?thesis
    by(rule LIMSEQ_imp_Suc)
qed

lemma GaussLearn_KL_divergence_lem1':
  fixes b :: real
  assumes [arith]: "b > 0" "d > 0" "s > 0"
  shows "(λn. ln (sqrt (b2 * s2 / (real n * b2 + s2)) / sqrt (d2 * s2 / (real n * d2 + s2))))  0" (is "?f  0")
proof -
  have "?f = (λn. ln (sqrt ((b2 * (n * d2 + s2))/ (d2 * (n * b2 + s2)))))"
    by(simp add: real_sqrt_divide real_sqrt_mult mult.commute)
  also have "... = (λn. ln ((b2 * (n * d2 + s2) / (d2 * (n * b2 + s2)))) / 2)"
    by (standard, rule ln_sqrt) (auto intro!: divide_pos_pos mult_pos_pos add_nonneg_pos)
  also have "...  0"
    using GaussLearn_KL_divergence_lem1 by auto
  finally show ?thesis .
qed

lemma GaussLearn_KL_divergence_lem2:
  fixes s :: real
  assumes [arith]: "s > 0" "b > 0" "d > 0"
  shows "(λn. ((d * s) / (n * d + s)) / (2 * ((b * s) / (n * b + s))))  1 / 2"
proof -
  have "(λn::nat. ((d * s) / (Suc n * d + s)) / (2 * ((b * s) / (Suc n * b + s)))) = (λn. (d * b + d * s / Suc n) / (2 * b * d + 2 * b * s / Suc n))"
  proof
    fix n
    show "d * s / (real (Suc n) * d + s) / (2 * (b * s / (real (Suc n) * b + s))) = (d * b + d * s / real (Suc n)) / (2 * b * d + 2 * b * s / real (Suc n))" (is "?l = ?r")
    proof -
      have "?l = d * (Suc n * b + s) / ((2 * b) * (Suc n * d + s))"
        by(simp add: divide_divide_times_eq)
      also have "... = d * (b + s / Suc n) / ((2 * b) * (d + s / Suc n)) * (Suc n / Suc n)"
      proof -
        have 1:"(2 * b * d * real (Suc n) + 2 * b * (s / real (Suc n)) * real (Suc n))= (2 * b) * (Suc n * d + s)"
          unfolding distrib_left distrib_right by(simp add: mult.assoc mult.commute)
        show ?thesis
          unfolding times_divide_times_eq distrib_left distrib_right 1
          by (simp add: mult.assoc mult.commute)
      qed
      also have "... = ?r"
        by(auto simp:  distrib_right distrib_left mult.commute)
      finally show ?thesis .
    qed
  qed
  also have "...  1 / 2"
    by(rule Topological_Spaces.tendsto_eq_intros(25)[of _ "d * b" _ _ "2 * b * d",OF LIMSEQ_Suc LIMSEQ_Suc]) (intro Topological_Spaces.tendsto_eq_intros | auto)+
  finally show ?thesis
    by(rule LIMSEQ_imp_Suc)
qed

lemma GaussLearn_KL_divergence_lem2':
  fixes s :: real
  assumes [arith]: "s > 0" "b > 0" "d > 0"
  shows "(λn. ((d^2 * s^2) / (n * d^2 + s^2)) / (2 * ((b^2 * s^2) / (n * b^2 + s^2))) - 1 / 2)  0"
  using GaussLearn_KL_divergence_lem2[of "s^2" "b^2" "d^2"]
  by(rule LIM_zero) auto

lemma GaussLearn_KL_divergence_lem3:
  fixes a b c d s K L :: real
  assumes [arith]: "b > 0" "d > 0" "s > 0"
  shows "((K * d + c * s) / (n * d + s) - (L * b + a * s) / (n * b + s))^2 / (2 * ((b * s) / (n * b + s))) = ((((((K - L) * d * b * real n + c * s * b * real n + K * d * s + c * s * s) - a * s * d * real n - L * b * s - a * s * s))2 / (d * d * b * (real n * real n * real n) + s * s * b * real n + 2 * d * s * b * (real n * real n) + d * d * (real n * real n) * s + s * s * s + 2 * d * s * s * real n))) / (2 * (b * s))" (is "?lhs = ?rhs")
proof -
  have 0:"real n * d + s > 0" "real n * b + s > 0"
    by(auto intro!: add_nonneg_pos)
  hence 1:"real n * d + s  0" "real n * b + s  0" by simp_all
  have "?lhs = (((K * d + c * s) * (n * b + s) - (L * b + a * s) * (n * d + s)) / ((n * d + s) * (n * b + s)))2 / (2 * (b * s / (n * b + s)))"
    unfolding diff_frac_eq[OF 1] by simp
  also have "... = (((((K * d + c * s) * (n * b + s) - (L * b + a * s) * (n * d + s)))2 / ((n * d + s)^2 * (n * b + s)))) / (2 * (b * s))"
    by(auto simp: power2_eq_square)
  also have "... = (((((K * d * (n * b) + c * s * (n * b) + K * d * s + c * s * s) - ((L * b * (n * d) + a * s * (n * d) + L * b * s + a * s * s))))2 / ((n * d)^2 * (n * b) + s^2 * (n * b) + 2 * (n * d) * s * (n * b) + (n * d)^2 * s + s^2 * s + 2 * (n * d) * s * s))) / (2 * (b * s))"
    by(simp add: power2_sum distrib_left distrib_right is_num_normalize(1))
  also have "... = (((((K * d * b * real n + c * s * b * real n + K * d * s + c * s * s) - ((L * b * d * real n + a * s * d * real n + L * b * s + a * s * s))))2 / (d * d * b * (real n * real n * real n) + s * s * b *real n + 2 * d * s * b * (real n * real n) + d * d * (real n * real n) * s + s * s * s + 2 * d * s * s * real n))) / (2 * (b * s))"
    by (simp add: mult.commute mult.left_commute power2_eq_square)
  also have "... = ((((((K - L) * d * b * real n + c * s * b * real n + K * d * s + c * s * s) - ((a * s * d * real n + L * b * s + a * s * s))))2 / (d * d