Theory Bayesian_Linear_Regression

(*  Title:   Bayesian_Linear_Regression.thy
    Author:  Michikazu Hirata, Tokyo Institute of Technology
*)

subsection ‹ Bayesian Linear Regression ›

theory Bayesian_Linear_Regression
  imports "Measure_as_QuasiBorel_Measure"
begin

text ‹ We formalize the Bayesian linear regression presented in cite"Heunen_2017" section VI.›
subsubsection ‹ Prior ›
abbreviation "ν  density lborel (λx. ennreal (normal_density 0 3 x))"

interpretation ν: standard_borel_prob_space ν
  by(simp add: standard_borel_prob_space_def prob_space_normal_density)

term "ν.as_qbs_measure :: real qbs_prob_space"
definition prior :: "(real  real) qbs_prob_space" where
 "prior  do { s  ν.as_qbs_measure ;
                b  ν.as_qbs_measure ;
                qbs_return (Q Q Q) (λr. s * r + b)}"

lemma ν_as_qbs_measure_eq:
 "ν.as_qbs_measure = qbs_prob_space (Q,id,ν)"
  by(simp add: ν.as_qbs_measure_retract[of id id] distr_id' measure_to_qbs_cong_sets[OF sets_density] measure_to_qbs_cong_sets[OF sets_lborel])

interpretation ν_qp: pair_qbs_prob "Q" id ν "Q" id ν
  by(auto intro!: qbs_probI prob_space_normal_density simp: pair_qbs_prob_def)

lemma ν_as_qbs_measure_in_Pr:
 "ν.as_qbs_measure  monadP_qbs_Px Q"
  by(simp add: ν_as_qbs_measure_eq ν_qp.qp1.qbs_prob_space_in_Px)

lemma sets_real_real_real[measurable_cong]:
  "sets (qbs_to_measure ((Q Q Q) Q Q)) = sets ((borel M borel) M borel)"
  by (metis pair_standard_borel.l_r_r_sets pair_standard_borel_def r_preserves_product real.standard_borel_axioms real_real.standard_borel_axioms)

lemma lin_morphism:
 "(λ(s, b) r. s * r + b)  Q Q Q Q Q Q Q"
  apply(simp add: split_beta')
  apply(rule curry_preserves_morphisms[of "λ(x,r). fst x * r + snd x",simplified curry_def split_beta',simplified])
  by auto

lemma lin_measurable[measurable]:
 "(λ(s, b) r. s * r + b)  real_borel M real_borel M qbs_to_measure (Q Q Q)"
  using lin_morphism l_preserves_morphisms[of "Q Q Q" "exp_qbs Q Q"]
  by auto

lemma prior_computation:
 "qbs_prob (Q Q Q) ((λ(s, b) r. s * r + b)  real_real.g) (distr (ν M ν) real_borel real_real.f)" 
 "prior = qbs_prob_space (Q Q Q, (λ(s, b) r. s * r + b)  real_real.g, distr (ν M ν) real_borel real_real.f)"
  using ν_qp.qbs_bind_bind_return[OF lin_morphism]
  by(simp_all add: prior_def ν_as_qbs_measure_eq map_prod_def)

text ‹ The following lemma corresponds to the equation (5). ›
lemma prior_measure:
  "qbs_prob_measure prior = distr (ν M ν) (qbs_to_measure (exp_qbs Q Q)) (λ(s,b) r. s * r + b)"
  by(simp add: prior_computation(2) qbs_prob.qbs_prob_measure_computation[OF prior_computation(1)])    (simp add: distr_distr comp_def)

lemma prior_in_space:
 "prior  qbs_space (monadP_qbs (Q Q Q))"
  using qbs_prob.qbs_prob_space_in_Px[OF prior_computation(1)]
  by(simp add: prior_computation(2))


subsubsection ‹ Likelihood ›
abbreviation "d μ x  normal_density μ (1/2) x"

lemma d_positive : "0 < d μ x"
  by(simp add: normal_density_pos)

definition obs :: "(real  real)  ennreal" where
"obs f  d (f 1) 2.5 * d (f 2) 3.8 * d (f 3) 4.5 * d (f 4) 6.2 * d (f 5) 8"

lemma obs_morphism:
 "obs  Q Q Q Q Q0"
proof(rule qbs_morphismI)
  fix α
  assume "α  qbs_Mx (Q Q Q)"
  then have [measurable]:"(λ(x,y). α x y)  real_borel M real_borel M real_borel"
    by(auto simp: exp_qbs_Mx_def)
  show "obs  α  qbs_Mx Q0"
    by(auto simp: comp_def obs_def normal_density_def)
qed

lemma obs_measurable[measurable]:
 "obs  qbs_to_measure (exp_qbs Q Q) M ennreal_borel"
  using obs_morphism by auto


subsubsection ‹ Posterior ›
lemma id_obs_morphism:
 "(λf. (f,obs f))  Q Q Q Q (Q Q Q) Q Q0"
  by(rule qbs_morphism_tuple[OF qbs_morphism_ident' obs_morphism])

lemma push_forward_measure_in_space:
 "monadP_qbs_Pf (Q Q Q) ((Q Q Q) Q Q0) (λf. (f,obs f)) prior  qbs_space (monadP_qbs ((Q Q Q) Q Q0))"
  by(rule qbs_morphismE(2)[OF monadP_qbs_Pf_morphism[OF id_obs_morphism] prior_in_space])

lemma push_forward_measure_computation:
 "qbs_prob ((Q Q Q) Q Q0) (λl. (((λ(s, b) r. s * r + b)  real_real.g) l, ((obs  (λ(s, b) r. s * r + b))  real_real.g) l)) (distr (ν M ν) real_borel real_real.f)"
 "monadP_qbs_Pf (Q Q Q) ((Q Q Q) Q Q0) (λf. (f, obs f)) prior = qbs_prob_space ((Q Q Q) Q Q0, (λl. (((λ(s, b) r. s * r + b)  real_real.g) l, ((obs  (λ(s, b) r. s * r + b))  real_real.g) l)), distr (ν M ν) real_borel real_real.f)"
  using qbs_prob.monadP_qbs_Pf_computation[OF prior_computation id_obs_morphism] by(auto simp: comp_def)

subsubsection ‹ Normalizer ›
text ‹ We use the unit space for an error. ›
definition norm_qbs_measure :: "('a × ennreal) qbs_prob_space  'a qbs_prob_space + unit" where
"norm_qbs_measure p  (let (XR,αβ,ν) = rep_qbs_prob_space p in
                          if emeasure (density ν (snd  αβ)) UNIV = 0 then Inr ()
                          else if emeasure (density ν (snd  αβ)) UNIV =  then Inr ()
                          else Inl (qbs_prob_space (map_qbs fst XR, fst  αβ, density ν (λr. snd (αβ r) / emeasure (density ν (snd  αβ)) UNIV))))"


lemma norm_qbs_measure_qbs_prob:
  assumes "qbs_prob (X Q Q0) (λr. (α r, β r)) μ"
          "emeasure (density μ β) UNIV  0"
      and "emeasure (density μ β) UNIV  "
    shows "qbs_prob X α (density μ (λr. (β r) / emeasure (density μ β) UNIV))"
proof -
  interpret qp: qbs_prob "X Q Q0" "λr. (α r, β r)" μ
    by fact
  have ha[simp]: "α  qbs_Mx X"
   and hb[measurable] :"β  real_borel M ennreal_borel"
    using assms by(simp_all add: qbs_prob_def in_Mx_def pair_qbs_Mx_def comp_def)
  show ?thesis
  proof(rule qbs_probI)
    show "prob_space (density μ (λr. β r / emeasure (density μ β) UNIV))"
    proof(rule prob_spaceI)
      show "emeasure (density μ (λr. β r / emeasure (density μ β) UNIV)) (space (density μ (λr. β r / emeasure (density μ β) UNIV))) = 1"
             (is "?lhs = ?rhs")
      proof -
        have "?lhs = emeasure (density μ (λr. β r / emeasure (density μ β) UNIV)) UNIV"
          by simp
        also have "... = (+rUNIV. (β r / emeasure (density μ β) UNIV) μ)"
          by(intro emeasure_density) auto
        also have "... =  integralN μ (λr. β r / emeasure (density μ β) UNIV)"
          by simp
        also have "... = (integralN μ β) / emeasure (density μ β) UNIV"
          by(simp add: nn_integral_divide)
        also have "... = (+rUNIV. β r μ) / emeasure (density μ β) UNIV"
          by(simp add: emeasure_density)
        also have "... = 1"
          using assms(2,3) by(simp add: emeasure_density divide_eq_1_ennreal)
        finally show ?thesis .
      qed
    qed
  qed simp_all
qed

lemma norm_qbs_measure_computation:
  assumes "qbs_prob (X Q Q0) (λr. (α r, β r)) μ"
  shows "norm_qbs_measure (qbs_prob_space (X Q Q0, (λr. (α r, β r)), μ)) = (if emeasure (density μ β) UNIV = 0 then Inr ()
                                                                                else if emeasure (density μ β) UNIV =  then Inr ()
                                                                                else Inl (qbs_prob_space (X, α, density μ (λr. (β r) / emeasure (density μ β) UNIV))))"
proof -
  interpret qp: qbs_prob "X Q Q0" "λr. (α r, β r)" μ
    by fact
  have ha: "α  qbs_Mx X"
   and hb[measurable] :"β  real_borel M ennreal_borel"
    using assms by(simp_all add: qbs_prob_def in_Mx_def pair_qbs_Mx_def comp_def)
  show ?thesis
    unfolding norm_qbs_measure_def
  proof(rule qp.in_Rep_induct)
    fix XR αβ' μ'
    assume "(XR,αβ',μ')  Rep_qbs_prob_space (qbs_prob_space (X Q Q0, λr. (α r, β r), μ))"
    from qp.if_in_Rep[OF this]
    have h:"XR = X Q Q0"
           "qbs_prob XR αβ' μ'"
           "qbs_prob_eq (X Q Q0, λr. (α r, β r), μ) (XR, αβ', μ')"
      by auto
    have hint: "f. f  X Q Q0 Q Q0  (+ x. f (α x, β x) μ) = (+ x. f (αβ' x) μ')"
      using h(3)[simplified qbs_prob_eq_equiv14] by(simp add: qbs_prob_eq4_def)
    interpret qp': qbs_prob XR αβ' μ'
      by fact
    have ha': "fst  αβ'  qbs_Mx X" "(λx. fst (αβ' x))  qbs_Mx X"
     and hb'[measurable]: "snd  αβ'  real_borel M ennreal_borel" "(λx. snd (αβ' x))  real_borel M ennreal_borel" "(λx. fst (αβ' x))  real_borel M qbs_to_measure X"
      using h by(simp_all add: qbs_prob_def in_Mx_def pair_qbs_Mx_def comp_def)
    have fstX: "map_qbs fst XR = X"
      by(simp add: h(1) pair_qbs_fst)
    have he:"emeasure (density μ β) UNIV = emeasure (density μ' (snd  αβ')) UNIV"
      using hint[OF snd_qbs_morphism] by(simp add: emeasure_density)

    show "(let a = (XR,αβ',μ') in case a of (XR, αβ, ν')  if emeasure (density ν' (snd  αβ)) UNIV = 0 then Inr ()
                                                else if emeasure (density ν' (snd  αβ)) UNIV =  then Inr ()
                                                else Inl (qbs_prob_space (map_qbs fst XR, fst  αβ, density ν' (λr. snd (αβ r) / emeasure (density ν' (snd  αβ)) UNIV))))
         = (if emeasure (density μ β) UNIV = 0 then Inr ()
            else if emeasure (density μ β) UNIV =  then Inr ()
            else Inl (qbs_prob_space (X, α, density μ (λr. β r / emeasure (density μ β) UNIV))))"
    proof(auto simp: he[symmetric] fstX)
      assume het0:"emeasure (density μ β) UNIV  "
                  "emeasure (density μ β) UNIV  0"
      interpret pqp: pair_qbs_prob X "fst  αβ'" "density μ' (λr. snd (αβ' r) / emeasure (density μ β) UNIV)" X α "density μ (λr. β r / emeasure (density μ β) UNIV)"
        apply(auto intro!: norm_qbs_measure_qbs_prob  simp: pair_qbs_prob_def assms het0)
        using het0
        by(auto intro!: norm_qbs_measure_qbs_prob[of X "fst  αβ'" "snd  αβ'",simplified,OF h(2)[simplified h(1)]] simp: he)

      show "qbs_prob_space (X, fst  αβ', density μ' (λr. snd (αβ' r) / emeasure (density μ β) UNIV)) = qbs_prob_space (X, α, density μ (λr. β r / emeasure (density μ β) UNIV))"
      proof(rule pqp.qbs_prob_space_eq4)
        fix f
        assume hf[measurable]:"f  qbs_to_measure X M ennreal_borel"
        show "(+ x. f ((fst  αβ') x) density μ' (λr. snd (αβ' r) / emeasure (density μ β) UNIV)) = (+ x. f (α x) density μ (λr. β r / emeasure (density μ β) UNIV))"
             (is "?lhs = ?rhs")
        proof -
          have "?lhs =  (+ x. (λxr. (snd xr) / emeasure (density μ β) UNIV * f (fst xr)) (αβ' x) μ')"
            by(auto simp: nn_integral_density)
          also have "... = (+ x. (λxr. (snd xr) / emeasure (density μ β) UNIV * f (fst xr)) (α x,β x) μ)"
            by(intro hint[symmetric]) (auto intro!: pair_qbs_morphismI)
          also have "... = ?rhs"
            by(simp add: nn_integral_density)
          finally show ?thesis .
        qed
      qed simp
    qed
  qed
qed

lemma norm_qbs_measure_morphism:
 "norm_qbs_measure  monadP_qbs (X Q Q0) Q monadP_qbs X <+>Q 1Q"
proof(rule qbs_morphismI)
  fix γ
  assume "γ  qbs_Mx (monadP_qbs (X Q Q0))"
  then obtain α g where hc:
   "α  qbs_Mx (X Q Q0)" "g  real_borel M prob_algebra real_borel"
      "γ = (λr. qbs_prob_space (X Q Q0, α, g r))"
    using rep_monadP_qbs_MPx[of "γ" "(X Q Q0)"] by auto
  note [measurable] = hc(2) measurable_prob_algebraD[OF hc(2)]
  have setsg[measurable_cong]:"r. sets (g r) = sets real_borel"
    using measurable_space[OF hc(2)] by(simp add: space_prob_algebra)
  then have ha: "fst  α  qbs_Mx X"
   and hb[measurable]: "snd  α  real_borel M ennreal_borel" "(λx. snd (α x))  real_borel M ennreal_borel" "r. snd  α  g r  M ennreal_borel" "r. (λx. snd (α x))  g r  M ennreal_borel"
    using hc(1) by(auto simp add: pair_qbs_Mx_def measurable_cong_sets[OF setsg refl] comp_def)
  have emeas_den_meas[measurable]: "U. U  sets real_borel  (λr. emeasure (density (g r) (snd  α)) U)  real_borel M ennreal_borel"
    by(simp add: emeasure_density)
  have S_setsc:"UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0,}  sets real_borel"
    using measurable_sets_borel[OF emeas_den_meas] by simp
  have space_non_empty:"qbs_space (monadP_qbs X)  {}"
    using ha qbs_empty_equiv monadP_qbs_empty_iff[of X] by auto
  have g_meas:"(λr. if r  (UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0,}) then density (g r) (λl. ((snd  α) l) / emeasure (density (g r) (snd  α)) UNIV) else return real_borel 0)  real_borel M prob_algebra real_borel"
  proof -
    have H:"Ω M N c f. Ω  space M  sets M  c  space N 
             f  measurable (restrict_space M Ω) N  (λx. if x  Ω then f x else c)  measurable M N"
      by(simp add: measurable_restrict_space_iff)
    show ?thesis
    proof(rule H)
      show "(UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0, })  space real_borel  sets real_borel"
        using S_setsc by simp
    next
      show "(λr. density (g r) (λl. ((snd  α) l) / emeasure (density (g r) (snd  α)) UNIV))  restrict_space real_borel (UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0,}) M prob_algebra real_borel"
      proof(rule measurable_prob_algebra_generated[where Ω=UNIV and G="sets real_borel"])

        fix a
        assume "a  space (restrict_space real_borel (UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0, }))"
        then have 1:"(+ x. snd (α x) g a)  0" "(+ x. snd (α x) g a)  "
          by(simp_all add: space_restrict_space emeasure_density)
        show "prob_space (density (g a) (λl. (snd  α) l / emeasure (density (g a) (snd  α)) UNIV))"
          using 1
          by(auto intro!: prob_spaceI simp: emeasure_density nn_integral_divide divide_eq_1_ennreal)
      next
        fix U
        assume 1:"U  sets real_borel"
        then have 2:"a. U  sets (g a)" by auto
        show "(λa. emeasure (density (g a) (λl. (snd  α) l / emeasure (density (g a) (snd  α)) UNIV)) U)  (restrict_space real_borel (UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0, })) M ennreal_borel"
          using 1
          by(auto intro!: measurable_restrict_space1 nn_integral_measurable_subprob_algebra2[where N=real_borel] simp: emeasure_density emeasure_density[OF _ 2])
      qed (simp_all add: setsg sets.Int_stable sets.sigma_sets_eq[of real_borel,simplified])
    qed (simp add:space_prob_algebra prob_space_return)
  qed

  show "norm_qbs_measure  γ  qbs_Mx (monadP_qbs X <+>Q unit_quasi_borel)"
    apply(auto intro!: bexI[OF _ S_setsc] bexI[where x="(λr. ())"] bexI[where x="λr. qbs_prob_space (X,fst  α,if r  (UNIV - (λr. emeasure (density (g r) (snd  α)) UNIV) -` {0,}) then density (g r) (λl. ((snd  α) l) / emeasure (density (g r) (snd  α)) UNIV) else return real_borel 0)"]
                 simp: copair_qbs_Mx_equiv copair_qbs_Mx2_def space_non_empty[simplified])
     apply standard
     apply(simp add: hc(3) norm_qbs_measure_computation[of _ "fst  α" "snd  α",simplified,OF qbs_prob_MPx[OF hc(1,2)]])
    apply(simp add: monadP_qbs_MPx_def in_MPx_def)
    apply(auto intro!: bexI[OF _ ha] bexI[OF _ g_meas])
    done
qed


text ‹ The following is the semantics of the entire program. ›
definition program :: "(real  real) qbs_prob_space + unit" where
 "program  norm_qbs_measure (monadP_qbs_Pf (Q Q Q) ((Q Q Q) Q Q0) (λf. (f,obs f)) prior)"

lemma program_in_space:
 "program  qbs_space (monadP_qbs (Q Q Q) <+>Q 1Q)"
  unfolding program_def
  by(rule qbs_morphismE(2)[OF norm_qbs_measure_morphism push_forward_measure_in_space])


text ‹ We calculate the normalizing constant. ›
lemma complete_the_square:
  fixes a b c x :: real
  assumes "a  0"
  shows "a*x2 + b * x + c = a * (x + (b / (2*a)))2 - ((b2 - 4* a * c)/(4*a))"
  using assms by(simp add: comm_semiring_1_class.power2_sum power2_eq_square[of "b / (2 * a)"] ring_class.ring_distribs(1) division_ring_class.diff_divide_distrib power2_eq_square[of b])

lemma complete_the_square2':
  fixes a b c x :: real
  assumes "a  0"
  shows "a*x2 - 2 * b * x + c = a * (x - (b / a))2 - ((b2 - a*c)/a)"
  using complete_the_square[OF assms,where b="-2 * b" and x=x and c=c]
  by(simp add: division_ring_class.diff_divide_distrib assms)


lemma normal_density_mu_x_swap:
   "normal_density μ σ x = normal_density x σ μ"
  by(simp add: normal_density_def power2_commute)

lemma normal_density_plus_shift:
 "normal_density μ σ (x + y) = normal_density (μ - x) σ y"
  by(simp add: normal_density_def add.commute diff_diff_eq2)

lemma normal_density_times:
  assumes "σ > 0" "σ' > 0"
  shows "normal_density μ σ x * normal_density μ' σ' x = (1 / sqrt (2 * pi * (σ2 + σ'2))) * exp (- (μ - μ')2 / (2 * (σ2 + σ'2))) * normal_density ((μ*σ'2 + μ'*σ2)/(σ2 + σ'2)) (σ * σ' / sqrt (σ2 + σ'2)) x"
        (is "?lhs = ?rhs")
proof -
  have non0: "2*σ2  0" "2*σ'2  0" "σ2 + σ'2  0"
    using assms by auto
  have "?lhs = exp (- ((x - μ)2 / (2 * σ2))) * exp (- ((x - μ')2 / (2 * σ'2))) / (sqrt (2 * pi * σ2) * sqrt (2 * pi * σ'2)) "
    by(simp add: normal_density_def)
  also have "... = exp (- ((x - μ)2 / (2 * σ2)) - ((x - μ')2 / (2 * σ'2))) / (sqrt (2 * pi * σ2) * sqrt (2 * pi * σ'2))"
    by(simp add: exp_add[of "- ((x - μ)2 / (2 * σ2))" "- ((x - μ')2 / (2 * σ'2))",simplified add_uminus_conv_diff])
  also have "... = exp (- (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2) - (μ - μ')2 / (2 * (σ2 + σ'2)))  / (sqrt (2 * pi * σ2) * sqrt (2 * pi * σ'2))"
  proof -
    have "((x - μ)2 / (2 * σ2)) + ((x - μ')2 / (2 * σ'2)) = (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2) + (μ - μ')2 / (2 * (σ2 + σ'2))"
         (is "?lhs' = ?rhs'")
    proof -
      have "?lhs' = (2 * ((x - μ)2 * σ'2) + 2 * ((x - μ')2 * σ2)) / (4 * (σ2 * σ'2))"
        by(simp add: field_class.add_frac_eq[OF non0(1,2)])
      also have "... = ((x - μ)2 * σ'2 + (x - μ')2 * σ2) / (2 * (σ2 * σ'2))"
        by(simp add: power2_eq_square division_ring_class.add_divide_distrib)
      also have "... = ((σ2 + σ'2) * x2 - 2 * (μ * σ'2 + μ' * σ2) * x  + (μ'2 * σ2 + μ2 * σ'2)) / (2 * (σ2 * σ'2))"
        by(simp add: comm_ring_1_class.power2_diff ring_class.left_diff_distrib semiring_class.distrib_right)
       also have "... = ((σ2 + σ'2) * (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 - ((μ * σ'2 + μ' * σ2)2 - (σ2 + σ'2) * (μ'2 * σ2 + μ2 * σ'2)) / (σ2 + σ'2)) / (2 * (σ2 * σ'2))"
        by(simp only: complete_the_square2'[OF non0(3),of x "(μ * σ'2 + μ' * σ2)" "(μ'2 * σ2 + μ2 * σ'2)"])
      also have "... = ((σ2 + σ'2) * (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2) / (2 * (σ2 * σ'2)) - (((μ * σ'2 + μ' * σ2)2 - (σ2 + σ'2) * (μ'2 * σ2 + μ2 * σ'2)) / (σ2 + σ'2)) / (2 * (σ2 * σ'2))"
        by(simp add: division_ring_class.diff_divide_distrib)
      also have "... = (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * ((σ * σ') / sqrt (σ2 + σ'2))2) - (((μ * σ'2 + μ' * σ2)2 - (σ2 + σ'2) * (μ'2 * σ2 + μ2 * σ'2)) / (σ2 + σ'2)) / (2 * (σ2 * σ'2))"
        by(simp add: monoid_mult_class.power2_eq_square[of "(σ * σ') / sqrt (σ2 + σ'2)"] ab_semigroup_mult_class.mult.commute[of "σ2 + σ'2"] )
          (simp add: monoid_mult_class.power2_eq_square[of σ] monoid_mult_class.power2_eq_square[of σ'])
      also have "... =  (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2) - ((μ * σ'2)2 + (μ' * σ2)2 + 2 * (μ * σ'2) * (μ' * σ2) - (σ2 * (μ'2 * σ2) + σ2 * (μ2 * σ'2) + (σ'2 * (μ'2 * σ2) + σ'2 * (μ2 * σ'2)))) / ((σ2 + σ'2) * (2 * (σ2 * σ'2)))"
        by(simp add: comm_semiring_1_class.power2_sum[of "μ * σ'2" "μ' * σ2"] semiring_class.distrib_right[of "σ2" "σ'2" "μ'2 * σ2 + μ2 * σ'2"] )
          (simp add: semiring_class.distrib_left[of _ "μ'2 * σ2 " "μ2 * σ'2"])
      also have "... = (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2) + ((σ2 * σ'2)*μ2 + (σ2 * σ'2)*μ'2 - (σ2 * σ'2) * 2 * (μ*μ')) / ((σ2 + σ'2) * (2 * (σ2 * σ'2)))"
        by(simp add: monoid_mult_class.power2_eq_square division_ring_class.minus_divide_left)
      also have "... = (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2) + (μ2 + μ'2 - 2 * (μ*μ')) / ((σ2 + σ'2) * 2)"
        using assms by(simp add: division_ring_class.add_divide_distrib division_ring_class.diff_divide_distrib)
      also have "... = ?rhs'"
        by(simp add: comm_ring_1_class.power2_diff ab_semigroup_mult_class.mult.commute[of 2])
      finally show ?thesis .
    qed
    thus ?thesis
      by simp
  qed
  also have "... = (exp (- (μ - μ')2 / (2 * (σ2 + σ'2))) / (sqrt (2 * pi * σ2) * sqrt (2 * pi * σ'2))) * sqrt (2 * pi * (σ * σ' / sqrt (σ2 + σ'2))2)  * normal_density ((μ * σ'2 + μ' * σ2) / (σ2 + σ'2)) (σ * σ' / sqrt (σ2 + σ'2)) x"
    by(simp add: exp_add[of "- (x - (μ * σ'2 + μ' * σ2) / (σ2 + σ'2))2 / (2 * (σ * σ' / sqrt (σ2 + σ'2))2)" "- (μ - μ')2 / (2 * (σ2 + σ'2))",simplified] normal_density_def)
  also have "... = ?rhs" 
  proof -
    have "exp (- (μ - μ')2 / (2 * (σ2 + σ'2))) / (sqrt (2 * pi * σ2) * sqrt (2 * pi * σ'2)) * sqrt (2 * pi * (σ * σ' / sqrt (σ2 + σ'2))2) = 1 / sqrt (2 * pi * (σ2 + σ'2)) * exp (- (μ - μ')2 / (2 * (σ2 + σ'2)))"
      using assms by(simp add: real_sqrt_mult)
    thus ?thesis
      by simp
  qed
  finally show ?thesis .
qed

lemma normal_density_times':
  assumes "σ > 0" "σ' > 0"
  shows "a * normal_density μ σ x * normal_density μ' σ' x = a * (1 / sqrt (2 * pi * (σ2 + σ'2))) * exp (- (μ - μ')2 / (2 * (σ2 + σ'2))) * normal_density ((μ*σ'2 + μ'*σ2)/(σ2 + σ'2)) (σ * σ' / sqrt (σ2 + σ'2)) x"
  using normal_density_times[OF assms,of μ x μ']
  by (simp add: mult.assoc)

lemma normal_density_times_minusx:
  assumes "σ > 0" "σ' > 0" "a  a'"
  shows "normal_density (μ - a*x) σ y * normal_density (μ' - a'*x) σ' y = (1 / ¦a' - a¦) * normal_density ((μ'- μ)/(a'-a)) (sqrt ((σ2 + σ'2)/(a' - a)2)) x * normal_density (((μ - a*x)*σ'2 + (μ' - a'*x)*σ2)/(σ2 + σ'2)) (σ * σ' / sqrt (σ2 + σ'2)) y"
proof -
  have non0:"a' - a  0"
    using assms(3) by simp
  have "1 / sqrt (2 * pi * (σ2 + σ'2)) * exp (- (μ - a * x - (μ' - a' * x))2 / (2 * (σ2 + σ'2))) = 1 / ¦a' - a¦ * normal_density ((μ' - μ) / (a' - a)) (sqrt ((σ2 + σ'2) / (a' - a)2)) x"
       (is "?lhs = ?rhs")
  proof -
    have "?lhs = 1 / sqrt (2 * pi * (σ2 + σ'2)) * exp (- ((a' - a) * x - (μ' - μ))2 / (2 * (σ2 + σ'2)))"
      by(simp add: ring_class.left_diff_distrib group_add_class.diff_diff_eq2 add.commute add_diff_eq)
    also have "... = 1 / sqrt (2 * pi * (σ2 + σ'2)) * exp (- ((a' - a)2 * (x - (μ' - μ)/(a' - a))2) / (2 * (σ2 + σ'2)))"
    proof -
      have "((a' - a) * x - (μ' - μ))2 = ((a' - a) * (x - (μ' - μ)/(a' - a)))2"
        using non0 by(simp add: ring_class.right_diff_distrib[of "a'-a" x])
      also have "... = (a' - a)2 * (x - (μ' - μ)/(a' - a))2"
        by(simp add: monoid_mult_class.power2_eq_square)
      finally show ?thesis
        by simp
    qed
    also have "... = 1 / sqrt (2 * pi * (σ2 + σ'2)) * sqrt (2 * pi * (sqrt ((σ2 + σ'2)/(a' - a)2))2) * normal_density ((μ' - μ) / (a' - a)) (sqrt ((σ2 + σ'2) / (a' - a)2)) x"
      using non0 by (simp add: normal_density_def)
    also have "... = ?rhs"
    proof -
      have "1 / sqrt (2 * pi * (σ2 + σ'2)) * sqrt (2 * pi * (sqrt ((σ2 + σ'2)/(a' - a)2))2) = 1 / ¦a' - a¦"
        using assms by(simp add: real_sqrt_divide[symmetric]) (simp add: real_sqrt_divide)
      thus ?thesis
        by simp
    qed
    finally show ?thesis .
  qed
  thus ?thesis
    by(simp add:normal_density_times[OF assms(1,2),of "μ - a*x" y "μ' - a'*x"])
qed

text ‹ The following is the normalizing constant of the program. ›
abbreviation "C  ennreal ((4 * sqrt 2 / (pi2 * sqrt (66961 * pi))) * (exp (- (1674761 / 1674025))))"

lemma program_normalizing_constant:
 "emeasure (density (distr (ν M ν) real_borel real_real.f) (obs  (λ(s, b) r. s * r + b)  real_real.g)) UNIV = C"
  (is "?lhs = ?rhs")
proof -
  have "?lhs = (+ x. (obs  (λ(s, b) r. s * r + b)  real_real.g) x  (distr (ν M ν) real_borel real_real.f))"
    by(simp add: emeasure_density)
  also have "... = (+ z. (obs  (λ(s, b) r. s * r + b)) z (ν M ν))"
    using nn_integral_distr[of real_real.f "ν M ν" real_borel "obs  (λ(s, b) r. s * r + b)  real_real.g",simplified]
    by(simp add: comp_def)
  also have "... = (+ x. + y. (obs  (λ(s, b) r. s * r + b)) (x, y) ν ν)"
    by(simp only: ν_qp.nn_integral_snd[where f="(obs  (λ(s, b) r. s * r + b))",simplified,symmetric])
      (simp add: ν_qp.Fubini[where f="(obs  (λ(s, b) r. s * r + b))",simplified])
  also have "... = (+ x. 2 / 45 * normal_density (13 / 10) (1 / sqrt 2) x * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (13 / 10) (1 / sqrt 12) x * normal_density (3 / 2) (1 / sqrt 20) x * normal_density (5 / 3) (sqrt (181 / 180)) x ν)"
  proof(rule nn_integral_cong[where M=ν,simplified])
    fix x
    have [measurable]: "(λy. obs (λr. x * r + y))  real_borel M ennreal_borel"
      using measurable_Pair2[of "obs  (λ(s, b) r. s * r + b)"] by auto
    show "(+ y. (obs  (λ(s, b) r. s * r + b)) (x, y) ν) = 2 / 45 * normal_density (13 / 10) (1 / sqrt 2) x * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (13 / 10) (1 / sqrt 12) x * normal_density (3 / 2) (1 / sqrt 20) x * normal_density (5 / 3) (sqrt (181 / 180)) x"
          (is "?lhs' = ?rhs'")
    proof -
      have "?lhs' = (+ y. ennreal (d (5 / 2 - x) y * d (19 / 5 - x * 2) y * d (9 / 2 - x * 3) y * d (31 / 5 - x * 4) y * d (8 - x * 5) y * normal_density 0 3 y) lborel)"
        by(simp add: nn_integral_density obs_def normal_density_mu_x_swap[where x="5/2"] normal_density_mu_x_swap[where x="19/5"] normal_density_mu_x_swap[where x="9/2"] normal_density_mu_x_swap[where x="31/5"] normal_density_mu_x_swap[where x="8"] normal_density_plus_shift ab_semigroup_mult_class.mult.commute[of "ennreal (normal_density 0 3 _)"] ennreal_mult'[symmetric])
      also have "... = (+ y. ennreal (2 / 45 * normal_density (13 / 10) (1 / sqrt 2) x * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (13 / 10) (1 / sqrt 12) x * normal_density (3 / 2) (1 / sqrt 20) x * normal_density (5 / 3) (sqrt (181 / 180)) x * normal_density (20 / 181 * 9 * (5 - 3 * x)) (3 / (2 * sqrt 5) / sqrt (181 / 20)) y) lborel)"
      proof(rule nn_integral_cong[where M=lborel,simplified])
        fix y
        have "d (5 / 2 - x) y * d (19 / 5 - x * 2) y * d (9 / 2 - x * 3) y * d (31 / 5 - x * 4) y * d (8 - x * 5) y * normal_density 0 3 y = 2 / 45 * normal_density (13 / 10) (1 / sqrt 2) x * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (13 / 10) (1 / sqrt 12) x * normal_density (3 / 2) (1 / sqrt 20) x * normal_density (5 / 3) (sqrt (181 / 180)) x * normal_density (20 / 181 * 9 * (5 - 3 * x)) (3 / (2 * sqrt 5) / sqrt (181 / 20)) y"
             (is "?lhs'' = ?rhs''")
        proof -
          have "?lhs'' = normal_density (13 / 10) (1 / sqrt 2) x * normal_density (63 / 20 - (3 / 2) * x)  (sqrt 2 / 4) y * d (9 / 2 - x * 3) y * d (31 / 5 - x * 4) y * d (8 - x * 5) y * normal_density 0 3 y"
          proof -
            have "d (5 / 2 - x) y * d (19 / 5 - x * 2) y = normal_density (13 / 10) (1 / sqrt 2) x * normal_density (63 / 20 - (3 / 2) * x) (sqrt 2 / 4) y"
              by(simp add: normal_density_times_minusx[of "1/2" "1/2" 1 2 "5/2" x y "19/5",simplified ab_semigroup_mult_class.mult.commute[of 2 x],simplified])
                (simp add: monoid_mult_class.power2_eq_square real_sqrt_divide division_ring_class.diff_divide_distrib)
            thus ?thesis
              by simp
          qed
          also have "... = normal_density (13 / 10) (1 / sqrt 2) x * (2 / 3) * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (18 / 5 - 2 * x) (1 / (2 * sqrt 3)) y * d (31 / 5 - x * 4) y * d (8 - x * 5) y * normal_density 0 3 y"
          proof -
            have 1:"sqrt 2 * sqrt 8 / (8 * sqrt 3) = 1 / (2 * sqrt 3)"
              by(simp add: real_sqrt_divide[symmetric] real_sqrt_mult[symmetric])
            have "normal_density (63 / 20 - 3 / 2 * x) (sqrt 2 / 4) y * d (9 / 2 - x * 3) y = (2 / 3) * normal_density (9 / 10) (1 / sqrt 6) x * normal_density (18 / 5 - 2 * x) (1 / (2 * sqrt 3)) y"
              by(simp add: normal_density_times_minusx[of "sqrt 2 / 4" "1 / 2" "3 / 2" 3 "63 / 20" x y "9 / 2",simplified ab_semigroup_mult_class.mult.commute[of 3 x],simplified])
                (simp add: monoid_mult_class.power2_eq_square real_sqrt_divide division_ring_class.diff_divide_distrib 1)
            thus ?thesis
              by simp
          qed
          also have "... = normal_density (13 / 10) (1 / sqrt 2) x * (2 / 3) * normal_density (9 / 10) (1 / sqrt 6) x * (1 / 2) * normal_density (13 / 10) (1 / sqrt 12) x * normal_density (17 / 4