# Theory Generative_Probabilistic_Value

```(* Title: Generative_Probabilistic_Value.thy
Author: Andreas Lochbihler, ETH Zurich *)

theory Generative_Probabilistic_Value imports
Resumption
Generat
"HOL-Types_To_Sets.Types_To_Sets"
begin

hide_const (open) Done

subsection ‹Type definition›

context notes [[bnf_internals]] begin

codatatype (results'_gpv: 'a, outs'_gpv: 'out, 'in) gpv
= GPV (the_gpv: "('a, 'out, 'in ⇒ ('a, 'out, 'in) gpv) generat spmf")

end

declare gpv.rel_eq [relator_eq]

text ‹Reactive values are like generative, except that they take an input first.›

type_synonym ('a, 'out, 'in) rpv = "'in ⇒ ('a, 'out, 'in) gpv"
print_translation ― ‹pretty printing for @{typ "('a, 'out, 'in) rpv"}› ‹
let
fun tr' [in1, Const (@{type_syntax gpv}, _) \$ a \$ out \$ in2] =
if in1 = in2 then Syntax.const @{type_syntax rpv} \$ a \$ out \$ in1
else raise Match;
in [(@{type_syntax "fun"}, K tr')]
end
›
typ "('a, 'out, 'in) rpv"
text ‹
Effectively, @{typ "('a, 'out, 'in) gpv"} and @{typ "('a, 'out, 'in) rpv"} are mutually recursive.
›

lemma eq_GPV_iff: "f = GPV g ⟷ the_gpv f = g"
by(cases f) auto

declare gpv.set[simp del]

declare gpv.set_map[simp]

lemma rel_gpv_def':
"rel_gpv A B gpv gpv' ⟷
(∃gpv''. (∀(x, y) ∈ results'_gpv gpv''. A x y) ∧ (∀(x, y) ∈ outs'_gpv gpv''. B x y) ∧
map_gpv fst fst gpv'' = gpv ∧ map_gpv snd snd gpv'' = gpv')"
unfolding rel_gpv_def by(auto simp add: BNF_Def.Grp_def)

definition results'_rpv :: "('a, 'out, 'in) rpv ⇒ 'a set"
where "results'_rpv rpv = range rpv ⤜ results'_gpv"

definition outs'_rpv :: "('a, 'out, 'in) rpv ⇒ 'out set"
where "outs'_rpv rpv = range rpv ⤜ outs'_gpv"

abbreviation rel_rpv
:: "('a ⇒ 'b ⇒ bool) ⇒ ('out ⇒ 'out' ⇒ bool)
⇒ ('in ⇒ ('a, 'out, 'in) gpv) ⇒ ('in ⇒ ('b, 'out', 'in) gpv) ⇒ bool"
where "rel_rpv A B ≡ rel_fun (=) (rel_gpv A B)"

lemma in_results'_rpv [iff]: "x ∈ results'_rpv rpv ⟷ (∃input. x ∈ results'_gpv (rpv input))"

lemma in_outs_rpv [iff]: "out ∈ outs'_rpv rpv ⟷ (∃input. out ∈ outs'_gpv (rpv input))"

lemma results'_GPV [simp]:
"results'_gpv (GPV r) =
(set_spmf r ⤜ generat_pures) ∪
((set_spmf r ⤜ generat_conts) ⤜ results'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_GPV [simp]:
"outs'_gpv (GPV r) =
(set_spmf r ⤜ generat_outs) ∪
((set_spmf r ⤜ generat_conts) ⤜ outs'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_gpv_unfold:
"outs'_gpv r =
(set_spmf (the_gpv r) ⤜ generat_outs) ∪
((set_spmf (the_gpv r) ⤜ generat_conts) ⤜ outs'_rpv)"
by(cases r) simp

lemma outs'_gpv_induct [consumes 1, case_names Out Cont, induct set: outs'_gpv]:
assumes x: "x ∈ outs'_gpv gpv"
and Out: "⋀generat gpv. ⟦ generat ∈ set_spmf (the_gpv gpv); x ∈ generat_outs generat ⟧ ⟹ P gpv"
and Cont: "⋀generat gpv c input.
⟦ generat ∈ set_spmf (the_gpv gpv); c ∈ generat_conts generat; x ∈ outs'_gpv (c input); P (c input) ⟧ ⟹ P gpv"
shows "P gpv"
using x
apply(induction y≡"x" gpv)
apply(rule Out, simp add: in_set_spmf, simp)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma outs'_gpv_cases [consumes 1, case_names Out Cont, cases set: outs'_gpv]:
assumes "x ∈ outs'_gpv gpv"
obtains (Out) generat where "generat ∈ set_spmf (the_gpv gpv)" "x ∈ generat_outs generat"
| (Cont) generat c input where "generat ∈ set_spmf (the_gpv gpv)" "c ∈ generat_conts generat" "x ∈ outs'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma outs'_gpvI [intro?]:
shows outs'_gpv_Out: "⟦ generat ∈ set_spmf (the_gpv gpv); x ∈ generat_outs generat ⟧ ⟹ x ∈ outs'_gpv gpv"
and outs'_gpv_Cont: "⟦ generat ∈ set_spmf (the_gpv gpv); c ∈ generat_conts generat; x ∈ outs'_gpv (c input) ⟧ ⟹ x ∈ outs'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma results'_gpv_induct [consumes 1, case_names Pure Cont, induct set: results'_gpv]:
assumes x: "x ∈ results'_gpv gpv"
and Pure: "⋀generat gpv. ⟦ generat ∈ set_spmf (the_gpv gpv); x ∈ generat_pures generat ⟧ ⟹ P gpv"
and Cont: "⋀generat gpv c input.
⟦ generat ∈ set_spmf (the_gpv gpv); c ∈ generat_conts generat; x ∈ results'_gpv (c input); P (c input) ⟧ ⟹ P gpv"
shows "P gpv"
using x
apply(induction y≡"x" gpv)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma results'_gpv_cases [consumes 1, case_names Pure Cont, cases set: results'_gpv]:
assumes "x ∈ results'_gpv gpv"
obtains (Pure) generat where "generat ∈ set_spmf (the_gpv gpv)" "x ∈ generat_pures generat"
| (Cont) generat c input where "generat ∈ set_spmf (the_gpv gpv)" "c ∈ generat_conts generat" "x ∈ results'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma results'_gpvI [intro?]:
shows results'_gpv_Pure: "⟦ generat ∈ set_spmf (the_gpv gpv); x ∈ generat_pures generat ⟧ ⟹ x ∈ results'_gpv gpv"
and results'_gpv_Cont: "⟦ generat ∈ set_spmf (the_gpv gpv); c ∈ generat_conts generat; x ∈ results'_gpv (c input) ⟧ ⟹ x ∈ results'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma left_unique_rel_gpv [transfer_rule]:
"⟦ left_unique A; left_unique B ⟧ ⟹ left_unique (rel_gpv A B)"
unfolding left_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_unique_rel_gpv [transfer_rule]:
"⟦ right_unique A; right_unique B ⟧ ⟹ right_unique (rel_gpv A B)"
unfolding right_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_unique_rel_gpv [transfer_rule]:
"⟦ bi_unique A; bi_unique B ⟧ ⟹ bi_unique (rel_gpv A B)"
unfolding bi_unique_alt_def by(simp add: left_unique_rel_gpv right_unique_rel_gpv)

lemma left_total_rel_gpv [transfer_rule]:
"⟦ left_total A; left_total B ⟧ ⟹ left_total (rel_gpv A B)"
unfolding left_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_total_rel_gpv [transfer_rule]:
"⟦ right_total A; right_total B ⟧ ⟹ right_total (rel_gpv A B)"
unfolding right_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_total_rel_gpv [transfer_rule]: "⟦ bi_total A; bi_total B ⟧ ⟹ bi_total (rel_gpv A B)"
unfolding bi_total_alt_def by(simp add: left_total_rel_gpv right_total_rel_gpv)

declare gpv.map_transfer[transfer_rule]

lemma if_distrib_map_gpv [if_distribs]:
"map_gpv f g (if b then gpv else gpv') = (if b then map_gpv f g gpv else map_gpv f g gpv')"
by simp

lemma gpv_pred_mono_strong:
"⟦ pred_gpv P Q x; ⋀a. ⟦ a ∈ results'_gpv x; P a ⟧ ⟹ P' a; ⋀b. ⟦ b ∈ outs'_gpv x; Q b ⟧ ⟹ Q' b ⟧ ⟹ pred_gpv P' Q' x"

lemma pred_gpv_top [simp]:
"pred_gpv (λ_. True) (λ_. True) = (λ_. True)"

lemma pred_gpv_conj [simp]:
shows pred_gpv_conj1: "⋀P Q R. pred_gpv (λx. P x ∧ Q x) R = (λx. pred_gpv P R x ∧ pred_gpv Q R x)"
and pred_gpv_conj2: "⋀P Q R. pred_gpv P (λx. Q x ∧ R x) = (λx. pred_gpv P Q x ∧ pred_gpv P R x)"

lemma rel_gpv_restrict_relp1I [intro?]:
"⟦ rel_gpv R R' x y; pred_gpv P P' x; pred_gpv Q Q' y ⟧ ⟹ rel_gpv (R ↿ P ⊗ Q) (R' ↿ P' ⊗ Q') x y"

lemma rel_gpv_restrict_relpE [elim?]:
assumes "rel_gpv (R ↿ P ⊗ Q) (R' ↿ P' ⊗ Q') x y"
obtains "rel_gpv R R' x y" "pred_gpv P P' x" "pred_gpv Q Q' y"
proof
show "rel_gpv R R' x y" using assms by(auto elim!: gpv.rel_mono_strong)
have "pred_gpv (Domainp (R ↿ P ⊗ Q)) (Domainp (R' ↿ P' ⊗ Q')) x" using assms by(fold gpv.Domainp_rel) blast
then show "pred_gpv P P' x" by(rule gpv_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)+
have "pred_gpv (Domainp (R ↿ P ⊗ Q)¯¯) (Domainp (R' ↿ P' ⊗ Q')¯¯) y" using assms
by(fold gpv.Domainp_rel)(auto simp only: gpv.rel_conversep Domainp_conversep)
then show "pred_gpv Q Q' y" by(rule gpv_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma gpv_pred_map [simp]: "pred_gpv P Q (map_gpv f g gpv) = pred_gpv (P ∘ f) (Q ∘ g) gpv"

subsection ‹Generalised mapper and relator›

context includes lifting_syntax begin

primcorec map_gpv' :: "('a ⇒ 'b) ⇒ ('out ⇒ 'out') ⇒ ('ret' ⇒ 'ret) ⇒ ('a, 'out, 'ret) gpv ⇒ ('b, 'out', 'ret') gpv"
where
"map_gpv' f g h gpv =
GPV (map_spmf (map_generat f g ((∘) (map_gpv' f g h))) (map_spmf (map_generat id id (map_fun h id)) (the_gpv gpv)))"

declare map_gpv'.sel [simp del]

lemma map_gpv'_sel [simp]:
"the_gpv (map_gpv' f g h gpv) = map_spmf (map_generat f g (h ---> map_gpv' f g h)) (the_gpv gpv)"
by(simp add: map_gpv'.sel spmf.map_comp o_def generat.map_comp map_fun_def[abs_def])

lemma map_gpv'_GPV [simp]:
"map_gpv' f g h (GPV p) = GPV (map_spmf (map_generat f g (h ---> map_gpv' f g h)) p)"
by(rule gpv.expand) simp

lemma map_gpv'_id: "map_gpv' id id id = id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)
done

lemma map_gpv'_comp: "map_gpv' f g h (map_gpv' f' g' h' gpv) = map_gpv' (f ∘ f') (g ∘ g') (h' ∘ h) gpv"
by(coinduction arbitrary: gpv)(auto simp add: spmf.map_comp spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)

functor gpv: map_gpv' by(simp_all add: map_gpv'_comp map_gpv'_id o_def)

lemma map_gpv_conv_map_gpv': "map_gpv f g = map_gpv' f g id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: gpv.map_sel spmf_rel_map generat.rel_map rel_fun_def intro!: generat.rel_refl_strong rel_spmf_reflI)
done

coinductive rel_gpv'' :: "('a ⇒ 'b ⇒ bool) ⇒ ('out ⇒ 'out' ⇒ bool) ⇒ ('ret ⇒ 'ret' ⇒ bool) ⇒ ('a, 'out, 'ret) gpv ⇒ ('b, 'out', 'ret') gpv ⇒ bool"
for A C R
where
"rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')
⟹ rel_gpv'' A C R gpv gpv'"

lemma rel_gpv''_coinduct [consumes 1, case_names rel_gpv'', coinduct pred: rel_gpv'']:
"⟦X gpv gpv';
⋀gpv gpv'. X gpv gpv'
⟹ rel_spmf (rel_generat A C (R ===> (λgpv gpv'. X gpv gpv' ∨ rel_gpv'' A C R gpv gpv')))
(the_gpv gpv) (the_gpv gpv') ⟧
⟹ rel_gpv'' A C R gpv gpv'"
by(erule rel_gpv''.coinduct) blast

lemma rel_gpv''D:
"rel_gpv'' A C R gpv gpv'
⟹ rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')"

lemma rel_gpv''_GPV [simp]:
"rel_gpv'' A C R (GPV p) (GPV q) ⟷
rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"

lemma rel_gpv_conv_rel_gpv'': "rel_gpv A C = rel_gpv'' A C (=)"
proof(rule ext iffI)+
show "rel_gpv A C gpv gpv'" if "rel_gpv'' A C (=) gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
using that by(coinduct)(blast dest: rel_gpv''D)
show "rel_gpv'' A C (=) gpv gpv'" if "rel_gpv A C gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
using that by(coinduct)(auto elim!: gpv.rel_cases rel_spmf_mono generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_eq (* [relator_eq] do not use this attribute unless all transfer rules for gpv have been changed to rel_gvp'' *):
"rel_gpv'' (=) (=) (=) = (=)"

lemma rel_gpv''_mono:
assumes "A ≤ A'" "C ≤ C'" "R' ≤ R"
shows "rel_gpv'' A C R ≤ rel_gpv'' A' C' R'"
proof
show "rel_gpv'' A' C' R' gpv gpv'" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that
by(coinduct)(auto dest: rel_gpv''D elim!: rel_spmf_mono generat.rel_mono_strong rel_fun_mono intro: assms[THEN predicate2D])
qed

lemma rel_gpv''_conversep: "rel_gpv'' A¯¯ C¯¯ R¯¯ = (rel_gpv'' A C R)¯¯"
proof(intro ext iffI; simp)
show "rel_gpv'' A C R gpv gpv'" if "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv"
for A :: "'a1 ⇒ 'a2 ⇒ bool" and C :: "'c1 ⇒ 'c2 ⇒ bool" and R :: "'r1 ⇒ 'r2 ⇒ bool" and gpv gpv'
using that apply(coinduct)
apply(drule rel_gpv''D)
apply(rewrite in ⌑ conversep_iff[symmetric])
apply(subst spmf_rel_conversep[symmetric])
apply(erule rel_spmf_mono)
apply(subst generat.rel_conversep[symmetric])
apply(erule generat.rel_mono_strong)
done
from this[of "A¯¯" "C¯¯" "R¯¯"]
show "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that by simp
qed

lemma rel_gpv''_pos_distr:
"rel_gpv'' A C R OO rel_gpv'' A' C' R' ≤ rel_gpv'' (A OO A') (C OO C') (R OO R')"
proof(rule predicate2I; erule relcomppE)
show "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
if "rel_gpv'' A C R gpv gpv'" "rel_gpv'' A' C' R' gpv' gpv''"
for gpv gpv' gpv'' using that
apply(coinduction arbitrary: gpv gpv' gpv'')
apply(drule rel_gpv''D)+
apply(drule (1) rel_spmf_pos_distr[THEN predicate2D, OF relcomppI])
apply(erule spmf_rel_mono_strong)
apply(subst (asm) generat.rel_compp[symmetric])
apply(erule generat.rel_mono_strong, assumption, assumption)
apply(drule pos_fun_distr[THEN predicate2D])
done
qed

lemma left_unique_rel_gpv'':
"⟦ left_unique A; left_unique C; left_total R ⟧ ⟹ left_unique (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma right_unique_rel_gpv'':
"⟦ right_unique A; right_unique C; right_total R ⟧ ⟹ right_unique (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma bi_unique_rel_gpv'' [transfer_rule]:
"⟦ bi_unique A; bi_unique C; bi_total R ⟧ ⟹ bi_unique (rel_gpv'' A C R)"
unfolding bi_unique_alt_def bi_total_alt_def by(blast intro: left_unique_rel_gpv'' right_unique_rel_gpv'')

lemma rel_gpv''_map_gpv1:
"rel_gpv'' A C R (map_gpv f g gpv) gpv' = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R gpv gpv'" (is "?lhs = ?rhs")
proof
show ?rhs if ?lhs using that
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(erule rel_spmf_mono)
by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
show ?lhs if ?rhs using that
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(erule rel_spmf_mono)
by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_map_gpv2:
"rel_gpv'' A C R gpv (map_gpv f g gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv gpv'"
using rel_gpv''_map_gpv1[of "conversep A" "conversep C" "conversep R" f g gpv' gpv]
apply(rewrite in "⌑ = _" conversep_iff[symmetric])
apply(rewrite in "_ = ⌑" conversep_iff[symmetric])
apply(simp only: rel_gpv''_conversep)
apply(simp only: rel_gpv''_conversep[symmetric])
apply(simp only: conversep_iff[abs_def])
done

lemmas rel_gpv''_map_gpv = rel_gpv''_map_gpv1[abs_def] rel_gpv''_map_gpv2

lemma rel_gpv''_map_gpv' [simp]:
shows "⋀f g h gpv. NO_MATCH id f ∨ NO_MATCH id g
⟹ rel_gpv'' A C R (map_gpv' f g h gpv) = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R (map_gpv' id id h gpv)"
and "⋀f g h gpv gpv'. NO_MATCH id f ∨ NO_MATCH id g
⟹ rel_gpv'' A C R gpv (map_gpv' f g h gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv (map_gpv' id id h gpv')"
proof (goal_cases)
case (1 f g h gpv)
then show ?case using map_gpv'_comp[of f g id id id h gpv, symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
next
case (2 f g h gpv gpv')
then show ?case using map_gpv'_comp[of f g id id id h gpv', symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
qed

lemmas rel_gpv_map_gpv' = rel_gpv''_map_gpv'[where R="(=)", folded rel_gpv_conv_rel_gpv'']

definition rel_witness_gpv :: "('a ⇒ 'd ⇒ bool) ⇒ ('b ⇒ 'e ⇒ bool) ⇒ ('c ⇒ 'g ⇒ bool) ⇒ ('g ⇒ 'f ⇒ bool) ⇒ ('a, 'b, 'c) gpv × ('d, 'e, 'f) gpv ⇒ ('a × 'd, 'b × 'e, 'g) gpv" where
"rel_witness_gpv A C R R' = corec_gpv (
map_spmf (map_generat id id (λ(rpv, rpv'). (Inr ∘ rel_witness_fun R R' (rpv, rpv'))) ∘ rel_witness_generat) ∘
rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R')))) ∘ map_prod the_gpv the_gpv)"

lemma rel_witness_gpv_sel [simp]:
"the_gpv (rel_witness_gpv A C R R' (gpv, gpv')) =
map_spmf (map_generat id id (λ(rpv, rpv'). (rel_witness_gpv A C R R' ∘ rel_witness_fun R R' (rpv, rpv'))) ∘ rel_witness_generat)
(rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R')))) (the_gpv gpv, the_gpv gpv'))"
unfolding rel_witness_gpv_def
by(auto simp add: spmf.map_comp generat.map_comp o_def intro!: map_spmf_cong generat.map_cong)

lemma assumes "rel_gpv'' A C (R OO R') gpv gpv'"
and R: "left_unique R" "right_total R"
and R': "right_unique R'" "left_total R'"
shows rel_witness_gpv1: "rel_gpv'' (λa (a', b). a = a' ∧ A a' b) (λc (c', d). c = c' ∧ C c' d) R gpv (rel_witness_gpv A C R R' (gpv, gpv'))" (is "?thesis1")
and rel_witness_gpv2: "rel_gpv'' (λ(a, b') b. b = b' ∧ A a b') (λ(c, d') d. d = d' ∧ C c d') R' (rel_witness_gpv A C R R' (gpv, gpv')) gpv'" (is "?thesis2")
proof -
show ?thesis1 using assms(1)
proof(coinduction arbitrary: gpv gpv')
case rel_gpv''
from this[THEN rel_gpv''D] show ?case
by(auto simp add: spmf_rel_map generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun1[OF _ R R']]
rel_spmf_mono[OF rel_witness_spmf1] generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat1])
qed
show ?thesis2 using assms(1)
proof(coinduction arbitrary: gpv gpv')
case rel_gpv''
from this[THEN rel_gpv''D] show ?case
(erule rel_spmf_mono[OF rel_witness_spmf2]
, auto simp add: generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun2[OF _ R R']]
generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat2])
qed
qed

lemma rel_gpv''_neg_distr:
assumes R: "left_unique R" "right_total R"
and R': "right_unique R'" "left_total R'"
shows "rel_gpv'' (A OO A') (C OO C') (R OO R') ≤ rel_gpv'' A C R OO rel_gpv'' A' C' R'"
proof(rule predicate2I relcomppI)+
fix gpv gpv''
assume *: "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
let ?gpv' = "map_gpv (relcompp_witness A A') (relcompp_witness C C') (rel_witness_gpv (A OO A') (C OO C') R R' (gpv, gpv''))"
show "rel_gpv'' A C R gpv ?gpv'" using rel_witness_gpv1[OF * R R'] unfolding rel_gpv''_map_gpv
by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
show "rel_gpv'' A' C' R' ?gpv' gpv''" using rel_witness_gpv2[OF * R R'] unfolding rel_gpv''_map_gpv
by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
qed

lemma rel_gpv''_mono' [mono]:
assumes "⋀x y. A x y ⟶ A' x y"
and "⋀x y. C x y ⟶ C' x y"
and "⋀x y. R' x y ⟶ R x y"
shows "rel_gpv'' A C R gpv gpv' ⟶ rel_gpv'' A' C' R' gpv gpv'"
using rel_gpv''_mono[of A A' C C' R' R] assms by(blast)

lemma left_total_rel_gpv':
"⟦ left_total A; left_total C; left_unique R; right_total R ⟧ ⟹ left_total (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_mono; assumption)
done

lemma right_total_rel_gpv':
"⟦ right_total A; right_total C; right_unique R; left_total R ⟧ ⟹ right_total (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_mono; assumption)
done

lemma bi_total_rel_gpv' [transfer_rule]:
"⟦ bi_total A; bi_total C; bi_unique R; bi_total R ⟧ ⟹ bi_total (rel_gpv'' A C R)"
unfolding bi_total_alt_def bi_unique_alt_def by(blast intro: left_total_rel_gpv' right_total_rel_gpv')

lemma rel_fun_conversep_grp_grp:
"rel_fun (conversep (BNF_Def.Grp UNIV f)) (BNF_Def.Grp B g) = BNF_Def.Grp {x. (x ∘ f) ` UNIV ⊆ B} (map_fun f g)"
unfolding rel_fun_def Grp_def simp_thms fun_eq_iff conversep_iff by auto

lemma Quotient_gpv:
assumes Q1: "Quotient R1 Abs1 Rep1 T1"
and Q2: "Quotient R2 Abs2 Rep2 T2"
and Q3: "Quotient R3 Abs3 Rep3 T3"
shows "Quotient (rel_gpv'' R1 R2 R3) (map_gpv' Abs1 Abs2 Rep3) (map_gpv' Rep1 Rep2 Abs3) (rel_gpv'' T1 T2 T3)"
(is "Quotient ?R ?abs ?rep ?T")
unfolding Quotient_alt_def2
proof(intro conjI strip iffI; (elim conjE exE)?)
note [simp] = spmf_rel_map generat.rel_map
and [elim!] = rel_spmf_mono generat.rel_mono_strong
and [rule del] = rel_funI and [intro!] = rel_funI
have Abs1 [simp]: "Abs1 x = y" if "T1 x y" for x y using Q1 that by(simp add: Quotient_alt_def)
have Abs2 [simp]: "Abs2 x = y" if "T2 x y" for x y using Q2 that by(simp add: Quotient_alt_def)
have Abs3 [simp]: "Abs3 x = y" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def)
have Rep1: "T1 (Rep1 x) x" for x using Q1 by(simp add: Quotient_alt_def)
have Rep2: "T2 (Rep2 x) x" for x using Q2 by(simp add: Quotient_alt_def)
have Rep3: "T3 (Rep3 x) x" for x using Q3 by(simp add: Quotient_alt_def)
have T1: "T1 x (Abs1 y)" if "R1 x y" for x y using Q1 that by(simp add: Quotient_alt_def2)
have T2: "T2 x (Abs2 y)" if "R2 x y" for x y using Q2 that by(simp add: Quotient_alt_def2)
have T1': "T1 x (Abs1 y)" if "R1 y x" for x y using Q1 that by(simp add: Quotient_alt_def2)
have T2': "T2 x (Abs2 y)" if "R2 y x" for x y using Q2 that by(simp add: Quotient_alt_def2)
have R3: "R3 x (Rep3 y)" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
have R3': "R3 (Rep3 y) x" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
have r1: "R1 = T1 OO T1¯¯" using Q1 by(simp add: Quotient_alt_def4)
have r2: "R2 = T2 OO T2¯¯" using Q2 by(simp add: Quotient_alt_def4)
have r3: "R3 = T3 OO T3¯¯" using Q3 by(simp add: Quotient_alt_def4)
show abs: "?abs gpv = gpv'" if "?T gpv gpv'" for gpv gpv' using that
by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 4 intro: Rep3 dest: rel_funD)
show "?T (?rep gpv) gpv" for gpv
by(coinduction arbitrary: gpv)(auto simp add: Rep1 Rep2 intro!: rel_spmf_reflI generat.rel_refl_strong)
show "?T gpv (?abs gpv')" if "?R gpv gpv'" for gpv gpv' using that
by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1 T2 intro!: R3 dest: rel_funD)
show "?T gpv (?abs gpv')" if "?R gpv' gpv" for gpv gpv'
proof -
from that have "rel_gpv'' R1¯¯ R2¯¯ R3¯¯ gpv gpv'" unfolding rel_gpv''_conversep by simp
then show ?thesis
by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1' T2' intro!: R3' dest: rel_funD)
qed
show "?R gpv gpv'" if "?T gpv (?abs gpv')" "?T gpv' (?abs gpv)" for gpv gpv'
proof -
from that[THEN abs] have "?abs gpv' = ?abs gpv" by simp
with that have "(?T OO ?T¯¯) gpv gpv'" by(auto simp del: rel_gpv''_map_gpv')
hence "rel_gpv'' (T1 OO T1¯¯) (T2 OO T2¯¯) (T3 OO T3¯¯) gpv gpv'"
unfolding rel_gpv''_conversep[symmetric]
by(rule rel_gpv''_pos_distr[THEN predicate2D])
thus ?thesis by(simp add: r1 r2 r3)
qed
qed

lemma the_gpv_parametric':
"(rel_gpv'' A C R ===> rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R))) the_gpv the_gpv"
by(rule rel_funI)(auto elim: rel_gpv''.cases)

lemma GPV_parametric':
"(rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) ===> rel_gpv'' A C R) GPV GPV"
by(rule rel_funI)(auto)

lemma corec_gpv_parametric':
"((S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) ===> S ===> rel_gpv'' A C R)
corec_gpv corec_gpv"
proof(rule rel_funI)+
fix f g s1 s2
assume fg: "(S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) f g"
and s: "S s1 s2"
from s show "rel_gpv'' A C R (corec_gpv f s1) (corec_gpv g s2)"
apply(coinduction arbitrary: s1 s2)
apply(drule fg[THEN rel_funD])
apply(erule rel_spmf_mono)
apply(erule generat.rel_mono_strong; clarsimp simp add: o_def)
apply(rule rel_funI)
apply(drule (1) rel_funD)
apply(auto 4 3 elim!: rel_sum.cases)
done
qed

lemma map_gpv'_parametric [transfer_rule]:
"((A ===> A') ===> (C ===> C') ===> (R' ===> R) ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R') map_gpv' map_gpv'"
unfolding map_gpv'_def
supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by(transfer_prover)

lemma map_gpv_parametric': "((A ===> A') ===> (C ===> C') ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R) map_gpv map_gpv"
unfolding map_gpv_conv_map_gpv'[abs_def] by transfer_prover

end

subsection ‹Simple, derived operations›

primcorec Done :: "'a ⇒ ('a, 'out, 'in) gpv"
where "the_gpv (Done a) = return_spmf (Pure a)"

primcorec Pause :: "'out ⇒ ('in ⇒ ('a, 'out, 'in) gpv) ⇒ ('a, 'out, 'in) gpv"
where "the_gpv (Pause out c) = return_spmf (IO out c)"

primcorec lift_spmf :: "'a spmf ⇒ ('a, 'out, 'in) gpv"
where "the_gpv (lift_spmf p) = map_spmf Pure p"

definition Fail :: "('a, 'out, 'in) gpv"
where "Fail = GPV (return_pmf None)"

definition React :: "('in ⇒ 'out × ('a, 'out, 'in) rpv) ⇒ ('a, 'out, 'in) rpv"
where "React f input = case_prod Pause (f input)"

definition rFail :: "('a, 'out, 'in) rpv"
where "rFail = (λ_. Fail)"

lemma Done_inject [simp]: "Done x = Done y ⟷ x = y"

lemma Pause_inject [simp]: "Pause out c = Pause out' c' ⟷ out = out' ∧ c = c'"

lemma [simp]:
shows Done_neq_Pause: "Done x ≠ Pause out c"
and Pause_neq_Done: "Pause out c ≠ Done x"

lemma outs'_gpv_Done [simp]: "outs'_gpv (Done x) = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Done [simp]: "results'_gpv (Done x) = {x}"
by(auto intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Done [simp]: "pred_gpv P Q (Done x) = P x"

lemma outs'_gpv_Pause [simp]: "outs'_gpv (Pause out c) = insert out (⋃input. outs'_gpv (c input))"
by(auto 4 4 intro: outs'_gpvI elim: outs'_gpv_cases)

lemma results'_gpv_Pause [simp]: "results'_gpv (Pause out rpv) = results'_rpv rpv"
by(auto 4 4 intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Pause [simp]: "pred_gpv P Q (Pause x c) = (Q x ∧ All (pred_gpv P Q ∘ c))"

lemma lift_spmf_return [simp]: "lift_spmf (return_spmf x) = Done x"

lemma lift_spmf_None [simp]: "lift_spmf (return_pmf None) = Fail"

lemma the_gpv_lift_spmf [simp]: "the_gpv (lift_spmf r) = map_spmf Pure r"
by(simp)

lemma outs'_gpv_lift_spmf [simp]: "outs'_gpv (lift_spmf p) = {}"
by(auto 4 3 elim: outs'_gpv_cases)

lemma results'_gpv_lift_spmf [simp]: "results'_gpv (lift_spmf p) = set_spmf p"
by(auto 4 3 elim: results'_gpv_cases intro: results'_gpvI)

lemma pred_gpv_lift_spmf [simp]: "pred_gpv P Q (lift_spmf p) = pred_spmf P p"

lemma lift_spmf_inject [simp]: "lift_spmf p = lift_spmf q ⟷ p = q"
by(auto simp add: lift_spmf.code dest!: pmf.inj_map_strong[rotated] option.inj_map_strong[rotated])

lemma map_lift_spmf: "map_gpv f g (lift_spmf p) = lift_spmf (map_spmf f p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma lift_map_spmf: "lift_spmf (map_spmf f p) = map_gpv f id (lift_spmf p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma [simp]:
shows Fail_neq_Pause: "Fail ≠ Pause out c"
and Pause_neq_Fail: "Pause out c ≠ Fail"
and Fail_neq_Done: "Fail ≠ Done x"
and Done_neq_Fail: "Done x ≠ Fail"

text ‹Add @{typ unit} closure to circumvent SML value restriction›

definition Fail' :: "unit ⇒ ('a, 'out, 'in) gpv"
where [code del]: "Fail' _ = Fail"

lemma Fail_code [code_unfold]: "Fail = Fail' ()"

lemma Fail'_code [code]:
"Fail' x = GPV (return_pmf None)"

lemma Fail_sel [simp]:
"the_gpv Fail = return_pmf None"

lemma Fail_eq_GPV_iff [simp]: "Fail = GPV f ⟷ f = return_pmf None"

lemma outs'_gpv_Fail [simp]: "outs'_gpv Fail = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Fail [simp]: "results'_gpv Fail = {}"
by(auto elim: results'_gpv_cases)

lemma pred_gpv_Fail [simp]: "pred_gpv P Q Fail"

lemma React_inject [iff]: "React f = React f' ⟷ f = f'"
by(auto simp add: React_def fun_eq_iff split_def intro: prod.expand)

lemma React_apply [simp]: "f input = (out, c) ⟹ React f input = Pause out c"

lemma rFail_apply [simp]: "rFail input = Fail"

lemma [simp]:
shows rFail_neq_React: "rFail ≠ React f"
and React_neq_rFail: "React f ≠ rFail"

lemma rel_gpv_FailI [simp]: "rel_gpv A C Fail Fail"
by(subst gpv.rel_sel) simp

lemma rel_gpv_Done [iff]: "rel_gpv A C (Done x) (Done y) ⟷ A x y"
by(subst gpv.rel_sel) simp

lemma rel_gpv''_Done [iff]: "rel_gpv'' A C R (Done x) (Done y) ⟷ A x y"
by(subst rel_gpv''.simps) simp

lemma rel_gpv_Pause [iff]:
"rel_gpv A C (Pause out c) (Pause out' c') ⟷ C out out' ∧ (∀x. rel_gpv A C (c x) (c' x))"

lemma rel_gpv''_Pause [iff]:
"rel_gpv'' A C R (Pause out c) (Pause out' c') ⟷ C out out' ∧ (∀x x'. R x x' ⟶ rel_gpv'' A C R (c x) (c' x'))"

lemma rel_gpv_lift_spmf [iff]: "rel_gpv A C (lift_spmf p) (lift_spmf q) ⟷ rel_spmf A p q"

lemma rel_gpv''_lift_spmf [iff]:
"rel_gpv'' A C R (lift_spmf p) (lift_spmf q) ⟷ rel_spmf A p q"

context includes lifting_syntax begin
lemmas Fail_parametric [transfer_rule] = rel_gpv_FailI

lemma Fail_parametric' [simp]: "rel_gpv'' A C R Fail Fail"
unfolding Fail_def by simp

lemma Done_parametric [transfer_rule]: "(A ===> rel_gpv A C) Done Done"
by(rule rel_funI) simp

lemma Done_parametric': "(A ===> rel_gpv'' A C R) Done Done"
by(rule rel_funI) simp

lemma Pause_parametric [transfer_rule]:
"(C ===> ((=) ===> rel_gpv A C) ===> rel_gpv A C) Pause Pause"

lemma Pause_parametric':
"(C ===> (R ===> rel_gpv'' A C R) ===> rel_gpv'' A C R) Pause Pause"

lemma lift_spmf_parametric [transfer_rule]:
"(rel_spmf A ===> rel_gpv A C) lift_spmf lift_spmf"

lemma lift_spmf_parametric':
"(rel_spmf A ===> rel_gpv'' A C R) lift_spmf lift_spmf"
end

lemma map_gpv_Done [simp]: "map_gpv f g (Done x) = Done (f x)"

lemma map_gpv'_Done [simp]: "map_gpv' f g h (Done x) = Done (f x)"

lemma map_gpv_Pause [simp]: "map_gpv f g (Pause x c) = Pause (g x) (map_gpv f g ∘ c)"

lemma map_gpv'_Pause [simp]: "map_gpv' f g h (Pause x c) = Pause (g x) (map_gpv' f g h ∘ c ∘ h)"

lemma map_gpv_Fail [simp]: "map_gpv f g Fail = Fail"

lemma map_gpv'_Fail [simp]: "map_gpv' f g h Fail = Fail"

primcorec bind_gpv :: "('a, 'out, 'in) gpv ⇒ ('a ⇒ ('b, 'out, 'in) gpv) ⇒ ('b, 'out, 'in) gpv"
where
"the_gpv (bind_gpv r f) =
map_spmf (map_generat id id ((∘) (case_sum id (λr. bind_gpv r f))))
(the_gpv r ⤜
(case_generat
(λx. map_spmf (map_generat id id ((∘) Inl)) (the_gpv (f x)))
(λout c. return_spmf (IO out (λinput. Inr (c input))))))"

declare bind_gpv.sel [simp del]

lemma bind_gpv_unfold [code]:
"r ⤜ f = GPV (
do {
generat ← the_gpv r;
case generat of Pure x ⇒ the_gpv (f x)
| IO out c ⇒ return_spmf (IO out (λinput. c input ⤜ f))
})"
unfolding bind_gpv_def
apply(rule gpv.expand)
apply(rule arg_cong[where f="bind_spmf (the_gpv r)"])
apply(auto split: generat.split simp add: map_spmf_bind_spmf fun_eq_iff spmf.map_comp o_def generat.map_comp id_def[symmetric] generat.map_id pmf.map_id option.map_id)
done

lemma bind_gpv_code_cong: "f = f' ⟹ bind_gpv f g = bind_gpv f' g" by simp

lemma bind_gpv_sel:
"the_gpv (r ⤜ f) =
do {
generat ← the_gpv r;
case generat of Pure x ⇒ the_gpv (f x)
| IO out c ⇒ return_spmf (IO out (λinput. bind_gpv (c input) f))
}"
by(subst bind_gpv_unfold) simp

lemma bind_gpv_sel' [simp]:
"the_gpv (r ⤜ f) =
do {
generat ← the_gpv r;
if is_Pure generat then the_gpv (f (result generat))
else return_spmf (IO (output generat) (λinput. bind_gpv (continuation generat input) f))
}"
unfolding bind_gpv_sel
by(rule arg_cong[where f="bind_spmf (the_gpv r)"])(simp add: fun_eq_iff split: generat.split)

lemma Done_bind_gpv [simp]: "Done a ⤜ f = f a"
by(rule gpv.expand)(simp)

lemma bind_gpv_Done [simp]: "f ⤜ Done = f"
proof(coinduction arbitrary: f rule: gpv.coinduct)
case (Eq_gpv f)
have *: "the_gpv f ⤜ (case_generat (λx. return_spmf (Pure x)) (λout c. return_spmf (IO out (λinput. Inr (c input))))) =
map_spmf (map_generat id id ((∘) Inr)) (bind_spmf (the_gpv f) return_spmf)"
unfolding map_spmf_bind_spmf
by(rule arg_cong2[where f=bind_spmf])(auto simp add: fun_eq_iff split: generat.split)
show ?case
by(auto simp add: * bind_gpv.simps pmf.rel_map option.rel_map[abs_def] generat.rel_map[abs_def] simp del: bind_gpv_sel' intro!: rel_generatI rel_spmf_reflI)
qed

lemma if_distrib_bind_gpv2 [if_distribs]:
"bind_gpv gpv (λy. if b then f y else g y) = (if b then bind_gpv gpv f else bind_gpv gpv g)"
by simp

lemma lift_spmf_bind: "lift_spmf r ⤜ f = GPV (r ⤜ the_gpv ∘ f)"
by(coinduction arbitrary: r f rule: gpv.coinduct_strong)(auto simp add: bind_map_spmf o_def intro: rel_pmf_reflI rel_optionI rel_generatI)

lemma the_gpv_bind_gpv_lift_spmf [simp]:
"the_gpv (bind_gpv (lift_spmf p) f) = bind_spmf p (the_gpv ∘ f)"

lemma lift_spmf_bind_spmf: "lift_spmf (p ⤜ f) = lift_spmf p ⤜ (λx. lift_spmf (f x))"
by(rule gpv.expand)(simp add: lift_spmf_bind o_def map_spmf_bind_spmf)

lemma lift_bind_spmf: "lift_spmf (bind_spmf p f) = bind_gpv (lift_spmf p) (lift_spmf ∘ f)"
by(rule gpv.expand)(simp add: bind_map_spmf map_spmf_bind_spmf o_def)

lemma GPV_bind:
"GPV f ⤜ g =
GPV (f ⤜ (λgenerat. case generat of Pure x ⇒ the_gpv (g x) | IO out c ⇒ return_spmf (IO out (λinput. c input ⤜ g))))"
by(subst bind_gpv_unfold) simp

lemma GPV_bind':
"GPV f ⤜ g = GPV (f ⤜ (λgenerat. if is_Pure generat then the_gpv (g (result generat)) else return_spmf (IO (output generat) (λinput. continuation generat input ⤜ g))))"
unfolding GPV_bind gpv.inject
by(rule arg_cong[where f="bind_spmf f"])(simp add: fun_eq_iff split: generat.split)

lemma bind_gpv_assoc:
fixes f :: "('a, 'out, 'in) gpv"
shows "(f ⤜ g) ⤜ h = f ⤜ (λx. g x ⤜ h)"
proof(coinduction arbitrary: f g h rule: gpv.coinduct_strong)
case (Eq_gpv f g h)
show ?case
apply(simp cong del: if_weak_cong)
apply(rule rel_spmf_bindI[where R="(=)"])
apply(fastforce intro: rel_pmf_return_pmfI rel_generatI rel_spmf_reflI)
done
qed

lemma map_gpv_bind_gpv: "map_gpv f g (bind_gpv gpv h) = bind_gpv (map_gpv id g gpv) (λx. map_gpv f g (h x))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(simp add: bind_gpv.sel gpv.map_sel spmf_rel_map generat.rel_map o_def bind_map_spmf del: bind_gpv_sel')
apply(rule rel_spmf_bind_reflI)
apply(auto simp add: spmf_rel_map generat.rel_map split: generat.split del: rel_funI intro!: rel_spmf_reflI generat.rel_refl rel_funI)
done

lemma map_gpv_id_bind_gpv: "map_gpv f id (bind_gpv gpv g) = bind_gpv gpv (map_gpv f id ∘ g)"

lemma map_gpv_conv_bind:
"map_gpv f (λx. x) x = bind_gpv x (λx. Done (f x))"
using map_gpv_bind_gpv[of f "λx. x" x Done] by(simp add: id_def[symmetric] gpv.map_id)

lemma bind_map_gpv: "bind_gpv (map_gpv f id gpv) g = bind_gpv gpv (g ∘ f)"
by(simp add: map_gpv_conv_bind id_def bind_gpv_assoc o_def)

lemma outs_bind_gpv:
"outs'_gpv (bind_gpv x f) = outs'_gpv x ∪ (⋃x ∈ results'_gpv x. outs'_gpv (f x))"
(is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
fix out
assume "out ∈ ?lhs"
then show "out ∈ ?rhs"
proof(induction g≡"x ⤜ f" arbitrary: x)
case (Out generat)
then obtain generat' where *: "generat' ∈ set_spmf (the_gpv x)"
and **: "generat ∈ set_spmf (if is_Pure generat' then the_gpv (f (result generat'))
else return_spmf (IO (output generat') (λinput. continuation generat' input ⤜ f)))"
by(auto)
show ?case
proof(cases "is_Pure generat'")
case True
then have "out ∈ outs'_gpv (f (result generat'))" using Out(2) ** by(auto intro: outs'_gpvI)
moreover have "result generat' ∈ results'_gpv x" using * True
by(auto intro: results'_gpvI generat.set_sel)
ultimately show ?thesis by blast
next
case False
hence "out ∈ outs'_gpv x" using * ** Out(2) by(auto intro: outs'_gpvI generat.set_sel)
thus ?thesis by blast
qed
next
case (Cont generat c input)
then obtain generat' where *: "generat' ∈ set_spmf (the_gpv x)"
and **: "generat ∈ set_spmf (if is_Pure generat' then the_gpv (f (generat.result generat'))
else return_spmf (IO (generat.output generat') (λinput. continuation generat' input ⤜ f)))"
by(auto)
show ?case
proof(cases "is_Pure generat'")
case True
then have "out ∈ outs'_gpv (f (result generat'))" using Cont(2-3) ** by(auto intro: outs'_gpvI)
moreover have "result generat' ∈ results'_gpv x" using * True
by(auto intro: results'_gpvI generat.set_sel)
ultimately show ?thesis by blast
next
case False
then have generat: "generat = IO (output generat') (λinput. continuation generat' input ⤜ f)"
using ** by simp
with Cont(2) have "c input = continuation generat' input ⤜ f" by auto
hence "out ∈ outs'_gpv (continuation generat' input) ∪ (⋃x∈results'_gpv (continuation generat' input). outs'_gpv (f x))"
by(rule Cont)
thus ?thesis
proof
assume "out ∈ outs'_gpv (continuation generat' input)"
with * ** False have "out ∈ outs'_gpv x" by(auto intro: outs'_gpvI generat.set_sel)
thus ?thesis ..
next
assume "out ∈ (⋃x∈results'_gpv (continuation generat' input). outs'_gpv (f x))"
then obtain y where "y ∈ results'_gpv (continuation generat' input)" "out ∈ outs'_gpv (f y)" ..
from ‹y ∈ _› * ** False have "y ∈ results'_gpv x"
by(auto intro: results'_gpvI generat.set_sel)
with ‹out ∈ outs'_gpv (f y)› show ?thesis by blast
qed
qed
qed
next
fix out
assume "out ∈ ?rhs"
then show "out ∈ ?lhs"
proof
assume "out ∈ outs'_gpv x"
thus ?thesis
proof(induction)
case (Out generat gpv)
then show ?case
by(cases generat)(fastforce intro: outs'_gpvI rev_bexI)+
next
case (Cont generat gpv gpv')
then show ?case
by(cases generat)(auto 4 4 intro: outs'_gpvI rev_bexI simp add: in_set_spmf set_pmf_bind_spmf simp del: set_bind_spmf)
qed
next
assume "out ∈ (⋃x∈results'_gpv x. outs'_gpv (f x))"
then obtain y where "y ∈ results'_gpv x" "out ∈ outs'_gpv (f y)" ..
from ‹y ∈ _› show ?thesis
proof(induction)
case (Pure generat gpv)
thus ?case using ‹out ∈ outs'_gpv _›
by(cases generat)(auto 4 5 intro: outs'_gpvI rev_bexI elim: outs'_gpv_cases)
next
case (Cont generat gpv gpv')
thus ?case
by(cases generat)(auto 4 4 simp add: in_set_spmf simp add: set_pmf_bind_spmf intro: outs'_gpvI rev_bexI simp del: set_bind_spmf)
qed
qed
qed

lemma bind_gpv_Fail [simp]: "Fail ⤜ f = Fail"

lemma bind_gpv_eq_Fail:
"bind_gpv gpv f = Fail ⟷ (∀x∈set_spmf (the_gpv gpv). is_Pure x) ∧ (∀x∈results'_gpv gpv. f x = Fail)"
(is "?lhs = ?rhs")
proof(intro iffI conjI strip)
show ?lhs if ?rhs using that
by(intro gpv.expand)(auto 4 4 simp add: bind_eq_return_pmf_None intro: results'_gpv_Pure generat.set_sel dest: bspec)

assume ?lhs
hence *: "the_gpv (bind_gpv gpv f) = return_pmf None" by simp
from * show "is_Pure x" if "x ∈ set_spmf (the_gpv gpv)" for x using that
show "f x = Fail" if "x ∈ results'_gpv gpv" for x using that *
by(cases)(auto 4 3 simp add: bind_eq_return_pmf_None elim!: generat.set_cases intro: gpv.expand dest: bspec)
qed

context includes lifting_syntax begin

lemma bind_gpv_parametric [transfer_rule]:
"(rel_gpv A C ===> (A ===> rel_gpv B C) ===> rel_gpv B C) bind_gpv bind_gpv"
unfolding bind_gpv_def by transfer_prover

lemma bind_gpv_parametric':
"(rel_gpv'' A C R ===> (A ===> rel_gpv'' B C R) ===> rel_gpv'' B C R) bind_gpv bind_gpv"
unfolding bind_gpv_def supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by(transfer_prover)

end

by unfold_locales auto

lemma rel_gpv_bindI:
"⟦ rel_gpv A C gpv gpv'; ⋀x y. A x y ⟹ rel_gpv B C (f x) (g y) ⟧
⟹ rel_gpv B C (bind_gpv gpv f) (bind_gpv gpv' g)"
by(fact bind_gpv_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])

lemma bind_gpv_cong:
"⟦ gpv = gpv'; ⋀x. x ∈ results'_gpv gpv' ⟹ f x = g x ⟧ ⟹ bind_gpv gpv f = bind_gpv gpv' g"
apply(subst gpv.rel_eq[symmetric])
apply(rule rel_gpv_bindI[where A="eq_onp (λx. x ∈ results'_gpv gpv')"])
apply(subst (asm) gpv.rel_eq[symmetric])
apply(erule gpv.rel_mono_strong)
apply simp
done

definition bind_rpv :: "('a, 'in, 'out) rpv ⇒ ('a ⇒ ('b, 'in, 'out) gpv) ⇒ ('b, 'in, 'out) rpv"
where "bind_rpv rpv f = (λinput. bind_gpv (rpv input) f)"

lemma bind_rpv_apply [simp]: "bind_rpv rpv f input = bind_gpv (rpv input) f"

lemma bind_rpv_code_cong: "rpv = rpv' ⟹ bind_rpv rpv f = bind_rpv rpv' f" by simp

lemma bind_rpv_rDone [simp]: "bind_rpv rpv Done = rpv"

lemma bind_gpv_Pause [simp]: "bind_gpv (Pause out rpv) f = Pause out (bind_rpv rpv f)"

lemma bind_rpv_React [simp]: "bind_rpv (React f) g = React (apsnd (λrpv. bind_rpv rpv g) ∘ f)"

lemma bind_rpv_assoc: "bind_rpv (bind_rpv rpv f) g = bind_rpv rpv ((λgpv. bind_gpv gpv g) ∘ f)"

lemma bind_rpv_Done [simp]: "bind_rpv Done f = f"

lemma results'_rpv_Done [simp]: "results'_rpv Done = UNIV"

subsection ‹ Embedding @{typ "'a spmf"} as a monad ›

lemma neg_fun_distr3:
includes lifting_syntax
assumes 1: "left_unique R" "right_total R"
assumes 2: "right_unique S" "left_total S"
shows "(R OO R' ===> S OO S') ≤ ((R ===> S) OO (R' ===> S'))"
using functional_relation[OF 2] functional_converse_relation[OF 1]
unfolding rel_fun_def OO_def
apply clarify
apply (subst all_comm)
apply (subst all_conj_distrib[symmetric])
apply (intro choice)
by metis

locale spmf_to_gpv begin

text ‹
The lifting package cannot handle free term variables in the merging of transfer rules,
so for the embedding we define a specialised relator ‹rel_gpv'›
which acts only on the returned values.
›

definition rel_gpv' :: "('a ⇒ 'b ⇒ bool) ⇒ ('a, 'out, 'in) gpv ⇒ ('b, 'out, 'in) gpv ⇒ bool"
where "rel_gpv' A = rel_gpv A (=)"

lemma rel_gpv'_eq [relator_eq]: "rel_gpv' (=) = (=)"
unfolding rel_gpv'_def gpv.rel_eq ..

lemma rel_gpv'_mono [relator_mono]: "A ≤ B ⟹ rel_gpv' A ≤ rel_gpv' B"
unfolding rel_gpv'_def by(rule gpv.rel_mono; simp)

lemma rel_gpv'_distr [relator_distr]: "rel_gpv' A OO rel_gpv' B = rel_gpv' (A OO B)"
unfolding rel_gpv'_def by (metis OO_eq gpv.rel_compp)

lemma left_unique_rel_gpv' [transfer_rule]: "left_unique A ⟹ left_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_unique_rel_gpv left_unique_eq)

lemma right_unique_rel_gpv' [transfer_rule]: "right_unique A ⟹ right_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_unique_rel_gpv right_unique_eq)

lemma bi_unique_rel_gpv' [transfer_rule]: "bi_unique A ⟹ bi_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_unique_rel_gpv bi_unique_eq)

lemma left_total_rel_gpv' [transfer_rule]: "left_total A ⟹ left_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_total_rel_gpv left_total_eq)

lemma right_total_rel_gpv' [transfer_rule]: "right_total A ⟹ right_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_total_rel_gpv right_total_eq)

lemma bi_total_rel_gpv' [transfer_rule]: "bi_total A ⟹ bi_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_total_rel_gpv bi_total_eq)

text ‹
We cannot use ‹setup_lifting› because @{typ "('a, 'out, 'in) gpv"} contains
type variables which do not appear in @{typ "'a spmf"}.
›

definition cr_spmf_gpv :: "'a spmf ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where "cr_spmf_gpv p gpv ⟷ gpv = lift_spmf p"

definition spmf_of_gpv :: "('a, 'out, 'in) gpv ⇒ 'a spmf"
where "spmf_of_gpv gpv = (THE p. gpv = lift_spmf p)"

lemma spmf_of_gpv_lift_spmf [simp]: "spmf_of_gpv (lift_spmf p) = p"
unfolding spmf_of_gpv_def by auto

lemma rel_spmf_setD2:
"⟦ rel_spmf A p q; y ∈ set_spmf q ⟧ ⟹ ∃x∈set_spmf p. A x y"
by(erule rel_spmfE) force

lemma rel_gpv_lift_spmf1: "rel_gpv A B (lift_spmf p) gpv ⟷ (∃q. gpv = lift_spmf q ∧ rel_spmf A p q)"
apply(subst gpv.rel_sel)
apply safe
apply(rule exI[where x="map_spmf result (the_gpv gpv)"])
apply(rule conjI)
apply(rule gpv.expand)
apply(subst map_spmf_cong[OF refl, where g=id])
apply(drule (1) rel_spmf_setD2)
apply clarsimp
apply simp
apply(erule rel_spmf_mono)
apply clarsimp
done

lemma rel_gpv_lift_spmf2: "rel_gpv A B gpv (lift_spmf q) ⟷ (∃p. gpv = lift_spmf p ∧ rel_spmf A p q)"
by(subst gpv.rel_flip[symmetric])(simp add: rel_gpv_lift_spmf1 pmf.rel_flip option.rel_conversep)

definition pcr_spmf_gpv :: "('a ⇒ 'b ⇒ bool) ⇒ 'a spmf ⇒ ('b, 'out, 'in) gpv ⇒ bool"
where "pcr_spmf_gpv A = cr_spmf_gpv OO rel_gpv A (=)"

lemma pcr_cr_eq_spmf_gpv: "pcr_spmf_gpv (=) = cr_spmf_gpv"

lemma left_unique_cr_spmf_gpv: "left_unique cr_spmf_gpv"

lemma left_unique_pcr_spmf_gpv [transfer_rule]:
"left_unique A ⟹ left_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_unique_OO left_unique_cr_spmf_gpv left_unique_rel_gpv left_unique_eq)

lemma right_unique_cr_spmf_gpv: "right_unique cr_spmf_gpv"

lemma right_unique_pcr_spmf_gpv [transfer_rule]:
"right_unique A ⟹ right_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro right_unique_OO right_unique_cr_spmf_gpv right_unique_rel_gpv right_unique_eq)

lemma bi_unique_cr_spmf_gpv: "bi_unique cr_spmf_gpv"

lemma bi_unique_pcr_spmf_gpv [transfer_rule]: "bi_unique A ⟹ bi_unique (pcr_spmf_gpv A)"

lemma left_total_cr_spmf_gpv: "left_total cr_spmf_gpv"

lemma left_total_pcr_spmf_gpv [transfer_rule]: "left_total A ==> left_total (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_total_OO left_total_cr_spmf_gpv left_total_rel_gpv left_total_eq)

context includes lifting_syntax begin

lemma return_spmf_gpv_transfer':
"((=) ===> cr_spmf_gpv) return_spmf Done"

lemma return_spmf_gpv_transfer [transfer_rule]:
"(A ===> pcr_spmf_gpv A) return_spmf Done"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(⌑ ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
apply(rule return_spmf_gpv_transfer')
apply transfer_prover
done

lemma bind_spmf_gpv_transfer':
"(cr_spmf_gpv ===> ((=) ===> cr_spmf_gpv) ===> cr_spmf_gpv) bind_spmf bind_gpv"
apply(rule gpv.expand)
done

lemma bind_spmf_gpv_transfer [transfer_rule]:
"(pcr_spmf_gpv A ===> (A ===> pcr_spmf_gpv B) ===> pcr_spmf_gpv B) bind_spmf bind_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(_ ===> (⌑ ===> _) ===> _) _ _" eq_OO[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule order.refl)
apply(rule fun_mono)
apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_cr_spmf_gpv left_total_cr_spmf_gpv])
apply(rule order.refl)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule order.refl)
apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
apply(rule bind_spmf_gpv_transfer')
apply transfer_prover
done

lemma lift_spmf_gpv_transfer':
"((=) ===> cr_spmf_gpv) (λx. x) lift_spmf"

lemma lift_spmf_gpv_transfer [transfer_rule]:
"(rel_spmf A ===> pcr_spmf_gpv A) (λx. x) lift_spmf"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(⌑ ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
apply(rule lift_spmf_gpv_transfer')
apply transfer_prover
done

lemma fail_spmf_gpv_transfer': "cr_spmf_gpv (return_pmf None) Fail"

lemma fail_spmf_gpv_transfer [transfer_rule]: "pcr_spmf_gpv A (return_pmf None) Fail"
unfolding pcr_spmf_gpv_def
apply(rule relcomppI)
apply(rule fail_spmf_gpv_transfer')
apply transfer_prover
done

lemma map_spmf_gpv_transfer':
"((=) ===> R ===> cr_spmf_gpv ===> cr_spmf_gpv) (λf g. map_spmf f) map_gpv"

lemma map_spmf_gpv_transfer [transfer_rule]:
"((A ===> B) ===> R ===> pcr_spmf_gpv A ===> pcr_spmf_gpv B) (λf g. map_spmf f) map_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "((⌑ ===> _) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "((_ ===> ⌑) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "(_ ===> ⌑ ===> _)  _ _" OO_eq[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_eq left_total_eq])
apply(rule fun_mono[OF order.refl])
apply(rule pos_fun_distr)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule order.refl)
apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
apply(unfold rel_fun_eq)
apply(rule map_spmf_gpv_transfer')
apply(unfold rel_fun_eq[symmetric])
apply transfer_prover
done

end

end

subsection ‹ Embedding @{typ "'a option"} as a monad ›

locale option_to_gpv begin

interpretation option_to_spmf .
interpretation spmf_to_gpv .

definition cr_option_gpv :: "'a option ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where "cr_option_gpv x gpv ⟷ gpv = (lift_spmf ∘ return_pmf) x"

lemma cr_option_gpv_conv_OO:
"cr_option_gpv = cr_spmf_option¯¯ OO cr_spmf_gpv"
by(simp add: fun_eq_iff relcompp.simps cr_option_gpv_def cr_spmf_gpv_def cr_spmf_option_def)

context includes lifting_syntax begin

text ‹These transfer rules should follow from merging the transfer rules, but this has not yet been implemented.›

lemma return_option_gpv_transfer [transfer_rule]:
"((=) ===> cr_option_gpv) Some Done"

lemma bind_option_gpv_transfer [transfer_rule]:
"(cr_option_gpv ===> ((=) ===> cr_option_gpv) ===> cr_option_gpv) Option.bind bind_gpv"
subgoal for x f g by(cases x; simp)
done

lemma fail_option_gpv_transfer [transfer_rule]: "cr_option_gpv None Fail"

lemma map_option_gpv_transfer [transfer_rule]:
"((=) ===> R ===> cr_option_gpv ===> cr_option_gpv) (λf g. map_option f) map_gpv"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_gpv_def map_lift_spmf)

end

end

locale option_le_gpv begin

interpretation option_le_spmf .
interpretation spmf_to_gpv .

definition cr_option_le_gpv :: "'a option ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where "cr_option_le_gpv x gpv ⟷ gpv = (lift_spmf ∘ return_pmf) x ∨ x = None"

context includes lifting_syntax begin

lemma return_option_le_gpv_transfer [transfer_rule]:
"((=) ===> cr_option_le_gpv) Some Done"

lemma bind_option_gpv_transfer [transfer_rule]:
"(cr_option_le_gpv ===> ((=) ===> cr_option_le_gpv) ===> cr_option_le_gpv) Option.bind bind_gpv"
apply(clarsimp simp add: cr_option_le_gpv_def rel_fun_def bind_eq_Some_conv)
subgoal for f g x y by(erule allE[where x=y]) auto
done

lemma fail_option_gpv_transfer [transfer_rule]:
"cr_option_le_gpv None Fail"

lemma map_option_gpv_transfer [transfer_rule]:
"(((=) ===> (=)) ===> cr_option_le_gpv ===> cr_option_le_gpv) map_option (λf. map_gpv f id)"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_le_gpv_def map_lift_spmf)

end

end

subsection ‹Embedding resumptions›

primcorec lift_resumption :: "('a, 'out, 'in) resumption ⇒ ('a, 'out, 'in) gpv"
where
"the_gpv (lift_resumption r) =
(case r of resumption.Done None ⇒ return_pmf None
| resumption.Done (Some x') => return_spmf (Pure x')
| resumption.Pause out c => map_spmf (map_generat id id ((∘) lift_resumption)) (return_spmf (IO out c)))"

lemma the_gpv_lift_resumption:
"the_gpv (lift_resumption r) =
(if is_Done r then if Option.is_none (resumption.result r) then return_pmf None else return_spmf (Pure (the (resumption.result r)))
else return_spmf (IO (resumption.output r) (lift_resumption ∘ resume r)))"
by(simp split: option.split resumption.split)

declare lift_resumption.simps [simp del]

lemma lift_resumption_Done [code]:
"lift_resumption (resumption.Done x) = (case x of None ⇒ Fail | Some x' ⇒ Done x')"
by(rule gpv.expand)(simp add: the_gpv_lift_resumption split: option.split)

lemma lift_resumption_DONE [simp]:
"lift_resumption (DONE x) = Done x"

lemma lift_resumption_ABORT [simp]:
"lift_resumption ABORT = Fail"

lemma lift_resumption_Pause [simp, code]:
"lift_resumption (resumption.Pause out c) = Pause out (lift_resumption ∘ c)"

lemma lift_resumption_Done_Some [simp]: "lift_resumption (resumption.Done (Some x)) = Done x"
using lift_resumption_DONE unfolding DONE_def by simp

lemma results'_gpv_lift_resumption [simp]:
"results'_gpv (lift_resumption r) = results r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
by(induction gpv≡"lift_resumption r" arbitrary: r)
(auto intro: resumption.set_sel simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that by induction(auto simp add: lift_resumption.sel)
qed

lemma outs'_gpv_lift_resumption [simp]:
"outs'_gpv (lift_resumption r) = outputs r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
by(induction gpv≡"lift_resumption r" arbitrary: r)
(auto simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
show "x ∈ ?lhs" if "x ∈ ?rhs" for x using that by induction auto
qed

lemma pred_gpv_lift_resumption [simp]:
"⋀A. pred_gpv A C (lift_resumption r) = pred_resumption A C r"

lemma lift_resumption_bind: "lift_resumption (r ⤜ f) = lift_resumption r ⤜ lift_resumption ∘ f"
by(coinduction arbitrary: r rule: gpv.coinduct_strong)
(auto simp add: lift_resumption.sel Done_bind split: resumption.split option.split del: rel_funI intro!: rel_funI)

subsection ‹Assertions›

definition assert_gpv :: "bool ⇒ (unit, 'out, 'in) gpv"
where "assert_gpv b = (if b then Done () else Fail)"

lemma assert_gpv_simps [simp]:
"assert_gpv True = Done ()"
"assert_gpv False = Fail"

lemma [simp]:
shows assert_gpv_eq_Done: "assert_gpv b = Done x ⟷ b"
and Done_eq_assert_gpv: "Done x = assert_gpv b ⟷ b"
and Pause_neq_assert_gpv: "Pause out rpv ≠ assert_gpv b"
and assert_gpv_neq_Pause: "assert_gpv b ≠ Pause out rpv"
and assert_gpv_eq_Fail: "assert_gpv b = Fail ⟷ ¬ b"
and Fail_eq_assert_gpv: "Fail = assert_gpv b ⟷ ¬ b"

lemma assert_gpv_inject [simp]: "assert_gpv b = assert_gpv b' ⟷ b = b'"

lemma assert_gpv_sel [simp]:
"the_gpv (assert_gpv b) = map_spmf Pure (assert_spmf b)"

lemma the_gpv_bind_assert [simp]:
"the_gpv (bind_gpv (assert_gpv b) f) =
bind_spmf (assert_spmf b) (the_gpv ∘ f)"
by(cases b) simp_all

lemma pred_gpv_assert [simp]: "pred_gpv P Q (assert_gpv b) = (b ⟶ P ())"
by(cases b) simp_all

primcorec try_gpv :: "('a, 'call, 'ret) gpv ⇒ ('a, 'call, 'ret) gpv ⇒ ('a, 'call, 'ret) gpv" ("TRY _ ELSE _" [0,60] 59)
where
"the_gpv (TRY gpv ELSE gpv') =
map_spmf (map_generat id id (λc input. case c input of Inl gpv ⇒ try_gpv gpv gpv' | Inr gpv' ⇒ gpv'))
(try_spmf (map_spmf (map_generat id id (map_fun id Inl)) (the_gpv gpv))
(map_spmf (map_generat id id (map_fun id Inr)) (the_gpv gpv')))"

lemma try_gpv_sel:
"the_gpv (TRY gpv ELSE gpv') =
TRY map_spmf (map_generat id id (λc input. TRY c input ELSE gpv')) (the_gpv gpv) ELSE the_gpv gpv'"
by(simp add: try_gpv_def map_try_spmf spmf.map_comp o_def generat.map_comp generat.map_ident id_def)

lemma try_gpv_Done [simp]: "TRY Done x ELSE gpv' = Done x"
by(rule gpv.expand)(simp)

lemma try_gpv_Fail [simp]: "TRY Fail ELSE gpv' = gpv'"
by(rule gpv.expand)(simp add: spmf.map_comp o_def generat.map_comp generat.map_ident)

lemma try_gpv_Pause [simp]: "TRY Pause out c ELSE gpv' = Pause out (λinput. TRY c input ELSE gpv')"
by(rule gpv.expand) simp

lemma try_gpv_Fail2 [simp]: "TRY gpv ELSE Fail = gpv"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
(auto simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl)

lemma lift_try_spmf: "lift_spmf (TRY p ELSE q) = TRY lift_spmf p ELSE lift_spmf q"
by(rule gpv.expand)(simp add: map_try_spmf spmf.map_comp o_def)

lemma try_assert_gpv: "TRY assert_gpv b ELSE gpv' = (if b then Done () else gpv')"
by(simp)

context includes lifting_syntax begin
lemma try_gpv_parametric [transfer_rule]:
"(rel_gpv A C ===> rel_gpv A C ===> rel_gpv A C) try_gpv try_gpv"
unfolding try_gpv_def by transfer_prover

lemma try_gpv_parametric':
"(rel_gpv'' A C R ===> rel_gpv'' A C R ===> rel_gpv'' A C R) try_gpv try_gpv"
unfolding try_gpv_def
supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by transfer_prover
end

lemma map_try_gpv: "map_gpv f g (TRY gpv ELSE gpv') = TRY map_gpv f g gpv ELSE map_gpv f g gpv'"
by(simp add: gpv.rel_map try_gpv_parametric[THEN rel_funD, THEN rel_funD] gpv.rel_refl gpv.rel_eq[symmetric])

lemma map'_try_gpv: "map_gpv' f g h (TRY gpv ELSE gpv') = TRY map_gpv' f g h gpv ELSE map_gpv' f g h gpv'"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)(auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl rel_funI rel_spmf_try_spmf)

lemma try_bind_assert_gpv:
"TRY (assert_gpv b ⤜ f) ELSE gpv = (if b then TRY (f ()) ELSE gpv else gpv)"
by(simp)

subsection ‹Order for @{typ "('a, 'out, 'in) gpv"}›

coinductive ord_gpv :: "('a, 'out, 'in) gpv ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where
"ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) f g ⟹ ord_gpv (GPV f) (GPV g)"

inductive_simps ord_gpv_simps [simp]:
"ord_gpv (GPV f) (GPV g)"

lemma ord_gpv_coinduct [consumes 1, case_names ord_gpv, coinduct pred: ord_gpv]:
assumes "X f g"
and step: "⋀f g. X f g ⟹ ord_spmf (rel_generat (=) (=) (rel_fun (=) X)) (the_gpv f) (the_gpv g)"
shows "ord_gpv f g"
using ‹X f g›
by(coinduct)(auto dest: step simp add: eq_GPV_iff intro: ord_spmf_mono rel_generat_mono rel_fun_mono)

lemma ord_gpv_the_gpvD:
"ord_gpv f g ⟹ ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) (the_gpv f) (the_gpv g)"
by(erule ord_gpv.cases) simp

lemma reflp_equality: "reflp (=)"

lemma ord_gpv_reflI [simp]: "ord_gpv f f"
by(coinduction arbitrary: f)(auto intro: ord_spmf_reflI simp add: rel_generat_same rel_fun_def)

lemma reflp_ord_gpv: "reflp ord_gpv"
by(rule reflpI)(rule ord_gpv_reflI)

lemma ord_gpv_trans:
assumes "ord_gpv f g" "ord_gpv g h"
shows "ord_gpv f h"
using assms
proof(coinduction arbitrary: f g h)
case (ord_gpv f g h)
have *: "ord_spmf (rel_generat (=) (=) (rel_fun (=) (λf h. ∃g. ord_gpv f g ∧ ord_gpv g h))) (the_gpv f) (the_gpv h) =
ord_spmf (rel_generat ((=) OO (=)) ((=) OO (=)) (rel_fun (=) (ord_gpv OO ord_gpv))) (the_gpv f) (the_gpv h)"
then show ?case using ord_gpv
by(auto elim!: ord_gpv.cases simp add: generat.rel_compp ord_spmf_compp fun.rel_compp)
qed

lemma ord_gpv_compp: "(ord_gpv OO ord_gpv) = ord_gpv"
by(auto simp add: fun_eq_iff intro: ord_gpv_trans)

lemma transp_ord_gpv [simp]: "transp ord_gpv"
by(blast intro: transpI ord_gpv_trans)

lemma ord_gpv_antisym:
"⟦ ord_gpv f g; ord_gpv g f ⟧ ⟹ f = g"
proof(coinduction arbitrary: f g)
case (Eq_gpv f g)
let ?R = "rel_generat (=) (=) (rel_fun (=) ord_gpv)"
from ‹ord_gpv f g› have "ord_spmf ?R (the_gpv f) (the_gpv g)" by cases simp
moreover
from ‹ord_gpv g f› have "ord_spmf ?R (the_gpv g) (the_gpv f)" by cases simp
ultimately have "rel_spmf (inf ?R ?R¯¯) (the_gpv f) (the_gpv g)"
by(rule rel_spmf_inf)(auto 4 3 intro: transp_rel_generatI transp_ord_gpv reflp_ord_gpv reflp_equality reflp_fun1 is_equality_eq transp_rel_fun)
also have "inf ?R ?R¯¯ = rel_generat (inf (=) (=)) (inf (=) (=)) (rel_fun (=) (inf ord_gpv ord_gpv¯¯))"
unfolding rel_generat_inf[symmetric] rel_fun_inf[symmetric]
finally show ?case by(simp add: inf_fun_def)
qed

lemma RFail_least [simp]: "ord_gpv Fail f"

subsection ‹Bounds on interaction›

context
fixes "consider" :: "'out ⇒ bool"
notes monotone_SUP[partial_function_mono] [[function_internals]]
begin
declaration ‹Partial_Function.init "lfp_strong" @{term lfp.fixp_fun} @{term lfp.mono_body}
@{thm lfp.fixp_rule_uc} @{thm lfp.fixp_induct_strong2_uc} NONE›

partial_function (lfp_strong) interaction_bound :: "('a, 'out, 'in) gpv ⇒ enat"
where
"interaction_bound gpv =
(SUP generat∈set_spmf (the_gpv gpv). case generat of Pure _ ⇒ 0
| IO out c ⇒ if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"

lemma interaction_bound_fixp_induct [case_names adm bottom step]:
"⟦ ccpo.admissible (fun_lub Sup) (fun_ord (≤)) P;
P (λ_. 0);
⋀interaction_bound'.
⟦ P interaction_bound';
⋀gpv. interaction_bound' gpv ≤ interaction_bound gpv;
⋀gpv. interaction_bound' gpv ≤ (SUP generat∈set_spmf (the_gpv gpv). case generat of Pure _ ⇒ 0
| IO out c ⇒ if consider out then eSuc (SUP input. interaction_bound' (c input)) else (SUP input. interaction_bound' (c input)))
⟧
⟹ P (λgpv. ⨆generat∈set_spmf (the_gpv gpv). case generat of Pure x ⇒ 0
| IO out c ⇒ if consider out then eSuc (⨆input. interaction_bound' (c input)) else (⨆input. interaction_bound' (c input))) ⟧
⟹ P interaction_bound"

lemma interaction_bound_IO:
"IO out c ∈ set_spmf (the_gpv gpv)
⟹ (if consider out then eSuc (interaction_bound (c input)) else interaction_bound (c input)) ≤ interaction_bound gpv"
by(rewrite in "_ ≤ ⌑" interaction_bound.simps)(auto intro!: SUP_upper2)

lemma interaction_bound_IO_consider:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); consider out ⟧
⟹ eSuc (interaction_bound (c input)) ≤ interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_IO_ignore:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); ¬ consider out ⟧
⟹ interaction_bound (c input) ≤ interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_Done [simp]: "interaction_bound (Done x) = 0"

lemma interaction_bound_Fail [simp]: "interaction_bound Fail = 0"

lemma interaction_bound_Pause [simp]:
"interaction_bound (Pause out c) =
(if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"

lemma interaction_bound_lift_spmf [simp]: "interaction_bound (lift_spmf p) = 0"

lemma interaction_bound_assert_gpv [simp]: "interaction_bound (assert_gpv b) = 0"
by(cases b) simp_all

lemma interaction_bound_bind_step:
assumes IH: "⋀p. interaction_bound' (p ⤜ f) ≤ interaction_bound p + (⨆x∈results'_gpv p. interaction_bound' (f x))"
and unfold: "⋀gpv. interaction_bound' gpv ≤ (⨆generat∈set_spmf (the_gpv gpv). case generat of Pure x ⇒ 0
| IO out c ⇒ if consider out then eSuc (⨆input. interaction_bound' (c input)) else ⨆input. interaction_bound' (c input))"
shows "(⨆generat∈set_spmf (the_gpv (p ⤜ f)).
case generat of Pure x ⇒ 0
| IO out c ⇒
if consider out then eSuc (⨆input. interaction_bound' (c input))
else ⨆input. interaction_bound' (c input))
≤ interaction_bound p +
(⨆x∈results'_gpv p.
⨆generat∈set_spmf (the_gpv (f x)).
case generat of Pure x ⇒ 0
| IO out c ⇒
if consider out then eSuc (⨆input. interaction_bound' (c input))
else ⨆input. interaction_bound' (c input))"
(is "(SUP generat'∈?bind. ?g generat') ≤ ?p + ?f")
proof(rule SUP_least)
fix generat'
assume "generat' ∈ ?bind"
then obtain generat where generat: "generat ∈ set_spmf (the_gpv p)"
and *: "case generat of Pure x ⇒ generat' ∈ set_spmf (the_gpv (f x))
| IO out c ⇒ generat' = IO out (λinput. c input ⤜ f)"
by(clarsimp simp add: bind_gpv.sel simp del: bind_gpv_sel')
(clarsimp split: generat.split_asm simp add: generat.map_comp o_def generat.map_id[unfolded id_def])
show "?g generat' ≤ ?p + ?f"
proof(cases generat)
case (Pure x)
have "?g generat' ≤ (SUP generat'∈set_spmf (the_gpv (f x)). (case generat' of Pure x ⇒ 0 | IO out c ⇒ if consider out then eSuc (⨆input. interaction_bound' (c input)) else ⨆input. interaction_bound' (c input)))"
using * Pure by(auto intro: SUP_upper)
also have "… ≤ 0 + ?f" using generat Pure
by(auto 4 3 intro: SUP_upper results'_gpv_Pure)
also have "… ≤ ?p + ?f" by simp
finally show ?thesis .
next
case (IO out c)
with * have "?g generat' = (if consider out then eSuc (SUP input. interaction_bound' (c input ⤜ f)) else (SUP input. interaction_bound' (c input ⤜ f)))" by simp
also have "… ≤ (if consider out then eSuc (SUP input. interaction_bound (c input) + (⨆x∈results'_gpv (c input). interaction_bound' (f x))) else (SUP input. interaction_bound (c input) + (⨆x∈results'_gpv (c input). interaction_bound' (f x))))"
by(auto intro: SUP_mono IH)
also have "… ≤ (case IO out c of Pure (x :: 'a) ⇒ 0 | IO out c ⇒ if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input))) + (SUP input. SUP x∈results'_gpv (c input). interaction_bound' (f x))"
also have "… ≤ ?p + ?f"
apply(rewrite in "_ ≤ ⌑" interaction_bound.simps)
apply(rule add_mono SUP_least SUP_upper generat[unfolded IO])+
apply(rule order_trans[OF unfold])
apply(auto 4 3 intro: results'_gpv_Cont[OF generat] SUP_upper simp add: IO)
done
finally show ?thesis .
qed
qed

lemma interaction_bound_bind:
defines "ib1 ≡ interaction_bound"
shows "interaction_bound (p ⤜ f) ≤ ib1 p + (SUP x∈results'_gpv p. interaction_bound (f x))"
proof(induction arbitrary: p rule: interaction_bound_fixp_induct)
case adm show ?case by simp
case bottom show ?case by simp
case (step interaction_bound') then show ?case unfolding ib1_def by -(rule interaction_bound_bind_step)
qed

lemma interaction_bound_bind_lift_spmf [simp]:
"interaction_bound (lift_spmf p ⤜ f) = (SUP x∈set_spmf p. interaction_bound (f x))"
by(subst (1 2) interaction_bound.simps)(simp add: bind_UNION SUP_UNION)

end

lemma interaction_bound_map_gpv':
assumes "surj h"
shows "interaction_bound consider (map_gpv' f g h gpv) = interaction_bound (consider ∘ g) gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF lattice_partial_function_definition lattice_partial_function_definition interaction_bound.mono interaction_bound.mono interaction_bound_def interaction_bound_def, case_names adm bottom step])
case (step interaction_bound' interaction_bound'' gpv)
have *: "IO out c ∈ set_spmf (the_gpv gpv) ⟹  x ∈ UNIV ⟹ interaction_bound'' (c x) ≤ (⨆x. interaction_bound'' (c (h x)))" for out c x
using assms[THEN surjD, of x] by (clarsimp intro!: SUP_upper)

show ?case
by (auto simp add: * step.IH image_comp split: generat.split
intro!: SUP_cong [OF refl] antisym SUP_upper SUP_least)
qed simp_all

abbreviation interaction_any_bound :: "('a, 'out, 'in) gpv ⇒ enat"
where "interaction_any_bound ≡ interaction_bound (λ_. True)"

lemma interaction_any_bound_coinduct [consumes 1, case_names interaction_bound]:
assumes X: "X gpv n"
and *: "⋀gpv n out c input. ⟦ X gpv n; IO out c ∈ set_spmf (the_gpv gpv) ⟧
⟹ ∃n'. (X (c input) n' ∨ interaction_any_bound (c input) ≤ n') ∧ eSuc n' ≤ n"
shows "interaction_any_bound gpv ≤ n"
using X
proof(induction arbitrary: gpv n rule: interaction_bound_fixp_induct)
case adm show ?case by(intro cont_intro)
case bottom show ?case by simp
next
case (step interaction_bound')
{ fix out c
assume IO: "IO out c ∈ set_spmf (the_gpv gpv)"
from *[OF step.prems IO] obtain n' where n: "n = eSuc n'"
by(cases n rule: co.enat.exhaust) auto
moreover
{ fix input
have "∃n''. (X (c input) n'' ∨ interaction_any_bound (c input) ≤ n'') ∧ eSuc n'' ≤ n"
using step.prems IO ‹n = eSuc n'› by(auto 4 3 dest: *)
then have "interaction_bound' (c input) ≤ n'" using n
by(auto dest: step.IH intro: step.hyps[THEN order_trans] elim!: order_trans simp add: neq_zero_conv_eSuc) }
ultimately have "eSuc (⨆input. interaction_bound' (c input)) ≤ n"
by(auto intro: SUP_least) }
then show ?case by(auto intro!: SUP_least split: generat.split)
qed

context includes lifting_syntax begin
lemma interaction_bound_parametric':
assumes [transfer_rule]: "bi_total R"
shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=)) interaction_bound interaction_bound"
unfolding interaction_bound_def[abs_def]
apply(rule rel_funI)
apply(rule fixp_lfp_parametric_eq[OF interaction_bound.mono interaction_bound.mono])
subgoal premises [transfer_rule]
supply the_gpv_parametric'[transfer_rule] rel_gpv''_eq[relator_eq]
by transfer_prover
done

lemma interaction_bound_parametric [transfer_rule]:
"((C ===> (=)) ===> rel_gpv A C ===> (=)) interaction_bound interaction_bound"
unfolding rel_gpv_conv_rel_gpv'' by(rule interaction_bound_parametric')(rule bi_total_eq)
end

text ‹
There is no nice @{const interaction_bound} equation for @{const bind_gpv}, as it computes
an exact bound, but we only need an upper bound.
As @{typ enat} is hard to work with (and @{term ∞} does not constrain a gpv in any way),
we work with @{typ nat}.
›

inductive interaction_bounded_by :: "('out ⇒ bool) ⇒ ('a, 'out, 'in) gpv ⇒ enat ⇒ bool"
for "consider" gpv n where
interaction_bounded_by: "⟦ interaction_bound consider gpv ≤ n ⟧ ⟹ interaction_bounded_by consider gpv n"

lemmas interaction_bounded_byI = interaction_bounded_by
hide_fact (open) interaction_bounded_by

context includes lifting_syntax begin
lemma interaction_bounded_by_parametric [transfer_rule]:
"((C ===> (=)) ===> rel_gpv A C ===> (=) ===> (=)) interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover

lemma interaction_bounded_by_parametric':
notes interaction_bound_parametric'[transfer_rule]
assumes [transfer_rule]: "bi_total R"
shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=) ===> (=))
interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover
end

lemma interaction_bounded_by_mono:
"⟦ interaction_bounded_by consider gpv n; n ≤ m ⟧ ⟹ interaction_bounded_by consider gpv m"
unfolding interaction_bounded_by.simps by(erule order_trans) simp

lemma interaction_bounded_by_contD:
"⟦ interaction_bounded_by consider gpv n; IO out c ∈ set_spmf (the_gpv gpv); consider out ⟧
⟹ n > 0 ∧ interaction_bounded_by consider (c input) (n - 1)"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec)

lemma interaction_bounded_by_contD_ignore:
"⟦ interaction_bounded_by consider gpv n; IO out c ∈ set_spmf (the_gpv gpv) ⟧
⟹ interaction_bounded_by consider (c input) n"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto 4 4 simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec split: if_split_asm elim: order_trans)

lemma interaction_bounded_byI_epred:
assumes "⋀out c. ⟦ IO out c ∈ set_spmf (the_gpv gpv); consider out ⟧ ⟹ n ≠ 0 ∧ (∀input. interaction_bounded_by consider (c input) (n - 1))"
and "⋀out c input. ⟦ IO out c ∈ set_spmf (the_gpv gpv); ¬ consider out ⟧ ⟹ interaction_bounded_by consider (c input) n"
shows "interaction_bounded_by consider gpv n"
unfolding interaction_bounded_by.simps
by(subst interaction_bound.simps)(auto 4 5 intro!: SUP_least split: generat.split dest: assms simp add: eSuc_le_iff enat_eSuc_iff gr0_conv_Suc neq_zero_conv_eSuc interaction_bounded_by.simps)

lemma interaction_bounded_by_IO:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); interaction_bounded_by consider gpv n; consider out ⟧
⟹ n ≠ 0 ∧ interaction_bounded_by consider (c input) (n - 1)"
by(drule interaction_bound_IO[where input=input and ?consider="consider"])(auto simp add: interaction_bounded_by.simps epred_conv_minus eSuc_le_iff enat_eSuc_iff)

lemma interaction_bounded_by_0: "interaction_bounded_by consider gpv 0 ⟷ interaction_bound consider gpv = 0"

abbreviation interaction_bounded_by' :: "('out ⇒ bool) ⇒ ('a, 'out, 'in) gpv ⇒ nat ⇒ bool"
where "interaction_bounded_by' consider gpv n ≡ interaction_bounded_by consider gpv (enat n)"

named_theorems interaction_bound

lemmas interaction_bounded_by_start = interaction_bounded_by_mono

method interaction_bound_start = (rule interaction_bounded_by_start)
method interaction_bound_step uses add simp =
((match conclusion in "interaction_bounded_by _ _ _" ⇒ fail ¦ _ ⇒ ‹solves ‹clarsimp simp add: simp››) | rule add interaction_bound)
method interaction_bound_rec uses add simp =
method interaction_bound uses add simp =

lemma interaction_bounded_by_Done [simp]: "interaction_bounded_by consider (Done x) n"

lemma interaction_bounded_by_DoneI [interaction_bound]:
"interaction_bounded_by consider (Done x) 0"
by simp

lemma interaction_bounded_by_Fail [simp]: "interaction_bounded_by consider Fail n"

lemma interaction_bounded_by_FailI [interaction_bound]: "interaction_bounded_by consider Fail 0"
by simp

lemma interaction_bounded_by_lift_spmf [simp]: "interaction_bounded_by consider (lift_spmf p) n"

lemma interaction_bounded_by_lift_spmfI [interaction_bound]:
"interaction_bounded_by consider (lift_spmf p) 0"
by simp

lemma interaction_bounded_by_assert_gpv [simp]: "interaction_bounded_by consider (assert_gpv b) n"
by(cases b) simp_all

lemma interaction_bounded_by_assert_gpvI [interaction_bound]:
"interaction_bounded_by consider (assert_gpv b) 0"
by simp

lemma interaction_bounded_by_Pause [simp]:
"interaction_bounded_by consider (Pause out c) n ⟷
(if consider out then 0 < n ∧ (∀input. interaction_bounded_by consider (c input) (n - 1)) else (∀input. interaction_bounded_by consider (c input) n))"
by(cases n rule: co.enat.exhaust)
(auto 4 3 simp add: interaction_bounded_by.simps eSuc_le_iff enat_eSuc_iff gr0_conv_Suc intro: SUP_least dest: order_trans[OF SUP_upper, rotated])

lemma interaction_bounded_by_PauseI [interaction_bound]:
"(⋀input. interaction_bounded_by consider (c input) (n input))
⟹ interaction_bounded_by consider (Pause out c) (if consider out then 1 + (SUP input. n input) else (SUP input. n input))"

lemma interaction_bounded_by_bindI [interaction_bound]:
"⟦ interaction_bounded_by consider gpv n; ⋀x. x ∈ results'_gpv gpv ⟹ interaction_bounded_by consider (f x) (m x) ⟧
⟹ interaction_bounded_by consider (gpv ⤜ f) (n + (SUP x∈results'_gpv gpv. m x))"
unfolding interaction_bounded_by.simps plus_enat_simps(1)[symmetric]
by(rule interaction_bound_bind[THEN order_trans])(auto intro: add_mono SUP_mono)

lemma interaction_bounded_by_bind_PauseI [interaction_bound]:
"(⋀input. interaction_bounded_by consider (c input ⤜ f) (n input))
⟹ interaction_bounded_by consider (Pause out c ⤜ f) (if consider out then SUP input. n input + 1 else SUP input. n input)"

lemma interaction_bounded_by_bind_lift_spmf [simp]:
"interaction_bounded_by consider (lift_spmf p ⤜ f) n ⟷ (∀x∈set_spmf p. interaction_bounded_by consider (f x) n)"

lemma interaction_bounded_by_bind_lift_spmfI [interaction_bound]:
"(⋀x. x ∈ set_spmf p ⟹ interaction_bounded_by consider (f x) (n x))
⟹ interaction_bounded_by consider (lift_spmf p ⤜ f) (SUP x∈set_spmf p. n x)"
by(auto intro: interaction_bounded_by_mono SUP_upper)

lemma interaction_bounded_by_bind_DoneI [interaction_bound]:
"interaction_bounded_by consider (f x) n ⟹ interaction_bounded_by consider (Done x ⤜ f) n"
by(simp)

lemma interaction_bounded_by_if [interaction_bound]:
"⟦ b ⟹ interaction_bounded_by consider gpv1 n; ¬ b ⟹ interaction_bounded_by consider gpv2 m ⟧
⟹ interaction_bounded_by consider (if b then gpv1 else gpv2) (if b then n else m)"
by(auto 4 3 simp add: max_def not_le elim: interaction_bounded_by_mono)

lemma interaction_bounded_by_case_bool [interaction_bound]:
"⟦ b ⟹ interaction_bounded_by consider t bt; ¬ b ⟹ interaction_bounded_by consider f bf ⟧
⟹ interaction_bounded_by consider (case_bool t f b) (if b then bt else bf)"
by(cases b)(auto)

lemma interaction_bounded_by_case_sum [interaction_bound]:
"⟦ ⋀y. x = Inl y ⟹ interaction_bounded_by consider (l y) (bl y);
⋀y. x = Inr y ⟹ interaction_bounded_by consider (r y) (br y) ⟧
⟹ interaction_bounded_by consider (case_sum l r x) (case_sum bl br x)"
by(cases x)(auto)

lemma interaction_bounded_by_case_prod [interaction_bound]:
"(⋀a b. x = (a, b) ⟹ interaction_bounded_by consider (f a b) (n a b))
⟹ interaction_bounded_by consider (case_prod f x) (case_prod n x)"
by(simp split: prod.split)

lemma interaction_bounded_by_let [interaction_bound]: ― ‹This rule unfolds let's›
"interaction_bounded_by consider (f t) m ⟹ interaction_bounded_by consider (Let t f) m"

lemma interaction_bounded_by_map_gpv_id [interaction_bound]:
assumes [interaction_bound]: "interaction_bounded_by P gpv n"
shows "interaction_bounded_by P (map_gpv f id gpv) n"
unfolding id_def map_gpv_conv_bind by interaction_bound simp

abbreviation interaction_any_bounded_by :: "('a, 'out, 'in) gpv ⇒ enat ⇒ bool"
where "interaction_any_bounded_by ≡ interaction_bounded_by (λ_. True)"

lemma interaction_any_bounded_by_map_gpv':
assumes "interaction_any_bounded_by gpv n"
and "surj h"
shows "interaction_any_bounded_by (map_gpv' f g h gpv) n"
using assms by(simp add: interaction_bounded_by.simps interaction_bound_map_gpv' o_def)

subsection ‹Typing›

subsubsection ‹Interface between gpvs and rpvs / callees›

lemma is_empty_parametric [transfer_rule]: "rel_fun (rel_set A) (=) Set.is_empty Set.is_empty" (* Move *)
by(auto simp add: rel_fun_def Set.is_empty_def dest: rel_setD1 rel_setD2)

typedef ('call, 'ret) ℐ = "UNIV :: ('call ⇒ 'ret set) set" ..

setup_lifting type_definition_ℐ

lemma outs_ℐ_tparametric:
includes lifting_syntax
assumes [transfer_rule]: "bi_total A"
shows "((A ===> rel_set B) ===> rel_set A) (λresps. {out. resps out ≠ {}}) (λresps. {out. resps out ≠ {}})"
by(fold Set.is_empty_def) transfer_prover

lift_definition outs_ℐ :: "('call, 'ret) ℐ ⇒ 'call set" is "λresps. {out. resps out ≠ {}}" parametric outs_ℐ_tparametric .
lift_definition responses_ℐ :: "('call, 'ret) ℐ ⇒ 'call ⇒ 'ret set" is "λx. x" parametric id_transfer[unfolded id_def] .

lift_definition rel_ℐ :: "('call ⇒ 'call' ⇒ bool) ⇒ ('ret ⇒ 'ret' ⇒ bool) ⇒ ('call, 'ret) ℐ ⇒ ('call', 'ret') ℐ ⇒ bool"
is "λC R resp1 resp2. rel_set C {out. resp1 out ≠ {}} {out. resp2 out ≠ {}} ∧ rel_fun C (rel_set R) resp1 resp2"
.

lemma rel_ℐI [intro?]:
"⟦ rel_set C (outs_ℐ ℐ1) (outs_ℐ ℐ2); ⋀x y. C x y ⟹ rel_set R (responses_ℐ ℐ1 x) (responses_ℐ ℐ2 y) ⟧
⟹ rel_ℐ C R ℐ1 ℐ2"

lemma rel_ℐ_eq [relator_eq]: "rel_ℐ (=) (=) = (=)"
unfolding fun_eq_iff by transfer(auto simp add: relator_eq)

lemma rel_ℐ_conversep [simp]: "rel_ℐ C¯¯ R¯¯ = (rel_ℐ C R)¯¯"
unfolding fun_eq_iff conversep_iff
apply transfer
apply(rewrite in "rel_fun ⌑" conversep_iff[symmetric])
apply(rewrite in "rel_set ⌑" conversep_iff[symmetric])
apply(rewrite in "rel_fun _ ⌑" conversep_iff[symmetric])
apply(simp)
done

lemma rel_ℐ_conversep1_eq [simp]: "rel_ℐ C¯¯ (=) = (rel_ℐ C (=))¯¯"
by(rewrite in "⌑ = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma rel_ℐ_conversep2_eq [simp]: "rel_ℐ (=) R¯¯ = (rel_ℐ (=) R)¯¯"
by(rewrite in "⌑ = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma responses_ℐ_empty_iff: "responses_ℐ ℐ out = {} ⟷ out ∉ outs_ℐ ℐ"
including ℐ.lifting by transfer auto

lemma in_outs_ℐ_iff_responses_ℐ: "out ∈ outs_ℐ ℐ ⟷ responses_ℐ ℐ out ≠ {}"

lift_definition ℐ_full :: "('call, 'ret) ℐ" is "λ_. UNIV" .

lemma ℐ_full_sel [simp]:
shows outs_ℐ_full: "outs_ℐ ℐ_full = UNIV"
and responses_ℐ_full: "responses_ℐ ℐ_full x = UNIV"
by(transfer; simp; fail)+

context includes lifting_syntax begin
lemma outs_ℐ_parametric [transfer_rule]: "(rel_ℐ C R ===> rel_set C) outs_ℐ outs_ℐ"
unfolding rel_fun_def by transfer simp

lemma responses_ℐ_parametric [transfer_rule]:
"(rel_ℐ C R ===> C ===> rel_set R) responses_ℐ responses_ℐ"
unfolding rel_fun_def by transfer(auto dest: rel_funD)

end

definition ℐ_trivial :: "('out, 'in) ℐ ⇒ bool"
where "ℐ_trivial ℐ ⟷ outs_ℐ ℐ = UNIV"

lemma ℐ_trivialI [intro?]: "(⋀x. x ∈ outs_ℐ ℐ) ⟹ ℐ_trivial ℐ"

lemma ℐ_trivialD: "ℐ_trivial ℐ ⟹ outs_ℐ ℐ = UNIV"

lemma ℐ_trivial_ℐ_full [simp]: "ℐ_trivial ℐ_full"

lifting_update ℐ.lifting
lifting_forget ℐ.lifting

context includes ℐ.lifting begin

lift_definition ℐ_uniform :: "'out set ⇒ 'in set ⇒ ('out, 'in) ℐ" is "λA B x. if x ∈ A then B else {}" .

lemma outs_ℐ_uniform [simp]: "outs_ℐ (ℐ_uniform A B) = (if B = {} then {} else A)"
by transfer simp

lemma responses_ℐ_uniform [simp]: "responses_ℐ (ℐ_uniform A B) x = (if x ∈ A then B else {})"
by transfer simp

lemma ℐ_uniform_UNIV [simp]: "ℐ_uniform UNIV UNIV = ℐ_full" (* TODO: make ℐ_full an abbreviation *)
by transfer simp

lift_definition map_ℐ :: "('out' ⇒ 'out) ⇒ ('in ⇒ 'in') ⇒ ('out, 'in) ℐ ⇒ ('out', 'in') ℐ"
is "λf g resp x. g ` resp (f x)" .

lemma outs_ℐ_map_ℐ [simp]:
"outs_ℐ (map_ℐ f g ℐ) = f -` outs_ℐ ℐ"
by transfer simp

lemma responses_ℐ_map_ℐ [simp]:
"responses_ℐ (map_ℐ f g ℐ) x = g ` responses_ℐ ℐ (f x)"
by transfer simp

lemma map_ℐ_ℐ_uniform [simp]:
"map_ℐ f g (ℐ_uniform A B) = ℐ_uniform (f -` A) (g ` B)"

lemma map_ℐ_id [simp]: "map_ℐ id id ℐ = ℐ"
by transfer simp

lemma map_ℐ_id0: "map_ℐ id id = id"

lemma map_ℐ_comp [simp]: "map_ℐ f g (map_ℐ f' g' ℐ) = map_ℐ (f' ∘ f) (g ∘ g') ℐ"
by transfer auto

lemma map_ℐ_cong: "map_ℐ f g ℐ = map_ℐ f' g' ℐ'"
if "ℐ = ℐ'" and f: "f = f'" and "⋀x y. ⟦ x ∈ outs_ℐ ℐ'; y ∈ responses_ℐ ℐ' x ⟧ ⟹ g y = g' y"
unfolding that(1,2) using that(3-)
by transfer(auto simp add: fun_eq_iff intro!: image_cong)

lifting_update ℐ.lifting
lifting_forget ℐ.lifting
end

lemma ℐ_eqI: "⟦ outs_ℐ ℐ = outs_ℐ ℐ'; ⋀x. x ∈ outs_ℐ ℐ' ⟹ responses_ℐ ℐ x = responses_ℐ ℐ' x ⟧ ⟹ ℐ = ℐ'"
including ℐ.lifting by transfer auto

instantiation ℐ :: (type, type) order begin

definition less_eq_ℐ :: "('a, 'b) ℐ ⇒ ('a, 'b) ℐ ⇒ bool"
where le_ℐ_def: "less_eq_ℐ ℐ ℐ' ⟷ outs_ℐ ℐ ⊆ outs_ℐ ℐ' ∧ (∀x∈outs_ℐ ℐ. responses_ℐ ℐ' x ⊆ responses_ℐ ℐ x)"

definition less_ℐ :: "('a, 'b) ℐ ⇒ ('a, 'b) ℐ ⇒ bool"
where "less_ℐ = mk_less (≤)"

instance
proof
show "ℐ < ℐ' ⟷ ℐ ≤ ℐ' ∧ ¬ ℐ' ≤ ℐ" for ℐ ℐ' :: "('a, 'b) ℐ" by(simp add: less_ℐ_def mk_less_def)
show "ℐ ≤ ℐ" for ℐ :: "('a, 'b) ℐ" by(simp add: le_ℐ_def)
show "ℐ ≤ ℐ''" if "ℐ ≤ ℐ'" "ℐ' ≤ ℐ''" for ℐ ℐ' ℐ'' :: "('a, 'b) ℐ" using that
show "ℐ = ℐ'" if "ℐ ≤ ℐ'" "ℐ' ≤ ℐ" for ℐ ℐ' :: "('a, 'b) ℐ" using that
by(auto simp add: le_ℐ_def intro!: ℐ_eqI)
qed
end

instantiation ℐ :: (type, type) order_bot begin
definition bot_ℐ :: "('a, 'b) ℐ" where "bot_ℐ = ℐ_uniform {} UNIV"
instance by standard(auto simp add: bot_ℐ_def le_ℐ_def)
end

lemma outs_ℐ_bot [simp]: "outs_ℐ bot = {}"

lemma respones_ℐ_bot [simp]: "responses_ℐ bot x = {}"

lemma outs_ℐ_mono: "ℐ ≤ ℐ' ⟹ outs_ℐ ℐ ⊆ outs_ℐ ℐ'"

lemma responses_ℐ_mono: "⟦ ℐ ≤ ℐ'; x ∈ outs_ℐ ℐ ⟧ ⟹ responses_ℐ ℐ' x ⊆ responses_ℐ ℐ x"

lemma ℐ_uniform_empty [simp]: "ℐ_uniform {} A = bot"
unfolding bot_ℐ_def including ℐ.lifting by transfer simp

lemma ℐ_uniform_mono:
"ℐ_uniform A B ≤ ℐ_uniform C D" if "A ⊆ C" "D ⊆ B" "D = {} ⟶ B = {}"
unfolding le_ℐ_def using that by auto

context begin
qualified inductive resultsp_gpv :: "('out, 'in) ℐ ⇒ 'a ⇒ ('a, 'out, 'in) gpv ⇒ bool"
for Γ x
where
Pure: "Pure x ∈ set_spmf (the_gpv gpv) ⟹ resultsp_gpv Γ x gpv"
| IO:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ Γ out; resultsp_gpv Γ x (c input) ⟧
⟹ resultsp_gpv Γ x gpv"

definition results_gpv :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ 'a set"
where "results_gpv Γ gpv ≡ {x. resultsp_gpv Γ x gpv}"

lemma resultsp_gpv_results_gpv_eq [pred_set_conv]: "resultsp_gpv Γ x gpv ⟷ x ∈ results_gpv Γ gpv"

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "results_gpv")›

lemmas intros [intro?] = resultsp_gpv.intros[to_set]
and Pure = Pure[to_set]
and IO = IO[to_set]
and induct [consumes 1, case_names Pure IO, induct set: results_gpv] = resultsp_gpv.induct[to_set]
and cases [consumes 1, case_names Pure IO, cases set: results_gpv] = resultsp_gpv.cases[to_set]
and simps = resultsp_gpv.simps[to_set]
end

inductive_simps results_gpv_GPV [to_set, simp]: "resultsp_gpv Γ x (GPV gpv)"

end

lemma results_gpv_Done [iff]: "results_gpv Γ (Done x) = {x}"

lemma results_gpv_Fail [iff]: "results_gpv Γ Fail = {}"

lemma results_gpv_Pause [simp]:
"results_gpv Γ (Pause out c) = (⋃input∈responses_ℐ Γ out. results_gpv Γ (c input))"

lemma results_gpv_lift_spmf [iff]: "results_gpv Γ (lift_spmf p) = set_spmf p"

lemma results_gpv_assert_gpv [simp]: "results_gpv Γ (assert_gpv b) = (if b then {()} else {})"
by auto

lemma results_gpv_bind_gpv [simp]:
"results_gpv Γ (gpv ⤜ f) = (⋃x∈results_gpv Γ gpv. results_gpv Γ (f x))"
(is "?lhs = ?rhs")
proof(intro set_eqI iffI)
fix x
assume "x ∈ ?lhs"
then show "x ∈ ?rhs"
proof(induction gpv'≡"gpv ⤜ f" arbitrary: gpv)
case Pure thus ?case
by(auto 4 3 split: if_split_asm intro: results_gpv.intros rev_bexI)
next
case (IO out c input)
from ‹IO out c ∈ _›
obtain generat where generat: "generat ∈ set_spmf (the_gpv gpv)"
and *: "IO out c ∈ set_spmf (if is_Pure generat then the_gpv (f (result generat))
else return_spmf (IO (output generat) (λinput. continuation generat input ⤜ f)))"
by(auto)
thus ?case
proof(cases generat)
case (Pure y)
with generat have "y ∈ results_gpv Γ gpv" by(auto intro: results_gpv.intros)
thus ?thesis using * Pure ‹input ∈ responses_ℐ Γ out› ‹x ∈ results_gpv Γ (c input)›
by(auto intro: results_gpv.IO)
next
case (IO out' c')
hence [simp]: "out' = out"
and c: "⋀input. c input = bind_gpv (c' input) f" using * by simp_all
from IO.hyps(4)[OF c] obtain y where y: "y ∈ results_gpv Γ (c' input)"
and "x ∈ results_gpv Γ (f y)" by blast
from y IO generat have "y ∈ results_gpv Γ gpv" using ‹input ∈ responses_ℐ Γ out›
by(auto intro: results_gpv.IO)
with ‹x ∈ results_gpv Γ (f y)› show ?thesis by blast
qed
qed
next
fix x
assume "x ∈ ?rhs"
then obtain y where y: "y ∈ results_gpv Γ gpv"
and x: "x ∈ results_gpv Γ (f y)" by blast
from y show "x ∈ ?lhs"
proof(induction)
case (Pure gpv)
with x show ?case
by cases(auto 4 4 intro: results_gpv.intros rev_bexI)
qed(auto 4 4 intro: rev_bexI results_gpv.IO)
qed

lemma results_gpv_ℐ_full: "results_gpv ℐ_full = results'_gpv"
proof(intro ext set_eqI iffI)
show "x ∈ results'_gpv gpv" if "x ∈ results_gpv ℐ_full gpv" for x gpv
using that by induction(auto intro: results'_gpvI)
show "x ∈ results_gpv ℐ_full gpv" if "x ∈ results'_gpv gpv" for x gpv
using that by induction(auto intro: results_gpv.intros elim!: generat.set_cases)
qed

lemma results'_bind_gpv [simp]:
"results'_gpv (bind_gpv gpv f) = (⋃x∈results'_gpv gpv. results'_gpv (f x))"
unfolding results_gpv_ℐ_full[symmetric] by simp

lemma results_gpv_map_gpv_id [simp]: "results_gpv ℐ (map_gpv f id gpv) = f ` results_gpv ℐ gpv"

lemma results_gpv_map_gpv_id' [simp]: "results_gpv ℐ (map_gpv f (λx. x) gpv) = f ` results_gpv ℐ gpv"

lemma pred_gpv_bind [simp]: "pred_gpv P Q (bind_gpv gpv f) = pred_gpv (pred_gpv P Q ∘ f) Q gpv"

lemma results'_gpv_bind_option [simp]:
"results'_gpv (monad.bind_option Fail x f) = (⋃y∈set_option x. results'_gpv (f y))"
by(cases x) simp_all

lemma results'_gpv_map_gpv':
assumes "surj h"
shows "results'_gpv (map_gpv' f g h gpv) = f ` results'_gpv gpv" (is "?lhs = ?rhs")
proof -
have *:"IO z c ∈ set_spmf (the_gpv gpv) ⟹ x ∈ results'_gpv (c input) ⟹
f x ∈ results'_gpv (map_gpv' f g h (c input)) ⟹ f x ∈ results'_gpv (map_gpv' f g h gpv)" for x z gpv c input
using surjD[OF assms, of input] by(fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def)

show ?thesis
proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
by(induction gpv'≡"map_gpv' f g h gpv" arbitrary: gpv)(fastforce elim!: generat.set_cases intro: results'_gpvI)+
show "f x ∈ ?lhs" if "x ∈ results'_gpv gpv" for x using that
by induction (fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def
, clarsimp simp add: *  elim!: generat.set_cases)
qed
qed

lemma bind_gpv_bind_option_assoc:
"bind_gpv (monad.bind_option Fail x f) g = monad.bind_option Fail x (λx. bind_gpv (f x) g)"
by(cases x) simp_all

context begin
qualified inductive outsp_gpv :: "('out, 'in) ℐ ⇒ 'out ⇒ ('a, 'out, 'in) gpv ⇒ bool"
for ℐ x where
IO: "IO x c ∈ set_spmf (the_gpv gpv) ⟹ outsp_gpv ℐ x gpv"
| Cont: "⟦ IO out rpv ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out; outsp_gpv ℐ x (rpv input) ⟧
⟹ outsp_gpv ℐ x gpv"

definition outs_gpv :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ 'out set"
where "outs_gpv ℐ gpv ≡ {x. outsp_gpv ℐ x gpv}"

lemma outsp_gpv_outs_gpv_eq [pred_set_conv]: "outsp_gpv ℐ x = (λgpv. x ∈ outs_gpv ℐ gpv)"

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "outs_gpv")›

lemmas intros [intro?] = outsp_gpv.intros[to_set]
and IO = IO[to_set]
and Cont = Cont[to_set]
and induct [consumes 1, case_names IO Cont, induct set: outs_gpv] = outsp_gpv.induct[to_set]
and cases [consumes 1, case_names IO Cont, cases set: outs_gpv] = outsp_gpv.cases[to_set]
and simps = outsp_gpv.simps[to_set]
end

inductive_simps outs_gpv_GPV [to_set, simp]: "outsp_gpv ℐ x (GPV gpv)"

end

lemma outs_gpv_Done [iff]: "outs_gpv ℐ (Done x) = {}"

lemma outs_gpv_Fail [iff]: "outs_gpv ℐ Fail = {}"

lemma outs_gpv_Pause [simp]:
"outs_gpv ℐ (Pause out c) = insert out (⋃input∈responses_ℐ ℐ out. outs_gpv ℐ (c input))"

lemma outs_gpv_lift_spmf [iff]: "outs_gpv ℐ (lift_spmf p) = {}"

lemma outs_gpv_assert_gpv [simp]: "outs_gpv ℐ (assert_gpv b) = {}"
by(cases b)auto

lemma outs_gpv_bind_gpv [simp]:
"outs_gpv ℐ (gpv ⤜ f) = outs_gpv ℐ gpv ∪ (⋃x∈results_gpv ℐ gpv. outs_gpv ℐ (f x))"
(is "?lhs = ?rhs")
proof(intro Set.set_eqI iffI)
fix x
assume "x ∈ ?lhs"
then show "x ∈ ?rhs"
proof(induction gpv'≡"gpv ⤜ f" arbitrary: gpv)
case IO thus ?case
proof(clarsimp split: if_split_asm elim!: is_PureE not_is_PureE, goal_cases)
case (1 generat)
then show ?case by(cases generat)(auto intro: results_gpv.Pure outs_gpv.intros)
qed
next
case (Cont out rpv input)
thus ?case
proof(clarsimp split: if_split_asm, goal_cases)
case (1 generat)
then show ?case by(cases generat)(auto 4 3 split: if_split_asm intro: results_gpv.intros outs_gpv.intros)
qed
qed
next
fix x
assume "x ∈ ?rhs"
then consider (out) "x ∈ outs_gpv ℐ gpv" | (result) y where "y ∈ results_gpv ℐ gpv" "x ∈ outs_gpv ℐ (f y)" by auto
then show "x ∈ ?lhs"
proof cases
case out then show ?thesis
by(induction) (auto 4 4 intro: outs_gpv.IO  outs_gpv.Cont rev_bexI)
next
case result then show ?thesis
by induction ((erule outs_gpv.cases | rule outs_gpv.Cont),
auto 4 4 intro: outs_gpv.intros rev_bexI elim: outs_gpv.cases)+
qed
qed

lemma outs_gpv_ℐ_full: "outs_gpv ℐ_full = outs'_gpv"
proof(intro ext Set.set_eqI iffI)
show "x ∈ outs'_gpv gpv" if "x ∈ outs_gpv ℐ_full gpv" for x gpv
using that by induction(auto intro: outs'_gpvI)
show "x ∈ outs_gpv ℐ_full gpv" if "x ∈ outs'_gpv gpv" for x gpv
using that by induction(auto intro: outs_gpv.intros elim!: generat.set_cases)
qed

lemma outs'_bind_gpv [simp]:
"outs'_gpv (bind_gpv gpv f) = outs'_gpv gpv ∪ (⋃x∈results'_gpv gpv. outs'_gpv (f x))"
unfolding outs_gpv_ℐ_full[symmetric] results_gpv_ℐ_full[symmetric] by simp

lemma outs_gpv_map_gpv_id [simp]: "outs_gpv ℐ (map_gpv f id gpv) = outs_gpv ℐ gpv"

lemma outs_gpv_map_gpv_id' [simp]: "outs_gpv ℐ (map_gpv f (λx. x) gpv) = outs_gpv ℐ gpv"

lemma outs'_gpv_bind_option [simp]:
"outs'_gpv (monad.bind_option Fail x f) = (⋃y∈set_option x. outs'_gpv (f y))"
by(cases x) simp_all

lemma rel_gpv''_Grp: includes lifting_syntax shows
"rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯ =
BNF_Def.Grp {x. results_gpv (ℐ_uniform UNIV (range h)) x ⊆ A ∧ outs_gpv (ℐ_uniform UNIV (range h)) x ⊆ B} (map_gpv' f g h)"
(is "?lhs = ?rhs")
proof(intro ext GrpI iffI CollectI conjI subsetI)
let ?ℐ = "ℐ_uniform UNIV (range h)"
fix gpv gpv'
assume *: "?lhs gpv gpv'"
then show "map_gpv' f g h gpv = gpv'"
by(coinduction arbitrary: gpv gpv')
(drule rel_gpv''D
, auto 4 5 simp add: spmf_rel_map generat.rel_map elim!: rel_spmf_mono generat.rel_mono_strong GrpE intro!: GrpI dest: rel_funD)
show "x ∈ A" if "x ∈ results_gpv ?ℐ gpv" for x using that *
proof(induction arbitrary: gpv')
case (Pure gpv)
have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
using Pure.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] ..
with Pure.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
next
case (IO out c gpv input)
have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
with IO.hyps show ?case
by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: IO.IH dest!: bspec)
qed
show "x ∈ B" if "x ∈ outs_gpv ?ℐ gpv" for x using that *
proof(induction arbitrary: gpv')
case (IO c gpv)
have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
with IO.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
next
case (Cont out rpv gpv input)
have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
using Cont.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
with Cont.hyps show ?case
by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: Cont.IH dest!: bspec)
qed
next
fix gpv gpv'
assume "?rhs gpv gpv'"
then have gpv': "gpv' = map_gpv' f g h gpv"
and *: "results_gpv (ℐ_uniform UNIV (range h)) gpv ⊆ A" "outs_gpv (ℐ_uniform UNIV (range h)) gpv ⊆ B" by(auto simp add: Grp_iff)
show "?lhs gpv gpv'" using * unfolding gpv'
by(coinduction arbitrary: gpv)
(fastforce simp add: spmf_rel_map generat.rel_map Grp_iff intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI elim!: generat.set_cases intro: results_gpv.intros outs_gpv.intros)
qed

inductive pred_gpv' :: "('a ⇒ bool) ⇒ ('out ⇒ bool) ⇒ 'in set ⇒ ('a, 'out, 'in) gpv ⇒ bool" for P Q X gpv where
"pred_gpv' P Q X gpv"
if "⋀x. x ∈ results_gpv (ℐ_uniform UNIV X) gpv ⟹ P x" "⋀out. out ∈ outs_gpv (ℐ_uniform UNIV X) gpv ⟹ Q out"

lemma pred_gpv_conv_pred_gpv': "pred_gpv P Q = pred_gpv' P Q UNIV"
by(auto simp add: fun_eq_iff pred_gpv_def pred_gpv'.simps results_gpv_ℐ_full outs_gpv_ℐ_full)

lemma rel_gpv''_map_gpv'1:
"rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv gpv' ⟹ rel_gpv'' A C (=) (map_gpv' id id h gpv) gpv'"
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(erule rel_spmf_mono)
apply(erule generat.rel_mono_strong; simp?)
apply(subst map_fun2_id)
by(auto simp add: rel_fun_comp intro!: rel_fun_map_fun1 elim: rel_fun_mono)

lemma rel_gpv''_map_gpv'2:
"rel_gpv'' A C (eq_on (range h)) gpv gpv' ⟹ rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv (map_gpv' id id h gpv')"
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(erule rel_spmf_mono_strong)
apply(erule generat.rel_mono_strong; simp?)
apply(subst map_fun_id2_in)
apply(rule rel_fun_map_fun2)
by (auto simp add: rel_fun_comp  elim: rel_fun_mono)

context
fixes A :: "'a ⇒ 'd ⇒ bool"
and C :: "'c ⇒ 'g ⇒ bool"
and R :: "'b ⇒ 'e ⇒ bool"
begin

private lemma f11:" Pure x ∈ set_spmf (the_gpv gpv) ⟹
Domainp (rel_generat A C (rel_fun R (rel_gpv'' A C R))) (Pure x) ⟹ Domainp A x"
by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f21: "IO out c ∈ set_spmf (the_gpv gpv) ⟹
rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) ba ⟹ Domainp C out"
by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f12:
assumes "IO out c ∈ set_spmf (the_gpv gpv)"
and "input ∈ responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out"
and "x ∈ results_gpv (ℐ_uniform UNIV {x. Domainp R x}) (c input)"
and "Domainp (rel_gpv'' A C R) gpv"
shows "Domainp (rel_gpv'' A C R) (c input)"
proof -
obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by clarsimp
obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) b2"
using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
unfolding pred_spmf_def by - (drule (1) bspec, auto)

have "Ball (generat_conts (IO out c)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
unfolding pred_generat_def by simp

with assms(2) show ?thesis
apply -
apply(drule bspec)
apply simp
apply clarify
apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
by simp
qed

private lemma f22:
assumes "IO out' rpv ∈ set_spmf (the_gpv gpv)"
and "input ∈ responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out'"
and "out ∈ outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) (rpv input)"
and "Domainp (rel_gpv'' A C R) gpv"
shows "Domainp (rel_gpv'' A C R) (rpv input)"
proof -
obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by auto
obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out' rpv) b2"
using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
unfolding pred_spmf_def by - (drule (1) bspec, auto)

have "Ball (generat_conts (IO out' rpv)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
unfolding pred_generat_def by simp

with assms(2) show ?thesis
apply -
apply(drule bspec)
apply simp
apply clarify
apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
by simp
qed

lemma Domainp_rel_gpv''_le:
"Domainp (rel_gpv'' A C R) ≤ pred_gpv' (Domainp A) (Domainp C) {x. Domainp R x}"
proof(rule predicate1I pred_gpv'.intros)+
show "Domainp A x" if "x ∈ results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for x gpv using that
proof(induction)
case (Pure gpv)
then show ?case
by (clarify) (drule rel_gpv''D
, auto simp add: f11 pred_spmf_def dest: spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
show "Domainp C out" if "out ∈ outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for out gpv using that
proof( induction)
case (IO c gpv)
then show ?case
by (clarify) (drule rel_gpv''D
, auto simp add: f21 pred_spmf_def dest!: bspec spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
qed

end

lemma map_gpv'_id12: "map_gpv' f g h gpv = map_gpv' id id h (map_gpv f g gpv)"
unfolding map_gpv_conv_map_gpv' map_gpv'_comp by simp

lemma rel_gpv''_refl: "⟦ (=) ≤ A; (=) ≤ C; R ≤ (=) ⟧ ⟹ (=) ≤ rel_gpv'' A C R"
by(subst rel_gpv''_eq[symmetric])(rule rel_gpv''_mono)

context
fixes A A' :: "'a ⇒ 'b ⇒ bool"
and C C' :: "'c ⇒ 'd ⇒ bool"
and R R' :: "'e ⇒ 'f ⇒ bool"

begin

private abbreviation foo where
"foo ≡ (λ fx fy gpvx gpvy g1 g2.
∀x y. x ∈ fx (ℐ_uniform UNIV (Collect (Domainp R'))) gpvx ⟶
y ∈ fy (ℐ_uniform UNIV (Collect (Rangep R'))) gpvy ⟶ g1 x y ⟶ g2 x y)"

private lemma f1: "foo results_gpv results_gpv gpv gpv' A A' ⟹
x ∈ set_spmf (the_gpv gpv) ⟹ y ∈ set_spmf (the_gpv gpv') ⟹
a ∈ generat_conts x ⟹ b ∈ generat_conts y ⟹  R' a' α ⟹ R' β b' ⟹
foo results_gpv results_gpv (a a') (b b') A A'"
by (fastforce elim: generat.set_cases intro: results_gpv.IO)

private lemma f2: "foo outs_gpv outs_gpv gpv gpv' C C' ⟹
x ∈ set_spmf (the_gpv gpv) ⟹ y ∈ set_spmf (the_gpv gpv') ⟹
a ∈ generat_conts x ⟹ b ∈ generat_conts y ⟹ R' a' α ⟹ R' β b' ⟹
foo outs_gpv outs_gpv (a a') (b b') C C'"
by (fastforce elim: generat.set_cases intro: outs_gpv.Cont)

lemma rel_gpv''_mono_strong:
"⟦ rel_gpv'' A C R gpv gpv';
⋀x y. ⟦ x ∈ results_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y ∈ results_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; A x y ⟧ ⟹ A' x y;
⋀x y. ⟦ x ∈ outs_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y ∈ outs_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; C x y ⟧ ⟹ C' x y;
R' ≤ R ⟧
⟹ rel_gpv'' A' C' R' gpv gpv'"
apply(coinduction arbitrary: gpv gpv')
apply(drule rel_gpv''D)
apply(erule rel_spmf_mono_strong)
apply(erule generat.rel_mono_strong)
apply(erule generat.set_cases)+
apply(erule allE, rotate_tac -1)
apply(erule allE)
apply(erule impE)
apply(rule results_gpv.Pure)
apply simp
apply(erule impE)
apply(rule results_gpv.Pure)
apply simp
apply simp
apply(erule generat.set_cases)+
apply(rotate_tac 1)
apply(erule allE, rotate_tac -1)
apply(erule allE)
apply(erule impE)
apply(rule outs_gpv.IO)
apply simp
apply(erule impE)
apply(rule outs_gpv.IO)
apply simp
apply simp
apply(erule (1) rel_fun_mono_strong)
by (fastforce simp add: f1[simplified] f2[simplified])

end

lemma rel_gpv''_refl_strong:
assumes "⋀x. x ∈ results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv ⟹ A x x"
and "⋀x. x ∈ outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv ⟹ C x x"
and "R ≤ (=)"
shows "rel_gpv'' A C R gpv gpv"
proof -
have "rel_gpv'' (=) (=) (=) gpv gpv" unfolding rel_gpv''_eq by simp
then show ?thesis using _ _ assms(3) by(rule rel_gpv''_mono_strong)(auto intro: assms(1-2))
qed

lemma rel_gpv''_refl_eq_on:
"⟦ ⋀x. x ∈ results_gpv (ℐ_uniform UNIV X) gpv ⟹ A x x; ⋀out. out ∈ outs_gpv (ℐ_uniform UNIV X) gpv ⟹ B out out ⟧
⟹ rel_gpv'' A B (eq_on X) gpv gpv"
by(rule rel_gpv''_refl_strong) (auto elim: eq_onE)

lemma pred_gpv'_mono' [mono]:
"pred_gpv' A C R gpv ⟶ pred_gpv' A' C' R gpv"
if "⋀x. A x ⟶ A' x" "⋀x. C x ⟶ C' x"
using that unfolding pred_gpv'.simps
by auto

subsubsection ‹Type judgements›

coinductive WT_gpv :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ bool" ("((_)/ ⊢g (_) √)" [100, 0] 99)
for Γ
where
"(⋀out c. IO out c ∈ set_spmf gpv ⟹ out ∈ outs_ℐ Γ ∧ (∀input∈responses_ℐ Γ out. Γ ⊢g c input √))
⟹ Γ ⊢g GPV gpv √"

lemma WT_gpv_coinduct [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont, coinduct pred: WT_gpv]:
assumes *: "X gpv"
and step: "⋀gpv out c.
⟦ X gpv; IO out c ∈ set_spmf (the_gpv gpv) ⟧
⟹ out ∈ outs_ℐ Γ ∧ (∀input ∈ responses_ℐ Γ out. X (c input) ∨ Γ ⊢g c input √)"
shows "Γ ⊢g gpv √"
using * by(coinduct)(auto dest: step simp add: eq_GPV_iff)

lemma WT_gpv_simps:
"Γ ⊢g GPV gpv √ ⟷
(∀out c. IO out c ∈ set_spmf gpv ⟶ out ∈ outs_ℐ Γ ∧ (∀input∈responses_ℐ Γ out. Γ ⊢g c input √))"
by(subst WT_gpv.simps) simp

lemma WT_gpvI:
"(⋀out c. IO out c ∈ set_spmf (the_gpv gpv) ⟹ out ∈ outs_ℐ Γ ∧ (∀input∈responses_ℐ Γ out. Γ ⊢g c input √))
⟹ Γ ⊢g gpv √"

lemma WT_gpvD:
assumes "Γ ⊢g gpv √"
shows WT_gpv_OutD: "IO out c ∈ set_spmf (the_gpv gpv) ⟹ out ∈ outs_ℐ Γ"
and WT_gpv_ContD: "⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ Γ out ⟧ ⟹ Γ ⊢g c input √"
using assms by(cases, fastforce)+

lemma WT_gpv_mono:
assumes WT: "ℐ1 ⊢g gpv √"
and outs: "outs_ℐ ℐ1 ⊆ outs_ℐ ℐ2"
and responses: "⋀x. x ∈ outs_ℐ ℐ1 ⟹ responses_ℐ ℐ2 x ⊆ responses_ℐ ℐ1 x"
shows "ℐ2 ⊢g gpv √"
using WT
proof coinduct
case (WT_gpv gpv out c)
with outs show ?case by(auto 6 4 dest: responses WT_gpvD)
qed

lemma WT_gpv_Done [iff]: "Γ ⊢g Done x √"
by(rule WT_gpvI) simp_all

lemma WT_gpv_Fail [iff]: "Γ ⊢g Fail √"
by(rule WT_gpvI) simp_all

lemma WT_gpv_PauseI:
"⟦ out ∈ outs_ℐ Γ; ⋀input. input ∈ responses_ℐ Γ out ⟹ Γ ⊢g c input √ ⟧
⟹ Γ ⊢g Pause out c √"
by(rule WT_gpvI) simp_all

lemma WT_gpv_Pause [iff]:
"Γ ⊢g Pause out c √ ⟷ out ∈ outs_ℐ Γ ∧ (∀input ∈ responses_ℐ Γ out. Γ ⊢g c input √)"
by(auto intro: WT_gpv_PauseI dest: WT_gpvD)

lemma WT_gpv_bindI:
"⟦ Γ ⊢g gpv √; ⋀x. x ∈ results_gpv Γ gpv ⟹ Γ ⊢g f x √ ⟧
⟹ Γ ⊢g gpv ⤜ f √"
proof(coinduction arbitrary: gpv)
case [rule_format]: (WT_gpv out c gpv)
from ‹IO out c ∈ _›
obtain generat where generat: "generat ∈ set_spmf (the_gpv gpv)"
and *: "IO out c ∈ set_spmf (if is_Pure generat then the_gpv (f (result generat))
else return_spmf (IO (output generat) (λinput. continuation generat input ⤜ f)))"
by(auto)
show ?case
proof(cases generat)
case (Pure y)
with generat have "y ∈ results_gpv Γ gpv" by(auto intro: results_gpv.Pure)
hence "Γ ⊢g f y √" by(rule WT_gpv)
with * Pure show ?thesis by(auto dest: WT_gpvD)
next
case (IO out' c')
hence [simp]: "out' = out"
and c: "⋀input. c input = bind_gpv (c' input) f" using * by simp_all
from generat IO have **: "IO out c' ∈ set_spmf (the_gpv gpv)" by simp
with ‹Γ ⊢g gpv √› have ?out by(auto dest: WT_gpvD)
moreover {
fix input
assume input: "input ∈ responses_ℐ Γ out"
with ‹Γ ⊢g gpv √› ** have "Γ ⊢g c' input √" by(rule WT_gpvD)
moreover {
fix y
assume "y ∈ results_gpv Γ (c' input)"
with ** input have "y ∈ results_gpv Γ gpv" by(rule results_gpv.IO)
hence "Γ ⊢g f y √" by(rule WT_gpv) }
moreover note calculation }
hence ?cont using c by blast
ultimately show ?thesis ..
qed
qed

lemma WT_gpv_bindD2:
assumes WT: "Γ ⊢g gpv ⤜ f √"
and x: "x ∈ results_gpv Γ gpv"
shows "Γ ⊢g f x √"
using x WT
proof induction
case (Pure gpv)
show ?case
proof(rule WT_gpvI)
fix out c
assume "IO out c ∈ set_spmf (the_gpv (f x))"
with Pure have "IO out c ∈ set_spmf (the_gpv (gpv ⤜ f))" by(auto intro: rev_bexI)
with ‹Γ ⊢g gpv ⤜ f √› show "out ∈ outs_ℐ Γ ∧ (∀input∈responses_ℐ Γ out. Γ ⊢g c input √)"
by(auto dest: WT_gpvD simp del: set_bind_spmf)
qed
next
case (IO out c gpv input)
from ‹IO out c ∈ _›
have "IO out (λinput. bind_gpv (c input) f) ∈ set_spmf (the_gpv (gpv ⤜ f))"
by(auto intro: rev_bexI)
with IO.prems have "Γ ⊢g c input ⤜ f √" using ‹input ∈ _› by(rule WT_gpv_ContD)
thus ?case by(rule IO.IH)
qed

lemma WT_gpv_bindD1: "Γ ⊢g gpv ⤜ f √ ⟹ Γ ⊢g gpv √"
proof(coinduction arbitrary: gpv)
case (WT_gpv out c gpv)
from ‹IO out c ∈ _›
have "IO out (λinput. bind_gpv (c input) f) ∈ set_spmf (the_gpv (gpv ⤜ f))"
by(auto intro: rev_bexI)
with ‹Γ ⊢g gpv ⤜ f √› show ?case
by(auto simp del: bind_gpv_sel' dest: WT_gpvD)
qed

lemma WT_gpv_bind [simp]: "Γ ⊢g gpv ⤜ f √ ⟷ Γ ⊢g gpv √ ∧ (∀x∈results_gpv Γ gpv. Γ ⊢g f x √)"
by(blast intro: WT_gpv_bindI dest: WT_gpv_bindD1 WT_gpv_bindD2)

lemma WT_gpv_full [simp, intro!]: "ℐ_full ⊢g gpv √"
by(coinduction arbitrary: gpv)(auto)

lemma WT_gpv_lift_spmf [simp, intro!]: "ℐ ⊢g lift_spmf p √"
by(rule WT_gpvI) auto

lemma WT_gpv_coinduct_bind [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont]:
assumes *: "X gpv"
and step: "⋀gpv out c. ⟦ X gpv; IO out c ∈ set_spmf (the_gpv gpv) ⟧
⟹ out ∈ outs_ℐ ℐ ∧ (∀input∈responses_ℐ ℐ out.
X (c input) ∨
ℐ ⊢g c input √ ∨
(∃(gpv' :: ('b, 'call, 'ret) gpv) f. c input = gpv' ⤜ f ∧ ℐ ⊢g gpv' √ ∧ (∀x∈results_gpv ℐ gpv'. X (f x))))"
shows "ℐ ⊢g gpv √"
proof -
fix x
define gpv' :: "('b, 'call, 'ret) gpv" and f :: "'b ⇒ ('a, 'call, 'ret) gpv"
where "gpv' = Done x" and "f = (λ_. gpv)"
with * have "ℐ ⊢g gpv' √" and "⋀x. x ∈ results_gpv ℐ gpv' ⟹ X (f x)" by simp_all
then have "ℐ ⊢g gpv' ⤜ f √"
proof(coinduction arbitrary: gpv' f rule: WT_gpv_coinduct)
case [rule_format]: (WT_gpv out c gpv')
from ‹IO out c ∈ _›
obtain generat where generat: "generat ∈ set_spmf (the_gpv gpv')"
and *: "IO out c ∈ set_spmf (if is_Pure generat
then the_gpv (f (result generat))
else return_spmf (IO (output generat) (λinput. continuation generat input ⤜ f)))"
by(clarsimp)
show ?case
proof(cases generat)
case (Pure x)
from Pure * have IO: "IO out c ∈ set_spmf (the_gpv (f x))" by simp
from generat Pure have "x ∈ results_gpv ℐ gpv'" by (simp add: results_gpv.Pure)
then have "X (f x)" by(rule WT_gpv)
from step[OF this IO] show ?thesis by(auto 4 4 intro: exI[where x="Done _"])
next
case (IO out' c')
with * have [simp]: "out' = out"
and c: "c = (λinput. c' input ⤜ f)" by simp_all
from IO generat have IO: "IO out c' ∈ set_spmf (the_gpv gpv')" by simp
then have "⋀input. input ∈ responses_ℐ ℐ out ⟹ results_gpv ℐ (c' input) ⊆ results_gpv ℐ gpv'"
by(auto intro: results_gpv.IO)
with WT_gpvD[OF ‹ℐ ⊢g gpv' √› IO] show ?thesis unfolding c using WT_gpv(2) by blast
qed
qed
then show ?thesis unfolding gpv'_def f_def by simp
qed

lemma ℐ_trivial_WT_gpvD [simp]: "ℐ_trivial ℐ ⟹ ℐ ⊢g gpv √"
using WT_gpv_full by(rule WT_gpv_mono)(simp_all add: ℐ_trivial_def)

lemma ℐ_trivial_WT_gpvI:
assumes "⋀gpv :: ('a, 'out, 'in) gpv. ℐ ⊢g gpv √"
shows "ℐ_trivial ℐ"
proof
fix x
have "ℐ ⊢g Pause x (λ_. Fail :: ('a, 'out, 'in) gpv) √" by(rule assms)
thus "x ∈ outs_ℐ ℐ" by(simp)
qed

lemma WT_gpv_ℐ_mono: "⟦ ℐ ⊢g gpv √; ℐ ≤ ℐ' ⟧ ⟹ ℐ' ⊢g gpv √"
by(erule WT_gpv_mono; rule outs_ℐ_mono responses_ℐ_mono)

lemma results_gpv_mono:
assumes le: "ℐ' ≤ ℐ" and WT: "ℐ' ⊢g gpv √"
shows "results_gpv ℐ gpv ⊆ results_gpv ℐ' gpv"
proof(rule subsetI, goal_cases)
case (1 x)
show ?case using 1 WT by(induction)(auto 4 3 intro: results_gpv.intros responses_ℐ_mono[OF le, THEN subsetD] intro: WT_gpvD)
qed

lemma WT_gpv_outs_gpv:
assumes "ℐ ⊢g gpv √"
shows "outs_gpv ℐ gpv ⊆ outs_ℐ ℐ"
proof
show "x ∈ outs_ℐ ℐ" if "x ∈ outs_gpv ℐ gpv" for x using that assms
by(induction)(blast intro: WT_gpv_OutD WT_gpv_ContD)+
qed

lemma WT_gpv_map_gpv': "ℐ ⊢g map_gpv' f g h gpv √" if "map_ℐ g h ℐ ⊢g gpv √"
using that by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma WT_gpv_map_gpv: "ℐ ⊢g map_gpv f g gpv √" if "map_ℐ g id ℐ ⊢g gpv √"
unfolding map_gpv_conv_map_gpv' using that by(rule WT_gpv_map_gpv')

lemma results_gpv_map_gpv' [simp]:
"results_gpv ℐ (map_gpv' f g h gpv) = f ` (results_gpv (map_ℐ g h ℐ) gpv)"
proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
show "x ∈ f ` results_gpv (map_ℐ g h ℐ) gpv" if "x ∈ results_gpv ℐ (map_gpv' f g h gpv)" for x using that
by(induction gpv'≡"map_gpv' f g h gpv" arbitrary: gpv)(fastforce intro: results_gpv.intros rev_image_eqI)+
show "f x ∈ results_gpv ℐ (map_gpv' f g h gpv)" if "x ∈ results_gpv (map_ℐ g h ℐ) gpv" for x using that
by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma WT_gpv_parametric': includes lifting_syntax shows
"bi_unique C ⟹ (rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) WT_gpv WT_gpv"
proof(rule rel_funI iffI)+
note [transfer_rule] = the_gpv_parametric'
show *: "ℐ ⊢g gpv √" if [transfer_rule]: "rel_ℐ C R ℐ ℐ'" "bi_unique C"
and *: "ℐ' ⊢g gpv' √" "rel_gpv'' A C R gpv gpv'" for ℐ ℐ' gpv gpv' A C R
using *
proof(coinduction arbitrary: gpv gpv')
case (WT_gpv out c gpv gpv')
note [transfer_rule] = WT_gpv(2)
have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))"
by transfer_prover
from rel_setD1[OF this WT_gpv(3)] obtain out' c'
where [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
and out': "IO out' c' ∈ set_spmf (the_gpv gpv')"
by(auto elim: generat.rel_cases)
have "out ∈ outs_ℐ ℐ ⟷ out' ∈ outs_ℐ ℐ'" by transfer_prover
with WT_gpvD(1)[OF WT_gpv(1) out'] have ?out by simp
moreover have ?cont
proof(standard; goal_cases cont)
case (cont input)
have "rel_set R (responses_ℐ ℐ out) (responses_ℐ ℐ' out')" by transfer_prover
from rel_setD1[OF this cont] obtain input' where [transfer_rule]: "R input input'"
and input': "input' ∈ responses_ℐ ℐ' out'" by blast
have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
with WT_gpvD(2)[OF WT_gpv(1) out' input'] show ?case by auto
qed
ultimately show ?case ..
qed

show "ℐ' ⊢g gpv' √" if "rel_ℐ C R ℐ ℐ'" "bi_unique C" "ℐ ⊢g gpv √" "rel_gpv'' A C R gpv gpv'"
for ℐ ℐ' gpv gpv'
using *[of "conversep C" "conversep R" ℐ' ℐ gpv "conversep A" gpv'] that
qed

lemma WT_gpv_map_gpv_id [simp]: "ℐ ⊢g map_gpv f id gpv √ ⟷ ℐ ⊢g gpv √"
using WT_gpv_parametric'[of "BNF_Def.Grp UNIV id" "(=)" "BNF_Def.Grp UNIV f", folded rel_gpv_conv_rel_gpv'']
unfolding gpv.rel_Grp unfolding eq_alt[symmetric] relator_eq
by(auto simp add: rel_fun_def Grp_def bi_unique_eq)

lemma WT_gpv_outs_gpvI:
assumes "outs_gpv ℐ gpv ⊆ outs_ℐ ℐ"
shows "ℐ ⊢g gpv √"
using assms by(coinduction arbitrary: gpv)(auto intro: outs_gpv.intros)

lemma WT_gpv_iff_outs_gpv:
"ℐ ⊢g gpv √ ⟷ outs_gpv ℐ gpv ⊆ outs_ℐ ℐ"
by(blast intro: WT_gpv_outs_gpvI dest: WT_gpv_outs_gpv)

subsection ‹Sub-gpvs›

context begin
qualified inductive sub_gpvsp :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ ('a, 'out, 'in) gpv ⇒ bool"
for ℐ x
where
base:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out; x = c input ⟧
⟹ sub_gpvsp ℐ x gpv"
| cont:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out; sub_gpvsp ℐ x (c input) ⟧
⟹ sub_gpvsp ℐ x gpv"

qualified lemma sub_gpvsp_base:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out ⟧
⟹ sub_gpvsp ℐ (c input) gpv"
by(rule base) simp_all

definition sub_gpvs :: "('out, 'in) ℐ ⇒  ('a, 'out, 'in) gpv ⇒ ('a, 'out, 'in) gpv set"
where "sub_gpvs ℐ gpv ≡ {x. sub_gpvsp ℐ x gpv}"

lemma sub_gpvsp_sub_gpvs_eq [pred_set_conv]: "sub_gpvsp ℐ x gpv ⟷ x ∈ sub_gpvs ℐ gpv"

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "sub_gpvs")›

lemmas intros [intro?] = sub_gpvsp.intros[to_set]
and base = sub_gpvsp_base[to_set]
and cont = cont[to_set]
and induct [consumes 1, case_names Pure IO, induct set: sub_gpvs] = sub_gpvsp.induct[to_set]
and cases [consumes 1, case_names Pure IO, cases set: sub_gpvs] = sub_gpvsp.cases[to_set]
and simps = sub_gpvsp.simps[to_set]
end
end

lemma WT_sub_gpvsD:
assumes "ℐ ⊢g gpv √" and "gpv' ∈ sub_gpvs ℐ gpv"
shows "ℐ ⊢g gpv' √"
using assms(2,1) by(induction)(auto dest: WT_gpvD)

lemma WT_sub_gpvsI:
"⟦ ⋀out c. IO out c ∈ set_spmf (the_gpv gpv) ⟹ out ∈ outs_ℐ Γ;
⋀gpv'. gpv' ∈ sub_gpvs Γ gpv ⟹ Γ ⊢g gpv' √ ⟧
⟹ Γ ⊢g gpv √"
by(rule WT_gpvI)(auto intro: sub_gpvs.base)

subsection ‹Losslessness›

text ‹A gpv is lossless iff we are guaranteed to get a result after a finite number of interactions
that respect the interface. It is colossless if the interactions may go on for ever, but there is
no non-termination.›

text ‹ We define both notions of losslessness simultaneously by mimicking what the (co)inductive
package would do internally. Thus, we get a constant which is parametrised by the choice of the
fixpoint, i.e., for non-recursive gpvs, we can state and prove both versions of losslessness
in one go.›

context
fixes co :: bool and ℐ :: "('out, 'in) ℐ"
and F :: "(('a, 'out, 'in) gpv ⇒ bool) ⇒ (('a, 'out, 'in) gpv ⇒ bool)"
and co' :: bool
defines "F ≡ λgen_lossless_gpv gpv. ∃pa. gpv = GPV pa ∧
lossless_spmf pa ∧ (∀out c input. IO out c ∈ set_spmf pa ⟶ input ∈ responses_ℐ ℐ out ⟶ gen_lossless_gpv (c input))"
and "co' ≡ co" ― ‹We use a copy of @{term co} such that we can do case distinctions on @{term co'} without
the simplifier rewriting the @{term co} in the local abbreviations for the constants.›
begin

lemma gen_lossless_gpv_mono: "mono F"
unfolding F_def
apply(rule monoI le_funI le_boolI')+
apply(tactic ‹REPEAT (resolve_tac @{context} (Inductive.get_monos @{context}) 1)›)
apply(erule le_funE)
apply(erule le_boolD)
done

definition gen_lossless_gpv :: "('a, 'out, 'in) gpv ⇒ bool"
where "gen_lossless_gpv = (if co' then gfp else lfp) F"

lemma gen_lossless_gpv_unfold: "gen_lossless_gpv = F gen_lossless_gpv"
by(simp add: gen_lossless_gpv_def gfp_unfold[OF gen_lossless_gpv_mono, symmetric] lfp_unfold[OF gen_lossless_gpv_mono, symmetric])

lemma gen_lossless_gpv_True: "co' = True ⟹ gen_lossless_gpv ≡ gfp F"
and gen_lossless_gpv_False: "co' = False ⟹ gen_lossless_gpv ≡ lfp F"

lemma gen_lossless_gpv_cases [elim?, cases pred]:
assumes "gen_lossless_gpv gpv"
obtains (gen_lossless_gpv) p where "gpv = GPV p" "lossless_spmf p"
"⋀out c input. ⟦IO out c ∈ set_spmf p; input ∈ responses_ℐ ℐ out⟧ ⟹ gen_lossless_gpv (c input)"
proof -
from assms show ?thesis
by(rewrite in asm gen_lossless_gpv_unfold)(auto simp add: F_def intro: that)
qed

lemma gen_lossless_gpvD:
assumes "gen_lossless_gpv gpv"
shows gen_lossless_gpv_lossless_spmfD: "lossless_spmf (the_gpv gpv)"
and gen_lossless_gpv_continuationD:
"⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out ⟧ ⟹ gen_lossless_gpv (c input)"
using assms by(auto elim: gen_lossless_gpv_cases)

lemma gen_lossless_gpv_intros:
"⟦ lossless_spmf p;
⋀out c input. ⟦IO out c ∈ set_spmf p; input ∈ responses_ℐ ℐ out ⟧ ⟹ gen_lossless_gpv (c input) ⟧
⟹ gen_lossless_gpv (GPV p)"

lemma gen_lossless_gpvI [intro?]:
"⟦ lossless_spmf (the_gpv gpv);
⋀out c input. ⟦ IO out c ∈ set_spmf (the_gpv gpv); input ∈ responses_ℐ ℐ out ⟧
⟹ gen_lossless_gpv (c input) ⟧
⟹ gen_lossless_gpv gpv"
by(cases gpv)(auto intro: gen_lossless_gpv_intros)

lemma gen_lossless_gpv_simps:
"gen_lossless_gpv gpv ⟷
(∃p. gpv = GPV p ∧ lossless_spmf p ∧ (∀out c input.
IO out c ∈ set_spmf p ⟶ input ∈ responses_ℐ ℐ out ⟶ gen_lossless_gpv (c input)))"

lemma gen_lossless_gpv_Done [iff]: "gen_lossless_gpv (Done x)"
by(rule gen_lossless_gpvI) auto

lemma gen_lossless_gpv_Fail [iff]: "¬ gen_lossless_gpv Fail"
by(auto dest: gen_lossless_gpvD)

lemma gen_lossless_gpv_Pause [simp]:
"gen_lossless_gpv (Pause out c) ⟷ (∀input ∈ responses_ℐ ℐ out. gen_lossless_gpv (c input))"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

lemma gen_lossless_gpv_lift_spmf [iff]: "gen_lossless_gpv (lift_spmf p) ⟷ lossless_spmf p"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

end

lemma gen_lossless_gpv_assert_gpv [iff]: "gen_lossless_gpv co ℐ (assert_gpv b) ⟷ b"
by(cases b) simp_all

abbreviation lossless_gpv :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where "lossless_gpv ≡ gen_lossless_gpv False"

abbreviation colossless_gpv :: "('out, 'in) ℐ ⇒ ('a, 'out, 'in) gpv ⇒ bool"
where "colossless_gpv ≡ gen_lossless_gpv True"

lemma lossless_gpv_induct [consumes 1, case_names lossless_gpv, induct pred]:
assumes *: "lossless_gpv ℐ gpv"
and step: "⋀p. ⟦ lossless_spmf p;
⋀out c input. ⟦IO out c ∈ set_spmf p; input ∈ responses_ℐ ℐ out⟧ ⟹ lossless_gpv ℐ (c input);
⋀out c input. ⟦IO out c ∈ set_spmf p; input ∈ responses_ℐ ℐ out⟧ ⟹ P (c input) ⟧
⟹ P (GPV p)"
shows "P gpv"
proof -
have "lossless_gpv ℐ ≤ P"
by(rule def_lfp_induct[OF gen_lossless_gpv_False gen_lossless_gpv_mono])(auto intro!: le_funI step)
then show ?thesis using * by auto
qed

lemma colossless_gpv_coinduct
[consumes 1, case_names colossless_gpv, case_conclusion colossless_gpv lossless_spmf continuation, coinduct pred]:
assumes *: "X gpv"
and step: "⋀gpv. X gpv ⟹ lossless_spmf (the_gpv gpv) ∧ (∀out c input.
IO out c ∈ set_spmf (the_gpv gpv) ⟶ input ∈ responses_ℐ ℐ out ⟶ X (c input) ∨ colossless_gpv ℐ (c input))"
shows "colossless_gpv ℐ gpv"
proof -
have "X ≤ colossless_gpv ℐ"
by(rule def_coinduct[OF gen_lossless_gpv_True gen_lossless_gpv_mono])
(auto 4 4 intro!: le_funI dest!: step intro: exI[where x="the_gpv _"])
then show ?thesis using * by auto
qed

lemmas lossless_gpvI = gen_lossless_gpvI[where co=False]
and lossless_gpvD = gen_lossless_gpvD[where co=False]
and lossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=False]
and lossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=False]

lemmas colossless_gpvI = gen_lossless_gpvI[where co=True]
and colossless_gpvD = gen_lossless_gpvD[where co=True]
and colossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=True]
and colossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=True]

lemma gen_lossless_bind_gpvI:
assumes "gen_lossless_gpv co ℐ gpv" "⋀x. x ∈ results_gpv ℐ gpv ⟹ gen_lossless_gpv co ℐ (f x)"
shows "gen_lossless_gpv co ℐ (gpv ⤜ f)"
proof(cases co)
case False
hence eq: "co = False" by simp
show ?thesis using assms unfolding eq
proof(induction)
case (lossless_gpv p)
{ fix x
assume "Pure x ∈ set_spmf p"
hence "x ∈ results_gpv ℐ (GPV p)" by simp
hence "lossless_gpv ℐ (f x)" by(rule lossless_gpv.prems) }
with ‹lossless_spmf p› show ?case unfolding GPV_bind
apply(intro gen_lossless_gpv_intros)
apply(fastforce dest: lossless_gpvD split: generat.split)
apply(clarsimp; split generat.split_asm)
apply(auto dest: lossless_gpvD intro!: lossless_gpv)
done
qed
next
case True
hence eq: "co = True" by simp
show ?thesis using assms unfolding eq
proof(coinduction arbitrary: gpv rule: colossless_gpv_coinduct)
case * [rule_format]: (colossless_gpv gpv)
from *(1) have ?lossless_spmf
by(auto 4 3 dest: colossless_gpv_lossless_spmfD elim!: is_PureE intro: *(2)[THEN colossless_gpv_lossless_spmfD] results_gpv.Pure)
moreover have ?continuation
proof(intro strip)
fix out c input
assume IO: "IO out c ∈ set_spmf (the_gpv (gpv ⤜ f))"
and input: "input ∈ responses_ℐ ℐ out"
from IO obtain generat where generat: "generat ∈ set_spmf (the_gpv gpv)"
and IO: "IO out c ∈ set_spmf (if is_Pure generat then the_gpv (f (result generat))
else return_spmf (IO (output generat) (λinput. continuation generat input ⤜ f)))"
by(auto)
show "(∃gpv. c input = gpv ⤜ f ∧ colossless_gpv ℐ gpv ∧ (∀x. x ∈ results_gpv ℐ gpv ⟶ colossless_gpv ℐ (f x))) ∨
colossless_gpv ℐ (c input)"
proof(cases generat)
case (Pure x)
hence "x ∈ results_gpv ℐ gpv" using generat by(auto intro: results_gpv.Pure)
from *(2)[OF this] have "colossless_gpv ℐ (c input)"
using IO Pure input by(auto intro: colossless_gpv_continuationD)
thus ?thesis ..
next
case **: (IO out' c')
with input generat IO have "colossless_gpv ℐ (f x)" if "x ∈ results_gpv ℐ (c' input)" for x
using that by(auto intro: * results_gpv.IO)
then show ?thesis using IO input ** *(1) generat by(auto dest: colossless_gpv_continuationD)
qed
qed
ultimately show ?case ..
qed
qed

lemmas lossless_bind_gpvI = gen_lossless_bind_gpvI[where co=False]
and colossless_bind_gpvI = gen_lossless_bind_gpvI[where co=True]

lemma gen_lossless_bind_gpvD1:
assumes "gen_lossless_gpv co ℐ (gpv ⤜ f)"
shows "gen_lossless_gpv co ℐ gpv"
proof(cases co)
case False
hence eq: "co = False" by simp
show ?thesis using assms unfolding eq
proof(induction gpv'≡"gpv ⤜ f" arbitrary: gpv)
case (lossless_gpv p)
obtain p' where gpv: "gpv = GPV p'" by(cases gpv)
from lossless_gpv.hyps gpv have "lossless_spmf p'" by(simp add: GPV_bind)
then show ?case unfolding gpv
proof(rule gen_lossless_gpv_intros)
fix out c input
assume "IO out c ∈ set_spmf p'" "input ∈ responses_ℐ ℐ out"
hence "IO out (λinput. c input ⤜ f) ∈ set_spmf p" using lossless_gpv.hyps gpv
by(auto simp add: GPV_bind intro: rev_bexI)
thus "lossless_gpv ℐ (c input)" using ‹input ∈ _› by(rule lossless_gpv.hyps) simp
qed
qed
next
case True
hence eq: "co = True" by simp
show ?thesis using assms unfolding eq
by(coinduction arbitrary: gpv)(auto 4 3 intro: rev_bexI elim!: colossless_gpv_continuationD dest: colossless_gpv_lossless_spmfD)
qed

lemmas lossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=False]
and colossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=True]

lemma gen_lossless_bind_gpvD2:
assumes "gen_lossless_gpv co ℐ (gpv ⤜ f)"
and "x ∈ results_gpv ℐ gpv"
shows "gen_lossless_gpv co ℐ (f x)"
using assms(2,1)
proof(induction)
case (Pure gpv)
thus ?case
by -(rule gen_lossless_gpvI, auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)
qed(auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)

lemmas lossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=False]
and colossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=True]

lemma gen_lossless_bind_gpv [simp]:
"gen_lossless_gpv co ℐ (gpv ⤜ f) ⟷ gen_lossless_gpv co ℐ gpv ∧ (∀x∈results_gpv ℐ gpv. gen_lossless_gpv co ℐ (f x))"
by(blast intro: gen_lossless_bind_gpvI dest: gen_lossless_bind_gpvD1 gen_lossless_bind_gpvD2)

lemmas lossless_bind_gpv = gen_lossless_bind_gpv[where co=False]
and colossless_bind_gpv = gen_lossless_bind_gpv[where co=True]

context includes lifting_syntax begin

lemma rel_gpv''_lossless_gpvD1:
assumes rel: "rel_gpv'' A C R gpv gpv'"
and gpv: "lossless_gpv ℐ gpv"
and [transfer_rule]: "rel_ℐ C R ℐ ℐ'"
shows "lossless_gpv ℐ' gpv'"
using gpv rel
proof(induction arbitrary: gpv')
case (lossless_gpv p)
from lossless_gpv.prems obtain q where q: "gpv' = GPV q"
and [transfer_rule]: "rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"
by(cases gpv') auto
show ?case
proof(rule lossless_gpvI)
have "lossless_spmf p = lossless_spmf q" by transfer_prover
with lossless_gpv.hyps(1) q show "lossless_spmf (the_gpv gpv')" by simp

fix out' c' input'
assume IO': "IO out' c' ∈ set_spmf (the_gpv gpv')"
and input': "input' ∈ responses_ℐ ℐ' out'"
have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf p) (set_spmf q)"
by transfer_prover
with IO' q obtain out c where IO: "IO out c ∈ set_spmf p"
and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
by(auto dest!: rel_setD2 elim: generat.rel_cases)
have "rel_set R (responses_ℐ ℐ out) (responses_ℐ ℐ' out')" by transfer_prover
moreover
from this[THEN rel_setD2, OF input'] obtain input
where [transfer_rule]: "R input input'" and input: "input ∈ responses_ℐ ℐ out" by blast
have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
ultimately show "lossless_gpv ℐ' (c' input')" using input IO by(auto intro: lossless_gpv.IH)
qed
qed

lemma rel_gpv''_lossless_gpvD2:
"⟦ rel_gpv'' A C R gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C R ℐ ℐ' ⟧
⟹ lossless_gpv ℐ gpv"
using rel_gpv''_lossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ℐ]

lemma rel_gpv_lossless_gpvD1:
"⟦ rel_gpv A C gpv gpv'; lossless_gpv ℐ gpv; rel_ℐ C (=) ℐ ℐ' ⟧ ⟹ lossless_gpv ℐ' gpv'"
using rel_gpv''_lossless_gpvD1[of A C "(=)" gpv gpv' ℐ ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_lossless_gpvD2:
"⟦ rel_gpv A C gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C (=) ℐ ℐ' ⟧
⟹ lossless_gpv ℐ gpv"
using rel_gpv_lossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ℐ]

lemma rel_gpv''_colossless_gpvD1:
assumes rel: "rel_gpv'' A C R gpv gpv'"
and gpv: "colossless_gpv ℐ gpv"
and [transfer_rule]: "rel_ℐ C R ℐ ℐ'"
shows "colossless_gpv ℐ' gpv'"
using gpv rel
proof(coinduction arbitrary: gpv gpv')
case (colossless_gpv gpv gpv')
note [transfer_rule] = ‹rel_gpv'' A C R gpv gpv'› the_gpv_parametric'
and co = ‹colossless_gpv ℐ gpv›
have "lossless_spmf (the_gpv gpv) = lossless_spmf (the_gpv gpv')" by transfer_prover
with co have "?lossless_spmf" by(auto dest: colossless_gpv_lossless_spmfD)
moreover have "?continuation"
proof(intro strip disjI1)
fix out' c' input'
assume IO': "IO out' c' ∈ set_spmf (the_gpv gpv')"
and input': "input' ∈ responses_ℐ ℐ' out'"
have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))"
by transfer_prover
with IO' obtain out c where IO: "IO out c ∈ set_spmf (the_gpv gpv)"
and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
by(auto dest!: rel_setD2 elim: generat.rel_cases)
have "rel_set R (responses_ℐ ℐ out) (responses_ℐ ℐ' out')" by transfer_prover
moreover
from this[THEN rel_setD2, OF input'] obtain input
where [transfer_rule]: "R input input'" and input: "input ∈ responses_ℐ ℐ out" by blast
have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
ultimately show "∃gpv gpv'. c' input' = gpv' ∧ colossless_gpv ℐ gpv ∧ rel_gpv'' A C R gpv gpv'"
using input IO co by(auto dest: colossless_gpv_continuationD)
qed
ultimately show ?case ..
qed

lemma rel_gpv''_colossless_gpvD2:
"⟦ rel_gpv'' A C R gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C R ℐ ℐ' ⟧
⟹ colossless_gpv ℐ gpv"
using rel_gpv''_colossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ℐ]

lemma rel_gpv_colossless_gpvD1:
"⟦ rel_gpv A C gpv gpv'; colossless_gpv ℐ gpv; rel_ℐ C (=) ℐ ℐ' ⟧ ⟹ colossless_gpv ℐ' gpv'"
using rel_gpv''_colossless_gpvD1[of A C "(=)" gpv gpv' ℐ ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_colossless_gpvD2:
"⟦ rel_gpv A C gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C (=) ℐ ℐ' ⟧
⟹ colossless_gpv ℐ gpv"
using rel_gpv_colossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ℐ]

lemma gen_lossless_gpv_parametric':
"((=) ===> rel_ℐ C R ===> rel_gpv'' A C R ===> (=))
gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
show "(rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
by(cases b)(auto intro!: rel_funI dest: rel_gpv''_colossless_gpvD1 rel_gpv''_colossless_gpvD2 rel_gpv''_lossless_gpvD1 rel_gpv''_lossless_gpvD2)
qed

lemma gen_lossless_gpv_parametric [transfer_rule]:
"((=) ===> rel_ℐ C (=) ===> rel_gpv A C ===> (=))
gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
show "(rel_ℐ C (=) ===> rel_gpv A C ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
by(cases b)(auto intro!: rel_funI dest: rel_gpv_colossless_gpvD1 rel_gpv_colossless_gpvD2 rel_gpv_lossless_gpvD1 rel_gpv_lossless_gpvD2)
qed

end

lemma gen_lossless_gpv_map_full [simp]:
"gen_lossless_gpv b ℐ_full (map_gpv f g gpv) = gen_lossless_gpv b ℐ_full gpv"
(is "?lhs = ?rhs")
proof(cases "b = True")
case True
show "?lhs = ?rhs"
proof
show ?rhs if ?lhs using that unfolding True
by(coinduction arbitrary: gpv)(auto 4 3 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
show ?lhs if ?rhs using that unfolding True
by(coinduction arbitrary: gpv)(auto 4 4 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
qed
next
case False
hence False: "b = False" by simp
show "?lhs = ?rhs"
proof
show ?rhs if ?lhs using that unfolding False
apply(induction gpv'≡"map_gpv f g gpv" arbitrary: gpv)
subgoal for p gpv by(cases gpv)(rule lossless_gpvI; fastforce intro: rev_image_eqI)
done
show ?lhs if ?rhs using that unfolding False
by induction(auto 4 4 intro: lossless_gpvI)
qed
qed

lemma gen_lossless_gpv_map_id [simp]:
"gen_lossless_gpv b ℐ (map_gpv f id gpv) = gen_lossless_gpv b ℐ gpv"
using gen_lossless_gpv_parametric[of "BNF_Def.Grp UNIV id" "BNF_Def.Grp UNIV f"] unfolding gpv.rel_Grp

lemma results_gpv_try_gpv [simp]:
"results_gpv ℐ (TRY gpv ELSE gpv') =
results_gpv ℐ gpv ∪ (if colossless_gpv ℐ gpv then {} else results_gpv ℐ gpv')"
(is "?lhs = ?rhs")
proof(intro set_eqI iffI)
show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
proof(induction gpv''≡"try_gpv gpv gpv'" arbitrary: gpv)
case Pure thus ?case
by(auto split: if_split_asm intro: results_gpv.Pure dest: colossless_gpv_lossless_spmfD)
next
case (IO out c input)
then show ?case
apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm)
apply(force intro: results_gpv.IO dest: colossless_gpv_continuationD split: if_split_asm)+
done
qed
next
fix x
assume "x ∈ ?rhs"
then consider (left) "x ∈ results_gpv ℐ gpv" | (right) "¬ colossless_gpv ℐ gpv" "x ∈ results_gpv ℐ gpv'"
by(auto split: if_split_asm)
thus "x ∈ ?lhs"
proof cases
case left
thus ?thesis
by(induction)(auto 4 4 intro: results_gpv.intros rev_image_eqI split del: if_split)
next
case right
from right(1) show ?thesis
proof(rule contrapos_np)
assume "x ∉ ?lhs"
with right(2) show "colossless_gpv ℐ gpv"
proof(coinduction arbitrary: gpv)
case (colossless_gpv gpv)
then have ?lossless_spmf
apply(rewrite in asm try_gpv.code)
apply(rule ccontr)
apply(erule results_gpv.cases)
apply(fastforce simp add: image_Un image_image generat.map_comp o_def)+
done
moreover have "?continuation" using colossless_gpv
by(auto 4 4 split del: if_split simp add: image_Un image_image generat.map_comp o_def intro: rev_image_eqI results_gpv.IO)
ultimately show ?case ..
qed
qed
qed
qed

lemma results'_gpv_try_gpv [simp]:
"results'_gpv (TRY gpv ELSE gpv') =
results'_gpv gpv ∪ (if colossless_gpv ℐ_full gpv then {} else results'_gpv gpv')"

lemma outs'_gpv_try_gpv [simp]:
"outs'_gpv (TRY gpv ELSE gpv') =
outs'_gpv gpv ∪ (if colossless_gpv ℐ_full gpv then {} else outs'_gpv gpv')"
(is "?lhs = ?rhs")
proof(intro set_eqI iffI)
show "x ∈ ?rhs" if "x ∈ ?lhs" for x using that
proof(induction gpv''≡"try_gpv gpv gpv'" arbitrary: gpv)
case Out thus ?case
by(auto 4 3 simp add: generat.map_comp o_def elim!: generat.set_cases(2) intro: outs'_gpv_Out split: if_split_asm dest: colossless_gpv_lossless_spmfD)
next
case (Cont generat c input)
then show ?case
apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm elim!: generat.set_cases(3))
apply(auto 4 3 dest: colossless_gpv_continuationD split: if_split_asm intro: outs'_gpv_Cont elim!: meta_allE meta_impE[OF _ refl])+
done
qed
next
fix x
assume "x ∈ ?rhs"
then consider (left) "x ∈ outs'_gpv gpv" | (right) "¬ colossless_gpv ℐ_full gpv" "x ∈ outs'_gpv gpv'"
by(auto split: if_split_asm)
thus "x ∈ ?lhs"
proof cases
case left
thus ?thesis
by(induction)(auto elim!: generat.set_cases(2,3) intro: outs'_gpvI intro!: rev_image_eqI split del: if_split simp add: image_Un image_image generat.map_comp o_def)
next
case right
from right(1) show ?thesis
proof(rule contrapos_np)
assume "x ∉ ?lhs"
with right(2) show "colossless_gpv ℐ_full gpv"
proof(coinduction arbitrary: gpv)
case (colossless_gpv gpv)
then have ?lossless_spmf
apply(rewrite in asm try_gpv.code)
apply(erule contrapos_np)
apply(erule gpv.set_cases)
apply(auto 4 3 simp add: image_Un image_image generat.map_comp o_def generat.set_map in_set_spmf[symmetric] bind_UNION generat.map_id[unfolded id_def] elim!: generat.set_cases)
done
moreover have "?continuation" using colossless_gpv
by(auto simp add: image_Un image_image generat.map_comp o_def split del: if_split intro!: rev_image_eqI intro: outs'_gpv_Cont)
ultimately show ?case ..
qed
qed
qed
qed

lemma pred_gpv_try [simp]:
"pred_gpv P Q (try_gpv gpv gpv') = (pred_gpv P Q gpv ∧ (¬ colossless_gpv ℐ_full gpv ⟶ pred_gpv P Q gpv'))"

lemma lossless_WT_gpv_induct [consumes 2, case_names lossless_gpv]:
assumes lossless: "lossless_gpv ℐ gpv"
and WT: "ℐ ⊢g gpv √"
and step: "⋀p. ⟦
lossless_spmf p;
⋀out c. IO out c ∈ set_spmf p ⟹ out ∈ outs_ℐ ℐ;
⋀out c input. ⟦IO out c ∈ set_spmf p; out ∈ outs_ℐ ℐ ⟹ input ∈ responses_ℐ ℐ out⟧ ⟹ lossless_gpv ℐ (c input);
⋀out c input. ⟦IO out c ∈ set_spmf p; out ∈ outs_ℐ ℐ ⟹ input ∈ responses_ℐ ℐ out⟧ ⟹ ℐ ⊢g c input √;
⋀out c input. ⟦IO out c ∈ set_spmf p; out ∈ outs_ℐ ℐ ⟹ input ∈ responses_ℐ ℐ out⟧ ⟹ P (c input)⟧
⟹ P (GPV p)"
shows "P gpv"
using lossless WT
apply(induction)
apply(erule step)
apply(auto elim: WT_gpvD simp add: WT_gpv_simps)
done

lemma lossless_gpv_induct_strong [consumes 1, case_names lossless_gpv]:
assumes gpv: "lossless_gpv ℐ gpv"
and step:
"⋀p. ⟦ lossless_spmf p;
⋀gpv. gpv ∈ sub_gpvs ℐ (GPV p) ⟹ lossless_gpv ℐ gpv;
⋀gpv. gpv ∈ sub_gpvs ℐ (GPV p) ⟹ P gpv ⟧
⟹ P (GPV p)"
shows "P gpv"
proof -
define gpv' where "gpv' = gpv"
then have "gpv' ∈ insert gpv (sub_gpvs ℐ gpv)" by simp
with gpv have "lossless_gpv ℐ gpv' ∧ P gpv'"
proof(induction arbitrary: gpv')
case (lossless_gpv p)
from ```