Theory MFMC_Countable.Rel_PMF_Characterisation
theory Rel_PMF_Characterisation imports
  Matrix_For_Marginals
begin
section ‹Characterisation of @{const rel_pmf}›
proposition rel_pmf_measureI:
  fixes p :: "'a pmf" and q :: "'b pmf"
  assumes le: "⋀A. measure (measure_pmf p) A ≤ measure (measure_pmf q) {y. ∃x∈A. R x y}"
  shows "rel_pmf R p q"
proof -
  let ?A = "set_pmf p" and ?f = "λx. ennreal (pmf p x)"
    and ?B = "set_pmf q" and ?g = "λy. ennreal (pmf q y)"
  define R' where "R' = {(x, y)∈?A × ?B. R x y}"
  have "(∑⇧+ x∈?A. ?f x) = (∑⇧+ y∈?B. ?g y)" (is "?lhs = ?rhs")
    and "(∑⇧+ y∈?B. ?g y) ≠ ⊤" (is ?bounded)
  proof -
    have "?lhs = (∑⇧+ x. ?f x)" "?rhs = (∑⇧+ y. ?g y)"
      by(auto simp add: nn_integral_count_space_indicator pmf_eq_0_set_pmf intro!: nn_integral_cong split: split_indicator)
    then show "?lhs = ?rhs" ?bounded by(simp_all add: nn_integral_pmf_eq_1)
  qed
  moreover
  have "(∑⇧+ x∈X. ?f x) ≤ (∑⇧+ y∈R' `` X. ?g y)" (is "?lhs ≤ ?rhs") if "X ⊆ set_pmf p" for X
  proof -
    have "?lhs = measure (measure_pmf p) X" 
      by(simp add: nn_integral_pmf measure_pmf.emeasure_eq_measure)
    also have "… ≤ measure (measure_pmf q) {y. ∃x∈X. R x y}" by(simp add: le)
    also have "… = measure (measure_pmf q) (R' `` X)" using that
      by(auto 4 3 simp add: R'_def AE_measure_pmf_iff intro!: measure_eq_AE)
    also have "… = ?rhs" by(simp add: nn_integral_pmf measure_pmf.emeasure_eq_measure)
    finally show ?thesis .
  qed
  moreover have "countable ?A" "countable ?B" by simp_all
  moreover have "R' ⊆ ?A × ?B" by(auto simp add: R'_def)
  ultimately obtain h
    where supp: "⋀x y. 0 < h x y ⟹ (x, y) ∈ R'"
    and bound: "⋀x y. h x y ≠ ⊤"
    and p: "⋀x. x ∈ ?A ⟹ (∑⇧+ y∈?B. h x y) = ?f x"
    and q: "⋀y. y ∈ ?B ⟹ (∑⇧+ x∈?A. h x y) = ?g y"
    by(rule bounded_matrix_for_marginals_ennreal) blast+
  let ?z = "λ(x, y). enn2real (h x y)"
  define z where "z = embed_pmf ?z"
  have nonneg: "⋀xy. 0 ≤ ?z xy" by clarsimp
  have outside: "h x y = 0" if "x ∉ set_pmf p ∨ y ∉ set_pmf q ∨ ¬ R x y" for x y
    using supp[of x y] that by(cases "h x y > 0")(auto simp add: R'_def)
  have prob: "(∑⇧+ xy. ?z xy) = 1"
  proof -
    have "(∑⇧+ xy. ?z xy) = (∑⇧+ x. ∑⇧+ y. (ennreal ∘ ?z) (x, y))"
      unfolding nn_integral_fst_count_space by(simp add: split_def o_def)
    also have "… = (∑⇧+ x. (∑⇧+y. h x y))" using bound
      by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
    also have "… = (∑⇧+ x∈?A. (∑⇧+y∈?B. h x y))"
      by(auto intro!: nn_integral_cong nn_integral_zero' simp add: nn_integral_count_space_indicator outside split: split_indicator)
    also have "… = (∑⇧+ x∈?A. ?f x)" by(auto simp add: p intro!: nn_integral_cong)
    also have "… = (∑⇧+ x. ?f x)"
      by(auto simp add: nn_integral_count_space_indicator pmf_eq_0_set_pmf intro!: nn_integral_cong split: split_indicator)
    finally show ?thesis by(simp add: nn_integral_pmf_eq_1)
  qed
  note z = nonneg prob
  have z_sel [simp]: "pmf z (x, y) = enn2real (h x y)" for x y
    by(simp add: z_def pmf_embed_pmf[OF z])
  show ?thesis
  proof
    show "R x y" if "(x, y) ∈ set_pmf z" for x y using that
      using that outside[of x y] unfolding set_pmf_iff
      by(auto simp add: enn2real_eq_0_iff)
    show "map_pmf fst z = p"
    proof(rule pmf_eqI)
      fix x
      have "pmf (map_pmf fst z) x = (∑⇧+ e∈range (Pair x). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have "… = (∑⇧+ y. h x y)"
        using bound by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
      also have "… = (∑⇧+y∈?B. h x y)" using outside
        by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have "… = ?f x" using p[of x] apply(cases "x ∈ set_pmf p")
        by(auto simp add: set_pmf_iff AE_count_space outside intro!: nn_integral_zero')
      finally show "pmf (map_pmf fst z) x = pmf p x" by simp
    qed
    show "map_pmf snd z = q"
    proof(rule pmf_eqI)
      fix y
      have "pmf (map_pmf snd z) y = (∑⇧+ e∈range (λx. (x, y)). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have "… = (∑⇧+ x. h x y)"
        using bound by(simp add: nn_integral_count_space_reindex ennreal_enn2real_if)
      also have "… = (∑⇧+x∈?A. h x y)" using outside
        by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have "… = ?g y" using q[of y] apply(cases "y ∈ set_pmf q")
        by(auto simp add: set_pmf_iff AE_count_space outside intro!: nn_integral_zero')
      finally show "pmf (map_pmf snd z) y = pmf q y" by simp
    qed
  qed
qed
subsection ‹Code generation for @{const rel_pmf}›
proposition rel_pmf_measureI':
  fixes p :: "'a pmf" and q :: "'b pmf"
  assumes le: "⋀A. A ⊆ set_pmf p ⟹ measure_pmf.prob p A ≤ measure_pmf.prob q {y ∈ set_pmf q. ∃x∈A. R x y}"
  shows "rel_pmf R p q"
proof(rule rel_pmf_measureI)
  fix A
  let ?A = "A ∩ set_pmf p"
  have "measure_pmf.prob p A = measure_pmf.prob p ?A" by(simp add: measure_Int_set_pmf)
  also have "… ≤ measure_pmf.prob q {y ∈ set_pmf q. ∃x∈?A. R x y}" by(rule le) simp
  also have "… ≤ measure_pmf.prob q {y. ∃x∈A. R x y}"
    by(rule measure_pmf.finite_measure_mono) auto
  finally show "measure_pmf.prob p A ≤ …" .
qed
lemma rel_pmf_code [code]:
  "rel_pmf R p q ⟷
   (let B = set_pmf q in
    ∀A∈Pow (set_pmf p). measure_pmf.prob p A ≤ measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × B)))"
  unfolding Let_def
proof(intro iffI strip)
  have eq: "snd ` Set.filter (case_prod R) (A × set_pmf q) = {y. ∃x∈A. R x y} ∩ set_pmf q" for A
    by(auto intro: rev_image_eqI simp add: Set.filter_def)
  show "measure_pmf.prob p A ≤ measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × set_pmf q))"
    if "rel_pmf R p q" and "A ∈ Pow (set_pmf p)" for A
    using that by(auto dest: rel_pmf_measureD simp add: eq measure_Int_set_pmf)
  show "rel_pmf R p q" if "∀A∈Pow (set_pmf p). measure_pmf.prob p A ≤ measure_pmf.prob q (snd ` Set.filter (case_prod R) (A × set_pmf q))"
    using that by(intro rel_pmf_measureI')(auto intro: ord_le_eq_trans arg_cong2[where f=measure] simp add: eq)
qed
end