Theory Definition_O2H

theory Definition_O2H

imports Registers.Pure_States
  O2H_Additional_Lemmas

begin

unbundle cblinfun_syntax
unbundle lattice_syntax

section ‹Definitions for the one-way to Hiding (O2H) Lemma›
text ‹Here, we first define the context of the O2H Lemma and foundations.›

text ‹First of all, we need a notion of a query to the oracle. This is defined in the unitary 
Uquery›, where the input H› is the (classical) oracle.›


definition Uquery :: ('x  ('y::plus))  (('x × 'y) ell2 CL ('x × 'y) ell2) where
  "Uquery H = classical_operator (Some o (λ(x,y). (x, y + (H x))))"


subsection ‹Locale for the general O2H setting›


text ‹Locale for O2H assumptions and setting.›

locale o2h_setting =
  ― ‹Fix types for instantiations of locales›
  fixes type_x ::"'x itself"
  fixes type_y :: "('y::group_add) itself"
  fixes type_mem :: "'mem itself"
  fixes type_l :: "'l itself"

― ‹X› and Y› are the embeddings of the (classical) oracle domain types. 'mem› is the type of 
  the quantum memory we work on.›
fixes X :: "'x update  'mem update"
fixes Y :: "('y::group_add) update  'mem update" 

― ‹The embeddings X› and Y› must be compatible with the registers.›
assumes compat[register]: "mutually compatible (X,Y)"

― ‹We fix the query depth d› of $A$. We ensure that we have queries at least once.›
fixes d :: nat 
assumes d_gr_0: "d > 0" 
  ― ‹The initial quantum state init› of the registers. For this version of the O2H, we work with 
    a pure initial state.›
fixes init :: 'mem ell2
assumes norm_init: "norm init = 1" ―‹init› is a pure state›


― ‹The type 'l› represents the quantum register for the query log.
    We also need three functions depending on the type 'l›, namely flip›, bit› and valid›.
    flip› is a bit-flipping operation that changes bits on the valid set and may behave like an
    identity function on the rest.
    bit› is a function returning the $i$-th bit of a valid element in 'l›.
    valid› is a functional representation of the valid set of the query log.
    Since 'l› may be (theoretically) infinitely large, we need to restrict on the valid set in 
    many lemmas.›

fixes flip:: nat  'l  'l
fixes bit:: 'l  nat  bool
fixes valid:: 'l  bool
fixes empty :: 'l

― ‹Empty is the initial state on 'l› (equalling the zero state).›
assumes valid_empty: "valid empty"

― ‹Assumptions on flip›, bit› and valid›: 
    flip› is a function that takes an index i› and an element l::'l› and 
    "flips the $i$-th bit". However, to remain in the valid range, this flip is only performed for
    indices smaller than d›, otherwise we may assume flip› to be the identity.›
assumes valid_flip: "i<d  valid l  valid (flip i l)"
  ― ‹The flip› operation must be idempotent.›
assumes inj_flip: "inj (flip i)"
assumes valid_flip_flip: "i<d  valid l  flip i (flip i l) = l"
  ― ‹The flip› operation must be commutative with itself.›
assumes valid_flip_comm: "i<d  j<d  valid l  flip i (flip j l) = flip j (flip i l)"

― ‹For valid elements, the bits in the range up to d› behave as in a normal 
      bit-flipping operation.›
assumes valid_bit_flip_same: "i<d  valid l  bit (flip i l) i = (¬ (bit l i))"
assumes valid_bit_flip_diff: "i<d  valid l  ij  bit (flip i l) j = bit l j"



begin
text ‹We introduce a set of $2^d$ valid elements for counting.
Since we need a finite set for easier proofs while counting the adversarial queries, we embed the 
set of $2^d$ elements into the valid set. 
The elements from blog› can all be derived by flipping bits from the initial empty state.
We then only look at the elements with bits in the first $d$ entries.›
inductive blog :: "'l  bool" where
  "blog empty"
| "blog l  i<d  blog (flip i l)"


lemma blog_empty: "blog empty"
  by (rule blog.intros)

lemma blog_flip: "i<d  blog l  blog (flip i l)" (*bij_on valid flip*)
  by (rule blog.intros)

lemma blog_valid:
  "blog l  valid l"
  by (induction rule: blog.induct) (auto simp add: valid_empty valid_flip)


lemma flip_flip: "i<d  blog l  flip i (flip i l) = l"
  using blog_valid valid_flip_flip by auto

lemma bit_flip_same: "i<d  blog l  bit (flip i l) i = (¬ (bit l i))"
  using blog_valid valid_bit_flip_same by auto

lemma bit_flip_diff: "i<d  blog l  ij  bit (flip i l) j = bit l j"
  using blog_valid valid_bit_flip_diff by auto

text ‹The embedding of a boolean list (of length $d$) into the blog› set.›

fun list_to_l :: "bool list  'l" where  (* intended only for d-length lists *)
  "list_to_l [] = empty" |
  "list_to_l (False # list) = list_to_l list" |
  "list_to_l (True # list) = flip (length list) (list_to_l list)"


definition len_d_lists :: "bool list set" where
  "len_d_lists = {xs. length xs = d}"

lemma card_len_d_lists:
  "card (len_d_lists) = (2::nat)^d" 
proof -
  have l: "len_d_lists = {xs. set xs  {True, False}  length xs = d}" 
    unfolding len_d_lists_def by auto
  show ?thesis unfolding l 
    by (subst card_lists_length_eq[of "{True, False}"])(auto simp add: numeral_2_eq_2)
qed

lemma finite_len_d_lists[simp]:
  "finite len_d_lists"
  using card_len_d_lists card.infinite by force




lemma blog_list_to_l:
  assumes "length ls  d"
  shows "blog (list_to_l ls)"
  using assms by (induction rule: list_to_l.induct) (auto simp add: blog.intros)

lemma flip_commute:
  assumes "ij" "i<d" "j<d" "length ls  d"
  shows "flip i (flip j (list_to_l ls)) = flip j (flip i (list_to_l ls))"
  by (simp add: assms(2) assms(3) assms(4) blog_list_to_l blog_valid valid_flip_comm)

lemma flip_list_to_l:
  assumes "i < length ls" "length ls  d"
  shows "flip i (list_to_l ls) = list_to_l (ls[length ls - i - 1 := ¬ ls ! (length ls - i - 1)])"
  using assms proof (induction ls arbitrary: i rule: list_to_l.induct)
  case (2 l)
  have "i<d" using 2 by auto
  show ?case proof (cases "i=length l")
    case True 
    have "flip i (list_to_l (False # l)) = flip i (list_to_l l)" by auto
    also have " = list_to_l (True # l)" by (simp add: True)
    also have " = list_to_l ((False # l)[length (False # l) - i - 1 :=
         ¬ (False # l) ! (length (False # l) - i - 1)])" unfolding i=length l by auto
    finally show ?thesis by auto
  next
    case False
    then have "i<length l" using "2"(2) by auto
    let ?l = "False # l"
    have "flip i (list_to_l (False # l)) = flip i (list_to_l l)" by auto
    also have " = list_to_l (l[length l-i-1 := ¬l!(length l-i-1)])" using 2 
      by (simp add: i < length l)
    also have " = list_to_l (?l[length ?l-i-1 := ¬?l!(length ?l-i-1)])"
      by (smt (verit, ccfv_threshold) "2"(2) Nat.diff_add_assoc2 Suc_diff_Suc Suc_eq_plus1 
          i < length l diff_Suc_1 diff_is_0_eq leD le_simps(3) list.size(4) list_update_code(3) 
          nth_Cons' numeral_nat(7) list_to_l.simps(2))
    finally show ?thesis by auto
  qed
next
  case (3 l)
  have "i<d" using 3 by auto
  show ?case proof (cases "i=length l")
    case True 
    have blog: "blog (list_to_l l)" using blog_list_to_l 3(3) by auto
    have "flip i (list_to_l (True # l)) = flip i (flip i (list_to_l l))" using True by auto
    also have " = list_to_l l" using flip_flip[OF i<d] blog by auto
    finally show ?thesis using True by auto
  next
    case False
    then have "i<length l" using "3"(2) by auto
    let ?l = "True # l"
    have "flip i (list_to_l ?l) = flip i (flip (length l) (list_to_l l))" by auto
    also have " = flip (length l) (flip i (list_to_l l))"
      by (intro flip_commute) (use 3 in auto simp add: False i<d)
    also have " = flip (length l) (list_to_l (l[length l-i-1 := ¬l!(length l-i-1)]))"
      using "3"(3) "3.IH" i < length l by auto
    also have " = list_to_l (True # (l[length l-i-1 := ¬l!(length l-i-1)]))" by simp
    also have " = list_to_l (?l[length ?l-i-1 := ¬?l!(length ?l-i-1)])"
      by (smt (verit, ccfv_threshold) Suc_diff_Suc i < length l
          cancel_ab_semigroup_add_class.diff_right_commute diff_Suc_eq_diff_pred length_tl list.sel(3) 
          list_update_code(3) nth_Cons_Suc)
    finally show ?thesis by auto
  qed
qed auto

text ‹The initial list corresponding to the initial value empty› is the list containing only 
False›.›

definition empty_list where
  "empty_list = replicate d False"

lemma empty_list_to_l_replicate:
  "list_to_l (replicate n False) = empty" 
  by (induct n, auto)

lemma empty_list_to_l [simp]:
  "list_to_l empty_list = empty" 
  by (auto simp add: empty_list_to_l_replicate empty_list_def)

lemma empty_list_len_d[simp]:
  "empty_list  len_d_lists"
  unfolding empty_list_def len_d_lists_def by auto

lemma empty_list_to_l_elem [simp]:
  "empty  list_to_l ` len_d_lists"
  by (metis empty_list_len_d empty_list_to_l imageI)


text ‹Lemmas on how list_to_l› works with flip› and bit›.›

lemma list_to_l_flip:
  assumes "i < length ls" "length ls  d"
  shows "list_to_l (ls[i := ¬ ls ! i]) = flip (length ls - 1 - i) (list_to_l ls)"
  using assms proof (induction ls arbitrary: i rule: list_to_l.induct)
  case (2 list)
  then show ?case proof (cases "i=0")
    case False
    then obtain j where j: "i = Suc j" using not0_implies_Suc by presburger
    have "list_to_l ((False # list)[i := ¬ (False # list) ! i]) = 
      list_to_l (False # list[i-1 := ¬ list ! (i-1)])" unfolding j by auto
    also have " = list_to_l (list[i-1 := ¬ list ! (i-1)])" by auto
    also have " = flip (length list - 1 - (i-1)) (list_to_l list)" using 2 "2.IH" j by auto
    also have " = flip (length (False # list) - 1 - i) (list_to_l (False # list))"
      by (simp add: j)
    finally show ?thesis by auto
  qed auto
next
  case (3 list)
  then show ?case proof (cases "i=0")
    case True
    have False: "(True # list)[i := ¬ (True # list) ! i] = False # list" using True by auto
    have len: "length (True # list) - 1 - i = length list" using True by auto
    have len': "length list < d" using 3 by auto
    have len'': "length list  d" using 3 by auto
    have "list_to_l list = flip (length list) (flip (length list) (list_to_l list))"
      using flip_flip[OF len'] by (simp add: blog_list_to_l len'') 
    then show ?thesis unfolding False len by (subst list_to_l.simps)+ auto
  next
    case False
    then obtain j where j: "i = Suc j" using not0_implies_Suc by presburger
    have len: "length list < d" using 3 by auto
    have "list_to_l ((True # list)[i := ¬ (True # list) ! i]) = 
      list_to_l (True # list[i-1 := ¬ list ! (i-1)])" unfolding j by auto
    also have " = flip (length list) (list_to_l (list[i-1 := ¬ list ! (i-1)]))" by auto
    also have " = flip (length list) (flip (length list - 1 - (i-1)) (list_to_l list))" 
      using 3 "3.IH" j by auto
    also have " = flip (length list) (flip (length (True # list) - 1 - i) (list_to_l list))"
      by (simp add: j)
    also have " = flip (length (True # list) - 1 - i) (flip (length list) (list_to_l list))"
      by (intro flip_commute) (use 3 j in auto simp add: len)
    finally show ?thesis by auto
  qed
qed auto

lemma surj_list_to_l: "list_to_l  ` len_d_lists = Collect blog" 
proof (safe, goal_cases)
  case (1 _ xa)
  then have "length xa  d" unfolding len_d_lists_def by auto
  then show ?case proof (induct xa rule: list_to_l.induct)
    case 1
    show ?case by (auto simp add: blog.intros)
  next
    case (2 list)
    then show ?case by (subst list_to_l.simps(2), simp)
  next
    case (3 list)
    then show ?case by (subst list_to_l.simps(3), intro blog.intros, auto)
  qed
next
  case (2 x)
  then show ?case proof (induct rule: blog.induct)
    case 1
    show ?case by (subst empty_list_to_l[symmetric]) auto
  next
    case (2 l i)
    then obtain ls where ls: "list_to_l ls = l" "length ls = d" using len_d_lists_def by force
    have "blog (flip i l)" using 2 by (intro blog.intros, auto)
    define flip_ls where "flip_ls = ls [d-i-1:= ¬ ls!(d-i-1)]"
    then have "length flip_ls = d" using length ls = d by auto
    moreover have "list_to_l flip_ls = flip i l" unfolding flip_ls_def ls(2)[symmetric] 
      by (subst flip_list_to_l[symmetric]) (auto simp add:2 ls)
    ultimately show ?case unfolding len_d_lists_def by (simp add: rev_image_eqI)
  qed
qed

lemma bit_list_to_l_over:
  assumes "length l  d" "i<d" "length l  i"
  shows "bit (list_to_l l) i = bit empty i"
  using assms proof (induct rule: list_to_l.induct)
  case (2 list)
  then show ?case using bit_flip_same[OF i<d] by auto
next
  case (3 list)
  then show ?case by (auto simp add: bit_flip_diff blog_list_to_l)
qed auto


lemma bit_list_to_l:
  assumes "length l  d" "i<length l"
  shows "bit (list_to_l l) i = (if l!(length l-i-1) then ¬ bit empty i else bit empty i)"
  using assms proof (induct rule: list_to_l.induct)
  case (2 list)
  let ?l = "False # list"
  have "length list  d" using 2 by auto
  have "i<d" using 2 by auto
  have rew: "bit (list_to_l ?l) i = bit (list_to_l list) i" by auto
  have c1: "bit (list_to_l list) i = (if ?l!(length ?l-i-1) then ¬ bit empty i else bit empty i)"
    if "ilength list" using "2"(1) "2"(3) length list  d that by auto
  have c2: "bit (list_to_l list) i = (if ?l!(length ?l-i-1) then ¬ bit empty i else bit empty i)"
    if "i=length list" using that by (subst bit_list_to_l_over[OF length list  d i<d]) auto
  show ?case by (subst rew, cases "i = length list")(use c1 c2 in auto)
next
  case (3 list)
  let ?l = "True # list"
  have "length list  d" using 3 by auto
  have "i<d" using 3 by auto
  have rew: "bit (list_to_l ?l) i = bit (flip (length list) (list_to_l list)) i" (is "_ = ?right") 
    by auto
  have c1: "?right = (if ?l!(length ?l-i-1) then ¬ bit empty i else bit empty i)" 
    if "ilength list"
  proof -
    have "?right = bit (list_to_l list) i" using that "3"(2) blog_list_to_l 
        blog_valid valid_bit_flip_diff by force
    also have " = (if ?l!(length ?l-i-1) then ¬ bit empty i else bit empty i)" 
      using 3 length list  d that by auto
    finally show ?thesis by auto
  qed
  have c2: "?right = (if ?l!(length ?l-i-1) then ¬ bit empty i else bit empty i)"
    if "i=length list" 
  proof -
    have "?right = (¬ bit (list_to_l list) i)"
      using i < d length list  d bit_flip_same blog_list_to_l that by blast
    also have " = (¬ bit empty i)"
      using i < d bit_list_to_l_over less_or_eq_imp_le that by blast 
    finally show ?thesis using that by auto
  qed
  show ?case by (subst rew, cases "i = length list")(use c1 c2 in auto)
qed auto

lemma list_to_l_eq:
  assumes "list_to_l xs = list_to_l ys" "length xs = d" "length ys = d"
  shows "xs = ys"
  using le0[of d] assms proof (induct d arbitrary: xs ys rule: Nat.dec_induct)
  case (step n)
  obtain x xs' where xs: "xs = x # xs'" "length xs' = n" using step by (meson length_Suc_conv)
  obtain y ys' where ys: "ys = y # ys'" "length ys' = n"using step by (meson length_Suc_conv)
  consider (same) "x=y" | (neq) "xy" by blast 
  then have "list_to_l xs' = list_to_l ys'  x=y" 
  proof (cases)
    case same
    then have "list_to_l xs' = list_to_l ys'"
      by (metis (full_types) blog_list_to_l flip_flip list_to_l.simps(2,3) step(2,4) 
          not_le order.asym xs ys)
    then show ?thesis using same by blast
  next
    case neq
    have False by (metis One_nat_def Suc_leI bit_list_to_l diff_Suc_1 diff_add_inverse2 lessI 
          step(2,4,5,6) neq nth_Cons_0 plus_1_eq_Suc xs(1) ys(1))
    then show ?thesis by auto
  qed
  then show ?case unfolding xs ys by (simp add: local.step(3) xs(2) ys(2))
qed auto


lemma inj_list_to_l: "inj_on list_to_l (len_d_lists)" 
  unfolding inj_on_def proof (safe, goal_cases)
  case (1 xs ys)
  have len: "length xs = d" "length ys = d" using 1 unfolding len_d_lists_def by auto
  show ?case using len 1 list_to_l_eq by auto
qed


lemma bij_betw_list_to_l: "bij_betw list_to_l len_d_lists (Collect blog)"
  using bij_betw_def inj_list_to_l surj_list_to_l by blast

lemma card_blog: "card (Collect blog) = 2^d"
  by (metis card_image card_len_d_lists inj_list_to_l surj_list_to_l)





text ‹We split the $2^d$ elements into elements that have bits only in a certain set.
This is later used to argue that an adversary running up to some $n$ can only generate a 
count up to the $n$-th bit.›

(* How to address values in {2^n..<2^d} in len_d_list *)
definition has_bits :: "nat set  bool list set" where
  "has_bits A = {l. llen_d_lists  True  (λi. l!(d-i-1)) ` A}"


lemma has_bits_empty[simp]:
  "has_bits {} = {}" 
  unfolding has_bits_def by auto

lemma has_bits_not_empty:
  assumes "y  has_bits A" "A{}" "ylen_d_lists"
  shows "list_to_l y  empty"
proof -
  obtain x where "xA" "y!(d-x-1)" using assms unfolding has_bits_def by auto
  then show ?thesis
    by (smt (verit, best) assms(3) bit_list_to_l d_gr_0 diff_Suc_1 diff_diff_cancel diff_is_0_eq' 
        diff_le_self le_simps(2) len_d_lists_def mem_Collect_eq not_le)
qed

lemma has_bits_empty_list:
  "empty_list  has_bits {0..<d}"
  using has_bits_not_empty by fastforce

lemma has_bits_incl:
  assumes "AB"
  shows "has_bits A  has_bits B"
  using assms has_bits_def by auto

lemma has_bits_in_len_d_lists[simp]:
  "has_bits A  len_d_lists"
  unfolding has_bits_def by auto

lemma finite_has_bits[simp]:
  "finite (has_bits A)"
  by (meson finite_len_d_lists has_bits_in_len_d_lists rev_finite_subset)

lemma has_bits_not_elem:
  assumes "yhas_bits A" "A{}" "A{0..<d}" "ylen_d_lists" "n  A" "n<d"
  shows "y[d-n-1:=¬y!(d-n-1)]  has_bits A"
proof -
  obtain i where i: "y ! (d - i - 1)" "iA" using assms has_bits_def by auto
  then have "ni" using assms by auto
  then have "y[d-n-1 := ¬ y!(d-n-1)]!(d-i-1)" using i assms by (subst nth_list_update_neq) auto
  moreover have "length (y[d - n - 1 := ¬ y ! (d - n - 1)]) = d" using assms len_d_lists_def by auto
  ultimately show ?thesis using iA unfolding has_bits_def len_d_lists_def by auto
qed

lemma has_bits_split_Suc:
  assumes "n<d"
  shows "has_bits {n..<d} = has_bits {n}  has_bits {Suc n..<d}"
proof -
  have "x  len_d_lists  x ! (d - Suc xa)  xa{Suc n..<d}. ¬ x ! (d - Suc xa) 
       n  xa  xa < d  x ! (d - Suc n)" for x xa
    by (metis atLeastLessThan_iff le_eq_less_or_eq le_simps(3))
  moreover have "x  len_d_lists  x ! (d - Suc n)  xa{n..<d}. x ! (d - Suc xa)" for x
    using assms atLeastLessThan_iff by blast
  ultimately show ?thesis unfolding has_bits_def by auto
qed

text ‹The function has_bits_upto› looks only at elements with bits lower than some $n$.›

definition has_bits_upto where
  "has_bits_upto n = len_d_lists - has_bits {n..<d}"

lemma finite_has_bits_upto [simp]:
  "finite (has_bits_upto n)"
  unfolding has_bits_upto_def by auto

lemma has_bits_elem:
  assumes "x  len_d_lists - has_bits A" "aA"
  shows "¬x!(d-a-1)"
  using assms(1) assms(2) has_bits_def by force

lemma has_bits_upto_elem:
  assumes "x  has_bits_upto n" "n<d"
  shows "¬x!(d-n-1)"
  using assms has_bits_upto_def by (intro has_bits_elem[of x "{n..<d}" n]) auto

lemma has_bits_upto_incl:
  assumes "n  m"
  shows "has_bits_upto n  has_bits_upto m"
  using assms unfolding has_bits_upto_def by (simp add: Diff_mono has_bits_incl)

lemma has_bits_upto_d:
  "has_bits_upto d = len_d_lists"
  unfolding has_bits_upto_def by auto

lemma empty_list_has_bits_upto:
  "empty_list  has_bits_upto n" 
  using empty_list_to_l has_bits_not_empty has_bits_upto_def by fastforce

lemma empty_list_to_l_has_bits_upto:
  "empty  list_to_l ` has_bits_upto n"
  using empty_list_has_bits_upto empty_list_to_l by (metis image_eqI)

lemma len_d_empty_has_bits:
  "len_d_lists - {empty_list} = has_bits {0..<d}"
proof (safe, goal_cases)
  case (1 x)
  then have "¬ x!(d-i-1)" if "i<d" for i using has_bits_elem that by auto
  then have "¬ x!i" if "i<d" for i
    by (metis Suc_leI d_gr_0 diff_Suc_less diff_add_inverse diff_diff_cancel plus_1_eq_Suc that)
  then have "x = empty_list" unfolding empty_list_def
    by (smt (verit, best) "1"(1) in_set_conv_nth len_d_lists_def mem_Collect_eq replicate_eqI)
  then show ?case by auto
qed (auto simp add: has_bits_def has_bits_empty_list empty_list_def)






text ‹Properties of d›
lemma two_d_gr_1:
  "2^d > (1::nat)"
  by (meson d_gr_0 one_less_power rel_simps(49) semiring_norm(76))

text ‹Lemmas on flip›, bit› and valid›.›
lemma valid_inv: "- Collect valid = valid -` {False}" by auto

lemma blog_inv: "- Collect blog = blog -` {False}" by auto

lemma not_blog_flip: "i<d  (¬ blog l)  (¬ blog (flip i l))"
  by (metis blog.intros(2) blog_valid inj_def inj_flip valid_flip_flip)




text ‹Lemmas on X› and Y›. 
X› and Y› are embeddings of the classical memory parts of input and output registers to the 
oracle function into the quantum register 'mem›.›
lemma register_X:
  "register X"
  by auto

lemma register_Y:
  "register Y"
  by auto

lemma X_0:
  "X 0 = 0" using clinear_register complex_vector.linear_0 register_X by blast


text ‹How to check that no qubit in 'x› is in the set S› in a quantum setting.
This is more complicated, since we cannot just ask if $x\in S$. We need to ask for the embedding 
of the projection of the classical set $S$ in the register X›.›

definition "proj_classical_set M = Proj (ccspan (ket ` M))" 
  (* This definition was taken from https://github.com/dominique-unruh/qrhl-tool/blob/
ecffff7667ab1e9b2cf957de82dfa7d22a8bd91a/isabelle-thys/QRHL_Core.thy#LL1864C1-L1864C69 *)
definition "S_embed S' = X (proj_classical_set (Collect S'))"
definition "not_S_embed S' = X (proj_classical_set (- (Collect S')))"


lemma is_Proj_proj_classical_set:
  "is_Proj (proj_classical_set M)"
  unfolding proj_classical_set_def by auto

lemma proj_classical_set_split_id:
  "id_cblinfun = proj_classical_set M + proj_classical_set (-M)"
  unfolding proj_classical_set_def
  by (smt (verit) Compl_iff Proj_orthog_ccspan_union Proj_top boolean_algebra_class.sup_compl_top 
      ccspan_range_ket imageE image_Un orthogonal_ket)

lemma proj_classical_set_sum_ket_finite:
  assumes "finite A"
  shows "proj_classical_set A = (iA. selfbutter (ket i))"
  using assms proof (induction A rule: Finite_Set.finite.induct)
  case (insertI A a)
  show ?case proof (cases "aA")
    case False
    have insert: "ket ` (insert a A) = insert (ket a) (ket ` A)" by auto
    have Proj: "Proj (ccspan (ket ` A)) = (iA. proj (ket i))" 
      using insertI unfolding proj_classical_set_def by (auto simp add: butterfly_eq_proj)
    show ?thesis unfolding proj_classical_set_def insert 
    proof (subst Proj_orthog_ccspan_insert, goal_cases)
      case 2
      then show ?case unfolding Proj
        by (simp add: False butterfly_eq_proj local.insertI(1))
    qed  (auto simp add: False)
  qed (simp add: insertI insert_absorb)
qed (auto simp add: proj_classical_set_def)


lemma proj_classical_set_not_elem:
  assumes "iA"
  shows "proj_classical_set A *V ket i = 0"
  by (metis Compl_iff Proj_fixes_image add_cancel_right_left assms cblinfun.add_left ccspan_superset' 
      id_cblinfun_apply proj_classical_set_def proj_classical_set_split_id rev_image_eqI)

lemma proj_classical_set_elem:
  assumes "iA"
  shows "proj_classical_set A *V ket i = ket i"
  using assms by (simp add: Proj_fixes_image ccspan_superset' proj_classical_set_def)


lemma proj_classical_set_upto:
  assumes "i<j"
  shows "proj_classical_set {j..} *V ket (i::nat) = 0"
  by (intro proj_classical_set_not_elem) (use assms in auto)

lemma proj_classical_set_apply:
  assumes "finite A"
  shows "proj_classical_set A *V y = (iA. Rep_ell2 y i *C ket i)"
  unfolding proj_classical_set_def trunc_ell2_as_Proj[symmetric]
  by (intro trunc_ell2_finite_sum, simp add: assms)

lemma proj_classical_set_split_Suc:
  "proj_classical_set {n..} = proj (ket n) + proj_classical_set {Suc n..}"
proof -
  have " ket ` {n..} = insert (ket n) (ket ` {Suc n..})" by fastforce
  then show ?thesis unfolding proj_classical_set_def 
    by (subst Proj_orthog_ccspan_insert[symmetric]) auto 
qed

lemma proj_classical_set_union:
  assumes "x y. x  ket ` A  y  ket ` B  is_orthogonal x y"
  shows "proj_classical_set (A  B) = proj_classical_set A + proj_classical_set B"
  unfolding proj_classical_set_def 
  by (subst image_Un, intro Proj_orthog_ccspan_union)(auto simp add: assms)


text ‹Later, we need to project only on the second part of the register (the counting part).›

definition Proj_ket_set :: "'a set  ('mem × 'a) update" where
  "Proj_ket_set A = id_cblinfun o proj_classical_set A"

lemma Proj_ket_set_vec:
  assumes "y  A"
  shows "Proj_ket_set A *V (v s ket y) = v s ket y"
  unfolding Proj_ket_set_def using proj_classical_set_elem[OF assms] 
  by (auto simp add: tensor_op_ell2)


definition Proj_ket_upto :: "bool list set  ('mem × 'l) update" where
  "Proj_ket_upto A = Proj_ket_set (list_to_l ` A)"

lemma Proj_ket_upto_vec:
  assumes "y  A"
  shows "Proj_ket_upto A *V (v s ket (list_to_l y)) = v s ket (list_to_l y)"
  unfolding Proj_ket_upto_def using assms by (auto intro!: Proj_ket_set_vec)


text ‹We can split a state into two parts: the part in S› and the one not in S›.›
lemma S_embed_not_S_embed_id [simp]:
  "S_embed S' + not_S_embed S' = id_cblinfun"
proof -
  have "proj_classical_set (Collect S') + proj_classical_set (- (Collect S')) = id_cblinfun"
    unfolding proj_classical_set_def 
    by (subst Proj_orthog_ccspan_union[symmetric]) (auto simp add: image_Un[symmetric])
  then have *: "X (proj_classical_set (Collect S') + proj_classical_set (- (Collect S'))) = 
    X id_cblinfun" by auto
  have "X (proj_classical_set (Collect S')) + X (proj_classical_set (- (Collect S'))) = 
    X id_cblinfun" unfolding *[symmetric]
    using clinear_register[OF register_X] by (simp add: clinear_iff)
  then show ?thesis unfolding S_embed_def not_S_embed_def by auto
qed

lemma S_embed_not_S_embed_add:
  "S_embed S' (ket a) + not_S_embed S' (ket a) = ket a"
  using S_embed_not_S_embed_id
  by (metis cblinfun_id_cblinfun_apply plus_cblinfun.rep_eq)

lemma S_embed_idem [simp]:
  "S_embed S' oCL S_embed S' = S_embed S'"
  unfolding S_embed_def Axioms_Quantum.register_mult[OF register_X] proj_classical_set_def by auto

lemma S_embed_adj:
  "(S_embed S')* = S_embed S'"
  unfolding S_embed_def register_adj[OF register_X, symmetric] proj_classical_set_def adj_Proj
  by auto

lemma not_S_embed_idem:
  "not_S_embed S' oCL not_S_embed S' = not_S_embed S'"
  unfolding not_S_embed_def Axioms_Quantum.register_mult[OF register_X] proj_classical_set_def by auto

lemma not_S_embed_adj:
  "(not_S_embed S')* = not_S_embed S'"
  unfolding not_S_embed_def register_adj[OF register_X, symmetric] proj_classical_set_def adj_Proj
  by auto


lemma not_S_embed_S_embed [simp]:
  "not_S_embed S' oCL S_embed S' = 0"
proof -
  have "orthogonal_spaces (ccspan (ket ` (- Collect S'))) (ccspan (ket ` Collect S'))" 
    using orthogonal_spaces_ccspan by fastforce
  then have "proj_classical_set (- Collect S') oCL proj_classical_set (Collect S') = 0"
    unfolding proj_classical_set_def using orthogonal_projectors_orthogonal_spaces by auto
  then show ?thesis unfolding not_S_embed_def S_embed_def Axioms_Quantum.register_mult[OF register_X] 
    using X_0 by auto
qed

lemma S_embed_not_S_embed [simp]:
  "S_embed S' oCL not_S_embed S' = 0"
  by (metis S_embed_adj adj_0 adj_cblinfun_compose not_S_embed_S_embed not_S_embed_adj)

lemma not_S_embed_Proj:
  "not_S_embed S = Proj (not_S_embed S *S )"
  unfolding not_S_embed_def using register_projector[OF register_X is_Proj_proj_classical_set]
  by (simp add: Proj_on_own_range)

lemma not_S_embed_in_X_image:
  assumes "a  space_as_set (- (not_S_embed S *S ))" 
  shows "(not_S_embed S)*V a = 0"
  using register_projector[OF register_X is_Proj_proj_classical_set]
    Proj_0_compl[OF assms] unfolding not_S_embed_def by (simp add: Proj_on_own_range)

text ‹In the register for the adversary runs run_B› and run_B_count›, 
we want to look at the 'mem› part only.
Ψs› lets us look at the 'mem› part that is tensored with the i›-th ket state.›
definition Ψs where "Ψs i v = (tensor_ell2_right (ket i)*) *V v"

lemma tensor_ell2_right_compose_id_cblinfun:
  "tensor_ell2_right (ket a)* oCL A o id_cblinfun = A oCL tensor_ell2_right (ket a)*"
  by (intro equal_ket)(auto simp add: tensor_ell2_ket[symmetric] tensor_op_ell2 cblinfun.scaleC_right)

lemma Ψs_id_cblinfun:
  "Ψs a ((A o id_cblinfun) *V v) = A *V (Ψs a v)"
  unfolding Ψs_def 
  by (auto simp add: cblinfun_apply_cblinfun_compose[symmetric] tensor_ell2_right_compose_id_cblinfun
      simp del: cblinfun_apply_cblinfun_compose)


text ‹Additional Lemmas›

lemma id_cblinfun_tensor_split_finite:
  assumes "finite A"
  shows "(id_cblinfun:: ('mem × 'a) ell2 CL ('mem × 'a) ell2) = 
 (iA. (tensor_ell2_right (ket i)) oCL (tensor_ell2_right (ket i)*)) + 
  Proj_ket_set (-A)" 
proof -
  have "(id_cblinfun:: ('mem × 'a) update) = id_cblinfun o id_cblinfun" by auto
  also have " = id_cblinfun o (proj_classical_set A + proj_classical_set (-A))"
    by (subst proj_classical_set_split_id[of "A"]) (auto)
  also have " = id_cblinfun o (iA. selfbutter (ket i)) + 
    Proj_ket_set (-A)" unfolding Proj_ket_set_def
    by (subst proj_classical_set_sum_ket_finite[OF assms])(auto simp add: tensor_op_right_add)
  also have " = (iA. id_cblinfun o selfbutter (ket i)) +
    Proj_ket_set (-A)"
    using clinear_tensor_right complex_vector.linear_sum by (smt (verit, best) sum.cong)
  also have " = (iA. (tensor_ell2_right (ket i)) oCL (tensor_ell2_right (ket i)*)) +
    Proj_ket_set (-A)" 
    by (simp add: tensor_ell2_right_butterfly)
  finally show ?thesis by auto
qed


text ‹Lemmas on sums of butterflys›

lemma sum_butterfly_ket0:
  assumes "(y::nat)<d+1"
  shows "(i<d+1. butterfly (ket 0) (ket i)) *V (ket y) = ket 0"
proof -
  have "(i<d+1. butterfly (ket 0) (ket i)) *V ket y = (i<d+1. if i=y then ket 0 else 0)"
    by (subst cblinfun.sum_left, intro sum.cong, auto)
  also have " = ket 0"  by (subst sum.delta, use assms in auto)
  finally show ?thesis by auto
qed

lemma sum_butterfly_ket0':
  "(i<d+1. butterfly (ket 0) (ket i))*V proj_classical_set {..<d+1} *V y =
 (i<d+1. Rep_ell2 y i) *C ket 0" 
  for y::"nat ell2"
proof -
  have "(i<d+1. butterfly (ket 0) (ket i)) *V proj_classical_set {..<d+1} *V y =
        (i<d+1. Rep_ell2 y i *C (i<d+1. butterfly (ket 0) (ket i)) *V ket i)"
    by (subst proj_classical_set_apply, simp) 
      (subst cblinfun.sum_right, intro sum.cong, auto simp add: cblinfun.scaleC_right)
  also have " = (i<d+1. Rep_ell2 y i *C ket 0)" 
    by (intro sum.cong) (use sum_butterfly_ket0 in auto)
  also have " = (i<d+1. Rep_ell2 y i) *C ket 0" by (rule scaleC_sum_left[symmetric]) 
  finally show ?thesis by auto
qed




text ‹The oracle query is a unitary.›

lemma inj_Uquery_map:
  "inj (λ(x, (y::'y)). (x, y + H x))"
  unfolding inj_def by auto

lemma classical_operator_exists_Uquery:
  "classical_operator_exists (Some o (λ(x,(y::'y)). (x, y + (H x))))"
  by (intro classical_operator_exists_inj, subst inj_map_total)
    (auto simp add: inj_Uquery_map)

lemma Uquery_ket:
  "Uquery F *V ket (a::'x) s ket (b::'y) = ket a s ket (b + F a)"
  unfolding Uquery_def tensor_ell2_ket 
  by (subst classical_operator_ket[OF classical_operator_exists_Uquery]) auto



lemma unitary_H: "unitary (Uquery (H::'x  'y))"
proof -
  have inj: "inj (λ(x, y). (x, y + H x))" by (auto simp add: inj_on_def)
  have surj: "surj (λ(x, y). (x, y + H x))" 
    by (metis (mono_tags, lifting) case_prod_Pair_iden diff_add_cancel split_conv split_def surj_def)
  show ?thesis unfolding Uquery_def 
    by (intro unitary_classical_operator) (auto simp add: inj surj bij_def)
qed

end

unbundle no cblinfun_syntax
unbundle no lattice_syntax

end