Theory Rel_PMF_Characterisation_MFMC

(* Author: Andreas Lochbihler, Digital Asset *)

theory Rel_PMF_Characterisation_MFMC
imports 
  MFMC_Bounded 
  MFMC_Unbounded
  "HOL-Library.Simps_Case_Conv"
begin

section ‹Characterisation of @{const rel_pmf} proved via MFMC›

context begin

private datatype ('a, 'b) vertex = Source | Sink | Left 'a | Right 'b

private lemma inj_Left [simp]: "inj_on Left X"
by(simp add: inj_on_def)

private lemma inj_Right [simp]: "inj_on Right X"
by(simp add: inj_on_def)

context fixes p :: "'a pmf" and q :: "'b pmf" and R :: "'a  'b  bool" begin

private inductive edge' :: "('a, 'b) vertex  ('a, 'b) vertex  bool" where
  "edge' Source (Left x)" if "x  set_pmf p"
| "edge' (Left x) (Right y)" if "R x y" "x  set_pmf p" "y  set_pmf q"
| "edge' (Right y) Sink" if "y  set_pmf q"

private inductive_simps edge'_simps [simp]:
  "edge' xv (Left x)"
  "edge' (Left x) (Right y)"
  "edge' (Right y) yv"
  "edge' Source (Right y)"
  "edge' Source Sink"
  "edge' xv Source"
  "edge' Sink yv"
  "edge' (Left x) Sink"

private inductive_cases edge'_SourceE [elim!]: "edge' Source yv"
private inductive_cases edge'_LeftE [elim!]: "edge' (Left x) yv"
private inductive_cases edge'_RightE [elim!]: "edge' xv (Right y)"
private inductive_cases edge'_SinkE [elim!]: "edge' xv Sink"

private function cap :: "('a, 'b) vertex flow" where
  "cap (xv, Left x) = (if xv = Source then ennreal (pmf p x) else 0)"
| "cap (Left x, Right y) = 
  (if R x y  x  set_pmf p  y  set_pmf q 
   then pmf q y ― ‹Return @{term pmf q y} so that total weight of @{text ‹x›}'s neighbours is finite,
                    i.e., the network satisfies @{locale bounded_countable_network}.›
   else 0)"
| "cap (Right y, yv) = (if yv = Sink then ennreal (pmf q y) else 0)"
| "cap (Source, Right y) = 0"
| "cap (Source, Sink) = 0"
| "cap (xv, Source) = 0"
| "cap (Sink, yv) = 0"
| "cap (Left x, Sink) = 0"
  by pat_completeness auto
termination by lexicographic_order

private definition Δ :: "('a, 'b) vertex network"
  where "Δ = edge = edge', capacity = cap, source = Source, sink = Sink"

private lemma Δ_sel [simp]:
  "edge Δ = edge'"
  "capacity Δ = cap"
  "source Δ = Source"
  "sink Δ = Sink"
  by(simp_all add: Δ_def)

private lemma IN_Left [simp]: "INΔ(Left x) = (if x  set_pmf p then {Source} else {})"
  by(auto simp add: incoming_def)
private lemma OUT_Right [simp]: "OUTΔ(Right y) = (if y  set_pmf q then {Sink} else {})"
  by(auto simp add: outgoing_def)

interpretation network: countable_network Δ
proof
  show "source Δ  sink Δ" by simp
  show "capacity Δ e = 0" if "e  EΔ⇙" for e using that
    by(cases e; cases "fst e"; cases "snd e")(auto simp add: pmf_eq_0_set_pmf)
  show "capacity Δ e  top" for e by(cases e rule: cap.cases)(auto)
  have "EΔ ((Pair Source  Left) ` set_pmf p)  (map_prod Left Right ` (set_pmf p × set_pmf q))  ((λy. (Right y, Sink)) ` set_pmf q)"
    by(auto elim: edge'.cases)
  thus "countable EΔ⇙" by(rule countable_subset) auto
qed

private lemma OUT_cap_Source: "d_OUT cap Source = 1"
proof -
  have "d_OUT cap Source = (+ yrange Left. cap (Source, y))"
    by(auto 4 4 simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.capacity_outside[simplified] split: split_indicator)
  also have " = (+ y. pmf p y)" by(simp add: nn_integral_count_space_reindex)
  also have " = 1" by(simp add: nn_integral_pmf)
  finally show ?thesis .
qed
private lemma IN_cap_Left: "d_IN cap (Left x) = pmf p x"
  by(subst d_IN_alt_def[of _ Δ])(simp_all add: pmf_eq_0_set_pmf nn_integral_count_space_indicator max_def)
private lemma OUT_cap_Right: "d_OUT cap (Right y) = pmf q y"
  by(subst d_OUT_alt_def[of _ Δ])(simp_all add: pmf_eq_0_set_pmf nn_integral_count_space_indicator max_def)

private lemma rel_pmf_measureI_aux:
  assumes ex_flow: "f S. flow Δ f  cut Δ S  orthogonal Δ f S"
    and le: "A. measure (measure_pmf p) A  measure (measure_pmf q) {y. xA. R x y}"
  shows "rel_pmf R p q"
proof -
  from ex_flow obtain f S
    where f: "flow Δ f" and cut: "cut Δ S" and ortho: "orthogonal Δ f S" by blast
  from cut obtain Source: "Source  S" and Sink: "Sink  S" by cases simp

  have f_finite [simp]: "f e < top" for e
    using network.flowD_finite[OF f, of e] by (simp_all add: less_top)

  have IN_f_Left: "d_IN f (Left x) = f (Source, Left x)" for x
    by(subst d_IN_alt_def[of _ Δ])(simp_all add: nn_integral_count_space_indicator max_def network.flowD_outside[OF f])
  have OUT_f_Right: "d_OUT f (Right y) = f (Right y, Sink)" for y
    by(subst d_OUT_alt_def[of _ Δ])(simp_all add: nn_integral_count_space_indicator max_def network.flowD_outside[OF f])

  have "value_flow Δ f  1" using flowD_capacity_OUT[OF f, of Source] by(simp add: OUT_cap_Source)
  moreover have "1  value_flow Δ f"
  proof -
    let ?L = "Left -` S  set_pmf p"
    let ?R' = "{y|y x. x  set_pmf p  Left x  S  R x y  y  set_pmf q  Right y  S}"
    let ?R'' = "{y|y x. x  set_pmf p  Left x  S  R x y  y  set_pmf q  ¬ Right y  S}"
    have "value_flow Δ f = (+ xrange Left. f (Source, x))" unfolding d_OUT_def
      by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
    also have " = (+ x. f (Source, Left x) * indicator ?L x) + (+ x. f (Source, Left x) * indicator (- ?L) x)"
      by(subst nn_integral_add[symmetric])(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_cong split: split_indicator)
    also have "(+ x. f (Source, Left x) * indicator (- ?L) x) = (+ x- ?L. cap (Source, Left x))"
      using orthogonalD_out[OF ortho _ Source]
      apply(auto simp add: set_pmf_iff network.flowD_outside[OF f] nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      subgoal for x by(cases "x  set_pmf p")(auto simp add: set_pmf_iff network.flowD_outside[OF f])
      done
    also have " = (+ x- ?L. pmf p x)" by simp
    also have " = emeasure p (- ?L)" by(simp add: nn_integral_pmf)
    also have "(+ x. f (Source, Left x) * indicator ?L x) = (+ x?L. d_IN f (Left x))"
      by(subst d_IN_alt_def[of _ Δ])(auto simp add: network.flowD_outside[OF f] nn_integral_count_space_indicator intro!: nn_integral_cong)
    also have " = (+ x?L. d_OUT f (Left x))"
      by(rule nn_integral_cong flowD_KIR[OF f, symmetric])+ simp_all
    also have " = (+ x. + y. f (Left x, y) * indicator (range Right) y * indicator ?L x)"
      by(auto simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
    also have " = (+ yrange Right. + x. f (Left x, y) * indicator ?L x)"
      by(subst nn_integral_fst_count_space[where f="case_prod _", simplified])
        (simp add: nn_integral_snd_count_space[where f="case_prod _", simplified] nn_integral_count_space_indicator nn_integral_cmult[symmetric] mult_ac)
    also have " = (+ y. + x. f (Left x, Right y) * indicator ?L x)"
      by(simp add: nn_integral_count_space_reindex)
    also have " = (+ y. + x. f (Left x, Right y) * indicator ?L x * indicator ?R' y) +
      (+ y. + x. f (Left x, Right y) * indicator ?L x * indicator ?R'' y)"
      by(subst nn_integral_add[symmetric]; simp)
        (subst nn_integral_add[symmetric]; auto intro!: nn_integral_cong split: split_indicator intro!: network.flowD_outside[OF f])
    also have "(+ y. + x. f (Left x, Right y) * indicator ?L x * indicator ?R' y) =
       (+ y. + x. f (Left x, Right y) * indicator ?R' y)"
      apply(clarsimp simp add: network.flowD_outside[OF f] intro!: nn_integral_cong split: split_indicator)
      subgoal for y x by(cases "edge Δ (Left x) (Right y)")(auto intro: orthogonalD_in[OF ortho] network.flowD_outside[OF f])
      done
    also have " = (+ y. + xrange Left. f (x, Right y) * indicator ?R' y)"
      by(simp add: nn_integral_count_space_reindex)
    also have " = (+ y?R'. d_IN f (Right y))"
      by(subst d_IN_alt_def[of _ Δ])(auto simp add: network.flowD_outside[OF f] nn_integral_count_space_indicator incoming_def intro!: nn_integral_cong split: split_indicator)
    also have " = (+ y?R'. d_OUT f (Right y))" using flowD_KIR[OF f] by simp
    also have " = (+ y?R'. d_OUT cap (Right y))"
      by(auto 4 3 intro!: nn_integral_cong simp add: d_OUT_def network.flowD_outside[OF f] Sink dest: intro: orthogonalD_out[OF ortho, of "Right _" Sink, simplified])
    also have " = (+ y?R'. pmf q y)" by(simp add: OUT_cap_Right)
    also have " = emeasure q ?R'" by(simp add: nn_integral_pmf)
    also have "(+ y. + x. f (Left x, Right y) * indicator ?L x * indicator ?R'' y)  emeasure q ?R''" (is "?lhs  _")
    proof -
      have "?lhs = (+ y. + x. (if R x y then pmf q y else 0) * indicator ?L x * indicator ?R'' y)"
        by(rule nn_integral_cong)+(auto split: split_indicator simp add: network.flowD_outside[OF f] simp add: orthogonalD_out[OF ortho, of "Left _" "Right _", simplified])
      also have "  (+ y. pmf q y * indicator ?R'' y)"
        by(rule nn_integral_mono)(auto split: split_indicator intro: order_trans[OF _ nn_integral_ge_point])
      also have "(+ y. pmf q y * indicator ?R'' y) = (+ y?R''. pmf q y)" 
        by(auto simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      finally show ?thesis by(simp add: nn_integral_pmf)
    qed
    ultimately have "value_flow Δ f  emeasure q ?R' + emeasure q ?R'' + emeasure p (- ?L)"
      by(simp add: add_right_mono)
    also have "emeasure q ?R' + emeasure q ?R'' = emeasure q {y|y x. x  set_pmf p  Left x  S  R x y  y  set_pmf q}"
      by(subst plus_emeasure)(auto intro!: arg_cong2[where f=emeasure])
    also have "  emeasure p ?L" using le[of ?L]
      by(auto elim!: order_trans simp add: measure_pmf.emeasure_eq_measure AE_measure_pmf_iff intro!: measure_pmf.finite_measure_mono_AE)
    ultimately have "value_flow Δ f  emeasure (measure_pmf p) ?L + emeasure (measure_pmf p) (- ?L)"
      by (smt (verit, best) add_right_mono inf.absorb_iff2 le_inf_iff)
    also have "emeasure (measure_pmf p) ?L + emeasure (measure_pmf p) (- ?L) = emeasure (measure_pmf p) (?L  - ?L)"
      by(subst plus_emeasure) auto
    also have "?L  -?L = UNIV" by blast
    finally show ?thesis by simp
  qed
  ultimately have val: "value_flow Δ f = 1" by simp

  have SAT_p: "f (Source, Left x) = pmf p x" for x
  proof(rule antisym)
    show "f (Source, Left x)  pmf p x" using flowD_capacity[OF f, of "(Source, Left x)"] by simp
    show "pmf p x  f (Source, Left x)"
    proof(rule ccontr)
      assume *: "¬ ?thesis"
      have finite: "(+ y. f (Source, Left y) * indicator (- {x}) y)  "
      proof -
        have "(+ y. f (Source, Left y) * indicator (- {x}) y)  (+ yrange Left. f (Source, y))"
          by(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_mono split: split_indicator)
        also have " = value_flow Δ f"
          by(auto simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
        finally show ?thesis using val by (auto simp: top_unique)
      qed
      have "value_flow Δ f = (+ yrange Left. f (Source, y))"
        by(auto simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
      also have " = (+ y. f (Source, Left y) * indicator (- {x}) y) + (+ y. f (Source, Left y) * indicator {x} y)"
        by(subst nn_integral_add[symmetric])(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_cong split: split_indicator)
      also have " < (+ y. f (Source, Left y) * indicator (- {x}) y) + (+ y. pmf p y * indicator {x} y)" using * finite
        by(auto simp add:)
      also have "  (+ y. pmf p y * indicator (- {x}) y) + (+ y. pmf p y * indicator {x} y)"
        using flowD_capacity[OF f, of "(Source, Left _)"]
        by(auto intro!: nn_integral_mono split: split_indicator)
      also have " = (+ y. pmf p y)"
        by(subst nn_integral_add[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
      also have " = 1" unfolding nn_integral_pmf by simp
      finally show False using val by simp
    qed
  qed

  have IN_Sink: "d_IN f Sink = 1"
  proof -
    have "d_IN f Sink = (+ xrange Right. f (x, Sink))" unfolding d_IN_def
      by(auto intro!: nn_integral_cong network.flowD_outside[OF f] simp add: nn_integral_count_space_indicator split: split_indicator)
    also have " = (+ y. d_OUT f (Right y))" by(simp add: nn_integral_count_space_reindex OUT_f_Right)
    also have " = (+ y. d_IN f (Right y))" by(simp add: flowD_KIR[OF f])
    also have " = (+ y. (+ xrange Left. f (x, Right y)))"
      by(auto simp add: d_IN_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
    also have " = (+ y. + x. f (Left x, Right y))" by(simp add: nn_integral_count_space_reindex)
    also have " = (+ x. + y. f (Left x, Right y))"
      by(subst nn_integral_fst_count_space[where f="case_prod _", simplified])(simp add: nn_integral_snd_count_space[where f="case_prod _", simplified])
    also have " = (+ x. (+ yrange Right. f (Left x, y)))"
      by(simp add: nn_integral_count_space_reindex)
    also have " = (+ x. d_OUT f (Left x))" unfolding d_OUT_def
      by(auto intro!: nn_integral_cong network.flowD_outside[OF f] simp add: nn_integral_count_space_indicator split: split_indicator)
    also have " = (+ x. d_IN f (Left x))" by(simp add: flowD_KIR[OF f])
    also have " = (+ x. pmf p x)" by(simp add: IN_f_Left SAT_p)
    also have " = 1" unfolding nn_integral_pmf by simp
    finally show ?thesis .
  qed

  have SAT_q: "f (Right y, Sink) = pmf q y" for y
  proof(rule antisym)
    show "f (Right y, Sink)  pmf q y" using flowD_capacity[OF f, of "(Right y, Sink)"] by simp
    show "pmf q y  f (Right y, Sink)"
    proof(rule ccontr)
      assume *: "¬ ?thesis"
      have finite: "(+ x. f (Right x, Sink) * indicator (- {y}) x)  "
      proof -
        have "(+ x. f (Right x, Sink) * indicator (- {y}) x)  (+ xrange Right. f (x, Sink))"
          by(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_mono split: split_indicator)
        also have " = d_IN f Sink"
          by(auto simp add: d_IN_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
        finally show ?thesis using IN_Sink by (auto simp: top_unique)
      qed
      have "d_IN f Sink = (+ xrange Right. f (x, Sink))"
        by(auto simp add: d_IN_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
      also have " = (+ x. f (Right x, Sink) * indicator (- {y}) x) + (+ x. f (Right x, Sink) * indicator {y} x)"
        by(subst nn_integral_add[symmetric])(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_cong split: split_indicator)
      also have " < (+ x. f (Right x, Sink) * indicator (- {y}) x) + (+ x. pmf q x * indicator {y} x)" using * finite
        by auto
      also have "  (+ x. pmf q x * indicator (- {y}) x) + (+ x. pmf q x * indicator {y} x)"
        using flowD_capacity[OF f, of "(Right _, Sink)"]
        by(auto intro!: nn_integral_mono split: split_indicator)
      also have " = (+ x. pmf q x)"
        by(subst nn_integral_add[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
      also have " = 1" unfolding nn_integral_pmf by simp
      finally show False using IN_Sink by simp
    qed
  qed

  let ?z = "λ(x, y). enn2real (f (Left x, Right y))"
  have nonneg: "xy. 0  ?z xy" by clarsimp
  have prob: "(+ xy. ?z xy) = 1"
  proof -
    have "(+ xy. ?z xy) = (+ x. + y. (ennreal  ?z) (x, y))"
      unfolding nn_integral_fst_count_space by(simp add: split_def o_def)
    also have " = (+ x. (+yrange Right. f (Left x, y)))"
      by(auto simp add: nn_integral_count_space_reindex intro!: nn_integral_cong)
    also have " = (+ x. d_OUT f (Left x))"
      by(auto simp add: d_OUT_def nn_integral_count_space_indicator split: split_indicator intro!: nn_integral_cong network.flowD_outside[OF f])
    also have " = (+ x. d_IN f (Left x))" using flowD_KIR[OF f] by simp
    also have " = (+ xrange Left. f (Source, x))" by(simp add: nn_integral_count_space_reindex IN_f_Left)
    also have " = value_flow Δ f"
      by(auto simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
    finally show ?thesis using val by(simp)
  qed
  note z = nonneg prob
  define z where "z = embed_pmf ?z"
  have z_sel [simp]: "pmf z (x, y) = enn2real (f (Left x, Right y))" for x y
    by(simp add: z_def pmf_embed_pmf[OF z])

  show ?thesis
  proof
    show "R x y" if "(x, y)  set_pmf z" for x y
      using that network.flowD_outside[OF f, of "(Left x, Right y)"] unfolding set_pmf_iff
      by(auto simp add: enn2real_eq_0_iff)

    show "map_pmf fst z = p"
    proof(rule pmf_eqI)
      fix x
      have "pmf (map_pmf fst z) x = (+ erange (Pair x). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = (+ yrange Right. f (Left x, y))" by(simp add: nn_integral_count_space_reindex)
      also have " = d_OUT f (Left x)"
        by(auto simp add: d_OUT_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
      also have " = d_IN f (Left x)" by(rule flowD_KIR[OF f]) simp_all
      also have " = f (Source, Left x)" by(simp add: IN_f_Left)
      also have " = pmf p x" by(simp add: SAT_p)
      finally show "pmf (map_pmf fst z) x = pmf p x" by simp
    qed

    show "map_pmf snd z = q"
    proof(rule pmf_eqI)
      fix y
      have "pmf (map_pmf snd z) y = (+ erange (λx. (x, y)). pmf z e)"
        by(auto simp add: ennreal_pmf_map nn_integral_measure_pmf nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
      also have " = (+ xrange Left. f (x, Right y))" by(simp add: nn_integral_count_space_reindex)
      also have " = d_IN f (Right y)"
        by(auto simp add: d_IN_def nn_integral_count_space_indicator intro!: nn_integral_cong network.flowD_outside[OF f] split: split_indicator)
      also have " = d_OUT f (Right y)" by(simp add: flowD_KIR[OF f])
      also have " = f (Right y, Sink)" by(simp add: OUT_f_Right)
      also have " = pmf q y" by(simp add: SAT_q)
      finally show "pmf (map_pmf snd z) y = pmf q y" by simp
    qed
  qed
qed

proposition rel_pmf_measureI_unbounded: ― ‹Proof uses the unbounded max-flow min-cut theorem›
  assumes le: "A. measure (measure_pmf p) A  measure (measure_pmf q) {y. xA. R x y}"
  shows "rel_pmf R p q"
  using assms by(rule rel_pmf_measureI_aux[OF network.max_flow_min_cut])

interpretation network: bounded_countable_network Δ
proof
  have OUT_Left: "d_OUT cap (Left x)  1" for x
  proof -
    have "d_OUT cap (Left x)  (+ yrange Right. cap (Left x, y))"
      by(subst d_OUT_alt_def[of _ Δ])(auto intro: network.capacity_outside[simplified] intro!: nn_integral_mono simp add: nn_integral_count_space_indicator outgoing_def split: split_indicator)
    also have " = (+ y. cap (Left x, Right y))" by(simp add: nn_integral_count_space_reindex)
    also have "  (+ y. pmf q y)" by(rule nn_integral_mono)(simp)
    also have " = 1" by(simp add: nn_integral_pmf)
    finally show ?thesis .
  qed
  show "d_OUT (capacity Δ) x < " if "x  VΔ⇙" "x  source Δ" "x  sink Δ" for x
    using that by(cases x)(auto simp add: OUT_cap_Right intro: le_less_trans[OF OUT_Left])
qed

proposition rel_pmf_measureI_bounded: ― ‹Proof uses the bounded max-flow min-cut theorem›
  assumes le: "A. measure (measure_pmf p) A  measure (measure_pmf q) {y. xA. R x y}"
  shows "rel_pmf R p q"
  using assms by(rule rel_pmf_measureI_aux[OF network.max_flow_min_cut_bounded])

end

end

interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI_bounded)

corollary rel_pmf_distr_mono: "rel_pmf R OO rel_pmf S  rel_pmf (R OO S)"
― ‹This fact has already been proven for the registration of @{typ "'a pmf"} as a BNF,
  but this proof is much shorter and more elegant. See citeHoelzlLochbihlerTraytel2015ITP for a
  comparison of formalisations.›
proof(intro le_funI le_boolI rel_pmf_measureI_bounded, elim relcomppE)
  fix p q r A
  assume pq: "rel_pmf R p q" and qr: "rel_pmf S q r"
  have "measure (measure_pmf p) A  measure (measure_pmf q) {y. xA. R x y}"
    (is "_  measure _ ?B") using pq by(rule rel_pmf_measureD)
  also have "  measure (measure_pmf r) {z. y?B. S y z}"
    using qr by(rule rel_pmf_measureD)
  also have "{z. y?B. S y z} = {z. xA. (R OO S) x z}" by auto
  finally show "measure (measure_pmf p) A  measure (measure_pmf r) " .
qed

end