Theory Generative_Probabilistic_Value

(* Title: Generative_Probabilistic_Value.thy
  Author: Andreas Lochbihler, ETH Zurich *)

theory Generative_Probabilistic_Value imports
  Resumption
  Generat
  "HOL-Types_To_Sets.Types_To_Sets"
begin

hide_const (open) Done

subsection ‹Type definition›

context notes [[bnf_internals]] begin

codatatype (results'_gpv: 'a, outs'_gpv: 'out, 'in) gpv
  = GPV (the_gpv: "('a, 'out, 'in  ('a, 'out, 'in) gpv) generat spmf")

end

declare gpv.rel_eq [relator_eq]

text ‹Reactive values are like generative, except that they take an input first.›

type_synonym ('a, 'out, 'in) rpv = "'in  ('a, 'out, 'in) gpv"
print_translation ― ‹pretty printing for @{typ "('a, 'out, 'in) rpv"} let
    fun tr' [in1, Const (@{type_syntax gpv}, _) $ a $ out $ in2] =
      if in1 = in2 then Syntax.const @{type_syntax rpv} $ a $ out $ in1
      else raise Match;
  in [(@{type_syntax "fun"}, K tr')]
  end
typ "('a, 'out, 'in) rpv"
text ‹
  Effectively, @{typ "('a, 'out, 'in) gpv"} and @{typ "('a, 'out, 'in) rpv"} are mutually recursive.
›

lemma eq_GPV_iff: "f = GPV g  the_gpv f = g"
by(cases f) auto

declare gpv.set[simp del]

declare gpv.set_map[simp]

lemma rel_gpv_def':
  "rel_gpv A B gpv gpv' 
  (gpv''. ((x, y)  results'_gpv gpv''. A x y)  ((x, y)  outs'_gpv gpv''. B x y) 
           map_gpv fst fst gpv'' = gpv  map_gpv snd snd gpv'' = gpv')"
unfolding rel_gpv_def by(auto simp add: BNF_Def.Grp_def)

definition results'_rpv :: "('a, 'out, 'in) rpv  'a set"
where "results'_rpv rpv = range rpv  results'_gpv"

definition outs'_rpv :: "('a, 'out, 'in) rpv  'out set"
where "outs'_rpv rpv = range rpv  outs'_gpv"

abbreviation rel_rpv
  :: "('a  'b  bool)  ('out  'out'  bool)
   ('in  ('a, 'out, 'in) gpv)  ('in  ('b, 'out', 'in) gpv)  bool"
where "rel_rpv A B  rel_fun (=) (rel_gpv A B)"

lemma in_results'_rpv [iff]: "x  results'_rpv rpv  (input. x  results'_gpv (rpv input))"
by(simp add: results'_rpv_def)

lemma in_outs_rpv [iff]: "out  outs'_rpv rpv  (input. out  outs'_gpv (rpv input))"
by(simp add: outs'_rpv_def)

lemma results'_GPV [simp]:
  "results'_gpv (GPV r) =
   (set_spmf r  generat_pures)  
   ((set_spmf r  generat_conts)  results'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_GPV [simp]:
  "outs'_gpv (GPV r) =
   (set_spmf r  generat_outs)  
   ((set_spmf r  generat_conts)  outs'_rpv)"
by(auto simp add: gpv.set bind_UNION set_spmf_def)

lemma outs'_gpv_unfold:
  "outs'_gpv r =
   (set_spmf (the_gpv r)  generat_outs)  
   ((set_spmf (the_gpv r)  generat_conts)  outs'_rpv)"
by(cases r) simp

lemma outs'_gpv_induct [consumes 1, case_names Out Cont, induct set: outs'_gpv]:
  assumes x: "x  outs'_gpv gpv"
  and Out: "generat gpv.  generat  set_spmf (the_gpv gpv); x  generat_outs generat   P gpv"
  and Cont: "generat gpv c input.
     generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  outs'_gpv (c input); P (c input)   P gpv"
  shows "P gpv"
using x
apply(induction y"x" gpv)
 apply(rule Out, simp add: in_set_spmf, simp)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma outs'_gpv_cases [consumes 1, case_names Out Cont, cases set: outs'_gpv]:
  assumes "x  outs'_gpv gpv"
  obtains (Out) generat where "generat  set_spmf (the_gpv gpv)" "x  generat_outs generat"
    | (Cont) generat c input where "generat  set_spmf (the_gpv gpv)" "c  generat_conts generat" "x  outs'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma outs'_gpvI [intro?]:
  shows outs'_gpv_Out: " generat  set_spmf (the_gpv gpv); x  generat_outs generat   x  outs'_gpv gpv"
  and outs'_gpv_Cont: " generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  outs'_gpv (c input)   x  outs'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma results'_gpv_induct [consumes 1, case_names Pure Cont, induct set: results'_gpv]:
  assumes x: "x  results'_gpv gpv"
  and Pure: "generat gpv.  generat  set_spmf (the_gpv gpv); x  generat_pures generat   P gpv"
  and Cont: "generat gpv c input.
     generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  results'_gpv (c input); P (c input)   P gpv"
  shows "P gpv"
using x
apply(induction y"x" gpv)
 apply(rule Pure; simp add: in_set_spmf)
apply(erule imageE, rule Cont, simp add: in_set_spmf, simp, simp, simp)
.

lemma results'_gpv_cases [consumes 1, case_names Pure Cont, cases set: results'_gpv]:
  assumes "x  results'_gpv gpv"
  obtains (Pure) generat where "generat  set_spmf (the_gpv gpv)" "x  generat_pures generat"
    | (Cont) generat c input where "generat  set_spmf (the_gpv gpv)" "c  generat_conts generat" "x  results'_gpv (c input)"
using assms by cases(auto simp add: in_set_spmf)

lemma results'_gpvI [intro?]:
  shows results'_gpv_Pure: " generat  set_spmf (the_gpv gpv); x  generat_pures generat   x  results'_gpv gpv"
  and results'_gpv_Cont: " generat  set_spmf (the_gpv gpv); c  generat_conts generat; x  results'_gpv (c input)   x  results'_gpv gpv"
by(auto intro: gpv.set_sel simp add: in_set_spmf)

lemma left_unique_rel_gpv [transfer_rule]:
  " left_unique A; left_unique B   left_unique (rel_gpv A B)"
unfolding left_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_unique_rel_gpv [transfer_rule]:
  " right_unique A; right_unique B   right_unique (rel_gpv A B)"
unfolding right_unique_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_unique_rel_gpv [transfer_rule]:
  " bi_unique A; bi_unique B   bi_unique (rel_gpv A B)"
unfolding bi_unique_alt_def by(simp add: left_unique_rel_gpv right_unique_rel_gpv)

lemma left_total_rel_gpv [transfer_rule]:
  " left_total A; left_total B   left_total (rel_gpv A B)"
unfolding left_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma right_total_rel_gpv [transfer_rule]:
  " right_total A; right_total B   right_total (rel_gpv A B)"
unfolding right_total_alt_def gpv.rel_conversep[symmetric] gpv.rel_compp[symmetric]
by(subst gpv.rel_eq[symmetric])(rule gpv.rel_mono)

lemma bi_total_rel_gpv [transfer_rule]: " bi_total A; bi_total B   bi_total (rel_gpv A B)"
unfolding bi_total_alt_def by(simp add: left_total_rel_gpv right_total_rel_gpv)

declare gpv.map_transfer[transfer_rule]

lemma if_distrib_map_gpv [if_distribs]:
  "map_gpv f g (if b then gpv else gpv') = (if b then map_gpv f g gpv else map_gpv f g gpv')"
by simp

lemma gpv_pred_mono_strong:
  " pred_gpv P Q x; a.  a  results'_gpv x; P a   P' a; b.  b  outs'_gpv x; Q b   Q' b   pred_gpv P' Q' x"
by(simp add: pred_gpv_def)

lemma pred_gpv_top [simp]:
  "pred_gpv (λ_. True) (λ_. True) = (λ_. True)"
by(simp add: pred_gpv_def)

lemma pred_gpv_conj [simp]:
  shows pred_gpv_conj1: "P Q R. pred_gpv (λx. P x  Q x) R = (λx. pred_gpv P R x  pred_gpv Q R x)"
  and pred_gpv_conj2: "P Q R. pred_gpv P (λx. Q x  R x) = (λx. pred_gpv P Q x  pred_gpv P R x)"
by(auto simp add: pred_gpv_def)

lemma rel_gpv_restrict_relp1I [intro?]:
  " rel_gpv R R' x y; pred_gpv P P' x; pred_gpv Q Q' y   rel_gpv (R  P  Q) (R'  P'  Q') x y"
by(erule gpv.rel_mono_strong)(simp_all add: pred_gpv_def)

lemma rel_gpv_restrict_relpE [elim?]:
  assumes "rel_gpv (R  P  Q) (R'  P'  Q') x y"
  obtains "rel_gpv R R' x y" "pred_gpv P P' x" "pred_gpv Q Q' y"
proof
  show "rel_gpv R R' x y" using assms by(auto elim!: gpv.rel_mono_strong)
  have "pred_gpv (Domainp (R  P  Q)) (Domainp (R'  P'  Q')) x" using assms by(fold gpv.Domainp_rel) blast
  then show "pred_gpv P P' x" by(rule gpv_pred_mono_strong)(blast dest!: restrict_relp_DomainpD)+
  have "pred_gpv (Domainp (R  P  Q)¯¯) (Domainp (R'  P'  Q')¯¯) y" using assms
    by(fold gpv.Domainp_rel)(auto simp only: gpv.rel_conversep Domainp_conversep)
  then show "pred_gpv Q Q' y" by(rule gpv_pred_mono_strong)(auto dest!: restrict_relp_DomainpD)
qed

lemma gpv_pred_map [simp]: "pred_gpv P Q (map_gpv f g gpv) = pred_gpv (P  f) (Q  g) gpv"
by(simp add: pred_gpv_def)

subsection ‹Generalised mapper and relator›

context includes lifting_syntax begin

primcorec map_gpv' :: "('a  'b)  ('out  'out')  ('ret'  'ret)  ('a, 'out, 'ret) gpv  ('b, 'out', 'ret') gpv"
where
  "map_gpv' f g h gpv = 
   GPV (map_spmf (map_generat f g ((∘) (map_gpv' f g h))) (map_spmf (map_generat id id (map_fun h id)) (the_gpv gpv)))"

declare map_gpv'.sel [simp del]

lemma map_gpv'_sel [simp]:
  "the_gpv (map_gpv' f g h gpv) = map_spmf (map_generat f g (h ---> map_gpv' f g h)) (the_gpv gpv)"
by(simp add: map_gpv'.sel spmf.map_comp o_def generat.map_comp map_fun_def[abs_def])

lemma map_gpv'_GPV [simp]:
  "map_gpv' f g h (GPV p) = GPV (map_spmf (map_generat f g (h ---> map_gpv' f g h)) p)"
by(rule gpv.expand) simp

lemma map_gpv'_id: "map_gpv' id id id = id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)
done

lemma map_gpv'_comp: "map_gpv' f g h (map_gpv' f' g' h' gpv) = map_gpv' (f  f') (g  g') (h'  h) gpv"
by(coinduction arbitrary: gpv)(auto simp add: spmf.map_comp spmf_rel_map generat.rel_map rel_fun_def intro!: rel_spmf_reflI generat.rel_refl)

functor gpv: map_gpv' by(simp_all add: map_gpv'_comp map_gpv'_id o_def) 

lemma map_gpv_conv_map_gpv': "map_gpv f g = map_gpv' f g id"
apply(rule ext)
apply(coinduction)
apply(auto simp add: gpv.map_sel spmf_rel_map generat.rel_map rel_fun_def intro!: generat.rel_refl_strong rel_spmf_reflI)
done

coinductive rel_gpv'' :: "('a  'b  bool)  ('out  'out'  bool)  ('ret  'ret'  bool)  ('a, 'out, 'ret) gpv  ('b, 'out', 'ret') gpv  bool"
  for A C R
where
  "rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')
   rel_gpv'' A C R gpv gpv'"

lemma rel_gpv''_coinduct [consumes 1, case_names rel_gpv'', coinduct pred: rel_gpv'']:
  "X gpv gpv';
    gpv gpv'. X gpv gpv'
      rel_spmf (rel_generat A C (R ===> (λgpv gpv'. X gpv gpv'  rel_gpv'' A C R gpv gpv')))
           (the_gpv gpv) (the_gpv gpv') 
    rel_gpv'' A C R gpv gpv'"
by(erule rel_gpv''.coinduct) blast

lemma rel_gpv''D:
  "rel_gpv'' A C R gpv gpv' 
   rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) (the_gpv gpv) (the_gpv gpv')"
by(simp add: rel_gpv''.simps)

lemma rel_gpv''_GPV [simp]:
  "rel_gpv'' A C R (GPV p) (GPV q) 
   rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"
by(simp add: rel_gpv''.simps)

lemma rel_gpv_conv_rel_gpv'': "rel_gpv A C = rel_gpv'' A C (=)"
proof(rule ext iffI)+
  show "rel_gpv A C gpv gpv'" if "rel_gpv'' A C (=) gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
    using that by(coinduct)(blast dest: rel_gpv''D)
  show "rel_gpv'' A C (=) gpv gpv'" if "rel_gpv A C gpv gpv'" for gpv :: "('a, 'b, 'c) gpv" and gpv' :: "('d, 'e, 'c) gpv"
    using that by(coinduct)(auto elim!: gpv.rel_cases rel_spmf_mono generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_eq (* [relator_eq] do not use this attribute unless all transfer rules for gpv have been changed to rel_gvp'' *):
  "rel_gpv'' (=) (=) (=) = (=)"
by(simp add: rel_gpv_conv_rel_gpv''[symmetric] gpv.rel_eq)

lemma rel_gpv''_mono:
  assumes "A  A'" "C  C'" "R'  R"
  shows "rel_gpv'' A C R  rel_gpv'' A' C' R'"
proof
  show "rel_gpv'' A' C' R' gpv gpv'" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that
    by(coinduct)(auto dest: rel_gpv''D elim!: rel_spmf_mono generat.rel_mono_strong rel_fun_mono intro: assms[THEN predicate2D])
qed

lemma rel_gpv''_conversep: "rel_gpv'' A¯¯ C¯¯ R¯¯ = (rel_gpv'' A C R)¯¯"
proof(intro ext iffI; simp)
  show "rel_gpv'' A C R gpv gpv'" if "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv"
    for A :: "'a1  'a2  bool" and C :: "'c1  'c2  bool" and R :: "'r1  'r2  bool" and gpv gpv'
    using that apply(coinduct)
    apply(drule rel_gpv''D)
    apply(rewrite in  conversep_iff[symmetric])
    apply(subst spmf_rel_conversep[symmetric])
    apply(erule rel_spmf_mono)
    apply(subst generat.rel_conversep[symmetric])
    apply(erule generat.rel_mono_strong)
    apply(auto simp add: rel_fun_def conversep_iff[abs_def])
    done
  from this[of "A¯¯" "C¯¯" "R¯¯"]
  show "rel_gpv'' A¯¯ C¯¯ R¯¯ gpv' gpv" if "rel_gpv'' A C R gpv gpv'" for gpv gpv' using that by simp
qed


lemma rel_gpv''_pos_distr:
  "rel_gpv'' A C R OO rel_gpv'' A' C' R'  rel_gpv'' (A OO A') (C OO C') (R OO R')"
proof(rule predicate2I; erule relcomppE)
  show "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
    if "rel_gpv'' A C R gpv gpv'" "rel_gpv'' A' C' R' gpv' gpv''"
    for gpv gpv' gpv'' using that
    apply(coinduction arbitrary: gpv gpv' gpv'')
    apply(drule rel_gpv''D)+
    apply(drule (1) rel_spmf_pos_distr[THEN predicate2D, OF relcomppI])
    apply(erule spmf_rel_mono_strong)
    apply(subst (asm) generat.rel_compp[symmetric])
    apply(erule generat.rel_mono_strong, assumption, assumption)
    apply(drule pos_fun_distr[THEN predicate2D])
    apply(auto simp add: rel_fun_def)
    done
qed

lemma left_unique_rel_gpv'':
  " left_unique A; left_unique C; left_total R   left_unique (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma right_unique_rel_gpv'':
  " right_unique A; right_unique C; right_total R   right_unique (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[OF rel_gpv''_pos_distr])
apply(erule (2) rel_gpv''_mono)
done

lemma bi_unique_rel_gpv'' [transfer_rule]:
  " bi_unique A; bi_unique C; bi_total R   bi_unique (rel_gpv'' A C R)"
unfolding bi_unique_alt_def bi_total_alt_def by(blast intro: left_unique_rel_gpv'' right_unique_rel_gpv'')

lemma rel_gpv''_map_gpv1:
  "rel_gpv'' A C R (map_gpv f g gpv) gpv' = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R gpv gpv'" (is "?lhs = ?rhs")
proof
  show ?rhs if ?lhs using that
    apply(coinduction arbitrary: gpv gpv')
    apply(drule rel_gpv''D)
    apply(simp add: gpv.map_sel spmf_rel_map)
    apply(erule rel_spmf_mono)
    by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
  show ?lhs if ?rhs using that
    apply(coinduction arbitrary: gpv gpv')
    apply(drule rel_gpv''D)
    apply(simp add: gpv.map_sel spmf_rel_map)
    apply(erule rel_spmf_mono)
    by(auto simp add: generat.rel_map rel_fun_comp elim!: generat.rel_mono_strong rel_fun_mono)
qed

lemma rel_gpv''_map_gpv2:
  "rel_gpv'' A C R gpv (map_gpv f g gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv gpv'"
  using rel_gpv''_map_gpv1[of "conversep A" "conversep C" "conversep R" f g gpv' gpv]
  apply(rewrite in " = _" conversep_iff[symmetric])
  apply(rewrite in "_ = " conversep_iff[symmetric])
  apply(simp only: rel_gpv''_conversep)
  apply(simp only: rel_gpv''_conversep[symmetric])
  apply(simp only: conversep_iff[abs_def])
  done

lemmas rel_gpv''_map_gpv = rel_gpv''_map_gpv1[abs_def] rel_gpv''_map_gpv2

lemma rel_gpv''_map_gpv' [simp]:
  shows "f g h gpv. NO_MATCH id f  NO_MATCH id g 
     rel_gpv'' A C R (map_gpv' f g h gpv) = rel_gpv'' (λa. A (f a)) (λc. C (g c)) R (map_gpv' id id h gpv)"
    and "f g h gpv gpv'. NO_MATCH id f  NO_MATCH id g 
     rel_gpv'' A C R gpv (map_gpv' f g h gpv') = rel_gpv'' (λa b. A a (f b)) (λc d. C c (g d)) R gpv (map_gpv' id id h gpv')"
proof (goal_cases)
  case (1 f g h gpv)
  then show ?case using map_gpv'_comp[of f g id id id h gpv, symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
next
  case (2 f g h gpv gpv')
  then show ?case using map_gpv'_comp[of f g id id id h gpv', symmetric] by(simp add: rel_gpv''_map_gpv[unfolded map_gpv_conv_map_gpv'])
qed

lemmas rel_gpv_map_gpv' = rel_gpv''_map_gpv'[where R="(=)", folded rel_gpv_conv_rel_gpv'']

definition rel_witness_gpv :: "('a  'd  bool)  ('b  'e  bool)  ('c  'g  bool)  ('g  'f  bool)  ('a, 'b, 'c) gpv × ('d, 'e, 'f) gpv  ('a × 'd, 'b × 'e, 'g) gpv" where
  "rel_witness_gpv A C R R' = corec_gpv (
     map_spmf (map_generat id id (λ(rpv, rpv'). (Inr  rel_witness_fun R R' (rpv, rpv')))  rel_witness_generat) 
     rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R'))))  map_prod the_gpv the_gpv)"

lemma rel_witness_gpv_sel [simp]:
  "the_gpv (rel_witness_gpv A C R R' (gpv, gpv')) = 
    map_spmf (map_generat id id (λ(rpv, rpv'). (rel_witness_gpv A C R R'  rel_witness_fun R R' (rpv, rpv')))  rel_witness_generat)
     (rel_witness_spmf (rel_generat A C (rel_fun (R OO R') (rel_gpv'' A C (R OO R')))) (the_gpv gpv, the_gpv gpv'))"
  unfolding rel_witness_gpv_def
  by(auto simp add: spmf.map_comp generat.map_comp o_def intro!: map_spmf_cong generat.map_cong)

lemma assumes "rel_gpv'' A C (R OO R') gpv gpv'"
  and R: "left_unique R" "right_total R"
  and R': "right_unique R'" "left_total R'"
shows rel_witness_gpv1: "rel_gpv'' (λa (a', b). a = a'  A a' b) (λc (c', d). c = c'  C c' d) R gpv (rel_witness_gpv A C R R' (gpv, gpv'))" (is "?thesis1")
  and rel_witness_gpv2: "rel_gpv'' (λ(a, b') b. b = b'  A a b') (λ(c, d') d. d = d'  C c d') R' (rel_witness_gpv A C R R' (gpv, gpv')) gpv'" (is "?thesis2")
proof -
  show ?thesis1 using assms(1)
  proof(coinduction arbitrary: gpv gpv')
    case rel_gpv''
    from this[THEN rel_gpv''D] show ?case
      by(auto simp add: spmf_rel_map generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun1[OF _ R R']]
          rel_spmf_mono[OF rel_witness_spmf1] generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat1])
  qed
  show ?thesis2 using assms(1)
  proof(coinduction arbitrary: gpv gpv')
    case rel_gpv''
    from this[THEN rel_gpv''D] show ?case
      by(simp add: spmf_rel_map) 
        (erule rel_spmf_mono[OF rel_witness_spmf2]
          , auto simp add: generat.rel_map rel_fun_comp elim!: rel_fun_mono[OF rel_witness_fun2[OF _ R R']]
          generat.rel_mono[THEN predicate2D, rotated -1, OF rel_witness_generat2])
  qed
qed

lemma rel_gpv''_neg_distr:
  assumes R: "left_unique R" "right_total R"
    and R': "right_unique R'" "left_total R'"
  shows "rel_gpv'' (A OO A') (C OO C') (R OO R')  rel_gpv'' A C R OO rel_gpv'' A' C' R'"
proof(rule predicate2I relcomppI)+
  fix gpv gpv''
  assume *: "rel_gpv'' (A OO A') (C OO C') (R OO R') gpv gpv''"
  let ?gpv' = "map_gpv (relcompp_witness A A') (relcompp_witness C C') (rel_witness_gpv (A OO A') (C OO C') R R' (gpv, gpv''))"
  show "rel_gpv'' A C R gpv ?gpv'" using rel_witness_gpv1[OF * R R'] unfolding rel_gpv''_map_gpv
    by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
  show "rel_gpv'' A' C' R' ?gpv' gpv''" using rel_witness_gpv2[OF * R R'] unfolding rel_gpv''_map_gpv
    by(rule rel_gpv''_mono[THEN predicate2D, rotated -1]; clarify del: relcomppE elim!: relcompp_witness)
qed

lemma rel_gpv''_mono' [mono]:
  assumes "x y. A x y  A' x y"
    and "x y. C x y  C' x y"
    and "x y. R' x y  R x y"
  shows "rel_gpv'' A C R gpv gpv'  rel_gpv'' A' C' R' gpv gpv'"
  using rel_gpv''_mono[of A A' C C' R' R] assms by(blast)

lemma left_total_rel_gpv':
  " left_total A; left_total C; left_unique R; right_total R   left_total (rel_gpv'' A C R)"
unfolding left_unique_alt_def left_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_neg_distr; simp add: left_unique_alt_def)
apply(rule rel_gpv''_mono; assumption)
done

lemma right_total_rel_gpv':
  " right_total A; right_total C; right_unique R; left_total R   right_total (rel_gpv'' A C R)"
unfolding right_unique_alt_def right_total_alt_def rel_gpv''_conversep[symmetric]
apply(subst rel_gpv''_eq[symmetric])
apply(rule order_trans[rotated])
apply(rule rel_gpv''_neg_distr; simp add: right_unique_alt_def)
apply(rule rel_gpv''_mono; assumption)
done

lemma bi_total_rel_gpv' [transfer_rule]:
  " bi_total A; bi_total C; bi_unique R; bi_total R   bi_total (rel_gpv'' A C R)"
unfolding bi_total_alt_def bi_unique_alt_def by(blast intro: left_total_rel_gpv' right_total_rel_gpv')

lemma rel_fun_conversep_grp_grp:
  "rel_fun (conversep (BNF_Def.Grp UNIV f)) (BNF_Def.Grp B g) = BNF_Def.Grp {x. (x  f) ` UNIV  B} (map_fun f g)"
unfolding rel_fun_def Grp_def simp_thms fun_eq_iff conversep_iff by auto

lemma Quotient_gpv:
  assumes Q1: "Quotient R1 Abs1 Rep1 T1"
  and Q2: "Quotient R2 Abs2 Rep2 T2"
  and Q3: "Quotient R3 Abs3 Rep3 T3"
  shows "Quotient (rel_gpv'' R1 R2 R3) (map_gpv' Abs1 Abs2 Rep3) (map_gpv' Rep1 Rep2 Abs3) (rel_gpv'' T1 T2 T3)"
  (is "Quotient ?R ?abs ?rep ?T")
unfolding Quotient_alt_def2
proof(intro conjI strip iffI; (elim conjE exE)?)
  note [simp] = spmf_rel_map generat.rel_map
    and [elim!] = rel_spmf_mono generat.rel_mono_strong
    and [rule del] = rel_funI and [intro!] = rel_funI
  have Abs1 [simp]: "Abs1 x = y" if "T1 x y" for x y using Q1 that by(simp add: Quotient_alt_def)
  have Abs2 [simp]: "Abs2 x = y" if "T2 x y" for x y using Q2 that by(simp add: Quotient_alt_def)
  have Abs3 [simp]: "Abs3 x = y" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def)
  have Rep1: "T1 (Rep1 x) x" for x using Q1 by(simp add: Quotient_alt_def)
  have Rep2: "T2 (Rep2 x) x" for x using Q2 by(simp add: Quotient_alt_def)
  have Rep3: "T3 (Rep3 x) x" for x using Q3 by(simp add: Quotient_alt_def)
  have T1: "T1 x (Abs1 y)" if "R1 x y" for x y using Q1 that by(simp add: Quotient_alt_def2)
  have T2: "T2 x (Abs2 y)" if "R2 x y" for x y using Q2 that by(simp add: Quotient_alt_def2)
  have T1': "T1 x (Abs1 y)" if "R1 y x" for x y using Q1 that by(simp add: Quotient_alt_def2)
  have T2': "T2 x (Abs2 y)" if "R2 y x" for x y using Q2 that by(simp add: Quotient_alt_def2)
  have R3: "R3 x (Rep3 y)" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
  have R3': "R3 (Rep3 y) x" if "T3 x y" for x y using Q3 that by(simp add: Quotient_alt_def2 Abs3[OF Rep3])
  have r1: "R1 = T1 OO T1¯¯" using Q1 by(simp add: Quotient_alt_def4)
  have r2: "R2 = T2 OO T2¯¯" using Q2 by(simp add: Quotient_alt_def4)
  have r3: "R3 = T3 OO T3¯¯" using Q3 by(simp add: Quotient_alt_def4)
  show abs: "?abs gpv = gpv'" if "?T gpv gpv'" for gpv gpv' using that
    by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 4 intro: Rep3 dest: rel_funD)
  show "?T (?rep gpv) gpv" for gpv
    by(coinduction arbitrary: gpv)(auto simp add: Rep1 Rep2 intro!: rel_spmf_reflI generat.rel_refl_strong)
  show "?T gpv (?abs gpv')" if "?R gpv gpv'" for gpv gpv' using that
    by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1 T2 intro!: R3 dest: rel_funD)
  show "?T gpv (?abs gpv')" if "?R gpv' gpv" for gpv gpv'
  proof -
    from that have "rel_gpv'' R1¯¯ R2¯¯ R3¯¯ gpv gpv'" unfolding rel_gpv''_conversep by simp
    then show ?thesis
      by(coinduction arbitrary: gpv gpv')(drule rel_gpv''D; auto 4 3 simp add: T1' T2' intro!: R3' dest: rel_funD)
  qed
  show "?R gpv gpv'" if "?T gpv (?abs gpv')" "?T gpv' (?abs gpv)" for gpv gpv'
  proof -
    from that[THEN abs] have "?abs gpv' = ?abs gpv" by simp
    with that have "(?T OO ?T¯¯) gpv gpv'" by(auto simp del: rel_gpv''_map_gpv')
    hence "rel_gpv'' (T1 OO T1¯¯) (T2 OO T2¯¯) (T3 OO T3¯¯) gpv gpv'"
      unfolding rel_gpv''_conversep[symmetric]
      by(rule rel_gpv''_pos_distr[THEN predicate2D])
    thus ?thesis by(simp add: r1 r2 r3)
  qed
qed

lemma the_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R))) the_gpv the_gpv"
by(rule rel_funI)(auto elim: rel_gpv''.cases)

lemma GPV_parametric':
  "(rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) ===> rel_gpv'' A C R) GPV GPV"
by(rule rel_funI)(auto)

lemma corec_gpv_parametric':
  "((S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) ===> S ===> rel_gpv'' A C R)
  corec_gpv corec_gpv"
proof(rule rel_funI)+
  fix f g s1 s2
  assume fg: "(S ===> rel_spmf (rel_generat A C (R ===> rel_sum (rel_gpv'' A C R) S))) f g"
    and s: "S s1 s2"
  from s show "rel_gpv'' A C R (corec_gpv f s1) (corec_gpv g s2)"
    apply(coinduction arbitrary: s1 s2)
    apply(drule fg[THEN rel_funD])
    apply(simp add: spmf_rel_map)
    apply(erule rel_spmf_mono)
    apply(simp add: generat.rel_map)
    apply(erule generat.rel_mono_strong; clarsimp simp add: o_def)
    apply(rule rel_funI)
    apply(drule (1) rel_funD)
    apply(auto 4 3 elim!: rel_sum.cases)
    done
qed

lemma map_gpv'_parametric [transfer_rule]:
  "((A ===> A') ===> (C ===> C') ===> (R' ===> R) ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R') map_gpv' map_gpv'"
  unfolding map_gpv'_def
  supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
  by(transfer_prover)

lemma map_gpv_parametric': "((A ===> A') ===> (C ===> C') ===> rel_gpv'' A C R ===> rel_gpv'' A' C' R) map_gpv map_gpv"
  unfolding map_gpv_conv_map_gpv'[abs_def] by transfer_prover

end

subsection ‹Simple, derived operations›

primcorec Done :: "'a  ('a, 'out, 'in) gpv"
where "the_gpv (Done a) = return_spmf (Pure a)"

primcorec Pause :: "'out  ('in  ('a, 'out, 'in) gpv)  ('a, 'out, 'in) gpv"
where "the_gpv (Pause out c) = return_spmf (IO out c)"

primcorec lift_spmf :: "'a spmf  ('a, 'out, 'in) gpv"
where "the_gpv (lift_spmf p) = map_spmf Pure p"

definition Fail :: "('a, 'out, 'in) gpv"
where "Fail = GPV (return_pmf None)"

definition React :: "('in  'out × ('a, 'out, 'in) rpv)  ('a, 'out, 'in) rpv"
where "React f input = case_prod Pause (f input)"

definition rFail :: "('a, 'out, 'in) rpv"
where "rFail = (λ_. Fail)"

lemma Done_inject [simp]: "Done x = Done y  x = y"
by(simp add: Done.ctr)

lemma Pause_inject [simp]: "Pause out c = Pause out' c'  out = out'  c = c'"
by(simp add: Pause.ctr)

lemma [simp]:
  shows Done_neq_Pause: "Done x  Pause out c"
  and Pause_neq_Done: "Pause out c  Done x"
by(simp_all add: Done.ctr Pause.ctr)

lemma outs'_gpv_Done [simp]: "outs'_gpv (Done x) = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Done [simp]: "results'_gpv (Done x) = {x}"
by(auto intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Done [simp]: "pred_gpv P Q (Done x) = P x"
by(simp add: pred_gpv_def)

lemma outs'_gpv_Pause [simp]: "outs'_gpv (Pause out c) = insert out (input. outs'_gpv (c input))"
by(auto 4 4 intro: outs'_gpvI elim: outs'_gpv_cases)

lemma results'_gpv_Pause [simp]: "results'_gpv (Pause out rpv) = results'_rpv rpv"
by(auto 4 4 intro: results'_gpvI elim: results'_gpv_cases)

lemma pred_gpv_Pause [simp]: "pred_gpv P Q (Pause x c) = (Q x  All (pred_gpv P Q  c))"
by(auto simp add: pred_gpv_def o_def)

lemma lift_spmf_return [simp]: "lift_spmf (return_spmf x) = Done x"
by(simp add: lift_spmf.ctr Done.ctr)

lemma lift_spmf_None [simp]: "lift_spmf (return_pmf None) = Fail"
by(rule gpv.expand)(simp add: Fail_def)

lemma the_gpv_lift_spmf [simp]: "the_gpv (lift_spmf r) = map_spmf Pure r"
by(simp)

lemma outs'_gpv_lift_spmf [simp]: "outs'_gpv (lift_spmf p) = {}"
by(auto 4 3 elim: outs'_gpv_cases)

lemma results'_gpv_lift_spmf [simp]: "results'_gpv (lift_spmf p) = set_spmf p"
by(auto 4 3 elim: results'_gpv_cases intro: results'_gpvI)

lemma pred_gpv_lift_spmf [simp]: "pred_gpv P Q (lift_spmf p) = pred_spmf P p"
by(simp add: pred_gpv_def pred_spmf_def)

lemma lift_spmf_inject [simp]: "lift_spmf p = lift_spmf q  p = q"
by(auto simp add: lift_spmf.code dest!: pmf.inj_map_strong[rotated] option.inj_map_strong[rotated])

lemma map_lift_spmf: "map_gpv f g (lift_spmf p) = lift_spmf (map_spmf f p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma lift_map_spmf: "lift_spmf (map_spmf f p) = map_gpv f id (lift_spmf p)"
by(rule gpv.expand)(simp add: gpv.map_sel spmf.map_comp o_def)

lemma [simp]:
  shows Fail_neq_Pause: "Fail  Pause out c"
  and Pause_neq_Fail: "Pause out c  Fail"
  and Fail_neq_Done: "Fail  Done x"
  and Done_neq_Fail: "Done x  Fail"
by(simp_all add: Fail_def Pause.ctr Done.ctr)

text ‹Add @{typ unit} closure to circumvent SML value restriction›

definition Fail' :: "unit  ('a, 'out, 'in) gpv"
where [code del]: "Fail' _ = Fail"

lemma Fail_code [code_unfold]: "Fail = Fail' ()"
by(simp add: Fail'_def)

lemma Fail'_code [code]:
  "Fail' x = GPV (return_pmf None)"
by(simp add: Fail'_def Fail_def)

lemma Fail_sel [simp]:
  "the_gpv Fail = return_pmf None"
by(simp add: Fail_def)

lemma Fail_eq_GPV_iff [simp]: "Fail = GPV f  f = return_pmf None"
by(auto simp add: Fail_def)

lemma outs'_gpv_Fail [simp]: "outs'_gpv Fail = {}"
by(auto elim: outs'_gpv_cases)

lemma results'_gpv_Fail [simp]: "results'_gpv Fail = {}"
by(auto elim: results'_gpv_cases)

lemma pred_gpv_Fail [simp]: "pred_gpv P Q Fail"
by(simp add: pred_gpv_def)

lemma React_inject [iff]: "React f = React f'  f = f'"
by(auto simp add: React_def fun_eq_iff split_def intro: prod.expand)

lemma React_apply [simp]: "f input = (out, c)  React f input = Pause out c"
by(simp add: React_def)

lemma rFail_apply [simp]: "rFail input = Fail"
by(simp add: rFail_def)

lemma [simp]:
  shows rFail_neq_React: "rFail  React f"
  and React_neq_rFail: "React f  rFail"
by(simp_all add: React_def fun_eq_iff split_beta)

lemma rel_gpv_FailI [simp]: "rel_gpv A C Fail Fail"
by(subst gpv.rel_sel) simp

lemma rel_gpv_Done [iff]: "rel_gpv A C (Done x) (Done y)  A x y"
by(subst gpv.rel_sel) simp

lemma rel_gpv''_Done [iff]: "rel_gpv'' A C R (Done x) (Done y)  A x y"
by(subst rel_gpv''.simps) simp

lemma rel_gpv_Pause [iff]:
  "rel_gpv A C (Pause out c) (Pause out' c')  C out out'  (x. rel_gpv A C (c x) (c' x))"
by(subst gpv.rel_sel)(simp add: rel_fun_def)

lemma rel_gpv''_Pause [iff]:
  "rel_gpv'' A C R (Pause out c) (Pause out' c')  C out out'  (x x'. R x x'  rel_gpv'' A C R (c x) (c' x'))"
by(subst rel_gpv''.simps)(simp add: rel_fun_def)

lemma rel_gpv_lift_spmf [iff]: "rel_gpv A C (lift_spmf p) (lift_spmf q)  rel_spmf A p q"
by(subst gpv.rel_sel)(simp add: spmf_rel_map)

lemma rel_gpv''_lift_spmf [iff]:
  "rel_gpv'' A C R (lift_spmf p) (lift_spmf q)  rel_spmf A p q"
by(subst rel_gpv''.simps)(simp add: spmf_rel_map)

context includes lifting_syntax begin
lemmas Fail_parametric [transfer_rule] = rel_gpv_FailI

lemma Fail_parametric' [simp]: "rel_gpv'' A C R Fail Fail"
unfolding Fail_def by simp

lemma Done_parametric [transfer_rule]: "(A ===> rel_gpv A C) Done Done"
by(rule rel_funI) simp

lemma Done_parametric': "(A ===> rel_gpv'' A C R) Done Done"
by(rule rel_funI) simp

lemma Pause_parametric [transfer_rule]:
  "(C ===> ((=) ===> rel_gpv A C) ===> rel_gpv A C) Pause Pause"
by(simp add: rel_fun_def)

lemma Pause_parametric':
  "(C ===> (R ===> rel_gpv'' A C R) ===> rel_gpv'' A C R) Pause Pause"
by(simp add: rel_fun_def)

lemma lift_spmf_parametric [transfer_rule]:
  "(rel_spmf A ===> rel_gpv A C) lift_spmf lift_spmf"
by(simp add: rel_fun_def)

lemma lift_spmf_parametric':
  "(rel_spmf A ===> rel_gpv'' A C R) lift_spmf lift_spmf"
by(simp add: rel_fun_def)
end

lemma map_gpv_Done [simp]: "map_gpv f g (Done x) = Done (f x)"
by(simp add: Done.code)

lemma map_gpv'_Done [simp]: "map_gpv' f g h (Done x) = Done (f x)"
by(simp add: Done.code)

lemma map_gpv_Pause [simp]: "map_gpv f g (Pause x c) = Pause (g x) (map_gpv f g  c)"
by(simp add: Pause.code)

lemma map_gpv'_Pause [simp]: "map_gpv' f g h (Pause x c) = Pause (g x) (map_gpv' f g h  c  h)"
by(simp add: Pause.code map_fun_def)

lemma map_gpv_Fail [simp]: "map_gpv f g Fail = Fail"
by(simp add: Fail_def)

lemma map_gpv'_Fail [simp]: "map_gpv' f g h Fail = Fail"
by(simp add: Fail_def)

subsection ‹Monad structure›

primcorec bind_gpv :: "('a, 'out, 'in) gpv  ('a  ('b, 'out, 'in) gpv)  ('b, 'out, 'in) gpv"
where
  "the_gpv (bind_gpv r f) =
   map_spmf (map_generat id id ((∘) (case_sum id (λr. bind_gpv r f))))
     (the_gpv r 
      (case_generat
        (λx. map_spmf (map_generat id id ((∘) Inl)) (the_gpv (f x)))
        (λout c. return_spmf (IO out (λinput. Inr (c input))))))"

declare bind_gpv.sel [simp del]

adhoc_overloading Monad_Syntax.bind bind_gpv

lemma bind_gpv_unfold [code]:
  "r  f = GPV (
   do {
     generat  the_gpv r;
     case generat of Pure x  the_gpv (f x)
       | IO out c  return_spmf (IO out (λinput. c input  f))
   })"
unfolding bind_gpv_def
apply(rule gpv.expand)
apply(simp add: map_spmf_bind_spmf)
apply(rule arg_cong[where f="bind_spmf (the_gpv r)"])
apply(auto split: generat.split simp add: map_spmf_bind_spmf fun_eq_iff spmf.map_comp o_def generat.map_comp id_def[symmetric] generat.map_id pmf.map_id option.map_id)
done

lemma bind_gpv_code_cong: "f = f'  bind_gpv f g = bind_gpv f' g" by simp
setup Code_Simp.map_ss (Simplifier.add_cong @{thm bind_gpv_code_cong})

lemma bind_gpv_sel:
  "the_gpv (r  f) =
   do {
     generat  the_gpv r;
     case generat of Pure x  the_gpv (f x)
       | IO out c  return_spmf (IO out (λinput. bind_gpv (c input) f))
   }"
by(subst bind_gpv_unfold) simp

lemma bind_gpv_sel' [simp]:
  "the_gpv (r  f) =
   do {
     generat  the_gpv r;
     if is_Pure generat then the_gpv (f (result generat))
     else return_spmf (IO (output generat) (λinput. bind_gpv (continuation generat input) f))
   }"
unfolding bind_gpv_sel
by(rule arg_cong[where f="bind_spmf (the_gpv r)"])(simp add: fun_eq_iff split: generat.split)

lemma Done_bind_gpv [simp]: "Done a  f = f a"
by(rule gpv.expand)(simp)

lemma bind_gpv_Done [simp]: "f  Done = f"
proof(coinduction arbitrary: f rule: gpv.coinduct)
  case (Eq_gpv f)
  have *: "the_gpv f  (case_generat (λx. return_spmf (Pure x)) (λout c. return_spmf (IO out (λinput. Inr (c input))))) =
           map_spmf (map_generat id id ((∘) Inr)) (bind_spmf (the_gpv f) return_spmf)"
    unfolding map_spmf_bind_spmf
    by(rule arg_cong2[where f=bind_spmf])(auto simp add: fun_eq_iff split: generat.split)
  show ?case
    by(auto simp add: * bind_gpv.simps pmf.rel_map option.rel_map[abs_def] generat.rel_map[abs_def] simp del: bind_gpv_sel' intro!: rel_generatI rel_spmf_reflI)
qed

lemma if_distrib_bind_gpv2 [if_distribs]:
  "bind_gpv gpv (λy. if b then f y else g y) = (if b then bind_gpv gpv f else bind_gpv gpv g)"
by simp

lemma lift_spmf_bind: "lift_spmf r  f = GPV (r  the_gpv  f)"
by(coinduction arbitrary: r f rule: gpv.coinduct_strong)(auto simp add: bind_map_spmf o_def intro: rel_pmf_reflI rel_optionI rel_generatI)

lemma the_gpv_bind_gpv_lift_spmf [simp]:
  "the_gpv (bind_gpv (lift_spmf p) f) = bind_spmf p (the_gpv  f)"
by(simp add: bind_map_spmf o_def)

lemma lift_spmf_bind_spmf: "lift_spmf (p  f) = lift_spmf p  (λx. lift_spmf (f x))"
by(rule gpv.expand)(simp add: lift_spmf_bind o_def map_spmf_bind_spmf)

lemma lift_bind_spmf: "lift_spmf (bind_spmf p f) = bind_gpv (lift_spmf p) (lift_spmf  f)"
by(rule gpv.expand)(simp add: bind_map_spmf map_spmf_bind_spmf o_def)

lemma GPV_bind:
  "GPV f  g = 
   GPV (f  (λgenerat. case generat of Pure x  the_gpv (g x) | IO out c  return_spmf (IO out (λinput. c input  g))))"
by(subst bind_gpv_unfold) simp

lemma GPV_bind':
  "GPV f  g = GPV (f  (λgenerat. if is_Pure generat then the_gpv (g (result generat)) else return_spmf (IO (output generat) (λinput. continuation generat input  g))))"
unfolding GPV_bind gpv.inject
by(rule arg_cong[where f="bind_spmf f"])(simp add: fun_eq_iff split: generat.split)

lemma bind_gpv_assoc:
  fixes f :: "('a, 'out, 'in) gpv"
  shows "(f  g)  h = f  (λx. g x  h)"
proof(coinduction arbitrary: f g h rule: gpv.coinduct_strong)
  case (Eq_gpv f g h)
  show ?case
    apply(simp cong del: if_weak_cong)
    apply(rule rel_spmf_bindI[where R="(=)"])
     apply(simp add: option.rel_eq pmf.rel_eq)
    apply(fastforce intro: rel_pmf_return_pmfI rel_generatI rel_spmf_reflI)
    done
qed

lemma map_gpv_bind_gpv: "map_gpv f g (bind_gpv gpv h) = bind_gpv (map_gpv id g gpv) (λx. map_gpv f g (h x))"
apply(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
apply(simp add: bind_gpv.sel gpv.map_sel spmf_rel_map generat.rel_map o_def bind_map_spmf del: bind_gpv_sel')
apply(rule rel_spmf_bind_reflI)
apply(auto simp add: spmf_rel_map generat.rel_map split: generat.split del: rel_funI intro!: rel_spmf_reflI generat.rel_refl rel_funI)
done

lemma map_gpv_id_bind_gpv: "map_gpv f id (bind_gpv gpv g) = bind_gpv gpv (map_gpv f id  g)"
by(simp add: map_gpv_bind_gpv gpv.map_id o_def)

lemma map_gpv_conv_bind:
  "map_gpv f (λx. x) x = bind_gpv x (λx. Done (f x))"
using map_gpv_bind_gpv[of f "λx. x" x Done] by(simp add: id_def[symmetric] gpv.map_id)

lemma bind_map_gpv: "bind_gpv (map_gpv f id gpv) g = bind_gpv gpv (g  f)"
by(simp add: map_gpv_conv_bind id_def bind_gpv_assoc o_def)

lemma outs_bind_gpv:
  "outs'_gpv (bind_gpv x f) = outs'_gpv x  (x  results'_gpv x. outs'_gpv (f x))"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  fix out
  assume "out  ?lhs"
  then show "out  ?rhs"
  proof(induction g"x  f" arbitrary: x)
    case (Out generat)
    then obtain generat' where *: "generat'  set_spmf (the_gpv x)"
      and **: "generat  set_spmf (if is_Pure generat' then the_gpv (f (result generat'))
                                else return_spmf (IO (output generat') (λinput. continuation generat' input  f)))"
      by(auto)
    show ?case
    proof(cases "is_Pure generat'")
      case True
      then have "out  outs'_gpv (f (result generat'))" using Out(2) ** by(auto intro: outs'_gpvI)
      moreover have "result generat'  results'_gpv x" using * True
        by(auto intro: results'_gpvI generat.set_sel)
      ultimately show ?thesis by blast
    next
      case False
      hence "out  outs'_gpv x" using * ** Out(2) by(auto intro: outs'_gpvI generat.set_sel)
      thus ?thesis by blast
    qed
  next
    case (Cont generat c input)
    then obtain generat' where *: "generat'  set_spmf (the_gpv x)"
      and **: "generat  set_spmf (if is_Pure generat' then the_gpv (f (generat.result generat'))
                                 else return_spmf (IO (generat.output generat') (λinput. continuation generat' input  f)))"
      by(auto)
    show ?case
    proof(cases "is_Pure generat'")
      case True
      then have "out  outs'_gpv (f (result generat'))" using Cont(2-3) ** by(auto intro: outs'_gpvI)
      moreover have "result generat'  results'_gpv x" using * True
        by(auto intro: results'_gpvI generat.set_sel)
      ultimately show ?thesis by blast
    next
      case False
      then have generat: "generat = IO (output generat') (λinput. continuation generat' input  f)"
        using ** by simp
      with Cont(2) have "c input = continuation generat' input  f" by auto
      hence "out  outs'_gpv (continuation generat' input)  (xresults'_gpv (continuation generat' input). outs'_gpv (f x))"
        by(rule Cont)
      thus ?thesis
      proof
        assume "out  outs'_gpv (continuation generat' input)"
        with * ** False have "out  outs'_gpv x" by(auto intro: outs'_gpvI generat.set_sel)
        thus ?thesis ..
      next
        assume "out  (xresults'_gpv (continuation generat' input). outs'_gpv (f x))"
        then obtain y where "y  results'_gpv (continuation generat' input)" "out  outs'_gpv (f y)" ..
        from y  _ * ** False have "y  results'_gpv x" 
          by(auto intro: results'_gpvI generat.set_sel)
        with out  outs'_gpv (f y) show ?thesis by blast
      qed
    qed
  qed
next
  fix out
  assume "out  ?rhs"
  then show "out  ?lhs"
  proof
    assume "out  outs'_gpv x"
    thus ?thesis
    proof(induction)
      case (Out generat gpv)
      then show ?case
        by(cases generat)(fastforce intro: outs'_gpvI rev_bexI)+
    next
      case (Cont generat gpv gpv')
      then show ?case
        by(cases generat)(auto 4 4 intro: outs'_gpvI rev_bexI simp add: in_set_spmf set_pmf_bind_spmf simp del: set_bind_spmf)
    qed
  next
    assume "out  (xresults'_gpv x. outs'_gpv (f x))"
    then obtain y where "y  results'_gpv x" "out  outs'_gpv (f y)" ..
    from y  _ show ?thesis
    proof(induction)
      case (Pure generat gpv)
      thus ?case using out  outs'_gpv _
        by(cases generat)(auto 4 5 intro: outs'_gpvI rev_bexI elim: outs'_gpv_cases)
    next
      case (Cont generat gpv gpv')
      thus ?case
        by(cases generat)(auto 4 4 simp add: in_set_spmf simp add: set_pmf_bind_spmf intro: outs'_gpvI rev_bexI simp del: set_bind_spmf)
    qed
  qed
qed

lemma bind_gpv_Fail [simp]: "Fail  f = Fail"
by(subst bind_gpv_unfold)(simp add: Fail_def)

lemma bind_gpv_eq_Fail:
  "bind_gpv gpv f = Fail  (xset_spmf (the_gpv gpv). is_Pure x)  (xresults'_gpv gpv. f x = Fail)"
  (is "?lhs = ?rhs")
proof(intro iffI conjI strip)
  show ?lhs if ?rhs using that
    by(intro gpv.expand)(auto 4 4 simp add: bind_eq_return_pmf_None intro: results'_gpv_Pure generat.set_sel dest: bspec)

  assume ?lhs
  hence *: "the_gpv (bind_gpv gpv f) = return_pmf None" by simp
  from * show "is_Pure x" if "x  set_spmf (the_gpv gpv)" for x using that
    by(simp add: bind_eq_return_pmf_None split: if_split_asm)
  show "f x = Fail" if "x  results'_gpv gpv" for x using that *
    by(cases)(auto 4 3 simp add: bind_eq_return_pmf_None elim!: generat.set_cases intro: gpv.expand dest: bspec)
qed

context includes lifting_syntax begin

lemma bind_gpv_parametric [transfer_rule]:
  "(rel_gpv A C ===> (A ===> rel_gpv B C) ===> rel_gpv B C) bind_gpv bind_gpv"
unfolding bind_gpv_def by transfer_prover

lemma bind_gpv_parametric':
  "(rel_gpv'' A C R ===> (A ===> rel_gpv'' B C R) ===> rel_gpv'' B C R) bind_gpv bind_gpv"
unfolding bind_gpv_def supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by(transfer_prover)

end

lemma monad_gpv [locale_witness]: "monad Done bind_gpv"
by(unfold_locales)(simp_all add: bind_gpv_assoc)

lemma monad_fail_gpv [locale_witness]: "monad_fail Done bind_gpv Fail"
by unfold_locales auto

lemma rel_gpv_bindI:
  " rel_gpv A C gpv gpv'; x y. A x y  rel_gpv B C (f x) (g y) 
   rel_gpv B C (bind_gpv gpv f) (bind_gpv gpv' g)"
by(fact bind_gpv_parametric[THEN rel_funD, THEN rel_funD, OF _ rel_funI])

lemma bind_gpv_cong:
  " gpv = gpv'; x. x  results'_gpv gpv'  f x = g x   bind_gpv gpv f = bind_gpv gpv' g"
apply(subst gpv.rel_eq[symmetric])
apply(rule rel_gpv_bindI[where A="eq_onp (λx. x  results'_gpv gpv')"])
 apply(subst (asm) gpv.rel_eq[symmetric])
 apply(erule gpv.rel_mono_strong)
  apply(simp add: eq_onp_def)
 apply simp
apply(clarsimp simp add: gpv.rel_eq eq_onp_def)
done

definition bind_rpv :: "('a, 'in, 'out) rpv  ('a  ('b, 'in, 'out) gpv)  ('b, 'in, 'out) rpv"
where "bind_rpv rpv f = (λinput. bind_gpv (rpv input) f)"

lemma bind_rpv_apply [simp]: "bind_rpv rpv f input = bind_gpv (rpv input) f"
by(simp add: bind_rpv_def fun_eq_iff)

adhoc_overloading Monad_Syntax.bind bind_rpv

lemma bind_rpv_code_cong: "rpv = rpv'  bind_rpv rpv f = bind_rpv rpv' f" by simp
setup Code_Simp.map_ss (Simplifier.add_cong @{thm bind_rpv_code_cong})

lemma bind_rpv_rDone [simp]: "bind_rpv rpv Done = rpv"
by(simp add: bind_rpv_def)

lemma bind_gpv_Pause [simp]: "bind_gpv (Pause out rpv) f = Pause out (bind_rpv rpv f)"
by(rule gpv.expand)(simp add: fun_eq_iff)

lemma bind_rpv_React [simp]: "bind_rpv (React f) g = React (apsnd (λrpv. bind_rpv rpv g)  f)"
by(simp add: React_def split_beta fun_eq_iff)

lemma bind_rpv_assoc: "bind_rpv (bind_rpv rpv f) g = bind_rpv rpv ((λgpv. bind_gpv gpv g)  f)"
by(simp add: fun_eq_iff bind_gpv_assoc o_def)

lemma bind_rpv_Done [simp]: "bind_rpv Done f = f"
by(simp add: bind_rpv_def)

lemma results'_rpv_Done [simp]: "results'_rpv Done = UNIV"
by(auto simp add: results'_rpv_def)


subsection ‹ Embedding @{typ "'a spmf"} as a monad ›

lemma neg_fun_distr3:
  includes lifting_syntax
  assumes 1: "left_unique R" "right_total R"
  assumes 2: "right_unique S" "left_total S"
  shows "(R OO R' ===> S OO S')  ((R ===> S) OO (R' ===> S'))"
using functional_relation[OF 2] functional_converse_relation[OF 1]
unfolding rel_fun_def OO_def
apply clarify
apply (subst all_comm)
apply (subst all_conj_distrib[symmetric])
apply (intro choice)
by metis

locale spmf_to_gpv begin

text ‹
  The lifting package cannot handle free term variables in the merging of transfer rules,
  so for the embedding we define a specialised relator rel_gpv'›
  which acts only on the returned values.
›

definition rel_gpv' :: "('a  'b  bool)  ('a, 'out, 'in) gpv  ('b, 'out, 'in) gpv  bool"
where "rel_gpv' A = rel_gpv A (=)"

lemma rel_gpv'_eq [relator_eq]: "rel_gpv' (=) = (=)"
unfolding rel_gpv'_def gpv.rel_eq ..

lemma rel_gpv'_mono [relator_mono]: "A  B  rel_gpv' A  rel_gpv' B"
unfolding rel_gpv'_def by(rule gpv.rel_mono; simp)

lemma rel_gpv'_distr [relator_distr]: "rel_gpv' A OO rel_gpv' B = rel_gpv' (A OO B)"
unfolding rel_gpv'_def by (metis OO_eq gpv.rel_compp) 

lemma left_unique_rel_gpv' [transfer_rule]: "left_unique A  left_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_unique_rel_gpv left_unique_eq)

lemma right_unique_rel_gpv' [transfer_rule]: "right_unique A  right_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_unique_rel_gpv right_unique_eq)

lemma bi_unique_rel_gpv' [transfer_rule]: "bi_unique A  bi_unique (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_unique_rel_gpv bi_unique_eq)

lemma left_total_rel_gpv' [transfer_rule]: "left_total A  left_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: left_total_rel_gpv left_total_eq)

lemma right_total_rel_gpv' [transfer_rule]: "right_total A  right_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: right_total_rel_gpv right_total_eq)

lemma bi_total_rel_gpv' [transfer_rule]: "bi_total A  bi_total (rel_gpv' A)"
unfolding rel_gpv'_def by(simp add: bi_total_rel_gpv bi_total_eq)


text ‹
  We cannot use setup_lifting› because @{typ "('a, 'out, 'in) gpv"} contains
  type variables which do not appear in @{typ "'a spmf"}.
›

definition cr_spmf_gpv :: "'a spmf  ('a, 'out, 'in) gpv  bool"
where "cr_spmf_gpv p gpv  gpv = lift_spmf p"

definition spmf_of_gpv :: "('a, 'out, 'in) gpv  'a spmf"
where "spmf_of_gpv gpv = (THE p. gpv = lift_spmf p)"

lemma spmf_of_gpv_lift_spmf [simp]: "spmf_of_gpv (lift_spmf p) = p"
unfolding spmf_of_gpv_def by auto

lemma rel_spmf_setD2:
  " rel_spmf A p q; y  set_spmf q   xset_spmf p. A x y"
by(erule rel_spmfE) force

lemma rel_gpv_lift_spmf1: "rel_gpv A B (lift_spmf p) gpv  (q. gpv = lift_spmf q  rel_spmf A p q)"
apply(subst gpv.rel_sel)
apply(simp add: spmf_rel_map rel_generat_Pure1)
apply safe
 apply(rule exI[where x="map_spmf result (the_gpv gpv)"])
 apply(clarsimp simp add: spmf_rel_map)
 apply(rule conjI)
  apply(rule gpv.expand)
  apply(simp add: spmf.map_comp)
  apply(subst map_spmf_cong[OF refl, where g=id])
   apply(drule (1) rel_spmf_setD2)
   apply clarsimp
  apply simp
 apply(erule rel_spmf_mono)
 apply clarsimp
apply(clarsimp simp add: spmf_rel_map)
done

lemma rel_gpv_lift_spmf2: "rel_gpv A B gpv (lift_spmf q)  (p. gpv = lift_spmf p  rel_spmf A p q)"
by(subst gpv.rel_flip[symmetric])(simp add: rel_gpv_lift_spmf1 pmf.rel_flip option.rel_conversep)

definition pcr_spmf_gpv :: "('a  'b  bool)  'a spmf  ('b, 'out, 'in) gpv  bool"
where "pcr_spmf_gpv A = cr_spmf_gpv OO rel_gpv A (=)"

lemma pcr_cr_eq_spmf_gpv: "pcr_spmf_gpv (=) = cr_spmf_gpv"
by(simp add: pcr_spmf_gpv_def gpv.rel_eq OO_eq)

lemma left_unique_cr_spmf_gpv: "left_unique cr_spmf_gpv"
by(rule left_uniqueI)(simp add: cr_spmf_gpv_def)

lemma left_unique_pcr_spmf_gpv [transfer_rule]:
  "left_unique A  left_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_unique_OO left_unique_cr_spmf_gpv left_unique_rel_gpv left_unique_eq)

lemma right_unique_cr_spmf_gpv: "right_unique cr_spmf_gpv"
by(rule right_uniqueI)(simp add: cr_spmf_gpv_def)

lemma right_unique_pcr_spmf_gpv [transfer_rule]:
  "right_unique A  right_unique (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro right_unique_OO right_unique_cr_spmf_gpv right_unique_rel_gpv right_unique_eq)

lemma bi_unique_cr_spmf_gpv: "bi_unique cr_spmf_gpv"
by(simp add: bi_unique_alt_def left_unique_cr_spmf_gpv right_unique_cr_spmf_gpv)

lemma bi_unique_pcr_spmf_gpv [transfer_rule]: "bi_unique A  bi_unique (pcr_spmf_gpv A)"
by(simp add: bi_unique_alt_def left_unique_pcr_spmf_gpv right_unique_pcr_spmf_gpv)

lemma left_total_cr_spmf_gpv: "left_total cr_spmf_gpv"
by(rule left_totalI)(simp add: cr_spmf_gpv_def)

lemma left_total_pcr_spmf_gpv [transfer_rule]: "left_total A ==> left_total (pcr_spmf_gpv A)"
unfolding pcr_spmf_gpv_def by(intro left_total_OO left_total_cr_spmf_gpv left_total_rel_gpv left_total_eq)

context includes lifting_syntax begin

lemma return_spmf_gpv_transfer':
  "((=) ===> cr_spmf_gpv) return_spmf Done"
by(rule rel_funI)(simp add: cr_spmf_gpv_def)

lemma return_spmf_gpv_transfer [transfer_rule]:
  "(A ===> pcr_spmf_gpv A) return_spmf Done"
unfolding pcr_spmf_gpv_def
apply(rewrite in "( ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule return_spmf_gpv_transfer')
apply transfer_prover
done

lemma bind_spmf_gpv_transfer':
  "(cr_spmf_gpv ===> ((=) ===> cr_spmf_gpv) ===> cr_spmf_gpv) bind_spmf bind_gpv"
apply(clarsimp simp add: rel_fun_def cr_spmf_gpv_def)
apply(rule gpv.expand)
apply(simp add: bind_map_spmf map_spmf_bind_spmf o_def)
done

lemma bind_spmf_gpv_transfer [transfer_rule]:
  "(pcr_spmf_gpv A ===> (A ===> pcr_spmf_gpv B) ===> pcr_spmf_gpv B) bind_spmf bind_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(_ ===> ( ===> _) ===> _) _ _" eq_OO[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule fun_mono)
  apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_cr_spmf_gpv left_total_cr_spmf_gpv])
 apply(rule order.refl)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule bind_spmf_gpv_transfer')
apply transfer_prover
done

lemma lift_spmf_gpv_transfer':
  "((=) ===> cr_spmf_gpv) (λx. x) lift_spmf"
by(simp add: rel_fun_def cr_spmf_gpv_def)

lemma lift_spmf_gpv_transfer [transfer_rule]:
  "(rel_spmf A ===> pcr_spmf_gpv A) (λx. x) lift_spmf"
unfolding pcr_spmf_gpv_def
apply(rewrite in "( ===> _) _ _" eq_OO[symmetric])
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(rule lift_spmf_gpv_transfer')
apply transfer_prover
done

lemma fail_spmf_gpv_transfer': "cr_spmf_gpv (return_pmf None) Fail"
by(simp add: cr_spmf_gpv_def)

lemma fail_spmf_gpv_transfer [transfer_rule]: "pcr_spmf_gpv A (return_pmf None) Fail"
unfolding pcr_spmf_gpv_def
apply(rule relcomppI)
 apply(rule fail_spmf_gpv_transfer')
apply transfer_prover
done

lemma map_spmf_gpv_transfer':
  "((=) ===> R ===> cr_spmf_gpv ===> cr_spmf_gpv) (λf g. map_spmf f) map_gpv"
by(simp add: rel_fun_def cr_spmf_gpv_def map_lift_spmf)

lemma map_spmf_gpv_transfer [transfer_rule]:
  "((A ===> B) ===> R ===> pcr_spmf_gpv A ===> pcr_spmf_gpv B) (λf g. map_spmf f) map_gpv"
unfolding pcr_spmf_gpv_def
apply(rewrite in "(( ===> _) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "((_ ===> ) ===> _)  _ _" eq_OO[symmetric])
apply(rewrite in "(_ ===>  ===> _)  _ _" OO_eq[symmetric])
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule neg_fun_distr3[OF left_unique_eq right_total_eq right_unique_eq left_total_eq])
 apply(rule fun_mono[OF order.refl])
 apply(rule pos_fun_distr)
apply(rule fun_mono[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
  apply(rule order.refl)
 apply(rule pos_fun_distr)
apply(rule pos_fun_distr[THEN le_funD, THEN le_funD, THEN le_boolD, THEN mp])
apply(rule relcomppI)
 apply(unfold rel_fun_eq)
 apply(rule map_spmf_gpv_transfer')
apply(unfold rel_fun_eq[symmetric])
apply transfer_prover
done

end

end

subsection ‹ Embedding @{typ "'a option"} as a monad ›

locale option_to_gpv begin

interpretation option_to_spmf .
interpretation spmf_to_gpv .

definition cr_option_gpv :: "'a option  ('a, 'out, 'in) gpv  bool"
where "cr_option_gpv x gpv  gpv = (lift_spmf  return_pmf) x"

lemma cr_option_gpv_conv_OO:
  "cr_option_gpv = cr_spmf_option¯¯ OO cr_spmf_gpv"
by(simp add: fun_eq_iff relcompp.simps cr_option_gpv_def cr_spmf_gpv_def cr_spmf_option_def)

context includes lifting_syntax begin

text ‹These transfer rules should follow from merging the transfer rules, but this has not yet been implemented.›

lemma return_option_gpv_transfer [transfer_rule]:
  "((=) ===> cr_option_gpv) Some Done"
by(simp add: cr_option_gpv_def rel_fun_def)

lemma bind_option_gpv_transfer [transfer_rule]:
  "(cr_option_gpv ===> ((=) ===> cr_option_gpv) ===> cr_option_gpv) Option.bind bind_gpv"
apply(clarsimp simp add: cr_option_gpv_def rel_fun_def)
subgoal for x f g by(cases x; simp)
done

lemma fail_option_gpv_transfer [transfer_rule]: "cr_option_gpv None Fail"
by(simp add: cr_option_gpv_def)

lemma map_option_gpv_transfer [transfer_rule]:
  "((=) ===> R ===> cr_option_gpv ===> cr_option_gpv) (λf g. map_option f) map_gpv"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_gpv_def map_lift_spmf)

end

end

locale option_le_gpv begin

interpretation option_le_spmf .
interpretation spmf_to_gpv .

definition cr_option_le_gpv :: "'a option  ('a, 'out, 'in) gpv  bool"
where "cr_option_le_gpv x gpv  gpv = (lift_spmf  return_pmf) x  x = None"

context includes lifting_syntax begin

lemma return_option_le_gpv_transfer [transfer_rule]:
  "((=) ===> cr_option_le_gpv) Some Done"
by(simp add: cr_option_le_gpv_def rel_fun_def)

lemma bind_option_gpv_transfer [transfer_rule]:
  "(cr_option_le_gpv ===> ((=) ===> cr_option_le_gpv) ===> cr_option_le_gpv) Option.bind bind_gpv"
apply(clarsimp simp add: cr_option_le_gpv_def rel_fun_def bind_eq_Some_conv)
subgoal for f g x y by(erule allE[where x=y]) auto
done

lemma fail_option_gpv_transfer [transfer_rule]:
  "cr_option_le_gpv None Fail"
by(simp add: cr_option_le_gpv_def)

lemma map_option_gpv_transfer [transfer_rule]:
  "(((=) ===> (=)) ===> cr_option_le_gpv ===> cr_option_le_gpv) map_option (λf. map_gpv f id)"
unfolding rel_fun_eq by(simp add: rel_fun_def cr_option_le_gpv_def map_lift_spmf)

end

end

subsection ‹Embedding resumptions›

primcorec lift_resumption :: "('a, 'out, 'in) resumption  ('a, 'out, 'in) gpv"
where
  "the_gpv (lift_resumption r) = 
  (case r of resumption.Done None  return_pmf None
    | resumption.Done (Some x') => return_spmf (Pure x')
    | resumption.Pause out c => map_spmf (map_generat id id ((∘) lift_resumption)) (return_spmf (IO out c)))"

lemma the_gpv_lift_resumption:
  "the_gpv (lift_resumption r) = 
   (if is_Done r then if Option.is_none (resumption.result r) then return_pmf None else return_spmf (Pure (the (resumption.result r)))
    else return_spmf (IO (resumption.output r) (lift_resumption  resume r)))"
by(simp split: option.split resumption.split)

declare lift_resumption.simps [simp del]

lemma lift_resumption_Done [code]:
  "lift_resumption (resumption.Done x) = (case x of None  Fail | Some x'  Done x')"
by(rule gpv.expand)(simp add: the_gpv_lift_resumption split: option.split)

lemma lift_resumption_DONE [simp]:
  "lift_resumption (DONE x) = Done x"
by(simp add: DONE_def lift_resumption_Done)

lemma lift_resumption_ABORT [simp]:
  "lift_resumption ABORT = Fail"
by(simp add: ABORT_def lift_resumption_Done)

lemma lift_resumption_Pause [simp, code]:
  "lift_resumption (resumption.Pause out c) = Pause out (lift_resumption  c)"
by(rule gpv.expand)(simp add: the_gpv_lift_resumption)

lemma lift_resumption_Done_Some [simp]: "lift_resumption (resumption.Done (Some x)) = Done x"
using lift_resumption_DONE unfolding DONE_def by simp

lemma results'_gpv_lift_resumption [simp]:
  "results'_gpv (lift_resumption r) = results r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv"lift_resumption r" arbitrary: r)
      (auto intro: resumption.set_sel simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
  show "x  ?lhs" if "x  ?rhs" for x using that by induction(auto simp add: lift_resumption.sel)
qed

lemma outs'_gpv_lift_resumption [simp]:
  "outs'_gpv (lift_resumption r) = outputs r" (is "?lhs = ?rhs")
proof(rule set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv"lift_resumption r" arbitrary: r)
      (auto simp add: lift_resumption.sel split: resumption.split_asm option.split_asm)
  show "x  ?lhs" if "x  ?rhs" for x using that by induction auto
qed

lemma pred_gpv_lift_resumption [simp]:
  "A. pred_gpv A C (lift_resumption r) = pred_resumption A C r"
by(simp add: pred_gpv_def pred_resumption_def)

lemma lift_resumption_bind: "lift_resumption (r  f) = lift_resumption r  lift_resumption  f"
by(coinduction arbitrary: r rule: gpv.coinduct_strong)
  (auto simp add: lift_resumption.sel Done_bind split: resumption.split option.split del: rel_funI intro!: rel_funI)

subsection ‹Assertions›

definition assert_gpv :: "bool  (unit, 'out, 'in) gpv"
where "assert_gpv b = (if b then Done () else Fail)"

lemma assert_gpv_simps [simp]:
  "assert_gpv True = Done ()"
  "assert_gpv False = Fail"
by(simp_all add: assert_gpv_def)

lemma [simp]:
  shows assert_gpv_eq_Done: "assert_gpv b = Done x  b"
  and Done_eq_assert_gpv: "Done x = assert_gpv b  b"
  and Pause_neq_assert_gpv: "Pause out rpv  assert_gpv b"
  and assert_gpv_neq_Pause: "assert_gpv b  Pause out rpv"
  and assert_gpv_eq_Fail: "assert_gpv b = Fail  ¬ b"
  and Fail_eq_assert_gpv: "Fail = assert_gpv b  ¬ b"
by(simp_all add: assert_gpv_def)

lemma assert_gpv_inject [simp]: "assert_gpv b = assert_gpv b'  b = b'"
by(simp add: assert_gpv_def)

lemma assert_gpv_sel [simp]:
  "the_gpv (assert_gpv b) = map_spmf Pure (assert_spmf b)"
by(simp add: assert_gpv_def)

lemma the_gpv_bind_assert [simp]:
  "the_gpv (bind_gpv (assert_gpv b) f) =
   bind_spmf (assert_spmf b) (the_gpv  f)"
by(cases b) simp_all

lemma pred_gpv_assert [simp]: "pred_gpv P Q (assert_gpv b) = (b  P ())"
by(cases b) simp_all

primcorec try_gpv :: "('a, 'call, 'ret) gpv  ('a, 'call, 'ret) gpv  ('a, 'call, 'ret) gpv" ("TRY _ ELSE _" [0,60] 59)
where
  "the_gpv (TRY gpv ELSE gpv') = 
   map_spmf (map_generat id id (λc input. case c input of Inl gpv  try_gpv gpv gpv' | Inr gpv'  gpv'))
     (try_spmf (map_spmf (map_generat id id (map_fun id Inl)) (the_gpv gpv))
               (map_spmf (map_generat id id (map_fun id Inr)) (the_gpv gpv')))"

lemma try_gpv_sel:
  "the_gpv (TRY gpv ELSE gpv') =
   TRY map_spmf (map_generat id id (λc input. TRY c input ELSE gpv')) (the_gpv gpv) ELSE the_gpv gpv'"
by(simp add: try_gpv_def map_try_spmf spmf.map_comp o_def generat.map_comp generat.map_ident id_def)

lemma try_gpv_Done [simp]: "TRY Done x ELSE gpv' = Done x"
by(rule gpv.expand)(simp)

lemma try_gpv_Fail [simp]: "TRY Fail ELSE gpv' = gpv'"
by(rule gpv.expand)(simp add: spmf.map_comp o_def generat.map_comp generat.map_ident)

lemma try_gpv_Pause [simp]: "TRY Pause out c ELSE gpv' = Pause out (λinput. TRY c input ELSE gpv')"
by(rule gpv.expand) simp

lemma try_gpv_Fail2 [simp]: "TRY gpv ELSE Fail = gpv"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)
  (auto simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl)

lemma lift_try_spmf: "lift_spmf (TRY p ELSE q) = TRY lift_spmf p ELSE lift_spmf q" 
by(rule gpv.expand)(simp add: map_try_spmf spmf.map_comp o_def)

lemma try_assert_gpv: "TRY assert_gpv b ELSE gpv' = (if b then Done () else gpv')"
by(simp)

context includes lifting_syntax begin
lemma try_gpv_parametric [transfer_rule]:
  "(rel_gpv A C ===> rel_gpv A C ===> rel_gpv A C) try_gpv try_gpv"
unfolding try_gpv_def by transfer_prover

lemma try_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_gpv'' A C R ===> rel_gpv'' A C R) try_gpv try_gpv"
unfolding try_gpv_def
supply corec_gpv_parametric'[transfer_rule] the_gpv_parametric'[transfer_rule]
by transfer_prover
end

lemma map_try_gpv: "map_gpv f g (TRY gpv ELSE gpv') = TRY map_gpv f g gpv ELSE map_gpv f g gpv'"
by(simp add: gpv.rel_map try_gpv_parametric[THEN rel_funD, THEN rel_funD] gpv.rel_refl gpv.rel_eq[symmetric])

lemma map'_try_gpv: "map_gpv' f g h (TRY gpv ELSE gpv') = TRY map_gpv' f g h gpv ELSE map_gpv' f g h gpv'"
by(coinduction arbitrary: gpv rule: gpv.coinduct_strong)(auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI generat.rel_refl rel_funI rel_spmf_try_spmf)
  

lemma try_bind_assert_gpv:
  "TRY (assert_gpv b  f) ELSE gpv = (if b then TRY (f ()) ELSE gpv else gpv)"
by(simp)



subsection ‹Order for @{typ "('a, 'out, 'in) gpv"}

coinductive ord_gpv :: "('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv  bool"
where
  "ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) f g  ord_gpv (GPV f) (GPV g)"

inductive_simps ord_gpv_simps [simp]:
  "ord_gpv (GPV f) (GPV g)"

lemma ord_gpv_coinduct [consumes 1, case_names ord_gpv, coinduct pred: ord_gpv]:
  assumes "X f g"
  and step: "f g. X f g  ord_spmf (rel_generat (=) (=) (rel_fun (=) X)) (the_gpv f) (the_gpv g)"
  shows "ord_gpv f g"
using X f g
by(coinduct)(auto dest: step simp add: eq_GPV_iff intro: ord_spmf_mono rel_generat_mono rel_fun_mono)

lemma ord_gpv_the_gpvD:
  "ord_gpv f g  ord_spmf (rel_generat (=) (=) (rel_fun (=) ord_gpv)) (the_gpv f) (the_gpv g)"
by(erule ord_gpv.cases) simp

lemma reflp_equality: "reflp (=)"
by(simp add: reflp_def)

lemma ord_gpv_reflI [simp]: "ord_gpv f f"
by(coinduction arbitrary: f)(auto intro: ord_spmf_reflI simp add: rel_generat_same rel_fun_def)

lemma reflp_ord_gpv: "reflp ord_gpv"
by(rule reflpI)(rule ord_gpv_reflI)

lemma ord_gpv_trans:
  assumes "ord_gpv f g" "ord_gpv g h"
  shows "ord_gpv f h"
using assms
proof(coinduction arbitrary: f g h)
  case (ord_gpv f g h)
  have *: "ord_spmf (rel_generat (=) (=) (rel_fun (=) (λf h. g. ord_gpv f g  ord_gpv g h))) (the_gpv f) (the_gpv h) =
    ord_spmf (rel_generat ((=) OO (=)) ((=) OO (=)) (rel_fun (=) (ord_gpv OO ord_gpv))) (the_gpv f) (the_gpv h)"
    by(simp add: relcompp.simps[abs_def])
  then show ?case using ord_gpv
    by(auto elim!: ord_gpv.cases simp add: generat.rel_compp ord_spmf_compp fun.rel_compp)
qed

lemma ord_gpv_compp: "(ord_gpv OO ord_gpv) = ord_gpv"
by(auto simp add: fun_eq_iff intro: ord_gpv_trans)

lemma transp_ord_gpv [simp]: "transp ord_gpv"
by(blast intro: transpI ord_gpv_trans)

lemma ord_gpv_antisym:
  " ord_gpv f g; ord_gpv g f   f = g"
proof(coinduction arbitrary: f g)
  case (Eq_gpv f g)
  let ?R = "rel_generat (=) (=) (rel_fun (=) ord_gpv)"
  from ‹ord_gpv f g have "ord_spmf ?R (the_gpv f) (the_gpv g)" by cases simp
  moreover
  from ‹ord_gpv g f have "ord_spmf ?R (the_gpv g) (the_gpv f)" by cases simp
  ultimately have "rel_spmf (inf ?R ?R¯¯) (the_gpv f) (the_gpv g)"
    by(rule rel_spmf_inf)(auto 4 3 intro: transp_rel_generatI transp_ord_gpv reflp_ord_gpv reflp_equality reflp_fun1 is_equality_eq transp_rel_fun)
  also have "inf ?R ?R¯¯ = rel_generat (inf (=) (=)) (inf (=) (=)) (rel_fun (=) (inf ord_gpv ord_gpv¯¯))"
    unfolding rel_generat_inf[symmetric] rel_fun_inf[symmetric]
    by(simp add: generat.rel_conversep[symmetric] fun.rel_conversep)
  finally show ?case by(simp add: inf_fun_def)
qed

lemma RFail_least [simp]: "ord_gpv Fail f"
by(coinduction arbitrary: f)(simp add: eq_GPV_iff)

subsection ‹Bounds on interaction›

context
  fixes "consider" :: "'out  bool"
  notes monotone_SUP[partial_function_mono] [[function_internals]]
begin
declaration Partial_Function.init "lfp_strong" @{term lfp.fixp_fun} @{term lfp.mono_body}
  @{thm lfp.fixp_rule_uc} @{thm lfp.fixp_induct_strong2_uc} NONE›

partial_function (lfp_strong) interaction_bound :: "('a, 'out, 'in) gpv  enat"
where
  "interaction_bound gpv =
  (SUP generatset_spmf (the_gpv gpv). case generat of Pure _  0 
     | IO out c  if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"

lemma interaction_bound_fixp_induct [case_names adm bottom step]:
  " ccpo.admissible (fun_lub Sup) (fun_ord (≤)) P;
     P (λ_. 0);
    interaction_bound'. 
     P interaction_bound'; 
      gpv. interaction_bound' gpv  interaction_bound gpv;
      gpv. interaction_bound' gpv  (SUP generatset_spmf (the_gpv gpv). case generat of Pure _  0 
     | IO out c  if consider out then eSuc (SUP input. interaction_bound' (c input)) else (SUP input. interaction_bound' (c input)))
      
       P (λgpv. generatset_spmf (the_gpv gpv). case generat of Pure x  0
         | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else (input. interaction_bound' (c input))) 
    P interaction_bound"
by(erule interaction_bound.fixp_induct)(simp_all add: bot_enat_def fun_ord_def)

lemma interaction_bound_IO:
   "IO out c  set_spmf (the_gpv gpv)
    (if consider out then eSuc (interaction_bound (c input)) else interaction_bound (c input))  interaction_bound gpv"
by(rewrite in "_  " interaction_bound.simps)(auto intro!: SUP_upper2)

lemma interaction_bound_IO_consider:
   " IO out c  set_spmf (the_gpv gpv); consider out 
    eSuc (interaction_bound (c input))  interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_IO_ignore:
   " IO out c  set_spmf (the_gpv gpv); ¬ consider out 
    interaction_bound (c input)  interaction_bound gpv"
by(drule interaction_bound_IO) simp

lemma interaction_bound_Done [simp]: "interaction_bound (Done x) = 0"
by(simp add: interaction_bound.simps)

lemma interaction_bound_Fail [simp]: "interaction_bound Fail = 0"
by(simp add: interaction_bound.simps bot_enat_def)

lemma interaction_bound_Pause [simp]:
  "interaction_bound (Pause out c) = 
   (if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input)))"
by(simp add: interaction_bound.simps)

lemma interaction_bound_lift_spmf [simp]: "interaction_bound (lift_spmf p) = 0"
by(simp add: interaction_bound.simps SUP_constant bot_enat_def)

lemma interaction_bound_assert_gpv [simp]: "interaction_bound (assert_gpv b) = 0"
by(cases b) simp_all

lemma interaction_bound_bind_step:
  assumes IH: "p. interaction_bound' (p  f)  interaction_bound p + (xresults'_gpv p. interaction_bound' (f x))"
  and unfold: "gpv. interaction_bound' gpv  (generatset_spmf (the_gpv gpv). case generat of Pure x  0
             | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else input. interaction_bound' (c input))"
  shows "(generatset_spmf (the_gpv (p  f)).
             case generat of Pure x  0
             | IO out c 
                 if consider out then eSuc (input. interaction_bound' (c input))
                 else input. interaction_bound' (c input))
          interaction_bound p +
            (xresults'_gpv p.
                generatset_spmf (the_gpv (f x)).
                   case generat of Pure x  0
                   | IO out c 
                       if consider out then eSuc (input. interaction_bound' (c input))
                       else input. interaction_bound' (c input))"
    (is "(SUP generat'?bind. ?g generat')  ?p + ?f")
proof(rule SUP_least)
  fix generat'
  assume "generat'  ?bind"
  then obtain generat where generat: "generat  set_spmf (the_gpv p)"
    and *: "case generat of Pure x  generat'  set_spmf (the_gpv (f x)) 
         | IO out c  generat' = IO out (λinput. c input  f)"
    by(clarsimp simp add: bind_gpv.sel simp del: bind_gpv_sel')
      (clarsimp split: generat.split_asm simp add: generat.map_comp o_def generat.map_id[unfolded id_def])
  show "?g generat'  ?p + ?f"
  proof(cases generat)
    case (Pure x)
    have "?g generat'  (SUP generat'set_spmf (the_gpv (f x)). (case generat' of Pure x  0 | IO out c  if consider out then eSuc (input. interaction_bound' (c input)) else input. interaction_bound' (c input)))"
      using * Pure by(auto intro: SUP_upper)
    also have "  0 + ?f" using generat Pure
      by(auto 4 3 intro: SUP_upper results'_gpv_Pure)
    also have "  ?p + ?f" by simp
    finally show ?thesis .
  next
    case (IO out c)
    with * have "?g generat' = (if consider out then eSuc (SUP input. interaction_bound' (c input  f)) else (SUP input. interaction_bound' (c input  f)))" by simp
    also have "  (if consider out then eSuc (SUP input. interaction_bound (c input) + (xresults'_gpv (c input). interaction_bound' (f x))) else (SUP input. interaction_bound (c input) + (xresults'_gpv (c input). interaction_bound' (f x))))"
      by(auto intro: SUP_mono IH)
    also have "  (case IO out c of Pure (x :: 'a)  0 | IO out c  if consider out then eSuc (SUP input. interaction_bound (c input)) else (SUP input. interaction_bound (c input))) + (SUP input. SUP xresults'_gpv (c input). interaction_bound' (f x))"
      by(simp add: iadd_Suc SUP_le_iff)(meson SUP_upper2 UNIV_I add_mono order_refl)
    also have "  ?p + ?f"
      apply(rewrite in "_  " interaction_bound.simps)
      apply(rule add_mono SUP_least SUP_upper generat[unfolded IO])+
      apply(rule order_trans[OF unfold])
      apply(auto 4 3 intro: results'_gpv_Cont[OF generat] SUP_upper simp add: IO)
      done
    finally show ?thesis .
  qed
qed

lemma interaction_bound_bind:
  defines "ib1  interaction_bound"
  shows "interaction_bound (p  f)  ib1 p + (SUP xresults'_gpv p. interaction_bound (f x))"
proof(induction arbitrary: p rule: interaction_bound_fixp_induct)
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step interaction_bound') then show ?case unfolding ib1_def by -(rule interaction_bound_bind_step)
qed

lemma interaction_bound_bind_lift_spmf [simp]:
  "interaction_bound (lift_spmf p  f) = (SUP xset_spmf p. interaction_bound (f x))"
by(subst (1 2) interaction_bound.simps)(simp add: bind_UNION SUP_UNION)

end

lemma interaction_bound_map_gpv':
  assumes "surj h"
  shows "interaction_bound consider (map_gpv' f g h gpv) = interaction_bound (consider  g) gpv"
proof(induction arbitrary: gpv rule: parallel_fixp_induct_1_1[OF lattice_partial_function_definition lattice_partial_function_definition interaction_bound.mono interaction_bound.mono interaction_bound_def interaction_bound_def, case_names adm bottom step])
  case (step interaction_bound' interaction_bound'' gpv)
  have *: "IO out c  set_spmf (the_gpv gpv)   x  UNIV  interaction_bound'' (c x)  (x. interaction_bound'' (c (h x)))" for out c x
    using assms[THEN surjD, of x] by (clarsimp intro!: SUP_upper)

  show ?case 
    by (auto simp add: * step.IH image_comp split: generat.split
      intro!: SUP_cong [OF refl] antisym SUP_upper SUP_least)
qed simp_all

abbreviation interaction_any_bound :: "('a, 'out, 'in) gpv  enat"
where "interaction_any_bound  interaction_bound (λ_. True)"

lemma interaction_any_bound_coinduct [consumes 1, case_names interaction_bound]:
  assumes X: "X gpv n"
  and *: "gpv n out c input.  X gpv n; IO out c  set_spmf (the_gpv gpv)  
     n'. (X (c input) n'  interaction_any_bound (c input)  n')  eSuc n'  n"
  shows "interaction_any_bound gpv  n"
using X
proof(induction arbitrary: gpv n rule: interaction_bound_fixp_induct)
  case adm show ?case by(intro cont_intro)
  case bottom show ?case by simp
next
  case (step interaction_bound')
  { fix out c
    assume IO: "IO out c  set_spmf (the_gpv gpv)"
    from *[OF step.prems IO] obtain n' where n: "n = eSuc n'"
      by(cases n rule: co.enat.exhaust) auto
    moreover 
    { fix input
      have "n''. (X (c input) n''  interaction_any_bound (c input)  n'')  eSuc n''  n"
        using step.prems IO n = eSuc n' by(auto 4 3 dest: *)
      then have "interaction_bound' (c input)  n'" using n
        by(auto dest: step.IH intro: step.hyps[THEN order_trans] elim!: order_trans simp add: neq_zero_conv_eSuc) }
    ultimately have "eSuc (input. interaction_bound' (c input))  n"
      by(auto intro: SUP_least) }
  then show ?case by(auto intro!: SUP_least split: generat.split)
qed

context includes lifting_syntax begin
lemma interaction_bound_parametric':
  assumes [transfer_rule]: "bi_total R"
  shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=)) interaction_bound interaction_bound"
unfolding interaction_bound_def[abs_def]
apply(rule rel_funI)
apply(rule fixp_lfp_parametric_eq[OF interaction_bound.mono interaction_bound.mono])
subgoal premises [transfer_rule]
  supply the_gpv_parametric'[transfer_rule] rel_gpv''_eq[relator_eq]
  by transfer_prover
done

lemma interaction_bound_parametric [transfer_rule]:
  "((C ===> (=)) ===> rel_gpv A C ===> (=)) interaction_bound interaction_bound"
unfolding rel_gpv_conv_rel_gpv'' by(rule interaction_bound_parametric')(rule bi_total_eq)
end

text ‹
  There is no nice @{const interaction_bound} equation for @{const bind_gpv}, as it computes
  an exact bound, but we only need an upper bound.
  As @{typ enat} is hard to work with (and @{term } does not constrain a gpv in any way),
  we work with @{typ nat}.
›

inductive interaction_bounded_by :: "('out  bool)  ('a, 'out, 'in) gpv  enat  bool"
for "consider" gpv n where
  interaction_bounded_by: " interaction_bound consider gpv  n   interaction_bounded_by consider gpv n"

lemmas interaction_bounded_byI = interaction_bounded_by
hide_fact (open) interaction_bounded_by

context includes lifting_syntax begin
lemma interaction_bounded_by_parametric [transfer_rule]:
  "((C ===> (=)) ===> rel_gpv A C ===> (=) ===> (=)) interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover

lemma interaction_bounded_by_parametric':
  notes interaction_bound_parametric'[transfer_rule]
  assumes [transfer_rule]: "bi_total R"
  shows "((C ===> (=)) ===> rel_gpv'' A C R ===> (=) ===> (=)) 
         interaction_bounded_by interaction_bounded_by"
unfolding interaction_bounded_by.simps[abs_def] by transfer_prover
end

lemma interaction_bounded_by_mono:
  " interaction_bounded_by consider gpv n; n  m   interaction_bounded_by consider gpv m"
unfolding interaction_bounded_by.simps by(erule order_trans) simp

lemma interaction_bounded_by_contD:
  " interaction_bounded_by consider gpv n; IO out c  set_spmf (the_gpv gpv); consider out 
   n > 0  interaction_bounded_by consider (c input) (n - 1)"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec)

lemma interaction_bounded_by_contD_ignore:
  " interaction_bounded_by consider gpv n; IO out c  set_spmf (the_gpv gpv) 
   interaction_bounded_by consider (c input) n"
unfolding interaction_bounded_by.simps
by(subst (asm) interaction_bound.simps)(auto 4 4 simp add: SUP_le_iff eSuc_le_iff enat_eSuc_iff dest!: bspec split: if_split_asm elim: order_trans)

lemma interaction_bounded_byI_epred:
  assumes "out c.  IO out c  set_spmf (the_gpv gpv); consider out   n  0  (input. interaction_bounded_by consider (c input) (n - 1))"
  and "out c input.  IO out c  set_spmf (the_gpv gpv); ¬ consider out   interaction_bounded_by consider (c input) n"
  shows "interaction_bounded_by consider gpv n"
unfolding interaction_bounded_by.simps
by(subst interaction_bound.simps)(auto 4 5 intro!: SUP_least split: generat.split dest: assms simp add: eSuc_le_iff enat_eSuc_iff gr0_conv_Suc neq_zero_conv_eSuc interaction_bounded_by.simps)

lemma interaction_bounded_by_IO:
  " IO out c  set_spmf (the_gpv gpv); interaction_bounded_by consider gpv n; consider out 
   n  0  interaction_bounded_by consider (c input) (n - 1)"
by(drule interaction_bound_IO[where input=input and ?consider="consider"])(auto simp add: interaction_bounded_by.simps epred_conv_minus eSuc_le_iff enat_eSuc_iff)

lemma interaction_bounded_by_0: "interaction_bounded_by consider gpv 0  interaction_bound consider gpv = 0"
by(simp add: interaction_bounded_by.simps zero_enat_def[symmetric])

abbreviation interaction_bounded_by' :: "('out  bool)  ('a, 'out, 'in) gpv  nat  bool"
where "interaction_bounded_by' consider gpv n  interaction_bounded_by consider gpv (enat n)"

named_theorems interaction_bound

lemmas interaction_bounded_by_start = interaction_bounded_by_mono

method interaction_bound_start = (rule interaction_bounded_by_start)
method interaction_bound_step uses add simp =
  ((match conclusion in "interaction_bounded_by _ _ _"  fail ¦ _  solvesclarsimp simp add: simp››) | rule add interaction_bound)
method interaction_bound_rec uses add simp = 
  (interaction_bound_step add: add simp: simp; (interaction_bound_rec add: add simp: simp)?)
method interaction_bound uses add simp =
  ((* use in *) interaction_bound_start, interaction_bound_rec add: add simp: simp)

lemma interaction_bounded_by_Done [simp]: "interaction_bounded_by consider (Done x) n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_DoneI [interaction_bound]:
  "interaction_bounded_by consider (Done x) 0"
by simp

lemma interaction_bounded_by_Fail [simp]: "interaction_bounded_by consider Fail n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_FailI [interaction_bound]: "interaction_bounded_by consider Fail 0"
by simp

lemma interaction_bounded_by_lift_spmf [simp]: "interaction_bounded_by consider (lift_spmf p) n"
by(simp add: interaction_bounded_by.simps)

lemma interaction_bounded_by_lift_spmfI [interaction_bound]:
  "interaction_bounded_by consider (lift_spmf p) 0"
by simp

lemma interaction_bounded_by_assert_gpv [simp]: "interaction_bounded_by consider (assert_gpv b) n"
by(cases b) simp_all

lemma interaction_bounded_by_assert_gpvI [interaction_bound]:
  "interaction_bounded_by consider (assert_gpv b) 0"
by simp

lemma interaction_bounded_by_Pause [simp]:
  "interaction_bounded_by consider (Pause out c) n  
  (if consider out then 0 < n  (input. interaction_bounded_by consider (c input) (n - 1)) else (input. interaction_bounded_by consider (c input) n))"
by(cases n rule: co.enat.exhaust)
  (auto 4 3 simp add: interaction_bounded_by.simps eSuc_le_iff enat_eSuc_iff gr0_conv_Suc intro: SUP_least dest: order_trans[OF SUP_upper, rotated])

lemma interaction_bounded_by_PauseI [interaction_bound]:
  "(input. interaction_bounded_by consider (c input) (n input))
   interaction_bounded_by consider (Pause out c) (if consider out then 1 + (SUP input. n input) else (SUP input. n input))"
by(auto simp add: iadd_is_0 enat_add_sub_same intro: interaction_bounded_by_mono SUP_upper)

lemma interaction_bounded_by_bindI [interaction_bound]:
  " interaction_bounded_by consider gpv n; x. x  results'_gpv gpv  interaction_bounded_by consider (f x) (m x) 
   interaction_bounded_by consider (gpv  f) (n + (SUP xresults'_gpv gpv. m x))"
unfolding interaction_bounded_by.simps plus_enat_simps(1)[symmetric]
by(rule interaction_bound_bind[THEN order_trans])(auto intro: add_mono SUP_mono)

lemma interaction_bounded_by_bind_PauseI [interaction_bound]:
  "(input. interaction_bounded_by consider (c input  f) (n input))
   interaction_bounded_by consider (Pause out c  f) (if consider out then SUP input. n input + 1 else SUP input. n input)"
by(auto 4 3 simp add: interaction_bounded_by.simps SUP_enat_add_left eSuc_plus_1 intro: SUP_least SUP_upper2)

lemma interaction_bounded_by_bind_lift_spmf [simp]:
  "interaction_bounded_by consider (lift_spmf p  f) n  (xset_spmf p. interaction_bounded_by consider (f x) n)"
by(simp add: interaction_bounded_by.simps SUP_le_iff)

lemma interaction_bounded_by_bind_lift_spmfI [interaction_bound]:
  "(x. x  set_spmf p  interaction_bounded_by consider (f x) (n x))
   interaction_bounded_by consider (lift_spmf p  f) (SUP xset_spmf p. n x)"
by(auto intro: interaction_bounded_by_mono SUP_upper)

lemma interaction_bounded_by_bind_DoneI [interaction_bound]:
  "interaction_bounded_by consider (f x) n  interaction_bounded_by consider (Done x  f) n"
by(simp)

lemma interaction_bounded_by_if [interaction_bound]:
  " b  interaction_bounded_by consider gpv1 n; ¬ b  interaction_bounded_by consider gpv2 m 
   interaction_bounded_by consider (if b then gpv1 else gpv2) (if b then n else m)"
by(auto 4 3 simp add: max_def not_le elim: interaction_bounded_by_mono)

lemma interaction_bounded_by_case_bool [interaction_bound]:
  " b  interaction_bounded_by consider t bt; ¬ b  interaction_bounded_by consider f bf 
   interaction_bounded_by consider (case_bool t f b) (if b then bt else bf)"
by(cases b)(auto)

lemma interaction_bounded_by_case_sum [interaction_bound]:
  " y. x = Inl y  interaction_bounded_by consider (l y) (bl y);
     y. x = Inr y  interaction_bounded_by consider (r y) (br y) 
   interaction_bounded_by consider (case_sum l r x) (case_sum bl br x)"
by(cases x)(auto)

lemma interaction_bounded_by_case_prod [interaction_bound]:
  "(a b. x = (a, b)  interaction_bounded_by consider (f a b) (n a b))
   interaction_bounded_by consider (case_prod f x) (case_prod n x)"
by(simp split: prod.split)

lemma interaction_bounded_by_let [interaction_bound]: ― ‹This rule unfolds let's›
  "interaction_bounded_by consider (f t) m  interaction_bounded_by consider (Let t f) m"
by(simp add: Let_def)

lemma interaction_bounded_by_map_gpv_id [interaction_bound]:
  assumes [interaction_bound]: "interaction_bounded_by P gpv n"
  shows "interaction_bounded_by P (map_gpv f id gpv) n"
unfolding id_def map_gpv_conv_bind by interaction_bound simp

abbreviation interaction_any_bounded_by :: "('a, 'out, 'in) gpv  enat  bool"
where "interaction_any_bounded_by  interaction_bounded_by (λ_. True)"

lemma interaction_any_bounded_by_map_gpv':
  assumes "interaction_any_bounded_by gpv n"
    and "surj h"
  shows "interaction_any_bounded_by (map_gpv' f g h gpv) n"
  using assms by(simp add: interaction_bounded_by.simps interaction_bound_map_gpv' o_def)

subsection ‹Typing›

subsubsection ‹Interface between gpvs and rpvs / callees›

lemma is_empty_parametric [transfer_rule]: "rel_fun (rel_set A) (=) Set.is_empty Set.is_empty" (* Move *)
by(auto simp add: rel_fun_def Set.is_empty_def dest: rel_setD1 rel_setD2)

typedef ('call, 'ret)= "UNIV :: ('call  'ret set) set" ..

setup_lifting type_definition_ℐ

lemma outs_ℐ_tparametric:
  includes lifting_syntax 
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> rel_set B) ===> rel_set A) (λresps. {out. resps out  {}}) (λresps. {out. resps out  {}})"
  by(fold Set.is_empty_def) transfer_prover

lift_definition outs_ℐ :: "('call, 'ret) 'call set" is "λresps. {out. resps out  {}}" parametric outs_ℐ_tparametric .
lift_definition responses_ℐ :: "('call, 'ret) 'call  'ret set" is "λx. x" parametric id_transfer[unfolded id_def] .

lift_definition rel_ℐ :: "('call  'call'  bool)  ('ret  'ret'  bool)  ('call, 'ret) ('call', 'ret') bool"
is "λC R resp1 resp2. rel_set C {out. resp1 out  {}} {out. resp2 out  {}}  rel_fun C (rel_set R) resp1 resp2"
.

lemma rel_ℐI [intro?]:
  " rel_set C (outs_ℐ ℐ1) (outs_ℐ ℐ2); x y. C x y  rel_set R (responses_ℐ ℐ1 x) (responses_ℐ ℐ2 y) 
   rel_ℐ C R ℐ1 ℐ2"
by transfer(auto simp add: rel_fun_def)

lemma rel_ℐ_eq [relator_eq]: "rel_ℐ (=) (=) = (=)"
unfolding fun_eq_iff by transfer(auto simp add: relator_eq)

lemma rel_ℐ_conversep [simp]: "rel_ℐ C¯¯ R¯¯ = (rel_ℐ C R)¯¯"
unfolding fun_eq_iff conversep_iff
apply transfer
apply(rewrite in "rel_fun " conversep_iff[symmetric])
apply(rewrite in "rel_set " conversep_iff[symmetric])
apply(rewrite in "rel_fun _ " conversep_iff[symmetric])
apply(simp del: conversep_iff add: rel_fun_conversep)
apply(simp)
done

lemma rel_ℐ_conversep1_eq [simp]: "rel_ℐ C¯¯ (=) = (rel_ℐ C (=))¯¯"
by(rewrite in " = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma rel_ℐ_conversep2_eq [simp]: "rel_ℐ (=) R¯¯ = (rel_ℐ (=) R)¯¯"
by(rewrite in " = _" conversep_eq[symmetric])(simp del: conversep_eq)

lemma responses_ℐ_empty_iff: "responses_ℐ  out = {}  out  outs_ℐ "
including ℐ.lifting by transfer auto

lemma in_outs_ℐ_iff_responses_ℐ: "out  outs_ℐ   responses_ℐ  out  {}"
by(simp add: responses_ℐ_empty_iff)

lift_definition ℐ_full :: "('call, 'ret) ℐ" is "λ_. UNIV" .

lemma ℐ_full_sel [simp]:
  shows outs_ℐ_full: "outs_ℐ ℐ_full = UNIV"
  and responses_ℐ_full: "responses_ℐ ℐ_full x = UNIV"
by(transfer; simp; fail)+

context includes lifting_syntax begin
lemma outs_ℐ_parametric [transfer_rule]: "(rel_ℐ C R ===> rel_set C) outs_ℐ outs_ℐ"
unfolding rel_fun_def by transfer simp

lemma responses_ℐ_parametric [transfer_rule]: 
  "(rel_ℐ C R ===> C ===> rel_set R) responses_ℐ responses_ℐ"
unfolding rel_fun_def by transfer(auto dest: rel_funD)

end

definition ℐ_trivial :: "('out, 'in) bool"
where "ℐ_trivial   outs_ℐ  = UNIV"

lemma ℐ_trivialI [intro?]: "(x. x  outs_ℐ )  ℐ_trivial "
by(auto simp add: ℐ_trivial_def)

lemma ℐ_trivialD: "ℐ_trivial   outs_ℐ  = UNIV"
by(simp add: ℐ_trivial_def)

lemma ℐ_trivial_ℐ_full [simp]: "ℐ_trivial ℐ_full"
by(simp add: ℐ_trivial_def)

lifting_update ℐ.lifting
lifting_forget ℐ.lifting

context includes ℐ.lifting begin

lift_definition ℐ_uniform :: "'out set  'in set  ('out, 'in) ℐ" is "λA B x. if x  A then B else {}" .

lemma outs_ℐ_uniform [simp]: "outs_ℐ (ℐ_uniform A B) = (if B = {} then {} else A)"
  by transfer simp

lemma responses_ℐ_uniform [simp]: "responses_ℐ (ℐ_uniform A B) x = (if x  A then B else {})"
  by transfer simp

lemma ℐ_uniform_UNIV [simp]: "ℐ_uniform UNIV UNIV = ℐ_full" (* TODO: make ℐ_full an abbreviation *)
  by transfer simp

lift_definition map_ℐ :: "('out'  'out)  ('in  'in')  ('out, 'in) ('out', 'in') ℐ"
  is "λf g resp x. g ` resp (f x)" .

lemma outs_ℐ_map_ℐ [simp]:
  "outs_ℐ (map_ℐ f g ) = f -` outs_ℐ "
  by transfer simp

lemma responses_ℐ_map_ℐ [simp]:
  "responses_ℐ (map_ℐ f g ) x = g ` responses_ℐ  (f x)"
  by transfer simp

lemma map_ℐ_ℐ_uniform [simp]:
  "map_ℐ f g (ℐ_uniform A B) = ℐ_uniform (f -` A) (g ` B)"
  by transfer(auto simp add: fun_eq_iff)

lemma map_ℐ_id [simp]: "map_ℐ id id  = "
  by transfer simp

lemma map_ℐ_id0: "map_ℐ id id = id"
  by(simp add: fun_eq_iff)

lemma map_ℐ_comp [simp]: "map_ℐ f g (map_ℐ f' g' ) = map_ℐ (f'  f) (g  g') "
  by transfer auto

lemma map_ℐ_cong: "map_ℐ f g  = map_ℐ f' g' ℐ'"
  if " = ℐ'" and f: "f = f'" and "x y.  x  outs_ℐ ℐ'; y  responses_ℐ ℐ' x   g y = g' y"
  unfolding that(1,2) using that(3-)
  by transfer(auto simp add: fun_eq_iff intro!: image_cong)

lifting_update ℐ.lifting
lifting_forget ℐ.lifting
end

functor map_ℐ by(simp_all add: fun_eq_iff)

lemma ℐ_eqI: " outs_ℐ  = outs_ℐ ℐ'; x. x  outs_ℐ ℐ'  responses_ℐ  x = responses_ℐ ℐ' x    = ℐ'"
  including ℐ.lifting by transfer auto

instantiation:: (type, type) order begin

definition less_eq_ℐ :: "('a, 'b) ('a, 'b) bool"
  where le_ℐ_def: "less_eq_ℐ  ℐ'  outs_ℐ   outs_ℐ ℐ'  (xouts_ℐ . responses_ℐ ℐ' x  responses_ℐ  x)"

definition less_ℐ :: "('a, 'b) ('a, 'b) bool"
  where "less_ℐ = mk_less (≤)"

instance
proof
  show " < ℐ'    ℐ'  ¬ ℐ'  " for  ℐ' :: "('a, 'b) ℐ" by(simp add: less_ℐ_def mk_less_def)
  show "  " for  :: "('a, 'b) ℐ" by(simp add: le_ℐ_def)
  show "  ℐ''" if "  ℐ'" "ℐ'  ℐ''" for  ℐ' ℐ'' :: "('a, 'b) ℐ" using that
    by(fastforce simp add: le_ℐ_def)
  show " = ℐ'" if "  ℐ'" "ℐ'  " for  ℐ' :: "('a, 'b) ℐ" using that
    by(auto simp add: le_ℐ_def intro!: ℐ_eqI)
qed
end

instantiation:: (type, type) order_bot begin
definition bot_ℐ :: "('a, 'b) ℐ" where "bot_ℐ = ℐ_uniform {} UNIV"
instance by standard(auto simp add: bot_ℐ_def le_ℐ_def)
end

lemma outs_ℐ_bot [simp]: "outs_ℐ bot = {}"
  by(simp add: bot_ℐ_def)

lemma respones_ℐ_bot [simp]: "responses_ℐ bot x = {}"
  by(simp add: bot_ℐ_def)

lemma outs_ℐ_mono: "  ℐ'  outs_ℐ   outs_ℐ ℐ'"
  by(simp add: le_ℐ_def)

lemma responses_ℐ_mono: "   ℐ'; x  outs_ℐ    responses_ℐ ℐ' x  responses_ℐ  x"
  by(simp add: le_ℐ_def)

lemma ℐ_uniform_empty [simp]: "ℐ_uniform {} A = bot" 
  unfolding bot_ℐ_def including ℐ.lifting by transfer simp

lemma ℐ_uniform_mono:
  "ℐ_uniform A B  ℐ_uniform C D" if "A  C" "D  B" "D = {}  B = {}"
  unfolding le_ℐ_def using that by auto


context begin
qualified inductive resultsp_gpv :: "('out, 'in) 'a  ('a, 'out, 'in) gpv  bool"
  for Γ x
where
  Pure: "Pure x  set_spmf (the_gpv gpv)  resultsp_gpv Γ x gpv"
| IO:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ Γ out; resultsp_gpv Γ x (c input) 
   resultsp_gpv Γ x gpv"

definition results_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  'a set"
where "results_gpv Γ gpv  {x. resultsp_gpv Γ x gpv}"

lemma resultsp_gpv_results_gpv_eq [pred_set_conv]: "resultsp_gpv Γ x gpv  x  results_gpv Γ gpv"
by(simp add: results_gpv_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "results_gpv")

lemmas intros [intro?] = resultsp_gpv.intros[to_set]
  and Pure = Pure[to_set]
  and IO = IO[to_set]
  and induct [consumes 1, case_names Pure IO, induct set: results_gpv] = resultsp_gpv.induct[to_set]
  and cases [consumes 1, case_names Pure IO, cases set: results_gpv] = resultsp_gpv.cases[to_set]
  and simps = resultsp_gpv.simps[to_set]
end

inductive_simps results_gpv_GPV [to_set, simp]: "resultsp_gpv Γ x (GPV gpv)"

end

lemma results_gpv_Done [iff]: "results_gpv Γ (Done x) = {x}"
by(auto simp add: Done.ctr)

lemma results_gpv_Fail [iff]: "results_gpv Γ Fail = {}"
by(auto simp add: Fail_def)

lemma results_gpv_Pause [simp]:
  "results_gpv Γ (Pause out c) = (inputresponses_ℐ Γ out. results_gpv Γ (c input))"
by(auto simp add: Pause.ctr)

lemma results_gpv_lift_spmf [iff]: "results_gpv Γ (lift_spmf p) = set_spmf p"
by(auto simp add: lift_spmf.ctr)

lemma results_gpv_assert_gpv [simp]: "results_gpv Γ (assert_gpv b) = (if b then {()} else {})"
by auto

lemma results_gpv_bind_gpv [simp]:
  "results_gpv Γ (gpv  f) = (xresults_gpv Γ gpv. results_gpv Γ (f x))"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  fix x
  assume "x  ?lhs"
  then show "x  ?rhs"
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case Pure thus ?case
      by(auto 4 3 split: if_split_asm intro: results_gpv.intros rev_bexI)
  next
    case (IO out c input)
    from ‹IO out c  _
    obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
      and *: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                                   else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
      by(auto)
    thus ?case
    proof(cases generat)
      case (Pure y)
      with generat have "y  results_gpv Γ gpv" by(auto intro: results_gpv.intros)
      thus ?thesis using * Pure input  responses_ℐ Γ out x  results_gpv Γ (c input)
        by(auto intro: results_gpv.IO)
    next
      case (IO out' c')
      hence [simp]: "out' = out"
        and c: "input. c input = bind_gpv (c' input) f" using * by simp_all
      from IO.hyps(4)[OF c] obtain y where y: "y  results_gpv Γ (c' input)"
        and "x  results_gpv Γ (f y)" by blast
      from y IO generat have "y  results_gpv Γ gpv" using input  responses_ℐ Γ out
        by(auto intro: results_gpv.IO)
      with x  results_gpv Γ (f y) show ?thesis by blast
    qed
  qed
next
  fix x
  assume "x  ?rhs"
  then obtain y where y: "y  results_gpv Γ gpv"
    and x: "x  results_gpv Γ (f y)" by blast
  from y show "x  ?lhs"
  proof(induction)
    case (Pure gpv)
    with x show ?case
      by cases(auto 4 4 intro: results_gpv.intros rev_bexI)
  qed(auto 4 4 intro: rev_bexI results_gpv.IO)
qed

lemma results_gpv_ℐ_full: "results_gpv ℐ_full = results'_gpv"
proof(intro ext set_eqI iffI)
  show "x  results'_gpv gpv" if "x  results_gpv ℐ_full gpv" for x gpv
    using that by induction(auto intro: results'_gpvI)
  show "x  results_gpv ℐ_full gpv" if "x  results'_gpv gpv" for x gpv
    using that by induction(auto intro: results_gpv.intros elim!: generat.set_cases)
qed

lemma results'_bind_gpv [simp]:
  "results'_gpv (bind_gpv gpv f) = (xresults'_gpv gpv. results'_gpv (f x))"
unfolding results_gpv_ℐ_full[symmetric] by simp

lemma results_gpv_map_gpv_id [simp]: "results_gpv  (map_gpv f id gpv) = f ` results_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma results_gpv_map_gpv_id' [simp]: "results_gpv  (map_gpv f (λx. x) gpv) = f ` results_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma pred_gpv_bind [simp]: "pred_gpv P Q (bind_gpv gpv f) = pred_gpv (pred_gpv P Q  f) Q gpv"
by(auto simp add: pred_gpv_def outs_bind_gpv)

lemma results'_gpv_bind_option [simp]:
  "results'_gpv (monad.bind_option Fail x f) = (yset_option x. results'_gpv (f y))"
by(cases x) simp_all

lemma results'_gpv_map_gpv':
  assumes "surj h"
  shows "results'_gpv (map_gpv' f g h gpv) = f ` results'_gpv gpv" (is "?lhs = ?rhs")
proof -
  have *:"IO z c  set_spmf (the_gpv gpv)  x  results'_gpv (c input) 
     f x  results'_gpv (map_gpv' f g h (c input))  f x  results'_gpv (map_gpv' f g h gpv)" for x z gpv c input
    using surjD[OF assms, of input] by(fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def)

  show ?thesis 
  proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
    show "x  ?rhs" if "x  ?lhs" for x using that
      by(induction gpv'"map_gpv' f g h gpv" arbitrary: gpv)(fastforce elim!: generat.set_cases intro: results'_gpvI)+
    show "f x  ?lhs" if "x  results'_gpv gpv" for x using that
      by induction (fastforce intro: results'_gpvI elim!: generat.set_cases intro: rev_image_eqI simp add: map_fun_def o_def
          , clarsimp simp add: *  elim!: generat.set_cases)
  qed
qed

lemma bind_gpv_bind_option_assoc:
  "bind_gpv (monad.bind_option Fail x f) g = monad.bind_option Fail x (λx. bind_gpv (f x) g)"
by(cases x) simp_all

context begin
qualified inductive outsp_gpv :: "('out, 'in) 'out  ('a, 'out, 'in) gpv  bool"
  for  x where
    IO: "IO x c  set_spmf (the_gpv gpv)  outsp_gpv  x gpv"
  | Cont: " IO out rpv  set_spmf (the_gpv gpv); input  responses_ℐ  out; outsp_gpv  x (rpv input) 
   outsp_gpv  x gpv"

definition outs_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  'out set"
  where "outs_gpv  gpv  {x. outsp_gpv  x gpv}"

lemma outsp_gpv_outs_gpv_eq [pred_set_conv]: "outsp_gpv  x = (λgpv. x  outs_gpv  gpv)"
  by(simp add: outs_gpv_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "outs_gpv")

lemmas intros [intro?] = outsp_gpv.intros[to_set]
  and IO = IO[to_set]
  and Cont = Cont[to_set]
  and induct [consumes 1, case_names IO Cont, induct set: outs_gpv] = outsp_gpv.induct[to_set]
  and cases [consumes 1, case_names IO Cont, cases set: outs_gpv] = outsp_gpv.cases[to_set]
  and simps = outsp_gpv.simps[to_set]
end

inductive_simps outs_gpv_GPV [to_set, simp]: "outsp_gpv  x (GPV gpv)"

end

lemma outs_gpv_Done [iff]: "outs_gpv  (Done x) = {}"
  by(auto simp add: Done.ctr)

lemma outs_gpv_Fail [iff]: "outs_gpv  Fail = {}"
  by(auto simp add: Fail_def)

lemma outs_gpv_Pause [simp]:
  "outs_gpv  (Pause out c) = insert out (inputresponses_ℐ  out. outs_gpv  (c input))"
  by(auto simp add: Pause.ctr)

lemma outs_gpv_lift_spmf [iff]: "outs_gpv  (lift_spmf p) = {}"
  by(auto simp add: lift_spmf.ctr)

lemma outs_gpv_assert_gpv [simp]: "outs_gpv  (assert_gpv b) = {}"
  by(cases b)auto

lemma outs_gpv_bind_gpv [simp]:
  "outs_gpv  (gpv  f) = outs_gpv  gpv  (xresults_gpv  gpv. outs_gpv  (f x))"
  (is "?lhs = ?rhs")
proof(intro Set.set_eqI iffI)
  fix x
  assume "x  ?lhs"
  then show "x  ?rhs"
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case IO thus ?case
    proof(clarsimp split: if_split_asm elim!: is_PureE not_is_PureE, goal_cases)
      case (1 generat)
      then show ?case by(cases generat)(auto intro: results_gpv.Pure outs_gpv.intros)
    qed
  next
    case (Cont out rpv input)
    thus ?case
    proof(clarsimp split: if_split_asm, goal_cases)
      case (1 generat)
      then show ?case by(cases generat)(auto 4 3 split: if_split_asm intro: results_gpv.intros outs_gpv.intros)
    qed
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (out) "x  outs_gpv  gpv" | (result) y where "y  results_gpv  gpv" "x  outs_gpv  (f y)" by auto
  then show "x  ?lhs"
  proof cases
    case out then show ?thesis
      by(induction) (auto 4 4 intro: outs_gpv.IO  outs_gpv.Cont rev_bexI) 
  next
    case result then show ?thesis
      by induction ((erule outs_gpv.cases | rule outs_gpv.Cont), 
          auto 4 4 intro: outs_gpv.intros rev_bexI elim: outs_gpv.cases)+
  qed
qed

lemma outs_gpv_ℐ_full: "outs_gpv ℐ_full = outs'_gpv"
proof(intro ext Set.set_eqI iffI)
  show "x  outs'_gpv gpv" if "x  outs_gpv ℐ_full gpv" for x gpv
    using that by induction(auto intro: outs'_gpvI)
  show "x  outs_gpv ℐ_full gpv" if "x  outs'_gpv gpv" for x gpv
    using that by induction(auto intro: outs_gpv.intros elim!: generat.set_cases)
qed

lemma outs'_bind_gpv [simp]:
  "outs'_gpv (bind_gpv gpv f) = outs'_gpv gpv  (xresults'_gpv gpv. outs'_gpv (f x))"
  unfolding outs_gpv_ℐ_full[symmetric] results_gpv_ℐ_full[symmetric] by simp

lemma outs_gpv_map_gpv_id [simp]: "outs_gpv  (map_gpv f id gpv) = outs_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma outs_gpv_map_gpv_id' [simp]: "outs_gpv  (map_gpv f (λx. x) gpv) = outs_gpv  gpv"
  by(auto simp add: map_gpv_conv_bind id_def)

lemma outs'_gpv_bind_option [simp]:
  "outs'_gpv (monad.bind_option Fail x f) = (yset_option x. outs'_gpv (f y))"
  by(cases x) simp_all

lemma rel_gpv''_Grp: includes lifting_syntax shows
  "rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯ = 
   BNF_Def.Grp {x. results_gpv (ℐ_uniform UNIV (range h)) x  A  outs_gpv (ℐ_uniform UNIV (range h)) x  B} (map_gpv' f g h)"
  (is "?lhs = ?rhs")
proof(intro ext GrpI iffI CollectI conjI subsetI)
  let ?ℐ = "ℐ_uniform UNIV (range h)"
  fix gpv gpv'
  assume *: "?lhs gpv gpv'"
  then show "map_gpv' f g h gpv = gpv'"
    by(coinduction arbitrary: gpv gpv')
      (drule rel_gpv''D
        , auto 4 5 simp add: spmf_rel_map generat.rel_map elim!: rel_spmf_mono generat.rel_mono_strong GrpE intro!: GrpI dest: rel_funD)
  show "x  A" if "x  results_gpv ?ℐ gpv" for x using that *
  proof(induction arbitrary: gpv')
    case (Pure gpv)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using Pure.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] ..
    with Pure.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
  next
    case (IO out c gpv input)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with IO.hyps show ?case 
      by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: IO.IH dest!: bspec)
  qed
  show "x  B" if "x  outs_gpv ?ℐ gpv" for x using that *
  proof(induction arbitrary: gpv')
    case (IO c gpv)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using IO.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with IO.hyps show ?case by(simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Domainp_Grp)
  next
    case (Cont out rpv gpv input)
    have "pred_spmf (Domainp (rel_generat (BNF_Def.Grp A f) (BNF_Def.Grp B g) ((BNF_Def.Grp UNIV h)¯¯ ===> rel_gpv'' (BNF_Def.Grp A f) (BNF_Def.Grp B g) (BNF_Def.Grp UNIV h)¯¯))) (the_gpv gpv)"
      using Cont.prems[THEN rel_gpv''D] unfolding spmf_Domainp_rel[symmetric] by(rule DomainPI)
    with Cont.hyps show ?case 
      by(auto simp add: generat.Domainp_rel pred_spmf_def pred_generat_def Grp_iff dest: rel_funD intro: Cont.IH dest!: bspec)
  qed
next
  fix gpv gpv'
  assume "?rhs gpv gpv'"
  then have gpv': "gpv' = map_gpv' f g h gpv"
    and *: "results_gpv (ℐ_uniform UNIV (range h)) gpv  A" "outs_gpv (ℐ_uniform UNIV (range h)) gpv  B" by(auto simp add: Grp_iff)
  show "?lhs gpv gpv'" using * unfolding gpv'
    by(coinduction arbitrary: gpv)
      (fastforce simp add: spmf_rel_map generat.rel_map Grp_iff intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI elim!: generat.set_cases intro: results_gpv.intros outs_gpv.intros)
qed

inductive pred_gpv' :: "('a  bool)  ('out  bool)  'in set  ('a, 'out, 'in) gpv  bool" for P Q X gpv where
  "pred_gpv' P Q X gpv" 
if "x. x  results_gpv (ℐ_uniform UNIV X) gpv  P x" "out. out  outs_gpv (ℐ_uniform UNIV X) gpv  Q out"

lemma pred_gpv_conv_pred_gpv': "pred_gpv P Q = pred_gpv' P Q UNIV"
  by(auto simp add: fun_eq_iff pred_gpv_def pred_gpv'.simps results_gpv_ℐ_full outs_gpv_ℐ_full)

lemma rel_gpv''_map_gpv'1:
  "rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv gpv'  rel_gpv'' A C (=) (map_gpv' id id h gpv) gpv'"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(simp add: spmf_rel_map)
  apply(erule rel_spmf_mono)
  apply(simp add: generat.rel_map)
  apply(erule generat.rel_mono_strong; simp?)
  apply(subst map_fun2_id)
  by(auto simp add: rel_fun_comp intro!: rel_fun_map_fun1 elim: rel_fun_mono)

lemma rel_gpv''_map_gpv'2:
  "rel_gpv'' A C (eq_on (range h)) gpv gpv'  rel_gpv'' A C (BNF_Def.Grp UNIV h)¯¯ gpv (map_gpv' id id h gpv')"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(simp add: spmf_rel_map)
  apply(erule rel_spmf_mono_strong)
  apply(simp add: generat.rel_map)
  apply(erule generat.rel_mono_strong; simp?)
  apply(subst map_fun_id2_in)
  apply(rule rel_fun_map_fun2)
  by (auto simp add: rel_fun_comp  elim: rel_fun_mono)

context
  fixes A :: "'a  'd  bool"
    and C :: "'c  'g  bool"
    and R :: "'b  'e  bool"
begin

private lemma f11:" Pure x  set_spmf (the_gpv gpv) 
   Domainp (rel_generat A C (rel_fun R (rel_gpv'' A C R))) (Pure x)  Domainp A x"
  by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f21: "IO out c  set_spmf (the_gpv gpv)  
  rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) ba  Domainp C out"
  by (auto simp add: pred_generat_def elim:bspec dest: generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])

private lemma f12:
  assumes "IO out c  set_spmf (the_gpv gpv)"
    and "input  responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out"
    and "x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) (c input)"
    and "Domainp (rel_gpv'' A C R) gpv"
  shows "Domainp (rel_gpv'' A C R) (c input)"
proof -
  obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by clarsimp
  obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out c) b2"
    using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_spmf_def by - (drule (1) bspec, auto)

  have "Ball (generat_conts (IO out c)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
    using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_generat_def by simp

  with assms(2) show ?thesis 
    apply -
    apply(drule bspec)
     apply simp
    apply clarify
    apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
    by simp  
qed

private lemma f22:
  assumes "IO out' rpv  set_spmf (the_gpv gpv)"
    and "input  responses_ℐ (ℐ_uniform UNIV {x. Domainp R x}) out'"
    and "out  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) (rpv input)"
    and "Domainp (rel_gpv'' A C R) gpv"
  shows "Domainp (rel_gpv'' A C R) (rpv input)"
proof -
  obtain b1 where o1:"rel_gpv'' A C R gpv b1" using assms(4) by auto
  obtain b2 where o2:"rel_generat A C (rel_fun R (rel_gpv'' A C R)) (IO out' rpv) b2"
    using assms(1) o1[THEN rel_gpv''D, THEN spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_spmf_def by - (drule (1) bspec, auto)

  have "Ball (generat_conts (IO out' rpv)) (Domainp (rel_fun R (rel_gpv'' A C R)))"
    using o2[THEN generat.Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI]]
    unfolding pred_generat_def by simp

  with assms(2) show ?thesis 
    apply -
    apply(drule bspec)
     apply simp
    apply clarify
    apply(drule Domainp_rel_fun_le[THEN predicate1D, OF Domainp_iff[THEN iffD2], OF exI])
    by simp 
qed

lemma Domainp_rel_gpv''_le:
  "Domainp (rel_gpv'' A C R)  pred_gpv' (Domainp A) (Domainp C) {x. Domainp R x}"
proof(rule predicate1I pred_gpv'.intros)+
  show "Domainp A x" if "x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for x gpv using that
  proof(induction)
    case (Pure gpv)
    then show ?case 
      by (clarify) (drule rel_gpv''D
          , auto simp add: f11 pred_spmf_def dest: spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
  qed (simp add: f12) 
  show "Domainp C out" if "out  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv" "Domainp (rel_gpv'' A C R) gpv" for out gpv using that
  proof( induction)
    case (IO c gpv)
    then show ?case
      by (clarify) (drule rel_gpv''D
          , auto simp add: f21 pred_spmf_def dest!: bspec spmf_Domainp_rel[THEN fun_cong, THEN iffD1, OF Domainp_iff[THEN iffD2], OF exI])
  qed (simp add: f22)
qed

end

lemma map_gpv'_id12: "map_gpv' f g h gpv = map_gpv' id id h (map_gpv f g gpv)"
  unfolding map_gpv_conv_map_gpv' map_gpv'_comp by simp

lemma rel_gpv''_refl: " (=)  A; (=)  C; R  (=)   (=)  rel_gpv'' A C R"
  by(subst rel_gpv''_eq[symmetric])(rule rel_gpv''_mono)


context
  fixes A A' :: "'a  'b  bool"
    and C C' :: "'c  'd  bool"
    and R R' :: "'e  'f  bool"
   
begin

private abbreviation foo where 
  "foo  (λ fx fy gpvx gpvy g1 g2. 
            x y. x  fx (ℐ_uniform UNIV (Collect (Domainp R'))) gpvx 
                  y  fy (ℐ_uniform UNIV (Collect (Rangep R'))) gpvy  g1 x y  g2 x y)"

private lemma f1: "foo results_gpv results_gpv gpv gpv' A A' 
       x  set_spmf (the_gpv gpv)  y  set_spmf (the_gpv gpv') 
       a  generat_conts x  b  generat_conts y   R' a' α  R' β b'  
    foo results_gpv results_gpv (a a') (b b') A A'"
  by (fastforce elim: generat.set_cases intro: results_gpv.IO)

private lemma f2: "foo outs_gpv outs_gpv gpv gpv' C C' 
       x  set_spmf (the_gpv gpv)  y  set_spmf (the_gpv gpv') 
       a  generat_conts x  b  generat_conts y  R' a' α  R' β b'  
    foo outs_gpv outs_gpv (a a') (b b') C C'"
  by (fastforce elim: generat.set_cases intro: outs_gpv.Cont)

lemma rel_gpv''_mono_strong:
  " rel_gpv'' A C R gpv gpv'; 
     x y.  x  results_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y  results_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; A x y   A' x y;
     x y.  x  outs_gpv (ℐ_uniform UNIV {x. Domainp R' x}) gpv; y  outs_gpv (ℐ_uniform UNIV {x. Rangep R' x}) gpv'; C x y   C' x y;
     R'  R 
   rel_gpv'' A' C' R' gpv gpv'"
  apply(coinduction arbitrary: gpv gpv')
  apply(drule rel_gpv''D)
  apply(erule rel_spmf_mono_strong)
  apply(erule generat.rel_mono_strong)
    apply(erule generat.set_cases)+
    apply(erule allE, rotate_tac -1)
    apply(erule allE)
    apply(erule impE)
     apply(rule results_gpv.Pure)
     apply simp
    apply(erule impE)
     apply(rule results_gpv.Pure)
     apply simp
    apply simp
   apply(erule generat.set_cases)+
   apply(rotate_tac 1)
   apply(erule allE, rotate_tac -1)
   apply(erule allE)
   apply(erule impE)
    apply(rule outs_gpv.IO)
    apply simp
   apply(erule impE)
    apply(rule outs_gpv.IO)
    apply simp
   apply simp
  apply(erule (1) rel_fun_mono_strong)
  by (fastforce simp add: f1[simplified] f2[simplified])

end

lemma rel_gpv''_refl_strong:
  assumes "x. x  results_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv  A x x"
    and "x. x  outs_gpv (ℐ_uniform UNIV {x. Domainp R x}) gpv  C x x"
    and "R  (=)"
  shows "rel_gpv'' A C R gpv gpv"
proof -
  have "rel_gpv'' (=) (=) (=) gpv gpv" unfolding rel_gpv''_eq by simp
  then show ?thesis using _ _ assms(3) by(rule rel_gpv''_mono_strong)(auto intro: assms(1-2))
qed

lemma rel_gpv''_refl_eq_on:
  " x. x  results_gpv (ℐ_uniform UNIV X) gpv  A x x; out. out  outs_gpv (ℐ_uniform UNIV X) gpv  B out out 
   rel_gpv'' A B (eq_on X) gpv gpv"
  by(rule rel_gpv''_refl_strong) (auto elim: eq_onE)

lemma pred_gpv'_mono' [mono]:
  "pred_gpv' A C R gpv  pred_gpv' A' C' R gpv"
  if "x. A x  A' x" "x. C x  C' x"
  using that unfolding pred_gpv'.simps
  by auto

subsubsection ‹Type judgements›

coinductive WT_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool" ("((_)/ ⊢g (_) )" [100, 0] 99)
  for Γ
where
  "(out c. IO out c  set_spmf gpv  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))
   Γ ⊢g GPV gpv "

lemma WT_gpv_coinduct [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont, coinduct pred: WT_gpv]:
  assumes *: "X gpv"
  and step: "gpv out c.
     X gpv; IO out c  set_spmf (the_gpv gpv) 
     out  outs_ℐ Γ  (input  responses_ℐ Γ out. X (c input)  Γ ⊢g c input )"
  shows "Γ ⊢g gpv "
using * by(coinduct)(auto dest: step simp add: eq_GPV_iff)

lemma WT_gpv_simps:
  "Γ ⊢g GPV gpv  
   (out c. IO out c  set_spmf gpv  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))"
by(subst WT_gpv.simps) simp

lemma WT_gpvI:
  "(out c. IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input ))
   Γ ⊢g gpv "
by(cases gpv)(simp add: WT_gpv_simps)

lemma WT_gpvD:
  assumes "Γ ⊢g gpv "
  shows WT_gpv_OutD: "IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ"
  and WT_gpv_ContD: " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ Γ out   Γ ⊢g c input "
using assms by(cases, fastforce)+

lemma WT_gpv_mono:
  assumes WT: "ℐ1 ⊢g gpv "
  and outs: "outs_ℐ ℐ1  outs_ℐ ℐ2"
  and responses: "x. x  outs_ℐ ℐ1  responses_ℐ ℐ2 x  responses_ℐ ℐ1 x"
  shows "ℐ2 ⊢g gpv "
using WT
proof coinduct
  case (WT_gpv gpv out c)
  with outs show ?case by(auto 6 4 dest: responses WT_gpvD)
qed

lemma WT_gpv_Done [iff]: "Γ ⊢g Done x "
by(rule WT_gpvI) simp_all

lemma WT_gpv_Fail [iff]: "Γ ⊢g Fail "
by(rule WT_gpvI) simp_all

lemma WT_gpv_PauseI: 
  " out  outs_ℐ Γ; input. input  responses_ℐ Γ out  Γ ⊢g c input  
    Γ ⊢g Pause out c "
by(rule WT_gpvI) simp_all

lemma WT_gpv_Pause [iff]:
  "Γ ⊢g Pause out c   out  outs_ℐ Γ  (input  responses_ℐ Γ out. Γ ⊢g c input )"
by(auto intro: WT_gpv_PauseI dest: WT_gpvD)

lemma WT_gpv_bindI:
  " Γ ⊢g gpv ; x. x  results_gpv Γ gpv  Γ ⊢g f x  
   Γ ⊢g gpv  f "
proof(coinduction arbitrary: gpv)
  case [rule_format]: (WT_gpv out c gpv)
  from ‹IO out c  _
  obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
    and *: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                                 else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
    by(auto)
  show ?case
  proof(cases generat)
    case (Pure y)
    with generat have "y  results_gpv Γ gpv" by(auto intro: results_gpv.Pure)
    hence "Γ ⊢g f y " by(rule WT_gpv)
    with * Pure show ?thesis by(auto dest: WT_gpvD) 
  next
    case (IO out' c')
    hence [simp]: "out' = out"
      and c: "input. c input = bind_gpv (c' input) f" using * by simp_all
    from generat IO have **: "IO out c'  set_spmf (the_gpv gpv)" by simp
    with Γ ⊢g gpv  have ?out by(auto dest: WT_gpvD)
    moreover {
      fix input
      assume input: "input  responses_ℐ Γ out"
      with Γ ⊢g gpv  ** have "Γ ⊢g c' input " by(rule WT_gpvD)
      moreover {
        fix y
        assume "y  results_gpv Γ (c' input)"
        with ** input have "y  results_gpv Γ gpv" by(rule results_gpv.IO)
        hence "Γ ⊢g f y " by(rule WT_gpv) }
      moreover note calculation }
    hence ?cont using c by blast
    ultimately show ?thesis ..
  qed
qed

lemma WT_gpv_bindD2:
  assumes WT: "Γ ⊢g gpv  f "
  and x: "x  results_gpv Γ gpv"
  shows "Γ ⊢g f x "
using x WT
proof induction
  case (Pure gpv)
  show ?case
  proof(rule WT_gpvI)
    fix out c
    assume "IO out c  set_spmf (the_gpv (f x))"
    with Pure have "IO out c  set_spmf (the_gpv (gpv  f))" by(auto intro: rev_bexI)
    with Γ ⊢g gpv  f  show "out  outs_ℐ Γ  (inputresponses_ℐ Γ out. Γ ⊢g c input )"
      by(auto dest: WT_gpvD simp del: set_bind_spmf)
  qed
next
  case (IO out c gpv input)
  from ‹IO out c  _
  have "IO out (λinput. bind_gpv (c input) f)  set_spmf (the_gpv (gpv  f))"
    by(auto intro: rev_bexI)
  with IO.prems have "Γ ⊢g c input  f " using input  _ by(rule WT_gpv_ContD)
  thus ?case by(rule IO.IH)
qed

lemma WT_gpv_bindD1: "Γ ⊢g gpv  f   Γ ⊢g gpv "
proof(coinduction arbitrary: gpv)
  case (WT_gpv out c gpv)
  from ‹IO out c  _
  have "IO out (λinput. bind_gpv (c input) f)  set_spmf (the_gpv (gpv  f))"
    by(auto intro: rev_bexI)
  with Γ ⊢g gpv  f  show ?case
    by(auto simp del: bind_gpv_sel' dest: WT_gpvD)
qed

lemma WT_gpv_bind [simp]: "Γ ⊢g gpv  f   Γ ⊢g gpv   (xresults_gpv Γ gpv. Γ ⊢g f x )"
by(blast intro: WT_gpv_bindI dest: WT_gpv_bindD1 WT_gpv_bindD2)

lemma WT_gpv_full [simp, intro!]: "ℐ_full ⊢g gpv "
by(coinduction arbitrary: gpv)(auto)

lemma WT_gpv_lift_spmf [simp, intro!]: " ⊢g lift_spmf p "
by(rule WT_gpvI) auto

lemma WT_gpv_coinduct_bind [consumes 1, case_names WT_gpv, case_conclusion WT_gpv out cont]:
  assumes *: "X gpv"
  and step: "gpv out c.  X gpv; IO out c  set_spmf (the_gpv gpv) 
     out  outs_ℐ   (inputresponses_ℐ  out.
            X (c input) 
             ⊢g c input  
            ((gpv' :: ('b, 'call, 'ret) gpv) f. c input = gpv'  f   ⊢g gpv'   (xresults_gpv  gpv'. X (f x))))"
  shows " ⊢g gpv "
proof -
  fix x
  define gpv' :: "('b, 'call, 'ret) gpv" and f :: "'b  ('a, 'call, 'ret) gpv"
    where "gpv' = Done x" and "f = (λ_. gpv)"
  with * have " ⊢g gpv' " and "x. x  results_gpv  gpv'  X (f x)" by simp_all
  then have " ⊢g gpv'  f "
  proof(coinduction arbitrary: gpv' f rule: WT_gpv_coinduct)
    case [rule_format]: (WT_gpv out c gpv')
    from ‹IO out c  _
    obtain generat where generat: "generat  set_spmf (the_gpv gpv')"
      and *: "IO out c  set_spmf (if is_Pure generat
        then the_gpv (f (result generat))
        else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
      by(clarsimp)
    show ?case
    proof(cases generat)
      case (Pure x)
      from Pure * have IO: "IO out c  set_spmf (the_gpv (f x))" by simp
      from generat Pure have "x  results_gpv  gpv'" by (simp add: results_gpv.Pure)
      then have "X (f x)" by(rule WT_gpv)
      from step[OF this IO] show ?thesis by(auto 4 4 intro: exI[where x="Done _"])
    next
      case (IO out' c')
      with * have [simp]: "out' = out"
        and c: "c = (λinput. c' input  f)" by simp_all
      from IO generat have IO: "IO out c'  set_spmf (the_gpv gpv')" by simp
      then have "input. input  responses_ℐ  out  results_gpv  (c' input)  results_gpv  gpv'"
        by(auto intro: results_gpv.IO)
      with WT_gpvD[OF  ⊢g gpv'  IO] show ?thesis unfolding c using WT_gpv(2) by blast
    qed
  qed
  then show ?thesis unfolding gpv'_def f_def by simp
qed

lemma ℐ_trivial_WT_gpvD [simp]: "ℐ_trivial    ⊢g gpv "
using WT_gpv_full by(rule WT_gpv_mono)(simp_all add: ℐ_trivial_def)

lemma ℐ_trivial_WT_gpvI: 
  assumes "gpv :: ('a, 'out, 'in) gpv.  ⊢g gpv "
  shows "ℐ_trivial "
proof
  fix x
  have " ⊢g Pause x (λ_. Fail :: ('a, 'out, 'in) gpv) " by(rule assms)
  thus "x  outs_ℐ " by(simp)
qed

lemma WT_gpv_ℐ_mono: "  ⊢g gpv ;   ℐ'   ℐ' ⊢g gpv "
  by(erule WT_gpv_mono; rule outs_ℐ_mono responses_ℐ_mono)

lemma results_gpv_mono:
  assumes le: "ℐ'  " and WT: "ℐ' ⊢g gpv "
  shows "results_gpv  gpv  results_gpv ℐ' gpv"
proof(rule subsetI, goal_cases)
  case (1 x)
  show ?case using 1 WT by(induction)(auto 4 3 intro: results_gpv.intros responses_ℐ_mono[OF le, THEN subsetD] intro: WT_gpvD)
qed

lemma WT_gpv_outs_gpv:
  assumes " ⊢g gpv "
  shows "outs_gpv  gpv  outs_ℐ "
proof
  show "x  outs_ℐ " if "x  outs_gpv  gpv" for x using that assms
    by(induction)(blast intro: WT_gpv_OutD WT_gpv_ContD)+
qed

lemma WT_gpv_map_gpv': " ⊢g map_gpv' f g h gpv " if "map_ℐ g h  ⊢g gpv "
  using that by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma WT_gpv_map_gpv: " ⊢g map_gpv f g gpv " if "map_ℐ g id  ⊢g gpv "
  unfolding map_gpv_conv_map_gpv' using that by(rule WT_gpv_map_gpv')

lemma results_gpv_map_gpv' [simp]:
  "results_gpv  (map_gpv' f g h gpv) = f ` (results_gpv (map_ℐ g h ) gpv)"
proof(intro Set.set_eqI iffI; (elim imageE; hypsubst)?)
  show "x  f ` results_gpv (map_ℐ g h ) gpv" if "x  results_gpv  (map_gpv' f g h gpv)" for x using that
    by(induction gpv'"map_gpv' f g h gpv" arbitrary: gpv)(fastforce intro: results_gpv.intros rev_image_eqI)+
  show "f x  results_gpv  (map_gpv' f g h gpv)" if "x  results_gpv (map_ℐ g h ) gpv" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma WT_gpv_parametric': includes lifting_syntax shows
  "bi_unique C  (rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) WT_gpv WT_gpv"
proof(rule rel_funI iffI)+
  note [transfer_rule] = the_gpv_parametric'
  show *: " ⊢g gpv " if [transfer_rule]: "rel_ℐ C R  ℐ'" "bi_unique C" 
    and *: "ℐ' ⊢g gpv' " "rel_gpv'' A C R gpv gpv'" for  ℐ' gpv gpv' A C R
    using *
  proof(coinduction arbitrary: gpv gpv')
    case (WT_gpv out c gpv gpv')
    note [transfer_rule] = WT_gpv(2)
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))" 
      by transfer_prover
    from rel_setD1[OF this WT_gpv(3)] obtain out' c'
      where [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
        and out': "IO out' c'  set_spmf (the_gpv gpv')"
      by(auto elim: generat.rel_cases)
    have "out  outs_ℐ   out'  outs_ℐ ℐ'" by transfer_prover
    with WT_gpvD(1)[OF WT_gpv(1) out'] have ?out by simp
    moreover have ?cont
    proof(standard; goal_cases cont)
      case (cont input)
      have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
      from rel_setD1[OF this cont] obtain input' where [transfer_rule]: "R input input'"
        and input': "input'  responses_ℐ ℐ' out'" by blast
      have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
      with WT_gpvD(2)[OF WT_gpv(1) out' input'] show ?case by auto
    qed
    ultimately show ?case ..
  qed

  show "ℐ' ⊢g gpv' " if "rel_ℐ C R  ℐ'" "bi_unique C" " ⊢g gpv " "rel_gpv'' A C R gpv gpv'" 
    for  ℐ' gpv gpv'
    using *[of "conversep C" "conversep R" ℐ'  gpv "conversep A" gpv'] that
    by(simp add: rel_gpv''_conversep)
qed

lemma WT_gpv_map_gpv_id [simp]: " ⊢g map_gpv f id gpv    ⊢g gpv "
  using WT_gpv_parametric'[of "BNF_Def.Grp UNIV id" "(=)" "BNF_Def.Grp UNIV f", folded rel_gpv_conv_rel_gpv'']
  unfolding gpv.rel_Grp unfolding eq_alt[symmetric] relator_eq
  by(auto simp add: rel_fun_def Grp_def bi_unique_eq)

lemma WT_gpv_outs_gpvI:
  assumes "outs_gpv  gpv  outs_ℐ "
  shows " ⊢g gpv "
  using assms by(coinduction arbitrary: gpv)(auto intro: outs_gpv.intros)

lemma WT_gpv_iff_outs_gpv:
  " ⊢g gpv   outs_gpv  gpv  outs_ℐ "
  by(blast intro: WT_gpv_outs_gpvI dest: WT_gpv_outs_gpv)

subsection ‹Sub-gpvs›

context begin
qualified inductive sub_gpvsp :: "('out, 'in) ('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv  bool"
  for  x
where
  base:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out; x = c input  
   sub_gpvsp  x gpv"
| cont: 
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out; sub_gpvsp  x (c input) 
   sub_gpvsp  x gpv"

qualified lemma sub_gpvsp_base:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out  
   sub_gpvsp  (c input) gpv"
by(rule base) simp_all

definition sub_gpvs :: "('out, 'in)  ('a, 'out, 'in) gpv  ('a, 'out, 'in) gpv set"
where "sub_gpvs  gpv  {x. sub_gpvsp  x gpv}"

lemma sub_gpvsp_sub_gpvs_eq [pred_set_conv]: "sub_gpvsp  x gpv  x  sub_gpvs  gpv"
by(simp add: sub_gpvs_def)

context begin
local_setup ‹Local_Theory.map_background_naming (Name_Space.mandatory_path "sub_gpvs")

lemmas intros [intro?] = sub_gpvsp.intros[to_set]
  and base = sub_gpvsp_base[to_set]
  and cont = cont[to_set]
  and induct [consumes 1, case_names Pure IO, induct set: sub_gpvs] = sub_gpvsp.induct[to_set]
  and cases [consumes 1, case_names Pure IO, cases set: sub_gpvs] = sub_gpvsp.cases[to_set]
  and simps = sub_gpvsp.simps[to_set]
end
end

lemma WT_sub_gpvsD:
  assumes " ⊢g gpv " and "gpv'  sub_gpvs  gpv"
  shows " ⊢g gpv' "
using assms(2,1) by(induction)(auto dest: WT_gpvD)

lemma WT_sub_gpvsI:
  " out c. IO out c  set_spmf (the_gpv gpv)  out  outs_ℐ Γ; 
     gpv'. gpv'  sub_gpvs Γ gpv  Γ ⊢g gpv'  
   Γ ⊢g gpv "
by(rule WT_gpvI)(auto intro: sub_gpvs.base)

subsection ‹Losslessness›

text ‹A gpv is lossless iff we are guaranteed to get a result after a finite number of interactions
  that respect the interface. It is colossless if the interactions may go on for ever, but there is
  no non-termination.›

text ‹ We define both notions of losslessness simultaneously by mimicking what the (co)inductive
  package would do internally. Thus, we get a constant which is parametrised by the choice of the
  fixpoint, i.e., for non-recursive gpvs, we can state and prove both versions of losslessness
  in one go.›

context
  fixes co :: bool and  :: "('out, 'in) ℐ"
  and F :: "(('a, 'out, 'in) gpv  bool)  (('a, 'out, 'in) gpv  bool)"
  and co' :: bool
  defines "F  λgen_lossless_gpv gpv. pa. gpv = GPV pa  
     lossless_spmf pa  (out c input. IO out c  set_spmf pa  input  responses_ℐ  out  gen_lossless_gpv (c input))"
  and "co'  co" ― ‹We use a copy of @{term co} such that we can do case distinctions on @{term co'} without
    the simplifier rewriting the @{term co} in the local abbreviations for the constants.›
begin

lemma gen_lossless_gpv_mono: "mono F"
unfolding F_def
apply(rule monoI le_funI le_boolI')+
apply(tactic ‹REPEAT (resolve_tac @{context} (Inductive.get_monos @{context}) 1))
apply(erule le_funE)
apply(erule le_boolD)
done

definition gen_lossless_gpv :: "('a, 'out, 'in) gpv  bool"
where "gen_lossless_gpv = (if co' then gfp else lfp) F"

lemma gen_lossless_gpv_unfold: "gen_lossless_gpv = F gen_lossless_gpv"
by(simp add: gen_lossless_gpv_def gfp_unfold[OF gen_lossless_gpv_mono, symmetric] lfp_unfold[OF gen_lossless_gpv_mono, symmetric])

lemma gen_lossless_gpv_True: "co' = True  gen_lossless_gpv  gfp F"
  and gen_lossless_gpv_False: "co' = False  gen_lossless_gpv  lfp F"
by(simp_all add: gen_lossless_gpv_def)

lemma gen_lossless_gpv_cases [elim?, cases pred]:
  assumes "gen_lossless_gpv gpv"
  obtains (gen_lossless_gpv) p where "gpv = GPV p" "lossless_spmf p" 
    "out c input. IO out c  set_spmf p; input  responses_ℐ  out  gen_lossless_gpv (c input)"
proof -
  from assms show ?thesis
    by(rewrite in asm gen_lossless_gpv_unfold)(auto simp add: F_def intro: that)
qed

lemma gen_lossless_gpvD:
  assumes "gen_lossless_gpv gpv"
  shows gen_lossless_gpv_lossless_spmfD: "lossless_spmf (the_gpv gpv)"
  and gen_lossless_gpv_continuationD:
  " IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out   gen_lossless_gpv (c input)"
using assms by(auto elim: gen_lossless_gpv_cases)

lemma gen_lossless_gpv_intros:
  " lossless_spmf p;
     out c input. IO out c  set_spmf p; input  responses_ℐ  out   gen_lossless_gpv (c input) 
   gen_lossless_gpv (GPV p)"
by(rewrite gen_lossless_gpv_unfold)(simp add: F_def)

lemma gen_lossless_gpvI [intro?]:
  " lossless_spmf (the_gpv gpv);
     out c input.  IO out c  set_spmf (the_gpv gpv); input  responses_ℐ  out 
      gen_lossless_gpv (c input) 
   gen_lossless_gpv gpv"
by(cases gpv)(auto intro: gen_lossless_gpv_intros)

lemma gen_lossless_gpv_simps:
  "gen_lossless_gpv gpv 
   (p. gpv = GPV p  lossless_spmf p  (out c input.
        IO out c  set_spmf p  input  responses_ℐ  out  gen_lossless_gpv (c input)))"
by(rewrite gen_lossless_gpv_unfold)(simp add: F_def)

lemma gen_lossless_gpv_Done [iff]: "gen_lossless_gpv (Done x)"
by(rule gen_lossless_gpvI) auto

lemma gen_lossless_gpv_Fail [iff]: "¬ gen_lossless_gpv Fail"
by(auto dest: gen_lossless_gpvD)

lemma gen_lossless_gpv_Pause [simp]:
  "gen_lossless_gpv (Pause out c)  (input  responses_ℐ  out. gen_lossless_gpv (c input))"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

lemma gen_lossless_gpv_lift_spmf [iff]: "gen_lossless_gpv (lift_spmf p)  lossless_spmf p"
by(auto dest: gen_lossless_gpvD intro: gen_lossless_gpvI)

end

lemma gen_lossless_gpv_assert_gpv [iff]: "gen_lossless_gpv co  (assert_gpv b)  b"
by(cases b) simp_all

abbreviation lossless_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool"
where "lossless_gpv  gen_lossless_gpv False"

abbreviation colossless_gpv :: "('out, 'in) ('a, 'out, 'in) gpv  bool"
where "colossless_gpv  gen_lossless_gpv True"

lemma lossless_gpv_induct [consumes 1, case_names lossless_gpv, induct pred]:
  assumes *: "lossless_gpv  gpv"
  and step: "p.  lossless_spmf p;
     out c input. IO out c  set_spmf p; input  responses_ℐ  out  lossless_gpv  (c input);
     out c input. IO out c  set_spmf p; input  responses_ℐ  out  P (c input) 
      P (GPV p)"
  shows "P gpv"
proof -
  have "lossless_gpv   P"
    by(rule def_lfp_induct[OF gen_lossless_gpv_False gen_lossless_gpv_mono])(auto intro!: le_funI step)
  then show ?thesis using * by auto
qed

lemma colossless_gpv_coinduct 
  [consumes 1, case_names colossless_gpv, case_conclusion colossless_gpv lossless_spmf continuation, coinduct pred]:
  assumes *: "X gpv"
  and step: "gpv. X gpv  lossless_spmf (the_gpv gpv)  (out c input. 
       IO out c  set_spmf (the_gpv gpv)  input  responses_ℐ  out  X (c input)  colossless_gpv  (c input))"
  shows "colossless_gpv  gpv"
proof -
  have "X  colossless_gpv "
    by(rule def_coinduct[OF gen_lossless_gpv_True gen_lossless_gpv_mono])
      (auto 4 4 intro!: le_funI dest!: step intro: exI[where x="the_gpv _"])
  then show ?thesis using * by auto
qed

lemmas lossless_gpvI = gen_lossless_gpvI[where co=False]
  and lossless_gpvD = gen_lossless_gpvD[where co=False]
  and lossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=False]
  and lossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=False]

lemmas colossless_gpvI = gen_lossless_gpvI[where co=True]
  and colossless_gpvD = gen_lossless_gpvD[where co=True]
  and colossless_gpv_lossless_spmfD = gen_lossless_gpv_lossless_spmfD[where co=True]
  and colossless_gpv_continuationD = gen_lossless_gpv_continuationD[where co=True]

lemma gen_lossless_bind_gpvI:
  assumes "gen_lossless_gpv co  gpv" "x. x  results_gpv  gpv  gen_lossless_gpv co  (f x)"
  shows "gen_lossless_gpv co  (gpv  f)"
proof(cases co)
  case False
  hence eq: "co = False" by simp
  show ?thesis using assms unfolding eq
  proof(induction)
    case (lossless_gpv p)
    { fix x
      assume "Pure x  set_spmf p"
      hence "x  results_gpv  (GPV p)" by simp
      hence "lossless_gpv  (f x)" by(rule lossless_gpv.prems) }
    with ‹lossless_spmf p show ?case unfolding GPV_bind
      apply(intro gen_lossless_gpv_intros)
       apply(fastforce dest: lossless_gpvD split: generat.split)
      apply(clarsimp; split generat.split_asm)
      apply(auto dest: lossless_gpvD intro!: lossless_gpv)
      done
  qed
next
  case True
  hence eq: "co = True" by simp
  show ?thesis using assms unfolding eq
  proof(coinduction arbitrary: gpv rule: colossless_gpv_coinduct)
    case * [rule_format]: (colossless_gpv gpv)
    from *(1) have ?lossless_spmf 
      by(auto 4 3 dest: colossless_gpv_lossless_spmfD elim!: is_PureE intro: *(2)[THEN colossless_gpv_lossless_spmfD] results_gpv.Pure)
    moreover have ?continuation
    proof(intro strip)
      fix out c input
      assume IO: "IO out c  set_spmf (the_gpv (gpv  f))"
        and input: "input  responses_ℐ  out"
      from IO obtain generat where generat: "generat  set_spmf (the_gpv gpv)"
        and IO: "IO out c  set_spmf (if is_Pure generat then the_gpv (f (result generat))
                 else return_spmf (IO (output generat) (λinput. continuation generat input  f)))"
        by(auto)
      show "(gpv. c input = gpv  f  colossless_gpv  gpv  (x. x  results_gpv  gpv  colossless_gpv  (f x))) 
        colossless_gpv  (c input)"
      proof(cases generat)
        case (Pure x)
        hence "x  results_gpv  gpv" using generat by(auto intro: results_gpv.Pure)
        from *(2)[OF this] have "colossless_gpv  (c input)"
          using IO Pure input by(auto intro: colossless_gpv_continuationD)
        thus ?thesis ..
      next
        case **: (IO out' c')
        with input generat IO have "colossless_gpv  (f x)" if "x  results_gpv  (c' input)" for x
          using that by(auto intro: * results_gpv.IO)
        then show ?thesis using IO input ** *(1) generat by(auto dest: colossless_gpv_continuationD)
      qed
    qed
    ultimately show ?case ..
  qed
qed

lemmas lossless_bind_gpvI = gen_lossless_bind_gpvI[where co=False]
  and colossless_bind_gpvI = gen_lossless_bind_gpvI[where co=True]

lemma gen_lossless_bind_gpvD1: 
  assumes "gen_lossless_gpv co  (gpv  f)"
  shows "gen_lossless_gpv co  gpv"
proof(cases co)
  case False
  hence eq: "co = False" by simp
  show ?thesis using assms unfolding eq
  proof(induction gpv'"gpv  f" arbitrary: gpv)
    case (lossless_gpv p)
    obtain p' where gpv: "gpv = GPV p'" by(cases gpv)
    from lossless_gpv.hyps gpv have "lossless_spmf p'" by(simp add: GPV_bind)
    then show ?case unfolding gpv
    proof(rule gen_lossless_gpv_intros)
      fix out c input
      assume "IO out c  set_spmf p'" "input  responses_ℐ  out"
      hence "IO out (λinput. c input  f)  set_spmf p" using lossless_gpv.hyps gpv
        by(auto simp add: GPV_bind intro: rev_bexI)
      thus "lossless_gpv  (c input)" using input  _ by(rule lossless_gpv.hyps) simp
    qed
  qed
next
  case True
  hence eq: "co = True" by simp
  show ?thesis using assms unfolding eq
    by(coinduction arbitrary: gpv)(auto 4 3 intro: rev_bexI elim!: colossless_gpv_continuationD dest: colossless_gpv_lossless_spmfD)
qed

lemmas lossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=False]
  and colossless_bind_gpvD1 = gen_lossless_bind_gpvD1[where co=True]

lemma gen_lossless_bind_gpvD2:
  assumes "gen_lossless_gpv co  (gpv  f)"
  and "x  results_gpv  gpv"
  shows "gen_lossless_gpv co  (f x)"
using assms(2,1)
proof(induction)
  case (Pure gpv)
  thus ?case
    by -(rule gen_lossless_gpvI, auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)
qed(auto 4 4 dest: gen_lossless_gpvD intro: rev_bexI)

lemmas lossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=False]
  and colossless_bind_gpvD2 = gen_lossless_bind_gpvD2[where co=True]

lemma gen_lossless_bind_gpv [simp]:
  "gen_lossless_gpv co  (gpv  f)  gen_lossless_gpv co  gpv  (xresults_gpv  gpv. gen_lossless_gpv co  (f x))"
by(blast intro: gen_lossless_bind_gpvI dest: gen_lossless_bind_gpvD1 gen_lossless_bind_gpvD2)

lemmas lossless_bind_gpv = gen_lossless_bind_gpv[where co=False]
  and colossless_bind_gpv = gen_lossless_bind_gpv[where co=True]

context includes lifting_syntax begin

lemma rel_gpv''_lossless_gpvD1:
  assumes rel: "rel_gpv'' A C R gpv gpv'"
  and gpv: "lossless_gpv  gpv"
  and [transfer_rule]: "rel_ℐ C R  ℐ'"
  shows "lossless_gpv ℐ' gpv'"
using gpv rel
proof(induction arbitrary: gpv')
  case (lossless_gpv p)
  from lossless_gpv.prems obtain q where q: "gpv' = GPV q"
    and [transfer_rule]: "rel_spmf (rel_generat A C (R ===> rel_gpv'' A C R)) p q"
    by(cases gpv') auto
  show ?case
  proof(rule lossless_gpvI)
    have "lossless_spmf p = lossless_spmf q" by transfer_prover
    with lossless_gpv.hyps(1) q show "lossless_spmf (the_gpv gpv')" by simp

    fix out' c' input'
    assume IO': "IO out' c'  set_spmf (the_gpv gpv')"
      and input': "input'  responses_ℐ ℐ' out'"
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf p) (set_spmf q)"
      by transfer_prover
    with IO' q obtain out c where IO: "IO out c  set_spmf p"
      and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
      by(auto dest!: rel_setD2 elim: generat.rel_cases)
    have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
    moreover
    from this[THEN rel_setD2, OF input'] obtain input
      where [transfer_rule]: "R input input'" and input: "input  responses_ℐ  out" by blast
    have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
    ultimately show "lossless_gpv ℐ' (c' input')" using input IO by(auto intro: lossless_gpv.IH)
  qed
qed

lemma rel_gpv''_lossless_gpvD2:
  " rel_gpv'' A C R gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C R  ℐ' 
   lossless_gpv  gpv"
using rel_gpv''_lossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ]
by(simp add: rel_gpv''_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv_lossless_gpvD1:
  " rel_gpv A C gpv gpv'; lossless_gpv  gpv; rel_ℐ C (=)  ℐ'   lossless_gpv ℐ' gpv'"
using rel_gpv''_lossless_gpvD1[of A C "(=)" gpv gpv'  ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_lossless_gpvD2:
  " rel_gpv A C gpv gpv'; lossless_gpv ℐ' gpv'; rel_ℐ C (=)  ℐ' 
   lossless_gpv  gpv"
using rel_gpv_lossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ]
by(simp add: gpv.rel_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv''_colossless_gpvD1:
  assumes rel: "rel_gpv'' A C R gpv gpv'"
  and gpv: "colossless_gpv  gpv"
  and [transfer_rule]: "rel_ℐ C R  ℐ'"
  shows "colossless_gpv ℐ' gpv'"
using gpv rel
proof(coinduction arbitrary: gpv gpv')
  case (colossless_gpv gpv gpv')
  note [transfer_rule] = ‹rel_gpv'' A C R gpv gpv' the_gpv_parametric'
    and co = ‹colossless_gpv  gpv
  have "lossless_spmf (the_gpv gpv) = lossless_spmf (the_gpv gpv')" by transfer_prover
  with co have "?lossless_spmf" by(auto dest: colossless_gpv_lossless_spmfD)
  moreover have "?continuation"
  proof(intro strip disjI1)
    fix out' c' input'
    assume IO': "IO out' c'  set_spmf (the_gpv gpv')"
      and input': "input'  responses_ℐ ℐ' out'"
    have "rel_set (rel_generat A C (R ===> rel_gpv'' A C R)) (set_spmf (the_gpv gpv)) (set_spmf (the_gpv gpv'))"
      by transfer_prover
    with IO' obtain out c where IO: "IO out c  set_spmf (the_gpv gpv)"
      and [transfer_rule]: "C out out'" "(R ===> rel_gpv'' A C R) c c'"
      by(auto dest!: rel_setD2 elim: generat.rel_cases)
    have "rel_set R (responses_ℐ  out) (responses_ℐ ℐ' out')" by transfer_prover
    moreover 
    from this[THEN rel_setD2, OF input'] obtain input
      where [transfer_rule]: "R input input'" and input: "input  responses_ℐ  out" by blast
    have "rel_gpv'' A C R (c input) (c' input')" by transfer_prover
    ultimately show "gpv gpv'. c' input' = gpv'  colossless_gpv  gpv  rel_gpv'' A C R gpv gpv'"
      using input IO co by(auto dest: colossless_gpv_continuationD)
  qed
  ultimately show ?case ..
qed

lemma rel_gpv''_colossless_gpvD2:
  " rel_gpv'' A C R gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C R  ℐ' 
   colossless_gpv  gpv"
using rel_gpv''_colossless_gpvD1[of "A¯¯" "C¯¯" "R¯¯" gpv' gpv ℐ' ]
by(simp add: rel_gpv''_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma rel_gpv_colossless_gpvD1:
  " rel_gpv A C gpv gpv'; colossless_gpv  gpv; rel_ℐ C (=)  ℐ'   colossless_gpv ℐ' gpv'"
using rel_gpv''_colossless_gpvD1[of A C "(=)" gpv gpv'  ℐ'] by(simp add: rel_gpv_conv_rel_gpv'')

lemma rel_gpv_colossless_gpvD2:
  " rel_gpv A C gpv gpv'; colossless_gpv ℐ' gpv'; rel_ℐ C (=)  ℐ' 
   colossless_gpv  gpv"
using rel_gpv_colossless_gpvD1[of "A¯¯" "C¯¯" gpv' gpv ℐ' ]
by(simp add: gpv.rel_conversep prod.rel_conversep rel_fun_eq_conversep)

lemma gen_lossless_gpv_parametric':
  "((=) ===> rel_ℐ C R ===> rel_gpv'' A C R ===> (=))
   gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
  show "(rel_ℐ C R ===> rel_gpv'' A C R ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
    by(cases b)(auto intro!: rel_funI dest: rel_gpv''_colossless_gpvD1 rel_gpv''_colossless_gpvD2 rel_gpv''_lossless_gpvD1 rel_gpv''_lossless_gpvD2)
qed

lemma gen_lossless_gpv_parametric [transfer_rule]:
  "((=) ===> rel_ℐ C (=) ===> rel_gpv A C ===> (=))
   gen_lossless_gpv gen_lossless_gpv"
proof(rule rel_funI; hypsubst)
  show "(rel_ℐ C (=) ===> rel_gpv A C ===> (=)) (gen_lossless_gpv b) (gen_lossless_gpv b)" for b
    by(cases b)(auto intro!: rel_funI dest: rel_gpv_colossless_gpvD1 rel_gpv_colossless_gpvD2 rel_gpv_lossless_gpvD1 rel_gpv_lossless_gpvD2)
qed

end

lemma gen_lossless_gpv_map_full [simp]:
  "gen_lossless_gpv b ℐ_full (map_gpv f g gpv) = gen_lossless_gpv b ℐ_full gpv"
  (is "?lhs = ?rhs")
proof(cases "b = True")
  case True
  show "?lhs = ?rhs"
  proof
    show ?rhs if ?lhs using that unfolding True
      by(coinduction arbitrary: gpv)(auto 4 3 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
    show ?lhs if ?rhs using that unfolding True
      by(coinduction arbitrary: gpv)(auto 4 4 dest: colossless_gpvD simp add: gpv.map_sel intro!: rev_image_eqI)
  qed
next
  case False
  hence False: "b = False" by simp
  show "?lhs = ?rhs"
  proof
    show ?rhs if ?lhs using that unfolding False
      apply(induction gpv'"map_gpv f g gpv" arbitrary: gpv)
      subgoal for p gpv by(cases gpv)(rule lossless_gpvI; fastforce intro: rev_image_eqI)
      done
    show ?lhs if ?rhs using that unfolding False
      by induction(auto 4 4 intro: lossless_gpvI)
  qed
qed

lemma gen_lossless_gpv_map_id [simp]:
  "gen_lossless_gpv b  (map_gpv f id gpv) = gen_lossless_gpv b  gpv"
  using gen_lossless_gpv_parametric[of "BNF_Def.Grp UNIV id" "BNF_Def.Grp UNIV f"] unfolding gpv.rel_Grp
  by(simp add: rel_fun_def eq_alt[symmetric] rel_ℐ_eq)(auto simp add: Grp_def)

lemma results_gpv_try_gpv [simp]:
  "results_gpv  (TRY gpv ELSE gpv') = 
   results_gpv  gpv  (if colossless_gpv  gpv then {} else results_gpv  gpv')"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  show "x  ?rhs" if "x  ?lhs" for x using that
  proof(induction gpv''"try_gpv gpv gpv'" arbitrary: gpv)
    case Pure thus ?case
      by(auto split: if_split_asm intro: results_gpv.Pure dest: colossless_gpv_lossless_spmfD)
  next
    case (IO out c input)
    then show ?case
      apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm)
      apply(force intro: results_gpv.IO dest: colossless_gpv_continuationD split: if_split_asm)+
      done
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (left) "x  results_gpv  gpv" | (right) "¬ colossless_gpv  gpv" "x  results_gpv  gpv'"
    by(auto split: if_split_asm)
  thus "x  ?lhs"
  proof cases
    case left
    thus ?thesis 
      by(induction)(auto 4 4 intro: results_gpv.intros rev_image_eqI split del: if_split)
  next
    case right
    from right(1) show ?thesis
    proof(rule contrapos_np)
      assume "x  ?lhs"
      with right(2) show "colossless_gpv  gpv"
      proof(coinduction arbitrary: gpv)
        case (colossless_gpv gpv)
        then have ?lossless_spmf
          apply(rewrite in asm try_gpv.code)
          apply(rule ccontr)
          apply(erule results_gpv.cases)
          apply(fastforce simp add: image_Un image_image generat.map_comp o_def)+
          done
        moreover have "?continuation" using colossless_gpv
          by(auto 4 4 split del: if_split simp add: image_Un image_image generat.map_comp o_def intro: rev_image_eqI results_gpv.IO)
        ultimately show ?case ..
      qed
    qed
  qed
qed

lemma results'_gpv_try_gpv [simp]:
  "results'_gpv (TRY gpv ELSE gpv') = 
   results'_gpv gpv  (if colossless_gpv ℐ_full gpv then {} else results'_gpv gpv')"
by(simp add: results_gpv_ℐ_full[symmetric])

lemma outs'_gpv_try_gpv [simp]:
  "outs'_gpv (TRY gpv ELSE gpv') =
   outs'_gpv gpv  (if colossless_gpv ℐ_full gpv then {} else outs'_gpv gpv')"
  (is "?lhs = ?rhs")
proof(intro set_eqI iffI)
  show "x  ?rhs" if "x  ?lhs" for x using that
  proof(induction gpv''"try_gpv gpv gpv'" arbitrary: gpv)
    case Out thus ?case
      by(auto 4 3 simp add: generat.map_comp o_def elim!: generat.set_cases(2) intro: outs'_gpv_Out split: if_split_asm dest: colossless_gpv_lossless_spmfD)
  next
    case (Cont generat c input)
    then show ?case
      apply(auto dest: colossless_gpv_lossless_spmfD split: if_split_asm elim!: generat.set_cases(3))
      apply(auto 4 3 dest: colossless_gpv_continuationD split: if_split_asm intro: outs'_gpv_Cont elim!: meta_allE meta_impE[OF _ refl])+
      done
  qed
next
  fix x
  assume "x  ?rhs"
  then consider (left) "x  outs'_gpv gpv" | (right) "¬ colossless_gpv ℐ_full gpv" "x  outs'_gpv gpv'"
    by(auto split: if_split_asm)
  thus "x  ?lhs"
  proof cases
    case left
    thus ?thesis 
      by(induction)(auto elim!: generat.set_cases(2,3) intro: outs'_gpvI intro!: rev_image_eqI split del: if_split simp add: image_Un image_image generat.map_comp o_def)
  next
    case right
    from right(1) show ?thesis
    proof(rule contrapos_np)
      assume "x  ?lhs"
      with right(2) show "colossless_gpv ℐ_full gpv"
      proof(coinduction arbitrary: gpv)
        case (colossless_gpv gpv)
        then have ?lossless_spmf
          apply(rewrite in asm try_gpv.code)
          apply(erule contrapos_np)
          apply(erule gpv.set_cases)
          apply(auto 4 3 simp add: image_Un image_image generat.map_comp o_def generat.set_map in_set_spmf[symmetric] bind_UNION generat.map_id[unfolded id_def] elim!: generat.set_cases)
          done
        moreover have "?continuation" using colossless_gpv
          by(auto simp add: image_Un image_image generat.map_comp o_def split del: if_split intro!: rev_image_eqI intro: outs'_gpv_Cont)
        ultimately show ?case ..
      qed
    qed
  qed
qed

lemma pred_gpv_try [simp]:
  "pred_gpv P Q (try_gpv gpv gpv') = (pred_gpv P Q gpv  (¬ colossless_gpv ℐ_full gpv  pred_gpv P Q gpv'))"
by(auto simp add: pred_gpv_def)

lemma lossless_WT_gpv_induct [consumes 2, case_names lossless_gpv]:
  assumes lossless: "lossless_gpv  gpv"
  and WT: " ⊢g gpv "
  and step: "p. 
       lossless_spmf p;
       out c. IO out c  set_spmf p  out  outs_ℐ ;
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out  lossless_gpv  (c input);
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out   ⊢g c input ;
       out c input. IO out c  set_spmf p; out  outs_ℐ   input  responses_ℐ  out  P (c input)
       P (GPV p)"
  shows "P gpv"
using lossless WT
apply(induction)
apply(erule step)
apply(auto elim: WT_gpvD simp add: WT_gpv_simps)
done

lemma lossless_gpv_induct_strong [consumes 1, case_names lossless_gpv]:
  assumes gpv: "lossless_gpv  gpv"
  and step:
  "p.  lossless_spmf p;
          gpv. gpv  sub_gpvs  (GPV p)  lossless_gpv  gpv;
          gpv. gpv  sub_gpvs  (GPV p)  P gpv 
        P (GPV p)"
  shows "P gpv"
proof -
  define gpv' where "gpv' = gpv"
  then have "gpv'  insert gpv (sub_gpvs  gpv)" by simp
  with gpv have "lossless_gpv  gpv'  P gpv'"
  proof(induction arbitrary: gpv')
    case (lossless_gpv p)
    from