Theory Computational_Model

(* Title: Computational_Model.thy
  Author: Andreas Lochbihler, ETH Zurich *)

section ‹Oracle combinators›

theory Computational_Model imports 
  Generative_Probabilistic_Value
begin

type_synonym security = nat
type_synonym advantage = "security  real"

type_synonym (, 'call, 'ret) oracle' = "  'call  ('ret × ) spmf"
type_synonym (, 'call, 'ret) "oracle" = "security  (, 'call, 'ret) oracle' × "

print_translation ― ‹pretty printing for @{typ "(, 'call, 'ret) oracle"} let
    fun tr' [Const (@{type_syntax nat}, _), 
      Const (@{type_syntax prod}, _) $ 
        (Const (@{type_syntax fun}, _) $ s1 $ 
          (Const (@{type_syntax fun}, _) $ call $
            (Const (@{type_syntax pmf}, _) $
              (Const (@{type_syntax option}, _) $
                (Const (@{type_syntax prod}, _) $ ret $ s2))))) $
        s3] =
      if s1 = s2 andalso s1 = s3 then Syntax.const @{type_syntax oracle} $ s1 $ call $ ret
      else raise Match;
  in [(@{type_syntax "fun"}, K tr')]
  end
typ "(, 'call, 'ret) oracle"

subsection ‹Shared state›

context includes ℐ.lifting lifting_syntax begin

lift_definition plus_ℐ :: "('out, 'ret)   ('out', 'ret')   ('out + 'out', 'ret + 'ret') " (infix "" 500)
is "λresp1 resp2. λout. case out of Inl out'  Inl ` resp1 out' | Inr out'  Inr ` resp2 out'" .

lemma plus_ℐ_sel [simp]:
  shows outs_plus_ℐ: "outs_ℐ (plus_ℐ ℐl ℐr) = outs_ℐ ℐl <+> outs_ℐ ℐr"
  and responses_plus_ℐ_Inl: "responses_ℐ (plus_ℐ ℐl ℐr) (Inl x) = Inl ` responses_ℐ ℐl x"
  and responses_plus_ℐ_Inr: "responses_ℐ (plus_ℐ ℐl ℐr) (Inr y) = Inr ` responses_ℐ ℐr y"
by(transfer; auto split: sum.split_asm; fail)+

lemma vimage_Inl_Plus [simp]: "Inl -` (A <+> B) = A" 
  and vimage_Inr_Plus [simp]: "Inr -` (A <+> B) = B"
by auto

lemma vimage_Inl_image_Inr: "Inl -` Inr ` A = {}"
  and vimage_Inr_image_Inl: "Inr -` Inl ` A = {}"
by auto

lemma plus_ℐ_parametric [transfer_rule]:
  "(rel_ℐ C R ===> rel_ℐ C' R' ===> rel_ℐ (rel_sum C C') (rel_sum R R')) plus_ℐ plus_ℐ"
apply(rule rel_funI rel_ℐI)+
subgoal premises [transfer_rule] by(simp; rule conjI; transfer_prover)
apply(erule rel_sum.cases; clarsimp simp add: inj_vimage_image_eq vimage_Inl_image_Inr empty_transfer vimage_Inr_image_Inl)
subgoal premises [transfer_rule] by transfer_prover
subgoal premises [transfer_rule] by transfer_prover
done

lifting_update ℐ.lifting
lifting_forget ℐ.lifting

lemma ℐ_trivial_plus_ℐ [simp]: "ℐ_trivial (1  2)  ℐ_trivial 1  ℐ_trivial 2"
by(auto simp add: ℐ_trivial_def)

end

lemma map_ℐ_plus_ℐ [simp]: 
  "map_ℐ (map_sum f1 f2) (map_sum g1 g2) (ℐ1  ℐ2) = map_ℐ f1 g1 ℐ1  map_ℐ f2 g2 ℐ2"
proof(rule ℐ_eqI[OF Set.set_eqI], goal_cases)
  case (1 x)
  then show ?case by(cases x) auto
qed (auto simp add: image_image)

lemma le_plus_ℐ_iff [simp]:
  "ℐ1  ℐ2  ℐ1'  ℐ2'  ℐ1  ℐ1'  ℐ2  ℐ2'"
  by(auto 4 4 simp add: le_ℐ_def dest: bspec[where x="Inl _"] bspec[where x="Inr _"])

lemma ℐ_full_le_plus_ℐ: "ℐ_full  plus_ℐ ℐ1 ℐ2" if "ℐ_full  ℐ1" "ℐ_full  ℐ2"
  using that by(auto simp add: le_ℐ_def top_unique)

lemma plus_ℐ_mono: "plus_ℐ ℐ1 ℐ2  plus_ℐ ℐ1' ℐ2'" if "ℐ1  ℐ1'" "ℐ2  ℐ2'" 
  using that by(fastforce simp add: le_ℐ_def)

context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
  and s :: "'s"
begin

primrec plus_oracle :: "'a + 'c  (('b + 'd) × 's) spmf"
where
  "plus_oracle (Inl a) = map_spmf (apfst Inl) (left s a)"
| "plus_oracle (Inr b) = map_spmf (apfst Inr) (right s b)"

lemma lossless_plus_oracleI [intro, simp]:
  " a. x = Inl a  lossless_spmf (left s a); 
     b. x = Inr b  lossless_spmf (right s b) 
   lossless_spmf (plus_oracle x)"
by(cases x) simp_all

lemma plus_oracle_split:
  "P (plus_oracle lr) 
  (x. lr = Inl x  P (map_spmf (apfst Inl) (left s x))) 
  (y. lr = Inr y  P (map_spmf (apfst Inr) (right s y)))"
by(cases lr) auto

lemma plus_oracle_split_asm:
  "P (plus_oracle lr) 
  ¬ ((x. lr = Inl x  ¬ P (map_spmf (apfst Inl) (left s x))) 
     (y. lr = Inr y  ¬ P (map_spmf (apfst Inr) (right s y))))"
by(cases lr) auto

end

notation plus_oracle (infix "O" 500)

context
  fixes left :: "('s, 'a, 'b) oracle'"
  and right :: "('s,'c, 'd) oracle'"
begin

lemma WT_plus_oracleI [intro!]:
  " ℐl ⊢c left s ; ℐr ⊢c right s    ℐl  ℐr ⊢c (left O right) s "
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)

lemma WT_plus_oracleD1:
  assumes "ℐl  ℐr ⊢c (left O right) s  " (is "?ℐ ⊢c ?callee s ")
  shows "ℐl ⊢c left s "
proof(rule WT_calleeI)
  fix call ret s'
  assume "call  outs_ℐ ℐl" "(ret, s')  set_spmf (left s call)"
  hence "(Inl ret, s')  set_spmf (?callee s (Inl call))" "Inl call  outs_ℐ (ℐl  ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inl ret  responses_ℐ ?ℐ (Inl call)" by(rule WT_calleeD[OF assms])
  then show "ret  responses_ℐ ℐl call" by(simp add: inj_image_mem_iff)
qed

lemma WT_plus_oracleD2:
  assumes "ℐl  ℐr ⊢c (left O right) s  " (is "?ℐ ⊢c ?callee s ")
  shows "ℐr ⊢c right s "
proof(rule WT_calleeI)
  fix call ret s'
  assume "call  outs_ℐ ℐr" "(ret, s')  set_spmf (right s call)"
  hence "(Inr ret, s')  set_spmf (?callee s (Inr call))" "Inr call  outs_ℐ (ℐl  ℐr)"
    by(auto intro: rev_image_eqI)
  hence "Inr ret  responses_ℐ ?ℐ (Inr call)" by(rule WT_calleeD[OF assms])
  then show "ret  responses_ℐ ℐr call" by(simp add: inj_image_mem_iff)
qed

lemma WT_plus_oracle_iff [simp]: "ℐl  ℐr ⊢c (left O right) s   ℐl ⊢c left s   ℐr ⊢c right s "
by(blast dest: WT_plus_oracleD1 WT_plus_oracleD2)

lemma callee_invariant_on_plus_oracle [simp]:
  "callee_invariant_on (left O right) I (ℐl  ℐr) 
   callee_invariant_on left I ℐl  callee_invariant_on right I ℐr"
   (is "?lhs  ?rhs")
proof(intro iffI conjI)
  assume ?lhs
  then interpret plus: callee_invariant_on "left O right" I "ℐl  ℐr" .
  show "callee_invariant_on left I ℐl"
  proof
    fix s x y s'
    assume "(y, s')  set_spmf (left s x)" and "I s" and "x  outs_ℐ ℐl"
    then have "(Inl y, s')  set_spmf ((left O right) s (Inl x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using I s by(rule plus.callee_invariant)(simp add: x  outs_ℐ ℐl)
  next
    show "ℐl ⊢c left s " if "I s" for s using plus.WT_callee[OF that] by simp
  qed
  show "callee_invariant_on right I ℐr"
  proof
    fix s x y s'
    assume "(y, s')  set_spmf (right s x)" and "I s" and "x  outs_ℐ ℐr"
    then have "(Inr y, s')  set_spmf ((left O right) s (Inr x))"
      by(auto intro: rev_image_eqI)
    then show "I s'" using I s by(rule plus.callee_invariant)(simp add: x  outs_ℐ ℐr)
  next
    show "ℐr ⊢c right s " if "I s" for s using plus.WT_callee[OF that] by simp
  qed
next
  assume ?rhs
  interpret left: callee_invariant_on left I ℐl using ?rhs by simp
  interpret right: callee_invariant_on right I ℐr using ?rhs by simp
  show ?lhs
  proof
    fix s x y s'
    assume "(y, s')  set_spmf ((left O right) s x)" and "I s" and "x  outs_ℐ (ℐl  ℐr)"
    then have "(projl y, s')  set_spmf (left s (projl x))  projl x  outs_ℐ ℐl 
      (projr y, s')  set_spmf (right s (projr x))  projr x  outs_ℐ ℐr"
      by (cases x)  auto
    then show "I s'" using I s 
      by (auto dest: left.callee_invariant right.callee_invariant)
  next
    show "ℐl  ℐr ⊢c (left O right) s " if "I s" for s 
      using left.WT_callee[OF that] right.WT_callee[OF that] by simp
  qed
qed

lemma callee_invariant_plus_oracle [simp]:
  "callee_invariant (left O right) I 
   callee_invariant left I  callee_invariant right I"
  (is "?lhs   ?rhs")
proof -
  have "?lhs  callee_invariant_on (left O right) I (ℐ_full  ℐ_full)"
    by(rule callee_invariant_on_cong)(auto split: plus_oracle_split_asm)
  also have "  ?rhs" by(rule callee_invariant_on_plus_oracle)
  finally show ?thesis .
qed

lemma plus_oracle_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> A ===> rel_spmf (rel_prod B S))
   ===> (S ===> C ===> rel_spmf (rel_prod D S))
   ===> S ===> rel_sum A C ===> rel_spmf (rel_prod (rel_sum B D) S))
   plus_oracle plus_oracle"
unfolding plus_oracle_def[abs_def] by transfer_prover

lemma rel_spmf_plus_oracle:
  " q1' q2'.  q1 = Inl q1'; q2 = Inl q2'   rel_spmf (rel_prod B S) (left1 s1 q1') (left2 s2 q2');
    q1' q2'.  q1 = Inr q1'; q2 = Inr q2'   rel_spmf (rel_prod D S) (right1 s1 q1') (right2 s2 q2');
    S s1 s2; rel_sum A C q1 q2 
   rel_spmf (rel_prod (rel_sum B D) S) ((left1 O right1) s1 q1) ((left2 O right2) s2 q2)"
apply(erule rel_sum.cases; clarsimp)
 apply(erule meta_allE)+
 apply(erule meta_impE, rule refl)+
 subgoal premises [transfer_rule] by transfer_prover
apply(erule meta_allE)+
apply(erule meta_impE, rule refl)+
subgoal premises [transfer_rule] by transfer_prover
done

end

subsection ‹Shared state with aborts›

context
  fixes left :: "('s, 'a, 'b option) oracle'"
  and right :: "('s,'c, 'd option) oracle'"
  and s :: "'s"
begin

primrec plus_oracle_stop :: "'a + 'c  (('b + 'd) option × 's) spmf"
where
  "plus_oracle_stop (Inl a) = map_spmf (apfst (map_option Inl)) (left s a)"
| "plus_oracle_stop (Inr b) = map_spmf (apfst (map_option Inr)) (right s b)"

lemma lossless_plus_oracle_stopI [intro, simp]:
  " a. x = Inl a  lossless_spmf (left s a); 
     b. x = Inr b  lossless_spmf (right s b) 
   lossless_spmf (plus_oracle_stop x)"
by(cases x) simp_all

lemma plus_oracle_stop_split:
  "P (plus_oracle_stop lr) 
  (x. lr = Inl x  P (map_spmf (apfst (map_option Inl)) (left s x))) 
  (y. lr = Inr y  P (map_spmf (apfst (map_option Inr)) (right s y)))"
by(cases lr) auto

lemma plus_oracle_stop_split_asm:
  "P (plus_oracle_stop lr) 
  ¬ ((x. lr = Inl x  ¬ P (map_spmf (apfst (map_option Inl)) (left s x))) 
     (y. lr = Inr y  ¬ P (map_spmf (apfst (map_option Inr)) (right s y))))"
by(cases lr) auto

end

notation plus_oracle_stop (infix "OS" 500)

subsection ‹Disjoint state›

context
  fixes left :: "('s1, 'a, 'b) oracle'"
  and right :: "('s2, 'c, 'd) oracle'"
begin

fun parallel_oracle :: "('s1 × 's2, 'a + 'c, 'b + 'd) oracle'"
where
  "parallel_oracle (s1, s2) (Inl a) = map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)"
| "parallel_oracle (s1, s2) (Inr b) = map_spmf (map_prod Inr (Pair s1)) (right s2 b)"

lemma parallel_oracle_def:
  "parallel_oracle = (λ(s1, s2). case_sum (λa. map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 a)) (λb. map_spmf (map_prod Inr (Pair s1)) (right s2 b)))"
by(auto intro!: ext split: sum.split)

lemma lossless_parallel_oracle [simp]:
  "lossless_spmf (parallel_oracle s12 xy) 
   (x. xy = Inl x  lossless_spmf (left (fst s12) x)) 
   (y. xy = Inr y  lossless_spmf (right (snd s12) y))"
by(cases s12; cases xy) simp_all

lemma parallel_oracle_split:
  "P (parallel_oracle s1s2 lr) 
  (s1 s2 x. s1s2 = (s1, s2)  lr = Inl x  P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) 
  (s1 s2 y. s1s2 = (s1, s2)  lr = Inr y  P (map_spmf (map_prod Inr (Pair s1)) (right s2 y)))"
by(cases s1s2; cases lr) auto

lemma parallel_oracle_split_asm:
  "P (parallel_oracle s1s2 lr) 
  ¬ ((s1 s2 x. s1s2 = (s1, s2)  lr = Inl x  ¬ P (map_spmf (map_prod Inl (λs1'. (s1', s2))) (left s1 x))) 
     (s1 s2 y. s1s2 = (s1, s2)  lr = Inr y  ¬ P (map_spmf (map_prod Inr (Pair s1)) (right s2 y))))"
by(cases s1s2; cases lr) auto

lemma WT_parallel_oracle [intro!, simp]:
  " ℐl ⊢c left sl ; ℐr ⊢c right sr    plus_ℐ ℐl ℐr ⊢c parallel_oracle (sl, sr) "
by(rule WT_calleeI)(auto elim!: WT_calleeD simp add: inj_image_mem_iff)

lemma callee_invariant_parallel_oracleI [simp, intro]:
  assumes "callee_invariant_on left Il ℐl" "callee_invariant_on right Ir ℐr"
  shows "callee_invariant_on parallel_oracle (pred_prod Il Ir) (ℐl  ℐr)"
proof
  interpret left: callee_invariant_on left Il ℐl by fact
  interpret right: callee_invariant_on right Ir ℐr by fact

  show "pred_prod Il Ir s12'"
    if "(y, s12')  set_spmf (parallel_oracle s12 x)" and "pred_prod Il Ir s12" and "x  outs_ℐ (ℐl  ℐr)"
    for s12 x y s12' using that
    by(cases s12; cases s12; cases x)(auto dest: left.callee_invariant right.callee_invariant)

  show "ℐl  ℐr ⊢c local.parallel_oracle s " if "pred_prod Il Ir s" for s using that
    by(cases s)(simp add: left.WT_callee right.WT_callee)
qed

end

lemma parallel_oracle_parametric:
  includes lifting_syntax shows
  "((S1 ===> CALL1 ===> rel_spmf (rel_prod (=) S1)) 
  ===> (S2 ===> CALL2 ===> rel_spmf (rel_prod (=) S2))
  ===> rel_prod S1 S2 ===> rel_sum CALL1 CALL2 ===> rel_spmf (rel_prod (=) (rel_prod S1 S2)))
  parallel_oracle parallel_oracle"
unfolding parallel_oracle_def[abs_def] by (fold relator_eq)transfer_prover

subsection ‹Indexed oracles›

definition family_oracle :: "('i  ('s, 'a, 'b) oracle')  ('i  's, 'i × 'a, 'b) oracle'"
where "family_oracle f s = (λ(i, x). map_spmf (λ(y, s'). (y, s(i := s'))) (f i (s i) x))"

lemma family_oracle_apply [simp]:
  "family_oracle f s (i, x) = map_spmf (apsnd (fun_upd s i)) (f i (s i) x)"
by(simp add: family_oracle_def apsnd_def map_prod_def)

lemma lossless_family_oracle:
  "lossless_spmf (family_oracle f s ix)  lossless_spmf (f (fst ix) (s (fst ix)) (snd ix))"
by(simp add: family_oracle_def split_beta)

subsection ‹State extension›

definition extend_state_oracle :: "('call, 'ret, 's) callee  ('call, 'ret, 's' × 's) callee" ("_" [1000] 1000)
where "extend_state_oracle callee = (λ(s', s) x. map_spmf (λ(y, s). (y, (s', s))) (callee s x))"

lemma extend_state_oracle_simps [simp]:
  "extend_state_oracle callee (s', s) x = map_spmf (λ(y, s). (y, (s', s))) (callee s x)"
by(simp add: extend_state_oracle_def)

context includes lifting_syntax begin
lemma extend_state_oracle_parametric [transfer_rule]:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S' S ===> C ===> rel_spmf (rel_prod R (rel_prod S' S)))
  extend_state_oracle extend_state_oracle"
unfolding extend_state_oracle_def[abs_def] by transfer_prover

lemma extend_state_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) 
  ===> rel_prod2 S ===> C ===> rel_spmf (rel_prod R (rel_prod2 S)))
  (λoracle. oracle) extend_state_oracle"
unfolding extend_state_oracle_def[abs_def]
apply(rule rel_funI)+
apply clarsimp
apply(drule (1) rel_funD)+
apply(auto simp add: spmf_rel_map split_def dest: rel_funD intro: rel_spmf_mono)
done
end

lemma callee_invariant_extend_state_oracle_const [simp]:
  "callee_invariant oracle (λ(s', s). I s')"
by unfold_locales auto

lemma callee_invariant_extend_state_oracle_const':
  "callee_invariant oracle (λs. I (fst s))"
by unfold_locales auto

definition lift_stop_oracle :: "('call, 'ret, 's) callee  ('call, 'ret option, 's) callee"
where "lift_stop_oracle oracle s x = map_spmf (apfst Some) (oracle s x)"

lemma lift_stop_oracle_apply [simp]: "lift_stop_oracle  oracle s x = map_spmf (apfst Some) (oracle s x)"
  by(fact lift_stop_oracle_def)
  
context includes lifting_syntax begin

lemma lift_stop_oracle_transfer:
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> (S ===> C ===> rel_spmf (rel_prod (pcr_Some R) S)))
   (λx. x) lift_stop_oracle"
unfolding lift_stop_oracle_def
apply(rule rel_funI)+
apply(drule (1) rel_funD)+
apply(simp add: spmf_rel_map apfst_def prod.rel_map)
done

end

definition extend_state_oracle2 :: "('call, 'ret, 's) callee  ('call, 'ret, 's × 's') callee" ("_" [1000] 1000)
  where "extend_state_oracle2 callee = (λ(s, s') x. map_spmf (λ(y, s). (y, (s, s'))) (callee s x))"

lemma extend_state_oracle2_simps [simp]:
  "extend_state_oracle2 callee (s, s') x = map_spmf (λ(y, s). (y, (s, s'))) (callee s x)"
  by(simp add: extend_state_oracle2_def)

lemma extend_state_oracle2_parametric [transfer_rule]: includes lifting_syntax shows
  "((S ===> C ===> rel_spmf (rel_prod R S)) ===> rel_prod S S' ===> C ===> rel_spmf (rel_prod R (rel_prod S S')))
  extend_state_oracle2 extend_state_oracle2"
  unfolding extend_state_oracle2_def[abs_def] by transfer_prover

lemma callee_invariant_extend_state_oracle2_const [simp]:
  "callee_invariant oracle (λ(s, s'). I s')"
  by unfold_locales auto

lemma callee_invariant_extend_state_oracle2_const':
  "callee_invariant oracle (λs. I (snd s))"
  by unfold_locales auto

lemma extend_state_oracle2_plus_oracle: 
  "extend_state_oracle2 (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle2 oracle1) (extend_state_oracle2 oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed

lemma parallel_oracle_conv_plus_oracle:
  "parallel_oracle oracle1 oracle2 = plus_oracle (oracle1) (oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (auto simp add: spmf.map_comp apfst_def o_def split_def map_prod_def)
qed

lemma map_sum_parallel_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (parallel_oracle oracle1 oracle2)
  = parallel_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed

lemma map_sum_plus_oracle: includes lifting_syntax shows
  "(id ---> map_sum f g ---> map_spmf (map_prod (map_sum h k) id)) (plus_oracle oracle1 oracle2)
  = plus_oracle ((id ---> f ---> map_spmf (map_prod h id)) oracle1) ((id ---> g ---> map_spmf (map_prod k id)) oracle2)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases q) (simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
qed

lemma map_rsuml_plus_oracle: includes lifting_syntax shows
  "(id ---> rsuml ---> (map_spmf (map_prod lsumr id))) (oracle1 O (oracle2 O oracle3)) =
   ((oracle1 O oracle2) O oracle3)"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inl ql)
    then show ?thesis by(cases ql)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed

lemma map_lsumr_plus_oracle: includes lifting_syntax shows
  "(id ---> lsumr ---> (map_spmf (map_prod rsuml id))) ((oracle1 O oracle2) O oracle3) =
   (oracle1 O (oracle2 O oracle3))"
proof((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case 
  proof(cases q)
    case (Inr qr)
    then show ?thesis by(cases qr)(simp_all add: spmf.map_comp o_def apfst_def prod.map_comp)
  qed (simp add: spmf.map_comp o_def apfst_def prod.map_comp id_def)
qed

context includes lifting_syntax begin

definition lift_state_oracle
  :: "(('s  'a  (('b × 't) × 's) spmf)  ('s'  'a  (('b × 't) × 's') spmf)) 
   ('t × 's  'a  ('b × 't × 's) spmf)  ('t × 's'  'a  ('b × 't × 's') spmf)" where
  "lift_state_oracle F oracle = 
   (λ(t, s') a. map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a))"

lemma lift_state_oracle_simps [simp]:
  "lift_state_oracle F oracle (t, s') a = map_spmf rprodl (F ((Pair t ---> id ---> map_spmf lprodr) oracle) s' a)"
  by(simp add: lift_state_oracle_def)

lemma lift_state_oracle_parametric [transfer_rule]: includes lifting_syntax shows
  "(((S ===> A ===> rel_spmf (rel_prod (rel_prod B T) S)) ===> S' ===> A ===> rel_spmf (rel_prod (rel_prod B T) S'))
  ===> (rel_prod T S ===> A ===> rel_spmf (rel_prod B (rel_prod T S)))
  ===> rel_prod T S' ===> A ===> rel_spmf (rel_prod B (rel_prod T S')))
  lift_state_oracle lift_state_oracle"
  unfolding lift_state_oracle_def map_fun_def o_def by transfer_prover

lemma lift_state_oracle_extend_state_oracle:
  includes lifting_syntax
  assumes "B. Transfer.Rel (((=) ===> (=) ===> rel_spmf (rel_prod B (=))) ===> (=) ===> (=) ===> rel_spmf (rel_prod B (=))) G F"
    (* TODO: implement simproc to discharge parametricity assumptions like this one *)
  shows "lift_state_oracle F (extend_state_oracle oracle) = extend_state_oracle (G oracle)"
  unfolding lift_state_oracle_def extend_state_oracle_def
  apply(clarsimp simp add: fun_eq_iff map_fun_def o_def spmf.map_comp split_def rprodl_def)
  subgoal for t s a
    apply(rule sym)
    apply(fold spmf_rel_eq)
    apply(simp add: spmf_rel_map)
    apply(rule rel_spmf_mono)
     apply(rule assms[unfolded Rel_def, where B="λx (y, z). x = y  z = t", THEN rel_funD, THEN rel_funD, THEN rel_funD])
       apply(auto simp add: rel_fun_def spmf_rel_map intro!: rel_spmf_reflI)
    done
  done

lemma lift_state_oracle_compose: 
  "lift_state_oracle F (lift_state_oracle G oracle) = lift_state_oracle (F  G) oracle"
  by(simp add: lift_state_oracle_def map_fun_def o_def split_def spmf.map_comp)

lemma lift_state_oracle_id [simp]: "lift_state_oracle id = id"
  by(simp add: fun_eq_iff spmf.map_comp o_def)

lemma rprodl_extend_state_oracle: includes lifting_syntax shows
  "(rprodl ---> id ---> map_spmf (map_prod id lprodr)) (extend_state_oracle (extend_state_oracle oracle)) = 
  extend_state_oracle oracle"
  by(simp add: fun_eq_iff spmf.map_comp o_def split_def)

end

section ‹Combining GPVs›

subsection ‹Shared state without interrupts›

context
  fixes left :: "'s  'x1  ('y1 × 's, 'call, 'ret) gpv"
  and right :: "'s  'x2  ('y2 × 's, 'call, 'ret) gpv"
begin

primrec plus_intercept :: "'s  'x1 + 'x2  (('y1 + 'y2) × 's, 'call, 'ret) gpv"
where
  "plus_intercept s (Inl x) = map_gpv (apfst Inl) id (left s x)"
| "plus_intercept s (Inr x) = map_gpv (apfst Inr) id (right s x)"

end

lemma plus_intercept_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod Y1 S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod Y2 S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_sum Y1 Y2) S) C)
  plus_intercept plus_intercept"
unfolding plus_intercept_def[abs_def] by transfer_prover

lemma interaction_bounded_by_plus_intercept [interaction_bound]:
  fixes left right
  shows " x'. x = Inl x'  interaction_bounded_by P (left s x') (n x');
    y. x = Inr y  interaction_bounded_by P (right s y) (m y) 
   interaction_bounded_by P (plus_intercept left right s x) (case x of Inl x  n x | Inr y  m y)"
by(simp split!: sum.split add: interaction_bounded_by_map_gpv_id)

subsection ‹Shared state with interrupts›

context 
  fixes left :: "'s  'x1  ('y1 option × 's, 'call, 'ret) gpv"
  and right :: "'s  'x2  ('y2 option × 's, 'call, 'ret) gpv"
begin

primrec plus_intercept_stop :: "'s  'x1 + 'x2  (('y1 + 'y2) option × 's, 'call, 'ret) gpv"
where
  "plus_intercept_stop s (Inl x) = map_gpv (apfst (map_option Inl)) id (left s x)"
| "plus_intercept_stop s (Inr x) = map_gpv (apfst (map_option Inr)) id (right s x)"

end

lemma plus_intercept_stop_parametric [transfer_rule]:
  includes lifting_syntax shows
  "((S ===> X1 ===> rel_gpv (rel_prod (rel_option Y1) S) C)
  ===> (S ===> X2 ===> rel_gpv (rel_prod (rel_option Y2) S) C)
  ===> S ===> rel_sum X1 X2 ===> rel_gpv (rel_prod (rel_option (rel_sum Y1 Y2)) S) C)
  plus_intercept_stop plus_intercept_stop"
unfolding plus_intercept_stop_def by transfer_prover

subsection ‹One-sided shifts›

primcorec (transfer) left_gpv :: "('a, 'out, 'in) gpv  ('a, 'out + 'out', 'in + 'in') gpv" where
  "the_gpv (left_gpv gpv) = 
   map_spmf (map_generat id Inl (λrpv input. case input of Inl input'  left_gpv (rpv input') | _  Fail)) (the_gpv gpv)"

abbreviation left_rpv :: "('a, 'out, 'in) rpv  ('a, 'out + 'out', 'in + 'in') rpv" where
  "left_rpv rpv  λinput. case input of Inl input'  left_gpv (rpv input') | _  Fail"

primcorec (transfer) right_gpv :: "('a, 'out, 'in) gpv  ('a, 'out' + 'out, 'in' + 'in) gpv" where
  "the_gpv (right_gpv gpv) =
   map_spmf (map_generat id Inr (λrpv input. case input of Inr input'  right_gpv (rpv input') | _  Fail)) (the_gpv gpv)"

abbreviation right_rpv :: "('a, 'out, 'in) rpv  ('a, 'out' + 'out, 'in' + 'in) rpv" where
  "right_rpv rpv  λinput. case input of Inr input'  right_gpv (rpv input') | _  Fail"

context 
  includes lifting_syntax
  notes [transfer_rule] = corec_gpv_parametric' Fail_parametric' the_gpv_parametric'
begin

lemmas left_gpv_parametric = left_gpv.transfer

lemma left_gpv_parametric':
  "(rel_gpv'' A C R ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) left_gpv left_gpv"
  unfolding left_gpv_def by transfer_prover

lemmas right_gpv_parametric = right_gpv.transfer

lemma right_gpv_parametric':
  "(rel_gpv'' A C' R' ===> rel_gpv'' A (rel_sum C C') (rel_sum R R')) right_gpv right_gpv"
  unfolding right_gpv_def by transfer_prover

end

lemma left_gpv_Done [simp]: "left_gpv (Done x) = Done x"
  by(rule gpv.expand) simp

lemma right_gpv_Done [simp]: "right_gpv (Done x) = Done x"
  by(rule gpv.expand) simp

lemma left_gpv_Pause [simp]:
  "left_gpv (Pause x rpv) = Pause (Inl x) (λinput. case input of Inl input'  left_gpv (rpv input') | _  Fail)"
  by(rule gpv.expand) simp

lemma right_gpv_Pause [simp]:
  "right_gpv (Pause x rpv) = Pause (Inr x) (λinput. case input of Inr input'  right_gpv (rpv input') | _  Fail)"
  by(rule gpv.expand) simp

lemma left_gpv_map: "left_gpv (map_gpv f g gpv) = map_gpv f (map_sum g h) (left_gpv gpv)"
  using left_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)

lemma right_gpv_map: "right_gpv (map_gpv f g gpv) = map_gpv f (map_sum h g) (right_gpv gpv)"
  using right_gpv.transfer[of "BNF_Def.Grp UNIV f" "BNF_Def.Grp UNIV g" "BNF_Def.Grp UNIV h"]
  unfolding sum.rel_Grp gpv.rel_Grp
  by(auto simp add: rel_fun_def Grp_def)

lemma results'_gpv_left_gpv [simp]: 
  "results'_gpv (left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"left_gpv gpv :: ('a, 'out + 'out', 'in + 'in') gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inl _"])
qed

lemma results'_gpv_right_gpv [simp]: 
  "results'_gpv (right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv) = results'_gpv gpv" (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"right_gpv gpv :: ('a, 'out' + 'out, 'in' + 'in) gpv" arbitrary: gpv)
      (fastforce simp add: elim!: generat.set_cases intro: results'_gpvI split: sum.splits)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)
      (auto 4 3 elim!: generat.set_cases intro: results'_gpv_Pure rev_image_eqI results'_gpv_Cont[where input="Inr _"])
qed

lemma left_gpv_Inl_transfer: "rel_gpv'' (=) (λl r. l = Inl r) (λl r. l = Inl r) (left_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)

lemma right_gpv_Inr_transfer: "rel_gpv'' (=) (λl r. l = Inr r) (λl r. l = Inr r) (right_gpv gpv) gpv"
  by(coinduction arbitrary: gpv)
    (auto simp add: spmf_rel_map generat.rel_map del: rel_funI intro!: rel_spmf_reflI generat.rel_refl_strong rel_funI)

lemma exec_gpv_plus_oracle_left: "exec_gpv (plus_oracle oracle1 oracle2) (left_gpv gpv) s = exec_gpv oracle1 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inl r" and R="λl r. l = Inl r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI left_gpv_Inl_transfer)

lemma exec_gpv_plus_oracle_right: "exec_gpv (plus_oracle oracle1 oracle2) (right_gpv gpv) s = exec_gpv oracle2 gpv s"
  unfolding spmf_rel_eq[symmetric] prod.rel_eq[symmetric]
  by(rule exec_gpv_parametric'[where A="(=)" and S="(=)" and CALL="λl r. l = Inr r" and R="λl r. l = Inr r", THEN rel_funD, THEN rel_funD, THEN rel_funD])
    (auto intro!: rel_funI simp add: spmf_rel_map apfst_def map_prod_def rel_prod_conv intro: rel_spmf_reflI right_gpv_Inr_transfer)

lemma left_gpv_bind_gpv: "left_gpv (bind_gpv gpv f) = bind_gpv (left_gpv gpv) (left_gpv  f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)

lemma inline1_left_gpv:
  "inline1 (λs q. left_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inl (map_prod left_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed

lemma left_gpv_inline: "left_gpv (inline callee gpv s) = inline (λs q. left_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_left_gpv left_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)

lemma right_gpv_bind_gpv: "right_gpv (bind_gpv gpv f) = bind_gpv (right_gpv gpv) (right_gpv  f)"
  by(coinduction arbitrary:gpv f rule: gpv.coinduct_strong)
    (auto 4 4 simp add: bind_map_spmf spmf_rel_map intro!: rel_spmf_reflI rel_spmf_bindI[of "(=)"] generat.rel_refl rel_funI split: sum.splits)

lemma inline1_right_gpv:
  "inline1 (λs q. right_gpv (callee s q)) gpv s = 
   map_spmf (map_sum id (map_prod Inr (map_prod right_rpv id))) (inline1 callee gpv s)"
proof(induction arbitrary: gpv s rule: parallel_fixp_induct_2_2[OF partial_function_definitions_spmf partial_function_definitions_spmf inline1.mono inline1.mono inline1_def inline1_def, unfolded lub_spmf_empty, case_names adm bottom step])
  case adm show ?case by simp
  case bottom show ?case by simp
  case (step inline1' inline1'')
  then show ?case
    by(auto simp add: map_spmf_bind_spmf o_def bind_map_spmf intro!: ext bind_spmf_cong split: generat.split)
qed

lemma right_gpv_inline: "right_gpv (inline callee gpv s) = inline (λs q. right_gpv (callee s q)) gpv s"
  by(coinduction arbitrary: callee gpv s rule: gpv_coinduct_bind)
    (fastforce simp add: inline_sel spmf_rel_map inline1_right_gpv right_gpv_bind_gpv o_def split_def intro!: rel_spmf_reflI split: sum.split intro!: rel_funI gpv.rel_refl_strong)

lemma WT_gpv_left_gpv: "ℐ1 ⊢g gpv   ℐ1  ℐ2 ⊢g left_gpv gpv "
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma WT_gpv_right_gpv: "ℐ2 ⊢g gpv   ℐ1  ℐ2 ⊢g right_gpv gpv "
  by(coinduction arbitrary: gpv)(auto 4 4 dest: WT_gpvD)

lemma results_gpv_left_gpv [simp]: "results_gpv (ℐ1  ℐ2) (left_gpv gpv) = results_gpv ℐ1 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"left_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma results_gpv_right_gpv [simp]: "results_gpv (ℐ1  ℐ2) (right_gpv gpv) = results_gpv ℐ2 gpv"
  (is "?lhs = ?rhs")
proof(rule Set.set_eqI iffI)+
  show "x  ?rhs" if "x  ?lhs" for x using that
    by(induction gpv'"right_gpv gpv :: ('a, 'b + 'c, 'd + 'e) gpv" arbitrary: gpv rule: results_gpv.induct)
      (fastforce intro: results_gpv.intros)+
  show "x  ?lhs" if "x  ?rhs" for x using that
    by(induction)(fastforce intro: results_gpv.intros)+
qed

lemma left_gpv_Fail [simp]: "left_gpv Fail = Fail"
  by(rule gpv.expand) auto

lemma right_gpv_Fail [simp]: "right_gpv Fail = Fail"
  by(rule gpv.expand) auto

lemma rsuml_lsumr_left_gpv_left_gpv:"map_gpv' id rsuml lsumr (left_gpv (left_gpv gpv)) = left_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma rsuml_lsumr_left_gpv_right_gpv: "map_gpv' id rsuml lsumr (left_gpv (right_gpv gpv)) = right_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma rsuml_lsumr_right_gpv: "map_gpv' id rsuml lsumr (right_gpv gpv) = right_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: lsumr.elims intro: exI[where x=Fail])

lemma map_gpv'_map_gpv_swap:
  "map_gpv' f g h (map_gpv f' id gpv) = map_gpv (f  f') id (map_gpv' id g h gpv)"
  by(simp add: map_gpv_conv_map_gpv' map_gpv'_comp)

lemma lsumr_rsuml_left_gpv: "map_gpv' id lsumr rsuml (left_gpv gpv) = left_gpv (left_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])

lemma lsumr_rsuml_right_gpv_left_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (left_gpv gpv)) = left_gpv (right_gpv gpv)"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split intro: exI[where x=Fail])

lemma lsumr_rsuml_right_gpv_right_gpv:
  "map_gpv' id lsumr rsuml (right_gpv (right_gpv gpv)) = right_gpv gpv"
  by(coinduction arbitrary: gpv)
    (auto 4 3 simp add: spmf_rel_map generat.rel_map intro!: rel_spmf_reflI rel_generat_reflI rel_funI split!: sum.split elim!: rsuml.elims intro: exI[where x=Fail])


lemma in_set_spmf_extend_state_oracle [simp]:
  "x  set_spmf (extend_state_oracle oracle s y) 
   fst (snd x) = fst s  (fst x, snd (snd x))  set_spmf (oracle (snd s) y)"
  by(auto 4 4 simp add: extend_state_oracle_def split_beta intro: rev_image_eqI prod.expand)

lemma extend_state_oracle_plus_oracle: 
  "extend_state_oracle (plus_oracle oracle1 oracle2) = plus_oracle (extend_state_oracle oracle1) (extend_state_oracle oracle2)"
proof ((rule ext)+; goal_cases)
  case (1 s q)
  then show ?case by (cases s; cases q) (simp_all add: apfst_def spmf.map_comp o_def split_def)
qed


definition stateless_callee :: "('a  ('b, 'out, 'in) gpv)  ('s  'a  ('b × 's, 'out, 'in) gpv)" where
  "stateless_callee callee s = map_gpv (λb. (b, s)) id  callee"

lemma stateless_callee_parametric': 
  includes lifting_syntax notes [transfer_rule] = map_gpv_parametric' shows
    "((A ===> rel_gpv'' B C R) ===> S ===> A ===> (rel_gpv'' (rel_prod B S) C R))
   stateless_callee stateless_callee"
  unfolding stateless_callee_def by transfer_prover

lemma id_oralce_alt_def: "id_oracle = stateless_callee (λx. Pause x Done)"
  by(simp add: id_oracle_def fun_eq_iff stateless_callee_def)

context
  fixes left :: "'s1  'x1  ('y1 × 's1, 'call1, 'ret1) gpv"
    and right :: "'s2  'x2  ('y2 × 's2, 'call2, 'ret2) gpv"
begin

fun parallel_intercept :: "'s1 × 's2  'x1 + 'x2  (('y1 + 'y2) × ('s1 × 's2), 'call1 + 'call2, 'ret1 + 'ret2) gpv"
  where
    "parallel_intercept (s1, s2) (Inl a) = left_gpv (map_gpv (map_prod Inl (λs1'. (s1', s2))) id (left s1 a))"
  | "parallel_intercept (s1, s2) (Inr b) = right_gpv (map_gpv (map_prod Inr (Pair s1)) id (right s2 b))"

end

end