Theory Probabilistic_While.While_SPMF
theory While_SPMF imports
  MFMC_Countable.Rel_PMF_Characterisation
  "HOL-Types_To_Sets.Types_To_Sets"
  "HOL-Library.Complete_Partial_Order2"
begin
text ‹
  This theory defines a probabilistic while combinator for discrete (sub-)probabilities and
  formalises rules for probabilistic termination similar to those by Hurd \<^cite>‹"Hurd2002TPHOLs"›
  and McIver and Morgan \<^cite>‹"McIverMorgan2005"›.
›
section ‹Miscellaneous library additions›
fun map_option_set :: "('a ⇒ 'b option set) ⇒ 'a option ⇒ 'b option set"
where
  "map_option_set f None = {None}"
| "map_option_set f (Some x) = f x"
lemma None_in_map_option_set:
  "None ∈ map_option_set f x ⟷ None ∈ Set.bind (set_option x) f ∨ x = None"
by(cases x) simp_all
lemma None_in_map_option_set_None [intro!]: "None ∈ map_option_set f None"
by simp
lemma None_in_map_option_set_Some [intro!]: "None ∈ f x ⟹ None ∈ map_option_set f (Some x)"
by simp
lemma Some_in_map_option_set [intro!]: "Some y ∈ f x ⟹ Some y ∈ map_option_set f (Some x)"
by simp
lemma map_option_set_singleton [simp]: "map_option_set (λx. {f x}) y = {Option.bind y f}"
by(cases y) simp_all
lemma Some_eq_bind_conv: "Some y = Option.bind x f ⟷ (∃z. x = Some z ∧ f z = Some y)"
by(cases x) auto
lemma map_option_set_bind: "map_option_set f (Option.bind x g) = map_option_set (map_option_set f ∘ g) x"
by(cases x) simp_all
lemma Some_in_map_option_set_conv: "Some y ∈ map_option_set f x ⟷ (∃z. x = Some z ∧ Some y ∈ f z)"
by(cases x) auto
interpretation rel_spmf_characterisation by unfold_locales(rule rel_pmf_measureI)
hide_fact (open) rel_pmf_measureI
lemma Sup_conv_fun_lub: "Sup = fun_lub Sup"
  by(auto simp add: Sup_fun_def fun_eq_iff fun_lub_def intro: arg_cong[where f=Sup])
lemma le_conv_fun_ord: "(≤) = fun_ord (≤)"
  by(auto simp add: fun_eq_iff fun_ord_def le_fun_def)
lemmas parallel_fixp_induct_2_1 = parallel_fixp_induct_uc[
  of _ _ _ _ "case_prod" _ "curry" "λx. x" _ "λx. x",
  where P="λf g. P (curry f) g",
  unfolded case_prod_curry curry_case_prod curry_K,
  OF _ _ _ _ _ _ refl refl]
  for P
lemma monotone_Pair:
  "⟦ monotone ord orda f; monotone ord ordb g ⟧
  ⟹ monotone ord (rel_prod orda ordb) (λx. (f x, g x))"
by(simp add: monotone_def)
lemma cont_Pair:
  "⟦ cont lub ord luba orda f; cont lub ord lubb ordb g ⟧
  ⟹ cont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule contI)(auto simp add: prod_lub_def image_image dest!: contD)
lemma mcont_Pair:
  "⟦ mcont lub ord luba orda f; mcont lub ord lubb ordb g ⟧
  ⟹ mcont lub ord (prod_lub luba lubb) (rel_prod orda ordb) (λx. (f x, g x))"
by(rule mcontI)(simp_all add: monotone_Pair mcont_mono cont_Pair)
lemma mono2mono_emeasure_spmf [THEN lfp.mono2mono]:
  shows monotone_emeasure_spmf:
  "monotone (ord_spmf (=)) (≤) (λp. emeasure (measure_spmf p))"
  by(rule monotoneI le_funI ord_spmf_eqD_emeasure)+
lemma cont_emeasure_spmf: "cont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"
  by (rule contI) (simp add: emeasure_lub_spmf fun_eq_iff image_comp)
lemma mcont2mcont_emeasure_spmf [THEN lfp.mcont2mcont, cont_intro]:
  shows mcont_emeasure_spmf: "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p))"
  by(simp add: mcont_def monotone_emeasure_spmf cont_emeasure_spmf)
lemma mcont2mcont_emeasure_spmf' [THEN lfp.mcont2mcont, cont_intro]:
  shows mcont_emeasure_spmf': "mcont lub_spmf (ord_spmf (=)) Sup (≤) (λp. emeasure (measure_spmf p) A)"
  using mcont_emeasure_spmf[unfolded Sup_conv_fun_lub le_conv_fun_ord]
  by(subst (asm) mcont_fun_lub_apply) blast
lemma mcont_bind_pmf [cont_intro]:
  assumes g: "⋀y. mcont luba orda lub_spmf (ord_spmf (=)) (g y)"
  shows "mcont luba orda lub_spmf (ord_spmf (=)) (λx. bind_pmf p (λy. g y x))"
using mcont_bind_spmf[where f="λ_. spmf_of_pmf p" and g=g, OF _ assms] by(simp)
lemma ennreal_less_top_iff: "x < ⊤ ⟷ x ≠ (⊤ :: ennreal)"
  by(cases x) simp_all
lemma type_definition_Domainp: 
  fixes Rep Abs A T
  assumes type: "type_definition Rep Abs A"
  assumes T_def: "T ≡ (λ(x::'a) (y::'b). x = Rep y)"
  shows "Domainp T = (λx. x ∈ A)"
proof -
  interpret type_definition Rep Abs A by(rule type)
  show ?thesis unfolding Domainp_iff[abs_def] T_def fun_eq_iff by(metis Abs_inverse Rep)
qed
context includes lifting_syntax begin
lemma weight_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> (=)) weight_spmf weight_spmf"
by(simp add: rel_fun_def rel_spmf_weightD)
lemma lossless_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> (=)) lossless_spmf lossless_spmf"
by(simp add: rel_fun_def lossless_spmf_def rel_spmf_weightD)
lemma UNIV_parametric_pred: "rel_pred R UNIV UNIV"
  by(auto intro!: rel_predI)
end
lemma bind_spmf_spmf_of_set:
  "⋀A. ⟦ finite A; A ≠ {} ⟧ ⟹ bind_spmf (spmf_of_set A) = bind_pmf (pmf_of_set A)"
by(simp add: spmf_of_set_def fun_eq_iff del: spmf_of_pmf_pmf_of_set)
lemma set_pmf_bind_spmf: "set_pmf (bind_spmf M f) = set_pmf M ⤜ map_option_set (set_pmf ∘ f)"
by(auto 4 3 simp add: bind_spmf_def split: option.splits intro: rev_bexI)
lemma set_pmf_spmf_of_set:
  "set_pmf (spmf_of_set A) = (if finite A ∧ A ≠ {} then Some ` A else {None})"
by(simp add: spmf_of_set_def spmf_of_pmf_def del: spmf_of_pmf_pmf_of_set)
definition measure_measure_spmf :: "'a spmf ⇒ 'a set ⇒ real"
where [simp]: "measure_measure_spmf p = measure (measure_spmf p)"
lemma measure_measure_spmf_parametric [transfer_rule]:
  includes lifting_syntax shows
  "(rel_spmf A ===> rel_pred A ===> (=)) measure_measure_spmf measure_measure_spmf"
unfolding measure_measure_spmf_def[abs_def] by(rule measure_spmf_parametric)
lemma of_nat_le_one_cancel_iff [simp]:
  fixes n :: nat shows "real n ≤ 1 ⟷ n ≤ 1"
by linarith
lemma of_int_ceiling_less_add_one [simp]: "of_int ⌈r⌉ < r + 1"
  by linarith
lemma lessThan_subset_Collect: "{..<x} ⊆ Collect P ⟷ (∀y<x. P y)"
  by(auto simp add: lessThan_def)
lemma spmf_ub_tight:
  assumes ub: "⋀x. spmf p x ≤ f x"
  and sum: "(∫⇧+ x. f x ∂count_space UNIV) = weight_spmf p"
  shows "spmf p x = f x"
proof -
  have [rule_format]: "∀x. f x ≤ spmf p x"
  proof(rule ccontr)
    assume "¬ ?thesis"
    then obtain x where x: "spmf p x < f x" by(auto simp add: not_le)
    have *: "(∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) ≠ ⊤"
      by(rule neq_top_trans[where y="weight_spmf p"], simp)(auto simp add: sum[symmetric] intro!: nn_integral_mono split: split_indicator)
      
    have "weight_spmf p = ∫⇧+ y. spmf p y ∂count_space UNIV"
      by(simp add: nn_integral_spmf space_measure_spmf measure_spmf.emeasure_eq_measure)
    also have "… = (∫⇧+ y. ennreal (spmf p y) * indicator (- {x}) y ∂count_space UNIV) +
      (∫⇧+ y. spmf p y * indicator {x} y ∂count_space UNIV)"
      by(subst nn_integral_add[symmetric])(auto intro!: nn_integral_cong split: split_indicator)
    also have "… ≤ (∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) + spmf p x"
      using ub by(intro add_mono nn_integral_mono)(auto split: split_indicator intro: ennreal_leI)
    also have "… < (∫⇧+ y. ennreal (f y) * indicator (- {x}) y ∂count_space UNIV) + (∫⇧+ y. f y * indicator {x} y ∂count_space UNIV)"
      using * x by(simp add: ennreal_less_iff)
    also have "… = (∫⇧+ y. ennreal (f y) ∂count_space UNIV)"
      by(subst nn_integral_add[symmetric])(auto intro: nn_integral_cong split: split_indicator)
    also have "… = weight_spmf p" using sum by simp
    finally show False by simp
  qed
  from this[of x] ub[of x] show ?thesis by simp
qed
section ‹Probabilistic while loop›
locale loop_spmf = 
  fixes guard :: "'a ⇒ bool"
  and body :: "'a ⇒ 'a spmf"
begin
context notes [[function_internals]] begin
partial_function (spmf) while :: "'a ⇒ 'a spmf"
where "while s = (if guard s then bind_spmf (body s) while else return_spmf s)"
end
lemma while_fixp_induct [case_names adm bottom step]:
  assumes "spmf.admissible P"
  and "P (λwhile. return_pmf None)"
  and "⋀while'. P while' ⟹ P (λs. if guard s then body s ⤜ while' else return_spmf s)"
  shows "P while"
  using assms by(rule while.fixp_induct)
lemma while_simps:
  "guard s ⟹ while s = bind_spmf (body s) while"
  "¬ guard s ⟹ while s = return_spmf s"
by(rewrite while.simps; simp; fail)+
end
lemma while_spmf_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> (=)) ===> (S ===> rel_spmf S) ===> S ===> rel_spmf S) loop_spmf.while loop_spmf.while"
unfolding loop_spmf.while_def[abs_def]
apply(rule rel_funI)
apply(rule rel_funI)
apply(rule fixp_spmf_parametric[OF loop_spmf.while.mono loop_spmf.while.mono])
subgoal premises [transfer_rule] by transfer_prover
done
lemma loop_spmf_while_cong:
  "⟦ guard = guard'; ⋀s. guard' s ⟹ body s = body' s ⟧
  ⟹ loop_spmf.while guard body = loop_spmf.while guard' body'"
unfolding loop_spmf.while_def[abs_def] by(simp cong: if_cong)
section ‹Rules for probabilistic termination›
context loop_spmf begin
subsection ‹0/1 termination laws›
lemma termination_0_1_immediate:
  assumes p: "⋀s. guard s ⟹ spmf (map_spmf guard (body s)) False ≥ p"
  and p_pos: "0 < p"
  and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
proof -
  have "∀s. lossless_spmf (while s)"
  proof(rule ccontr)
    assume "¬ ?thesis"
    then obtain s where s: "¬ lossless_spmf (while s)" by blast
    hence True: "guard s" by(simp add: while.simps split: if_split_asm)
    from p[OF this] have p_le_1: "p ≤ 1" using pmf_le_1 by(rule order_trans)
    have new_bound: "p * (1 - k) + k ≤ weight_spmf (while s)" 
      if k: "0 ≤ k" "k ≤ 1" and k_le: "⋀s. k ≤ weight_spmf (while s)" for k s
    proof(cases "guard s")
      case False
      have "p * (1 - k) + k ≤ 1 * (1 - k) + k" using p_le_1 k by(intro mult_right_mono add_mono; simp)
      also have "… ≤ 1" by simp
      finally show ?thesis using False by(simp add: while.simps)
    next
      case True
      let ?M = "λs. measure_spmf (body s)"
      have bounded: "¦∫ s''. weight_spmf (while s'') ∂?M s'¦ ≤ 1" for s'
        using integral_nonneg_AE[of "λs''. weight_spmf (while s'')" "?M s'"]
        by(auto simp add: weight_spmf_nonneg weight_spmf_le_1 intro!: measure_spmf.nn_integral_le_const integral_real_bounded)
      have "p ≤ measure (?M s) {s'. ¬ guard s'}" using p[OF True]
        by(simp add: spmf_conv_measure_spmf measure_map_spmf vimage_def)
      hence "p * (1 - k) + k ≤ measure (?M s) {s'. ¬ guard s'} * (1 - k) + k"
        using k by(intro add_mono mult_right_mono)(simp_all)
      also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' * (1 - k) +  k ∂?M s"
        using True by(simp add: ennreal_less_top_iff lossless lossless_weight_spmfD)
      also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * k ∂?M s"
        by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
      also have "… = ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * ∫ s''. k ∂?M s' ∂?M s"
        by(rule Bochner_Integration.integral_cong)(auto simp add: lossless lossless_weight_spmfD split: split_indicator)
      also have "… ≤ ∫ s'. indicator {s'. ¬ guard s'} s' + indicator {s'. guard s'} s' * ∫ s''. weight_spmf (while s'') ∂?M s' ∂?M s"
        using k bounded
        by(intro integral_mono integrable_add measure_spmf.integrable_const_bound[where B=1] add_mono mult_left_mono)
          (simp_all add: weight_spmf_nonneg weight_spmf_le_1 mult_le_one k_le split: split_indicator)
      also have "… = ∫s'. (if ¬ guard s' then 1 else ∫ s''. weight_spmf (while s'') ∂?M s') ∂?M s"
        by(rule Bochner_Integration.integral_cong)(simp_all split: split_indicator)
      also have "… = ∫ s'. weight_spmf (while s') ∂measure_spmf (body s)"
        by(rule Bochner_Integration.integral_cong; simp add: while.simps weight_bind_spmf o_def)
      also have "… = weight_spmf (while s)" using True
        by(simp add: while.simps weight_bind_spmf o_def)
      finally show ?thesis .
    qed
    define k where "k ≡ INF s. weight_spmf (while s)"
    define k' where "k' ≡ p * (1 - k) + k"
    from s have "weight_spmf (while s) < 1"
      using weight_spmf_le_1[of "while s"] by(simp add: lossless_spmf_def)
    then have "k < 1"
      unfolding k_def by(rewrite cINF_less_iff)(auto intro!: bdd_belowI2 weight_spmf_nonneg)
    have "0 ≤ k" unfolding k_def by(auto intro: cINF_greatest simp add: weight_spmf_nonneg)
    moreover from ‹k < 1› have "k ≤ 1" by simp
    moreover have "k ≤ weight_spmf (while s)" for s unfolding k_def
      by(rule cINF_lower)(auto intro!: bdd_belowI2 weight_spmf_nonneg)
    ultimately have "⋀s. k' ≤ weight_spmf (while s)"
      unfolding k'_def by(rule new_bound)
    hence "k' ≤ k" unfolding k_def by(auto intro: cINF_greatest)
    also have "k < k'" using p_pos ‹k < 1› by(auto simp add: k'_def)
    finally show False by simp
  qed
  thus ?thesis by blast
qed
primrec iter :: "nat ⇒ 'a ⇒ 'a spmf"
where
  "iter 0 s = return_spmf s"
| "iter (Suc n) s = (if guard s then bind_spmf (body s) (iter n) else return_spmf s)"
lemma iter_unguarded [simp]: "¬ guard s ⟹ iter n s = return_spmf s"
  by(induction n) simp_all
  
lemma iter_bind_iter: "bind_spmf (iter m s) (iter n) = iter (m + n) s"
  by(induction m arbitrary: s) simp_all
lemma iter_Suc2: "iter (Suc n) s = bind_spmf (iter n s) (λs. if guard s then body s else return_spmf s)"
  using iter_bind_iter[of n s 1, symmetric]
  by(simp del: iter.simps)(rule bind_spmf_cong; simp cong: bind_spmf_cong)
lemma lossless_iter: "(⋀s. guard s ⟹ lossless_spmf (body s)) ⟹ lossless_spmf (iter n s)"
  by(induction n arbitrary: s) simp_all
lemma iter_mono_emeasure1:
  "emeasure (measure_spmf (iter n s)) {s. ¬ guard s} ≤ emeasure (measure_spmf (iter (Suc n) s)) {s. ¬ guard s}"
  (is "?lhs ≤ ?rhs")
proof(cases "guard s")
  case True
  have "?lhs = emeasure (measure_spmf (bind_spmf (iter n s) return_spmf)) {s. ¬ guard s}" by simp
  also have "… = ∫⇧+ s'. emeasure (measure_spmf (return_spmf s')) {s. ¬ guard s} ∂measure_spmf (iter n s)"
    by(simp del: bind_return_spmf add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
  also have "… ≤ ∫⇧+ s'. emeasure (measure_spmf (if guard s' then body s' else return_spmf s')) {s. ¬ guard s} ∂measure_spmf (iter n s)"
    by(rule nn_integral_mono)(simp add: measure_spmf_return_spmf)
  also have "… = ?rhs"
    by(simp add: iter_Suc2 measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra del: iter.simps)
  finally show ?thesis .
qed simp
lemma weight_while_conv_iter:
  "weight_spmf (while s) = (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
  (is "?lhs = ?rhs")
proof(rule antisym)
  have "emeasure (measure_spmf (while s)) UNIV ≤ (SUP n. emeasure (measure_spmf (iter n s)) {s. ¬ guard s})"
    (is "_ ≤ (SUP n. ?f n s)")
  proof(induction arbitrary: s rule: while_fixp_induct)
    case adm show ?case by simp
    case bottom show ?case by simp
    case (step while')
    show ?case (is "?lhs' ≤ ?rhs'")
    proof(cases "guard s")
      case True
      have inc: "incseq ?f" by(rule incseq_SucI le_funI iter_mono_emeasure1)+
      from True have "?lhs' = ∫⇧+ s'. emeasure (measure_spmf (while' s')) UNIV ∂measure_spmf (body s)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "… ≤ ∫⇧+ s'. (SUP n. ?f n s') ∂measure_spmf (body s)"
        by(rule nn_integral_mono)(rule step.IH)
      also have "… = (SUP n. ∫⇧+ s'. ?f n s' ∂measure_spmf (body s))" using inc
        by(subst nn_integral_monotone_convergence_SUP) simp_all
      also have "… = (SUP n. ?f (Suc n) s)" using True
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "… ≤ (SUP n. ?f n s)"
        by(rule SUP_mono)(auto intro: exI[where x="Suc _"])
      finally show ?thesis .
    next
      case False
      then have "?lhs' = emeasure (measure_spmf (iter 0 s)) {s. ¬ guard s}" 
        by(simp add: measure_spmf_return_spmf)
      also have ‹… ≤ ?rhs'› by(rule SUP_upper) simp
      finally show ?thesis .
    qed
  qed
  also have "… = ennreal (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
    by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
  also have "0 ≤ (SUP n. measure (measure_spmf (iter n s)) {s. ¬ guard s})"
    by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
  ultimately show "?lhs ≤ ?rhs" by(simp add: measure_spmf.emeasure_eq_measure space_measure_spmf)
  
  show "?rhs ≤ ?lhs"
  proof(rule cSUP_least)
    show "measure (measure_spmf (iter n s)) {s. ¬ guard s} ≤ weight_spmf (while s)" (is "?f n s ≤ _") for n
    proof(induction n arbitrary: s)
      case 0 show ?case
        by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
    next
      case (Suc n)
      show ?case
      proof(cases "guard s")
        case True
        have "?f (Suc n) s = ∫⇧+ s'. ?f n s' ∂measure_spmf (body s)"
          using True unfolding measure_spmf.emeasure_eq_measure[symmetric]
          by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
        also have "… ≤ ∫⇧+ s'. weight_spmf (while s') ∂measure_spmf (body s)"
          by(rule nn_integral_mono ennreal_leI Suc.IH)+
        also have "… = weight_spmf (while s)"
          using True unfolding measure_spmf.emeasure_eq_measure[symmetric] space_measure_spmf
          by(simp add: while_simps measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
        finally show ?thesis by(simp)
      next
        case False then show ?thesis
          by(simp add: measure_spmf_return_spmf measure_return while_simps split: split_indicator)
      qed
    qed
  qed simp
qed
lemma termination_0_1:
  assumes p: "⋀s. guard s ⟹ p ≤ weight_spmf (while s)"
    and p_pos: "0 < p"
    and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
  unfolding lossless_spmf_def
proof(rule antisym)
  let ?X = "{s. ¬ guard s}"
  show "weight_spmf (while s) ≤ 1" by(rule weight_spmf_le_1)
  
  define p' where "p' ≡ p / 2"
  have p'_pos: "p' > 0" and "p' < p" using p_pos by(simp_all add: p'_def)
  
  have "∃n. p' < measure (measure_spmf (iter n s)) ?X" if "guard s" for s using p[OF that] ‹p' < p›
    unfolding weight_while_conv_iter
    by(subst (asm) le_cSUP_iff)(auto intro!: measure_spmf.subprob_measure_le_1)
  then obtain N where p': "p' ≤ measure (measure_spmf (iter (N s) s)) ?X" if "guard s" for s
    using p by atomize_elim(rule choice, force dest: order.strict_implies_order)
  interpret fuse: loop_spmf guard "λs. iter (N s) s" .
  
  have "1 = weight_spmf (fuse.while s)"
    by(rule lossless_weight_spmfD[symmetric])
      (rule fuse.termination_0_1_immediate; auto simp add: spmf_map vimage_def intro: p' p'_pos lossless_iter lossless)
  also have "… ≤ (⨆n. measure (measure_spmf (iter n s)) ?X)"
    unfolding fuse.weight_while_conv_iter
  proof(rule cSUP_least)
    fix n
    have "emeasure (measure_spmf (fuse.iter n s)) ?X ≤ (SUP n. emeasure (measure_spmf (iter n s)) ?X)"
    proof(induction n arbitrary: s)
      case 0 show ?case by(auto intro!: SUP_upper2[where i=0])
    next
      case (Suc n)
      have inc: "incseq (λn s'. emeasure (measure_spmf (iter n s')) ?X)"
        by(rule incseq_SucI le_funI iter_mono_emeasure1)+
      have "emeasure (measure_spmf (fuse.iter (Suc n) s)) ?X = emeasure (measure_spmf (iter (N s) s ⤜ fuse.iter n)) ?X"
        by simp
      also have "… = ∫⇧+ s'. emeasure (measure_spmf (fuse.iter n s')) ?X ∂measure_spmf (iter (N s) s)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "… ≤ ∫⇧+ s'. (SUP n. emeasure (measure_spmf (iter n s')) ?X) ∂measure_spmf (iter (N s) s)"
        by(rule nn_integral_mono Suc.IH)+
      also have "… = (SUP n. ∫⇧+ s'. emeasure (measure_spmf (iter n s')) ?X ∂measure_spmf (iter (N s) s))"
        by(rule nn_integral_monotone_convergence_SUP[OF inc]) simp
      also have "… = (SUP n. emeasure (measure_spmf (bind_spmf (iter (N s) s) (iter n))) ?X)"
        by(simp add: measure_spmf_bind o_def emeasure_bind[where N="measure_spmf _"] space_measure_spmf Pi_def space_subprob_algebra)
      also have "… = (SUP n. emeasure (measure_spmf (iter (N s + n) s)) ?X)" by(simp add: iter_bind_iter)
      also have "… ≤ (SUP n. emeasure (measure_spmf (iter n s)) ?X)" by(rule SUP_mono) auto
      finally show ?case .
    qed
    also have "… = ennreal (SUP n. measure (measure_spmf (iter n s)) ?X)"
      by(subst ennreal_SUP)(fold measure_spmf.emeasure_eq_measure, auto simp add: not_less measure_spmf.subprob_emeasure_le_1 intro!: exI[where x="1"])
    also have "0 ≤ (SUP n. measure (measure_spmf (iter n s)) ?X)"
      by(rule cSUP_upper2)(auto intro!: bdd_aboveI[where M=1] simp add: measure_spmf.subprob_measure_le_1)
    ultimately show "measure (measure_spmf (fuse.iter n s)) ?X ≤ …"
      by(simp add: measure_spmf.emeasure_eq_measure)
  qed simp
  finally show  "1 ≤ weight_spmf (while s)" unfolding weight_while_conv_iter .
qed
end
lemma termination_0_1_immediate_invar:
  fixes I :: "'s ⇒ bool"
  assumes p: "⋀s. ⟦ guard s; I s ⟧ ⟹ spmf (map_spmf guard (body s)) False ≥ p"
  and p_pos: "0 < p"
  and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
  and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
  and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof -
  { assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr ≡ λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp
    define guard' where "guard' ≡ (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
    define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s' ≡ Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
    have "⋀s. guard' s ⟹ p ≤ spmf (map_spmf guard' (body1' s)) False"
      by(transfer fixing: p)(simp add: body1_def p)
    moreover note p_pos
    moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
    ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_0_1_immediate)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed
lemma termination_0_1_invar:
  fixes I :: "'s ⇒ bool"
  assumes p: "⋀s. ⟦ guard s; I s ⟧ ⟹ p ≤ weight_spmf (loop_spmf.while guard body s)"
    and p_pos: "0 < p"
    and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
    and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
    and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof-
  { assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr ≡ λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp
    define guard' where "guard' ≡ (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
    define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s' ≡ Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
    
    interpret loop_spmf guard' body1' .
    note UNIV_parametric_pred[transfer_rule]
    have "⋀s. guard' s ⟹ p ≤ weight_spmf (while s)"
      unfolding measure_measure_spmf_def[symmetric] space_measure_spmf
      by(transfer fixing: p)(simp add: body1_def p[simplified space_measure_spmf] cong: loop_spmf_while_cong)
    moreover note p_pos
    moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)" by transfer(simp add: lossless body1_def)
    ultimately have "lossless_spmf (while s')" by(rule termination_0_1)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed
subsection ‹Variant rule›
context loop_spmf begin
lemma termination_variant:
  fixes bound :: nat
  assumes bound: "⋀s. guard s ⟹ f s ≤ bound"
  and step: "⋀s. guard s ⟹ p ≤ spmf (map_spmf (λs'. f s' < f s) (body s)) True"
  and p_pos: "0 < p"
  and lossless: "⋀s. guard s ⟹ lossless_spmf (body s)"
  shows "lossless_spmf (while s)"
proof -
  define p' and n where "p' ≡ min p 1" and "n ≡ bound + 1"
  have p'_pos: "0 < p'" and p'_le_1: "p' ≤ 1" 
    and step': "guard s ⟹ p' ≤ measure (measure_spmf (body s)) {s'. f s' < f s}" for s
    using p_pos step[of s] by(simp_all add: p'_def spmf_map vimage_def)
  have "p' ^ n ≤ weight_spmf (while s)" if "f s < n" for s using that
  proof(induction n arbitrary: s)
    case 0 thus ?case by simp
  next
    case (Suc n)
    show ?case
    proof(cases "guard s")
      case False
      hence "weight_spmf (while s) = 1" by(simp add: while.simps)
      thus ?thesis using p'_le_1 p_pos 
        by simp(meson less_eq_real_def mult_le_one p'_pos power_le_one zero_le_power)
    next
      case True
      let ?M = "measure_spmf (body s)"
      have "p' ^ Suc n ≤ (∫ s'. indicator {s'. f s' < f s} s' ∂?M) * p' ^ n"
        using step'[OF True] p'_pos by(simp add: mult_right_mono)
      also have "… = (∫ s'. indicator {s'. f s' < f s} s' * p' ^ n ∂?M)" by simp
      also have "… ≤ (∫ s'. indicator {s'. f s' < f s} s' * weight_spmf (while s') ∂?M)"
        using Suc.prems p'_le_1 p'_pos
        by(intro integral_mono)(auto simp add: Suc.IH power_le_one weight_spmf_le_1 split: split_indicator intro!: measure_spmf.integrable_const_bound[where B=1])
      also have "… ≤ … + (∫ s'. indicator {s'. f s' ≥ f s} s' * weight_spmf (while s') ∂?M)"
        by(simp add: integral_nonneg_AE weight_spmf_nonneg)
      also have "… = ∫ s'. weight_spmf (while s') ∂?M"
        by(subst Bochner_Integration.integral_add[symmetric])
          (auto intro!: Bochner_Integration.integral_cong measure_spmf.integrable_const_bound[where B=1] weight_spmf_le_1 split: split_indicator)
      also have "… = weight_spmf (while s)"
        using True by(subst (1 2) while.simps)(simp add: weight_bind_spmf o_def)
      finally show ?thesis .
    qed
  qed
  moreover have "0 < p' ^ n" using p'_pos by simp
  ultimately show ?thesis using lossless
  proof(rule termination_0_1_invar)
    show "f s < n" if "guard s" "guard s ⟶ f s < n" for s using that by simp
    show "guard s ⟶ f s < n" using bound[of s] by(auto simp add: n_def)
    show "guard s' ⟶ f s' < n" for s' using bound[of s'] by(clarsimp simp add: n_def)
  qed
qed
end
lemma termination_variant_invar:
  fixes bound :: nat and I :: "'s ⇒ bool"
  assumes bound: "⋀s. ⟦ guard s; I s ⟧ ⟹ f s ≤ bound"
  and step: "⋀s. ⟦ guard s; I s ⟧ ⟹ p ≤ spmf (map_spmf (λs'. f s' < f s) (body s)) True"
  and p_pos: "0 < p"
  and lossless: "⋀s. ⟦ guard s; I s ⟧ ⟹ lossless_spmf (body s)"
  and invar: "⋀s s'. ⟦ s' ∈ set_spmf (body s); I s; guard s ⟧ ⟹ I s'"
  and I: "I s"
  shows "lossless_spmf (loop_spmf.while guard body s)"
  including lifting_syntax
proof -
  { assume "∃(Rep :: 's' ⇒ 's) Abs. type_definition Rep Abs {s. I s}"
    then obtain Rep :: "'s' ⇒ 's" and Abs where td: "type_definition Rep Abs {s. I s}" by blast
    then interpret td: type_definition Rep Abs "{s. I s}" .
    define cr where "cr ≡ λx y. x = Rep y"
    have [transfer_rule]: "bi_unique cr" "right_total cr" using td cr_def by(rule typedef_bi_unique typedef_right_total)+
    have [transfer_domain_rule]: "Domainp cr = I" using type_definition_Domainp[OF td cr_def] by simp
    define guard' where "guard' ≡ (Rep ---> id) guard"
    have [transfer_rule]: "(cr ===> (=)) guard guard'" by(simp add: rel_fun_def cr_def guard'_def)
    define body1 where "body1 ≡ λs. if guard s then body s else return_pmf None"
    define body1' where "body1' ≡ (Rep ---> map_spmf Abs) body1"
    have [transfer_rule]: "(cr ===> rel_spmf cr) body1 body1'"
      by(auto simp add: rel_fun_def body1'_def body1_def cr_def spmf_rel_map td.Rep[simplified] invar td.Abs_inverse intro!: rel_spmf_reflI)
    define s' where "s' ≡ Abs s"
    have [transfer_rule]: "cr s s'" by(simp add: s'_def cr_def I td.Abs_inverse)
    define f' where "f' ≡ (Rep ---> id) f"
    have [transfer_rule]: "(cr ===> (=)) f f'" by(simp add: rel_fun_def cr_def f'_def)
    have "⋀s. guard' s ⟹ f' s ≤ bound" by(transfer fixing: bound)(rule bound)
    moreover have "⋀s. guard' s ⟹ p ≤ spmf (map_spmf (λs'. f' s' < f' s) (body1' s)) True"
      by(transfer fixing: p)(simp add: step body1_def)
    note this p_pos
    moreover have "⋀s. guard' s ⟹ lossless_spmf (body1' s)"
      by transfer(simp add: body1_def lossless)
    ultimately have "lossless_spmf (loop_spmf.while guard' body1' s')" by(rule loop_spmf.termination_variant)
    hence "lossless_spmf (loop_spmf.while guard body1 s)" by transfer }
  from this[cancel_type_definition] I show ?thesis by(auto cong: loop_spmf_while_cong)
qed
end