Theory Guessing_Many_One
subsection ‹Reducing games with many adversary guesses to games with single guesses›
theory Guessing_Many_One imports
  CryptHOL.Computational_Model
  CryptHOL.GPV_Bisim
begin
locale guessing_many_one =
  fixes init :: "('c_o × 'c_a × 's) spmf"
  and "oracle" :: "'c_o ⇒ 's ⇒ 'call ⇒ ('ret × 's) spmf"
  and "eval" :: "'c_o ⇒ 'c_a ⇒ 's ⇒ 'guess ⇒ bool spmf"
begin
type_synonym ('c_a', 'guess', 'call', 'ret') adversary_single = "'c_a' ⇒ ('guess', 'call', 'ret') gpv"
definition game_single :: "('c_a, 'guess, 'call, 'ret) adversary_single ⇒ bool spmf"
where
  "game_single 𝒜 = do {
    (c_o, c_a, s) ← init;
    (guess, s') ← exec_gpv (oracle c_o) (𝒜 c_a) s;
    eval c_o c_a s' guess
  }"
definition advantage_single :: "('c_a, 'guess, 'call, 'ret) adversary_single ⇒ real"
where "advantage_single 𝒜 = spmf (game_single 𝒜) True"
type_synonym ('c_a', 'guess', 'call', 'ret') adversary_many = "'c_a' ⇒ (unit, 'call' + 'guess', 'ret' + unit) gpv"
definition eval_oracle :: "'c_o ⇒ 'c_a ⇒ bool × 's ⇒ 'guess ⇒ (unit × (bool × 's)) spmf"
where
  "eval_oracle c_o c_a = (λ(b, s') guess. map_spmf (λb'. ((), (b ∨ b', s'))) (eval c_o c_a s' guess))"
definition game_multi :: "('c_a, 'guess, 'call, 'ret) adversary_many ⇒ bool spmf"
where
  "game_multi 𝒜 = do {
     (c_o, c_a, s) ← init;
     (_, (b, _)) ← exec_gpv
       (†(oracle c_o) ⊕⇩O eval_oracle c_o c_a)
       (𝒜 c_a)
       (False, s);
     return_spmf b
  }"
definition advantage_multi :: "('c_a, 'guess, 'call, 'ret) adversary_many ⇒ real"
where "advantage_multi 𝒜 = spmf (game_multi 𝒜) True"
type_synonym 'guess' reduction_state = "'guess' + nat"
primrec process_call :: "'guess reduction_state ⇒ 'call ⇒ ('ret option × 'guess reduction_state, 'call, 'ret) gpv"
where
  "process_call (Inr j) x = do {
    ret ← Pause x Done;
    Done (Some ret, Inr j)
  }"
| "process_call (Inl guess) x = Done (None, Inl guess)"
primrec process_guess :: "'guess reduction_state ⇒ 'guess ⇒ (unit option × 'guess reduction_state, 'call, 'ret) gpv"
where
  "process_guess (Inr j) guess = Done (if j > 0 then (Some (), Inr (j - 1)) else (None, Inl guess))"
| "process_guess (Inl guess) _ = Done (None, Inl guess)"
abbreviation reduction_oracle :: "'guess + nat ⇒ 'call + 'guess ⇒ (('ret + unit) option × ('guess + nat), 'call, 'ret) gpv"
where "reduction_oracle ≡ plus_intercept_stop process_call process_guess"
definition reduction :: "nat ⇒ ('c_a, 'guess, 'call, 'ret) adversary_many ⇒ ('c_a, 'guess, 'call, 'ret) adversary_single"
where
  "reduction q 𝒜 c_a = do {
    j_star ← lift_spmf (spmf_of_set {..<q});
    (_, s) ← inline_stop reduction_oracle (𝒜 c_a) (Inr j_star);
    Done (projl s)
  }"
lemma many_single_reduction:
  assumes bound: "⋀c_a c_o s. (c_o, c_a, s) ∈ set_spmf init ⟹ interaction_bounded_by (Not ∘ isl) (𝒜 c_a) q"
  and lossless_oracle: "⋀c_a c_o s s' x. (c_o, c_a, s) ∈ set_spmf init ⟹ lossless_spmf (oracle c_o s' x)"
  and lossless_eval: "⋀c_a c_o s s' guess. (c_o, c_a, s) ∈ set_spmf init ⟹ lossless_spmf (eval c_o c_a s' guess)"
  shows "advantage_multi 𝒜 ≤ advantage_single (reduction q 𝒜) * q"
  including lifting_syntax
proof -
  define eval_oracle'
    where "eval_oracle' = (λc_o c_a ((id, occ :: nat option), s') guess. 
    map_spmf (λb'. case occ of Some j⇩0 ⇒ ((), (Suc id, Some j⇩0), s')
                                | None ⇒ ((), (Suc id, (if b' then Some id else None)), s'))
      (eval c_o c_a s' guess))"
  let ?multi'_body = "λc_o c_a s. exec_gpv (†(oracle c_o) ⊕⇩O eval_oracle' c_o c_a) (𝒜 c_a) ((0, None), s)"
  define game_multi' where "game_multi' = (λc_o c_a s. do {
    (_, ((id, j⇩0), s' :: 's)) ← ?multi'_body c_o c_a s;
    return_spmf (j⇩0 ≠ None) })"
  define initialize :: "('c_o ⇒ 'c_a ⇒ 's ⇒ nat ⇒ bool spmf) ⇒ bool spmf" where
    "initialize body = do {
      (c_o, c_a, s) ← init;
      j⇩s ← spmf_of_set {..<q};
      body c_o c_a s j⇩s }" for body
  define body2 where "body2 c_o c_a s j⇩s = do {
    (_, (id, j⇩0), s') ← ?multi'_body c_o c_a s;
    return_spmf (j⇩0 = Some j⇩s) }" for c_o c_a s j⇩s
  let ?game2 = "initialize body2"
  define stop_oracle where "stop_oracle = (λc_o. 
     (λ(idgs, s) x. case idgs of Inr _ ⇒ map_spmf (λ(y, s). (Some y, (idgs, s))) (oracle c_o s x) | Inl _ ⇒ return_spmf (None, (idgs, s)))
     ⊕⇩O⇧S
     (λ(idgs, s) guess :: 'guess. return_spmf (case idgs of Inr 0 ⇒ (None, Inl (guess, s), s) | Inr (Suc i) ⇒ (Some (), Inr i, s) | Inl _ ⇒ (None, idgs, s))))"
  define body3 where "body3 c_o c_a s j⇩s = do {
    (_ :: unit option, idgs, _) ← exec_gpv_stop (stop_oracle c_o) (𝒜 c_a) (Inr j⇩s, s);
    (b' :: bool) ← case idgs of Inr _ ⇒ return_spmf False | Inl (g, s') ⇒ eval c_o c_a s' g;
    return_spmf b' }" for c_o c_a s j⇩s
  let ?game3 = "initialize body3"
  { define S :: "bool ⇒ nat × nat option ⇒ bool" where "S ≡ λb' (id, occ). b' ⟷ (∃j⇩0. occ = Some j⇩0)"
    let ?S = "rel_prod S (=)"
    define initial :: "nat × nat option" where "initial = (0, None)"
    define result :: "nat × nat option ⇒ bool" where "result p = (snd p ≠ None)" for p
    have [transfer_rule]: "(S ===> (=)) (λb. b) result" by(simp add: rel_fun_def result_def S_def)
    have [transfer_rule]: "S False initial" by (simp add: S_def initial_def)
    have eval_oracle'[transfer_rule]: 
      "((=) ===> (=) ===> ?S ===> (=) ===> rel_spmf (rel_prod (=) ?S))
       eval_oracle eval_oracle'"
      unfolding eval_oracle_def[abs_def] eval_oracle'_def[abs_def]
      by (auto simp add: rel_fun_def S_def map_spmf_conv_bind_spmf intro!: rel_spmf_bind_reflI split: option.split)
    
    have game_multi': "game_multi 𝒜 = bind_spmf init (λ(c_o, c_a, s). game_multi' c_o c_a s)"
      unfolding game_multi_def game_multi'_def initial_def[symmetric]
      by (rewrite in "case_prod ⌑" in "bind_spmf _ (case_prod ⌑)" in "_ = bind_spmf _ ⌑" split_def)
         (fold result_def; transfer_prover) }
  moreover
  have "spmf (game_multi' c_o c_a s) True = spmf (bind_spmf (spmf_of_set {..<q}) (body2 c_o c_a s)) True * q"
    if "(c_o, c_a, s) ∈ set_spmf init" for c_o c_a s
  proof -
    have bnd: "interaction_bounded_by (Not ∘ isl) (𝒜 c_a) q" using bound that by blast
    have bound_occ: "j⇩s < q" if that: "((), (id, Some j⇩s), s') ∈ set_spmf (?multi'_body c_o c_a s)" 
      for s' id j⇩s
    proof -
      have "id ≤ q" 
        by(rule oi_True.interaction_bounded_by'_exec_gpv_count[OF bnd that, where count="fst ∘ fst", simplified])
          (auto simp add: eval_oracle'_def split: plus_oracle_split_asm option.split_asm)
      moreover let ?I = "λ((id, occ), s'). case occ of None ⇒ True | Some j⇩s ⇒ j⇩s < id"
      have "callee_invariant (†(oracle c_o) ⊕⇩O eval_oracle' c_o c_a) ?I"
        by(clarsimp simp add: split_def intro!: conjI[OF callee_invariant_extend_state_oracle_const'])
          (unfold_locales; auto simp add: eval_oracle'_def split: option.split_asm)
      from callee_invariant_on.exec_gpv_invariant[OF this that] have "j⇩s < id" by simp
      ultimately show ?thesis by simp
    qed
    let ?M = "measure (measure_spmf (?multi'_body c_o c_a s))"
    have "spmf (game_multi' c_o c_a s) True = ?M {(u, (id, j⇩0), s'). j⇩0 ≠ None}"
      by(auto simp add: game_multi'_def map_spmf_conv_bind_spmf[symmetric] split_def spmf_conv_measure_spmf measure_map_spmf vimage_def)
    also have "{(u, (id, j⇩0), s'). j⇩0 ≠ None} =
      {((), (id, Some j⇩s), s') |j⇩s s' id. j⇩s < q} ∪ {((), (id, Some j⇩s), s') |j⇩s s' id. j⇩s ≥ q}"
      (is "_ = ?A ∪ _") by auto
    also have "?M … = ?M ?A"
      by (rule measure_spmf.measure_zero_union)(auto simp add: measure_spmf_zero_iff dest: bound_occ)
    also have "… = measure (measure_spmf (pair_spmf (spmf_of_set {..< q}) (?multi'_body c_o c_a s)))
         {(j⇩s, (), (id, j⇩0), s') |j⇩s j⇩0 s' id. j⇩0 = Some j⇩s } * q"
      (is "_ = measure ?M' ?B * _")
    proof - 
      have "?B = {(j⇩s, (), (id, j⇩0), s') |j⇩s j⇩0 s' id. j⇩0 = Some j⇩s ∧ j⇩s < q} ∪
        {(j⇩s, (), (id, j⇩0), s') |j⇩s j⇩0 s' id. j⇩0 = Some j⇩s ∧ j⇩s ≥ q}" (is "_ = ?Set1 ∪ ?Set2")
        by auto
      then have "measure ?M' ?B = measure ?M' (?Set1 ∪ ?Set2)" by simp
      also have "… = measure ?M' ?Set1"
        by (rule measure_spmf.measure_zero_union) (auto simp add: measure_spmf_zero_iff)
      also have "… = (∑j∈{0..<q}. measure ?M' ({j} × {((), (id, Some j), s')|s' id. True}))"
        by(subst measure_spmf.finite_measure_finite_Union[symmetric])
          (auto intro!: arg_cong2[where f=measure] simp add: disjoint_family_on_def)
      also have "… = (∑j∈{0..<q}. 1 / q * measure (measure_spmf (?multi'_body c_o c_a s)) {((), (id, Some j), s')|s' id. True})"
        by(simp add: measure_pair_spmf_times spmf_conv_measure_spmf[symmetric] spmf_of_set)
      also have "… = 1 / q * measure (measure_spmf (?multi'_body c_o c_a s)) {((), (id, Some j⇩s), s')|j⇩s s' id. j⇩s < q}"
        unfolding sum_distrib_left[symmetric]
        by(subst measure_spmf.finite_measure_finite_Union[symmetric])
          (auto intro!: arg_cong2[where f=measure] simp add: disjoint_family_on_def)
      finally show ?thesis by simp
    qed
    also have "?B = (λ(j⇩s, _, (_, j⇩0), _). j⇩0 = Some j⇩s) -` {True}"
      by (auto simp add: vimage_def)
    also have rw2: "measure ?M' … = spmf (bind_spmf (spmf_of_set {..<q}) (body2 c_o c_a s)) True"
      by (simp add: body2_def[abs_def] measure_map_spmf[symmetric] map_spmf_conv_bind_spmf
        split_def pair_spmf_alt_def spmf_conv_measure_spmf[symmetric])
    finally show ?thesis .
  qed
  hence "spmf (bind_spmf init (λ(c_a, c_o, s). game_multi' c_a c_o s)) True = spmf ?game2 True * q"
    unfolding initialize_def spmf_bind[where p=init]
    by (auto intro!: integral_cong_AE simp del: integral_mult_left_zero simp add: integral_mult_left_zero[symmetric])
  moreover
  have "ord_spmf (⟶) (body2 c_o c_a s j⇩s) (body3 c_o c_a s j⇩s)"
    if init: "(c_o, c_a, s) ∈ set_spmf init" and j⇩s: "j⇩s < Suc q" for c_o c_a s j⇩s
  proof -
    define oracle2' where "oracle2' ≡ λ(b, (id, gs), s) guess. if id = j⇩s then do {
        b' :: bool ← eval c_o c_a s guess;
        return_spmf ((), (Some b', (Suc id, Some (guess, s)), s))
      } else return_spmf ((), (b, (Suc id, gs), s))"
    let ?R = "λ((id1, j⇩0), s1) (b', (id2, gs), s2). s1 = s2 ∧ id1 = id2 ∧ (j⇩0 = Some j⇩s ⟶ b' = Some True) ∧ (id2 ≤ j⇩s ⟶ b' = None)"
    from init have "rel_spmf (rel_prod (=) ?R)
      (exec_gpv (extend_state_oracle (oracle c_o) ⊕⇩O eval_oracle' c_o c_a) (𝒜 c_a) ((0, None), s))
      (exec_gpv (extend_state_oracle (extend_state_oracle (oracle c_o)) ⊕⇩O oracle2') (𝒜 c_a) (None, (0, None), s))"
      by(intro exec_gpv_oracle_bisim[where X="?R"])(auto simp add: oracle2'_def eval_oracle'_def spmf_rel_map map_spmf_conv_bind_spmf[symmetric] rel_spmf_return_spmf2 lossless_eval o_def intro!: rel_spmf_reflI split: option.split_asm plus_oracle_split if_split_asm)
    then have "rel_spmf (⟶) (body2 c_o c_a s j⇩s) 
      (do {
        (_, b', _, _) ← exec_gpv (††(oracle c_o) ⊕⇩O oracle2') (𝒜 c_a) (None, (0, None), s);
        return_spmf (b' = Some True) })"
      (is "rel_spmf _ _ ?body2'")
      
      unfolding body2_def by(rule rel_spmf_bindI) clarsimp
    also
    let ?guess_oracle = "λ((id, gs), s) guess. return_spmf ((), (Suc id, if id = j⇩s then Some (guess, s) else gs), s)"
    let ?I = "λ(idgs, s). case idgs of (_, None) ⇒ False | (i, Some _) ⇒ j⇩s < i"
    interpret I: callee_invariant_on "†(oracle c_o) ⊕⇩O ?guess_oracle" "?I" ℐ_full
      by(simp)(unfold_locales; auto split: option.split)
    let ?f = "λs. case snd (fst s) of None ⇒ return_spmf False | Some a ⇒ eval c_o c_a (snd a) (fst a)"
    let ?X = "λj⇩s (b1, (id1, gs1), s1) (b2, (id2, gs2), s2). b1 = b2 ∧ id1 = id2 ∧ gs1 = gs2 ∧ s1 = s2 ∧ (b2 = None ⟷ gs2 = None) ∧ (id2 ≤ j⇩s ⟶ b2 = None)"
    have "?body2' = do {
      (a, r, s) ← exec_gpv (λ(r, s) x. do {
               (y, s') ← (†(oracle c_o) ⊕⇩O ?guess_oracle) s x;
               if ?I s' ∧ r = None then map_spmf (λr. (y, Some r, s')) (?f s') else return_spmf (y, r, s')
             })
         (𝒜 c_a) (None, (0, None), s);
      case r of None ⇒ ?f s ⤜ return_spmf | Some r' ⇒ return_spmf r' }"
      unfolding oracle2'_def spmf_rel_eq[symmetric]
      by(rule rel_spmf_bindI[OF exec_gpv_oracle_bisim'[where X="?X j⇩s"]])
        (auto simp add: bind_map_spmf o_def spmf.map_comp split_beta conj_comms map_spmf_conv_bind_spmf[symmetric] spmf_rel_map rel_spmf_reflI cong: conj_cong split: plus_oracle_split)
    also have "… = do {
        us' ← exec_gpv (†(oracle c_o) ⊕⇩O ?guess_oracle) (𝒜 c_a) ((0, None), s);
        (b' :: bool) ← ?f (snd us');
        return_spmf b' }"
      (is "_ = ?body2''")
      by(rule I.exec_gpv_bind_materialize[symmetric])(auto split: plus_oracle_split_asm option.split_asm)
    also have "… = do {
        us' ← exec_gpv_stop (lift_stop_oracle (†(oracle c_o) ⊕⇩O ?guess_oracle)) (𝒜 c_a) ((0, None), s);
        (b' :: bool) ← ?f (snd us');
        return_spmf b' }"
      supply lift_stop_oracle_transfer[transfer_rule] gpv_stop_transfer[transfer_rule] exec_gpv_parametric'[transfer_rule]
      by transfer simp
    also let ?S = "λ((id1, gs1), s1) ((id2, gs2), s2). gs1 = gs2 ∧ (gs2 = None ⟶ s1 = s2 ∧ id1 = id2) ∧ (gs1 = None ⟷ id1 ≤ j⇩s)"
    have "ord_spmf (⟶) … (exec_gpv_stop ((λ((id, gs), s) x. case gs of None ⇒ lift_stop_oracle (†(oracle c_o)) ((id, gs), s) x | Some _ ⇒ return_spmf (None, ((id, gs), s))) ⊕⇩O⇧S
            (λ((id, gs), s) guess. return_spmf (if id ≥ j⇩s then None else Some (), (Suc id, if id = j⇩s then Some (guess, s) else gs), s)))
           (𝒜 c_a) ((0, None), s) ⤜
          (λus'. case snd (fst (snd us')) of None ⇒ return_spmf False | Some a ⇒ eval c_o c_a (snd a) (fst a)))"
      unfolding body3_def stop_oracle_def
      by(rule ord_spmf_exec_gpv_stop[where stop = "λ((id, guess), _). guess ≠ None" and S="?S", THEN ord_spmf_bindI])
        (auto split: prod.split_asm plus_oracle_split_asm split!: plus_oracle_stop_split simp del: not_None_eq simp add: spmf.map_comp o_def apfst_compose ord_spmf_map_spmf1 ord_spmf_map_spmf2 split_beta ord_spmf_return_spmf2 intro!: ord_spmf_reflI)
    also let ?X = "λ((id, gs), s1) (idgs, s2). s1 = s2 ∧ (case (gs, idgs) of (None, Inr id') ⇒ id' = j⇩s - id ∧ id ≤ j⇩s | (Some gs, Inl gs') ⇒ gs = gs' ∧ id > j⇩s | _ ⇒ False)"
    have "… = body3 c_o c_a s j⇩s" unfolding body3_def spmf_rel_eq[symmetric] stop_oracle_def
      by(rule exec_gpv_oracle_bisim'[where X="?X", THEN rel_spmf_bindI])
        (auto split: option.split_asm plus_oracle_stop_split nat.splits split!: sum.split simp add: spmf_rel_map intro!: rel_spmf_reflI)
    finally show ?thesis by(rule pmf.rel_mono_strong)(auto elim!: option.rel_cases ord_option.cases)
  qed
  { then have "ord_spmf (⟶) ?game2 ?game3"
      by(clarsimp simp add: initialize_def intro!: ord_spmf_bind_reflI)
    also
    let ?X = "λ(gsid, s) (gid, s'). s = s' ∧ rel_sum (λ(g, s1) g'. g = g' ∧ s1 = s') (=) gsid gid"
    have "rel_spmf (⟶) ?game3 (game_single (reduction q 𝒜))"
      unfolding body3_def stop_oracle_def game_single_def reduction_def split_def initialize_def
      apply(clarsimp simp add: bind_map_spmf exec_gpv_bind exec_gpv_inline intro!: rel_spmf_bind_reflI)
      apply(rule rel_spmf_bindI[OF exec_gpv_oracle_bisim'[where X="?X"]])
      apply(auto split: plus_oracle_stop_split elim!: rel_sum.cases simp add: map_spmf_conv_bind_spmf[symmetric] split_def spmf_rel_map rel_spmf_reflI rel_spmf_return_spmf1 lossless_eval split: nat.split)
      done
    finally have "ord_spmf (⟶) ?game2 (game_single (reduction q 𝒜))"
      by(rule pmf.rel_mono_strong)(auto elim!: option.rel_cases ord_option.cases)
    from this[THEN ord_spmf_measureD, of "{True}"]
    have "spmf ?game2 True ≤ spmf (game_single (reduction q 𝒜)) True" unfolding spmf_conv_measure_spmf
      by(rule ord_le_eq_trans)(auto intro: arg_cong2[where f=measure]) }
  ultimately show ?thesis unfolding advantage_multi_def advantage_single_def 
    by(simp add: mult_right_mono)
qed
end
end