Theory CO_Operations

section CO_Operations› – Definition of the compressed oracle and related unitaries›

theory CO_Operations imports
  Complex_Bounded_Operators.Complex_L2 
  HOL.Map
  Registers.Quantum_Extra2
    
  Misc_Compressed_Oracle
  Function_At
begin

unbundle cblinfun_syntax

subsection function_oracle› - Querying a fixed function›

definition function_oracle :: ('x  'y::ab_group_add)  (('x × 'y) ell2 CL ('x × 'y) ell2) where
  function_oracle h = classical_operator (λ(x,y). Some (x, y + h x))

lemma function_oracle_apply: function_oracle h (ket (x, y)) = ket (x, y + h x)
  unfolding function_oracle_def
  apply (subst classical_operator_ket)
  by (auto intro!: classical_operator_exists_inj injI simp: inj_map_total[unfolded o_def] case_prod_unfold)

lemma function_oracle_adj_apply: function_oracle h* *V ket (x, y) = ket (x, y - h x)
proof -
  define f where f = (λ(x,y). (x, y + h x))
  define g where g = (λ(x,y). (x, y - h x))
  have gf: g  f = id and fg: f  g = id
    by (auto simp: f_def g_def)
  have [iff]: inj f
    by (metis fg gf injI isomorphism_expand)
  have inv f = g
    using fg gf inv_unique_comp by blast
  have inv_map_f: inv_map (Some o f) = (Some o g)
    by (metis inj f inv f = g fg fun.set_map inj_imp_surj_inv inv_map_total surj_id)
  have function_oracle h* = classical_operator (Some o f)*
    by (simp add: function_oracle_def f_def case_prod_unfold o_def)
  also have  = classical_operator (Some o g)
    using inv_map_f by (simp add: classical_operator_adjoint function_oracle_def)
  also have  *V ket (x,y) = ket (x, y - h x)
    apply (subst classical_operator_ket)
     apply (metis classical_operator_exists_inj inj_map_inv_map inv_map_f) 
    by (simp add: g_def)
  finally show ?thesis
    by -
qed

lemma unitary_function_oracle[iff]: unitary (function_oracle h)
proof -
  have bij (λx. (fst x, snd x + h (fst x)))
    apply (rule o_bij[where g=(λx. (fst x, snd x - h (fst x)))])
    by auto
  then show ?thesis
    by (auto intro!: unitary_classical_operator[unfolded o_def]
        simp add: function_oracle_def case_prod_unfold )
qed

lemma norm_function_oracle[simp]: norm (function_oracle h) = 1
  by (intro norm_isometry unitary_isometry unitary_function_oracle)

lemma function_oracle_adj[simp]: function_oracle h* = function_oracle (λx. - h x) for h :: 'x  'y::ab_group_add
  apply (rule equal_ket)
  by (auto simp: function_oracle_apply function_oracle_adj_apply)


subsection ‹Setup for compressed oracles›

consts trafo :: 'a ell2 CL 'a::{zero,finite} ell2
specification (trafo) 
  unitary_trafo[simp]: unitary trafo
  trafo_0[simp]: trafo *V ket 0 = uniform_superpos UNIV
proof -
  wlog CARD('a)  2
  proof -
    have CARD('a)  0
      by simp
    with negation have CARD('a) = 1
      by presburger
    then have [simp]: UNIV = {0::'a}
      by (metis UNIV_I card_1_singletonE singletonD)
    have uniform_superpos UNIV = ket (0::'a)
      by (simp add: uniform_superpos_def2)
    then show ?thesis
      by (auto intro!: exI[where x=id_cblinfun])
  qed

  let ?uniform = uniform_superpos (UNIV :: 'a set)
  define α where α = complex_of_real (1 / sqrt (of_nat CARD('a)))
  define p p2 p4 a c where p = cinner ?uniform (ket (0::'a)) and p2 = 1 - p * p
    and p4 = p2 * p2 and a = (1+p) / p2 and c = (-1-p) / p2
  define T :: 'a update where
    T = a *C butterfly (ket 0) ?uniform + a *C butterfly ?uniform (ket 0)
       + c *C selfbutter (ket 0) + c *C selfbutter ?uniform + id_cblinfun
  have : p = α
    apply (simp add: p_def cinner_ket_right α_def)
    apply transfer
    by simp
  have p20: p2  0
    unfolding α_def p2_def  using CARD('a)  2 apply auto
    by (smt (verit) complex_of_real_leq_1_iff numeral_nat_le_iff of_real_1 of_real_power power2_eq_square real_sqrt_pow2
        rel_simps(26) semiring_norm(69))
  have h1: a * p + c + 1 = 0
    using p20 apply (simp add: a_def c_def)
    by (metis add.assoc add_divide_distrib add_neg_numeral_special(8) diff_add_cancel divide_eq_minus_1_iff minus_diff_eq mult.commute mult_1 p2_def ring_class.ring_distribs(2) uminus_add_conv_diff)
  have h2: a + c * p = 1
    using p20 apply (simp add: c_def)
    by (metis a_def ab_group_add_class.ab_diff_conv_add_uminus add.commute add.inverse_inverse add_neg_numeral_special(8) add_right_cancel c_def divide_minus_left h1 minus_add_distrib mult_minus_left times_divide_eq_left)

  have [simp]: ?uniform C ?uniform = 1
    by (simp add: cdot_square_norm norm_uniform_superpos)
  have [simp]: ket (0::'a) C ?uniform = cnj p
    by (simp add: p_def)

  have T *V ket 0 = (a * p + c + 1) *C ket 0 + (a + c * p) *C ?uniform
    unfolding T_def
    by (auto simp: cblinfun.add_left scaleC_add_left simp flip: p_def)
  also have  = ?uniform
    by (simp add: h1 h2)
  finally have 1: T *V ket 0 = ?uniform
    by -

  have scaleC_add_left': v + scaleC x w + scaleC y w = v + scaleC (x+y) w for x y and v w :: 'a update
    by (simp add: scaleC_add_left)
  have sort: 
    v + x *C butterfly ?uniform (ket 0) + y *C selfbutter (ket 0) = v + y *C selfbutter (ket 0) + x *C butterfly ?uniform (ket 0) 
    v + x *C selfbutter ?uniform + y *C butterfly (ket 0) ?uniform = v + y *C butterfly (ket 0) ?uniform + x *C selfbutter ?uniform
    v + x *C selfbutter ?uniform + y *C butterfly ?uniform (ket 0) = v + y *C butterfly ?uniform (ket 0) + x *C selfbutter ?uniform
    v + x *C selfbutter ?uniform + y *C selfbutter (ket 0) = v + y *C selfbutter (ket 0) + x *C selfbutter ?uniform
    v + x *C butterfly (ket 0) ?uniform + y *C selfbutter (ket 0) = v + y *C selfbutter (ket 0) + x *C butterfly (ket 0) ?uniform
    v + x *C butterfly (ket 0) ?uniform + y *C butterfly ?uniform (ket 0) = v + y *C butterfly ?uniform (ket 0) + x *C butterfly (ket 0) ?uniform
    for v :: 'a update and x y
    by auto

  have aux: x = 0  x * p4 = 0 for x
    by (simp add: p20 p4_def)

  have [simp]: cnj p = p
    by (simp add: α_def )
  have [simp]: cnj c = c
    by (simp add: c_def p2_def)
  have [simp]: cnj a = a
    by (simp add: a_def p2_def)
  have [simp]: p4  0
    by (simp add: p20 p4_def)
  have [simp]: x * p4 / p2 = x * p2 for x
    by (simp add: p4_def)

  have h3: 2 * c + (2 * (a * (c * p)) + (a * a + c * c)) = 0
    apply (subst aux)
    apply (simp add: a_def c_def distrib_right distrib_left p20 add_divide_distrib 
        right_diff_distrib left_diff_distrib diff_divide_distrib
        flip: p4_def add.assoc
        del: mult_eq_0_iff vector_space_over_itself.scale_eq_0_iff)
    by (simp add: p4_def p2_def right_diff_distrib left_diff_distrib flip: add.assoc)
  have h4: 2 * a + (2 * (a * c) + (a * a * p + c * c * p)) = 0
    apply (subst aux)
    apply (simp add: a_def c_def distrib_right distrib_left p20 add_divide_distrib 
        right_diff_distrib left_diff_distrib diff_divide_distrib
        flip: p4_def add.assoc
        del: mult_eq_0_iff vector_space_over_itself.scale_eq_0_iff)
    by (simp add: p4_def p2_def right_diff_distrib left_diff_distrib flip: add.assoc)

  have 2: T oCL T* = id_cblinfun
    unfolding T_def
    apply (simp add: cblinfun_compose_add_left cblinfun_compose_add_right adj_plus
        scaleC_add_right flip: p_def add.assoc mult.assoc)
    apply (simp add: sort scaleC_add_left' flip: scaleC_add_left)
    by (simp add: h3 h4)

  have 3: T* = T
    unfolding T_def
    by (auto simp: adj_plus)

  from 2 3 have 4: unitary T
    by (simp add: unitary_def)

  from 1 4 show ?thesis
    by auto
qed

text ‹Set of total functions›
definition total_functions = {f::'x'y. None  range f}

lemma total_functions_def2: total_functions = (comp Some) ` UNIV
proof -
  have x  range ((∘) Some) if None  range x for x :: 'x  'y option
    by (metis function_factors_right option.collapse range_eqI that)
  then show ?thesis
    unfolding total_functions_def by auto
qed

lemma total_functions_def3: total_functions = {f. dom f = UNIV}
  by (force simp add: total_functions_def)

lemma card_total_functions: card (total_functions :: ('x  'y option) set) = CARD('y) ^ CARD('x::finite)
proof -
  have card (total_functions :: ('x  'y option) set) = CARD ('x  'y)
    unfolding total_functions_def2
    by (simp add: card_image fun.inj_map) 
  also have  = CARD('y) ^ CARD('x)
    by (simp add: card_fun)
  finally show ?thesis
    by -
qed

abbreviation superpos_total :: ('x::finite  'y::finite option) ell2 where superpos_total  uniform_superpos total_functions


text ‹Sets up the locale for defining the compressed oracle.
We use a locale because the compressed oracle can depend on some arbitrary unitary trafo›.
The choice of trafo› usually doesn't matter; in this case the default transformation consttrafo above
can be used.›

locale compressed_oracle =
  fixes dummy_constant :: ('x::finite × 'y::{finite,ab_group_add}) itself
  fixes trafo :: 'y::{finite,ab_group_add} ell2 CL 'y ell2
  assumes unitary_trafo[simp]: unitary trafo
  assumes trafo_0: trafo *V ket 0 = uniform_superpos UNIV
  assumes y_cancel[simp]: (y::'y) + y = 0
begin

(* This is a hack for defining N. If we define term‹N› directly as term‹‹N = CARD('y)››, Isabelle adds a type parameter to the constant N *)
definition dummy2 :: 'y update  ('y set  nat)  ('y set  nat)
  where dummy2 x y = y
definition N_def0: N = dummy2 trafo card UNIV
text termN is the cardinality of the oracle outputs. (Intuitively, termN = 2^n for an termn-bit output.›
lemma N_def: N = CARD('y)
  by (simp add: dummy2_def N_def0)

lemma Nneq0[iff]: N  0
  by (simp add: N_def)

definition α = complex_of_real (1 / sqrt (of_nat N))
  ― ‹We use this term very often, so this abbreviation comes in handy.›

lemma (in compressed_oracle) uminus_y[simp]: - y = y for y :: 'y
  by (metis add.right_inverse group_cancel.add1 group_cancel.rule0 y_cancel)

subsection switch0› - Operator exchanging termket (Some 0) and termket None

text termswitch0 maps termket None to termket (Some 0) and vice versa. 
  It leaves all other termket (Some y) unchanged.›
definition switch0 :: 'y option update where
  switch0 = classical_operator (Some o Fun.swap (Some 0) None id)

lemma switch0_None[simp]: switch0 *V ket None = ket (Some 0)
  unfolding switch0_def classical_operator_ket_finite
  by auto

lemma switch0_0[simp]: switch0 *V ket (Some 0) = ket None
  unfolding switch0_def classical_operator_ket_finite
  by auto

lemma switch0_other: switch0 *V ket (Some x) = ket (Some x) if x  0
  unfolding switch0_def classical_operator_ket_finite
  using that by auto

lemma unitary_switch0[simp]: unitary switch0
  unfolding switch0_def
  apply (rule unitary_classical_operator)
  by auto

lemma switch0_adj[simp]: switch0* = switch0
  unfolding switch0_def
  apply (subst classical_operator_adjoint)
   apply simp
  by (simp add: inv_map_total)


subsection compress1› - Operator to compress a single RO-output›

text ‹This unitary maps termket None onto the uniform superposition of all termket (Some y) and vice versa,
and leaves everything orthogonal to these unchanged.

This is the operation that deals with compressing a single oracle output.›

definition compress1 :: 'y option ell2 CL 'y option ell2 where
  compress1 = lift_op trafo oCL switch0 oCL (lift_op trafo)*

lemma uniform_superpos_y_sum: uniform_superpos UNIV = (dUNIV. α *C ket (d::'y))
  apply (subst ell2_sum_ket)
  by (simp add: uniform_superpos.rep_eq α_def N_def)

lemma compress1_None[simp]: compress1 *V ket None = (dUNIV. α *C ket (Some d))
  by (auto simp: cblinfun.sum_right compress1_def lift_op_adj trafo_0 uniform_superpos_y_sum  cblinfun.scaleC_right)

lemma compress1_Some[simp]: compress1 *V ket (Some d) = 
      ket (Some d) - (dUNIV. α2 *C ket (Some d)) + α *C ket None
proof -
  define c where c e = cinner (ket e) (trafo* *V ket d) for e
  have c0: c 0 = α
    apply (simp add: c_def cinner_adj_right trafo_0)
    by (simp add: α_def N_def cinner_ket_right uniform_superpos.rep_eq)

  have compress1 *V ket (Some d) = lift_op trafo *V switch0 *V lift_ell2 *V trafo* *V ket d
    by (auto simp: compress1_def lift_op_adj)
  also have  = lift_op trafo *V switch0 *V lift_ell2 *V (eUNIV. c e *C ket e)
    by (simp add: c_def cinner_ket_left flip: ell2_sum_ket)
  also have  = lift_op trafo *V switch0 *V (eUNIV. c e *C ket (Some e))
    by (auto simp: cblinfun.sum_right cblinfun.scaleC_right)
  also have  = lift_op trafo *V switch0 *V ((e-{0}. c e *C ket (Some e)) + c 0 *C ket (Some 0))
    apply (subst asm_rl[of UNIV = insert 0 (-{0})])
    by (auto simp add: add.commute)
  also have  = lift_op trafo *V ((e-{0}. c e *C (switch0 *V ket (Some e))) + c 0 *C switch0 *V ket (Some 0))
    by (simp add: cblinfun.add_right cblinfun.scaleC_right cblinfun.sum_right)
  also have  = lift_op trafo *V ((e-{0}. c e *C ket (Some e)) + c 0 *C ket None)
    by (simp add: switch0_other)
  also have  = lift_op trafo *V ((eUNIV. c e *C ket (Some e)) - c 0 *C ket (Some 0) + c 0 *C ket None)
    by (simp add: Compl_eq_Diff_UNIV sum_diff)
  also have  = (eUNIV. c e *C lift_ell2 *V trafo *V ket e) - c 0 *C lift_ell2 *V trafo *V ket 0 + c 0 *C ket None
    by (simp add: cblinfun.add_right cblinfun.diff_right cblinfun.scaleC_right cblinfun.sum_right)
  also have  = lift_ell2 *V trafo *V (eUNIV. c e *C ket e) - c 0 *C lift_ell2 *V uniform_superpos UNIV + c 0 *C ket None
    by (simp add: trafo_0 cblinfun.scaleC_right cblinfun.sum_right)
  also have  = lift_ell2 *V trafo *V trafo* *V ket d - c 0 *C lift_ell2 *V uniform_superpos UNIV + c 0 *C ket None
    by (simp add: c_def cinner_ket_left flip: ell2_sum_ket)
  also have  = lift_ell2 *V (trafo oCL trafo*) *V ket d - c 0 *C lift_ell2 *V uniform_superpos UNIV + c 0 *C ket None
    by (metis cblinfun_apply_cblinfun_compose)
  also have  = lift_ell2 *V ket d - c 0 *C lift_ell2 *V uniform_superpos UNIV + c 0 *C ket None
    by auto 
  also have  = ket (Some d) - c 0 *C (dUNIV. α *C ket (Some d)) + c 0 *C ket None
    by (auto simp: uniform_superpos_y_sum mult.commute scaleC_sum_right cblinfun.scaleC_right cblinfun.sum_right)
  also have  = ket (Some d) - (dUNIV. α2 *C ket (Some d)) + α *C ket None
    by (simp add: c0 power2_eq_square scaleC_sum_right)

  finally show ?thesis
    by -
qed

lemma unitary_compress1[simp]: unitary compress1
  by (simp add: compress1_def)

lemma compress1_adj[simp]: compress1* = compress1
  by (simp add: compress1_def cblinfun_compose_assoc)

lemma compress1_square: compress1 oCL compress1 = id_cblinfun
  by (metis compress1_adj unitary_compress1 unitary_def)

subsection compress› - Operator for compressing the RO›

text ‹This is the unitary that maps between the compressed representation of the random oracle
(in which the initial state is termket (λ_. None)) and the uncompressed one
(in which the initial state is the uniform superposition of all total functions).

It works by simply applying constcompress1 to each output separately.›

definition compress :: ('x  'y) update
  where compress = apply_every UNIV (λ_. compress1)

lemma unitary_compress[simp]: unitary compress
  by (simp add: compress_def apply_every_unitary)

lemma compress_selfinverse: compress oCL compress = id_cblinfun
  by (simp add: compress_def apply_every_mult compress1_square)

lemma compress_adj: compress* = compress
  by (simp add: compress_def apply_every_adj)

lemma compress_empty: compress *V ket Map.empty = superpos_total
proof -
  have *: apply_every M (λ_. compress1) *V ket Map.empty = 
      (f|dom f = M. ket f /R sqrt (CARD('y) ^ card M)) for M :: 'x set
  proof (use finite[of M] in induction)
    case empty
    then show ?case
      by simp
  next
    case (insert x F)
    have apply_every (insert x F) (λ_. compress1) *V ket Map.empty
      = function_at x compress1 *V apply_every F (λ_. compress1) *V ket Map.empty
      using insert.hyps by (simp add: apply_every_insert)
    also have  = function_at x compress1 *V (f | dom f = F. ket f /R sqrt (real (CARD('y) ^ card F)))
      by (simp add: insert.IH)
    also have  = (f | dom f = F. (function_at x compress1 *V ket f) /R sqrt (real (CARD('y) ^ card F)))
      by (simp add: cblinfun.real.scaleR_right cblinfun.sum_right)
    also have  = (f | dom f = F. (yUNIV. Rep_ell2 (compress1 *V ket (f x)) y *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      by (simp add: function_at_ket)
    also have  = (f | dom f = F. (yUNIV. Rep_ell2 (compress1 *V ket None) y *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      by (smt (verit) Finite_Cartesian_Product.sum_cong_aux domIff local.insert(2) mem_Collect_eq)
    also have  = (f | dom f = F. (yUNIV. Rep_ell2 (dUNIV. α *C ket (Some d)) y *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      by simp
    also have  = (f | dom f = F. (yUNIV. (dUNIV. α *C Rep_ell2 (ket (Some d)) y) *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      apply (subst complex_vector.linear_sum[where f=λx. Rep_ell2 x _])
       apply (simp add: clinearI plus_ell2.rep_eq scaleC_ell2.rep_eq)
      apply (subst clinear.scaleC[where f=λx. Rep_ell2 x _])
      by (simp_all add: clinearI plus_ell2.rep_eq scaleC_ell2.rep_eq)
    also have  = (f | dom f = F. (yUNIV. (if y = None then 0 else α) *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      apply (rule sum.cong, simp)
      subgoal for f
        apply (rule arg_cong[where f=λx. x /R _])
        apply (rule sum.cong, simp)
        subgoal for y
          apply (subst sum_single[where i=the y])
          by (auto simp: ket.rep_eq)
        by -
      by -
    also have  = (f | dom f = F. (yrange Some. α *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F)))
      apply (rule sum.cong, simp)
      apply (subst sum.mono_neutral_cong_right[where S=range Some and h=λy. α *C ket (_(x := y))])
      by auto
    also have  = (f | dom f = F. yrange Some. α *C ket (f(x := y)) /R sqrt (real (CARD('y) ^ card F)))
      by (simp add: scaleR_right.sum)
    also have  = ((f, y){f. dom f = F} × range Some. 
                         α *C ket (f(x := y)) /R sqrt (real (CARD('y) ^ card F)))
      by (simp add: sum.cartesian_product)
    also have  = ((f, y)(λf. (f(x:=None), f x)) ` {f. dom f = insert x F}. 
                         α *C ket (f(x := y))) /R sqrt (real (CARD('y) ^ card F))
    proof -
      have 1: {f. dom f = F} × range Some = (λf. (f(x := None), f x)) ` {f. dom f = insert x F}
      proof (rule Set.set_eqI, rule iffI)
        fix z :: ('x  'a option) × 'a option
        assume asm: z  {f. dom f = F} × range Some
        define f where f = (fst z)(x := snd z)
        have f  {f. dom f = insert x F}
          using asm by (auto simp: f_def)
        moreover have (λf. (f(x := None), f x)) f = z
          using asm insert.hyps by (auto simp add: f_def)
        ultimately show z  (λf. (f(x := None), f x)) ` {f. dom f = insert x F}
          by auto
      next
        fix z :: ('x  'a option) × 'a option
        assume z  (λf. (f(x := None), f x)) ` {f. dom f = insert x F}
        then obtain f where dom f = insert x F and z = (λf. (f(x := None), f x)) f
          by auto
        then show z  {f. dom f = F} × range Some
          using insert.hyps by auto
      qed
      show ?thesis
        apply (subst scaleR_right.sum)
        apply (rule sum.cong)
        using 1 by auto
    qed
    also have  = (f| dom f = insert x F. α *C ket f) /R sqrt (real (CARD('y) ^ card F))
      apply (subst sum.reindex)
       apply auto
      by (smt (verit) fun_upd_idem_iff fun_upd_upd inj_on_def prod.simps(1))
    also have  = (f| dom f = insert x F. ket f /R sqrt (real (CARD('y) ^ card (insert x F))))
      by (simp add: card_insert_disjoint insert.hyps real_sqrt_mult α_def N_def scaleR_scaleC
        divide_inverse_commute flip: scaleC_sum_right)
    finally show ?case
      by -
  qed

  have (f|dom f = UNIV. ket f /R sqrt (CARD('y) ^ CARD('x))) = (superpos_total :: ('x  'y option) ell2)
    unfolding uniform_superpos_def2
    apply (rule sum.cong)
     apply (simp add: total_functions_def3)
    by (simp add: card_total_functions scaleR_scaleC)

  with *[of UNIV]
  show ?thesis
    by (simp flip: compress_def)
qed

subsection standard_query1› - Operator for uncompressed query of a single RO-output›

text ‹We define the operation termstandard_query1 of querying the oracle, but first in the special case
  of an oracle that has no input register. That is, the oracle state consists of just one output value (or termNone)
  and this value is simply added to the query output register.

  Roughly speaking, it thus is the unitary $\lvert y,h⟩ ↦ \lvert y⊕h, h⟩$.
  In comparison, a ``normal'' oracle query would be defined by $\lvert x,y,h⟩ ↦ \lvert x, y⊕h(x), h⟩$.
    
  That is: If one starts with a three-partite state termψ s ket 0 s superpos_total and keeps performing
  operations $U_i$ on the parts 1, 2 of the state, interleaved with termstandard_query1 invocations on parts 2, 3,
  this is a simulation of starting with state termψ s 0 and performing $U_i$ interleaved with
  invocations of the unitary $\ket{y} ↦ \ket{y ⊕ h}$  on part 2 where $h$ is chosen uniformly at random
  in the beginning.

  When termh=None, there are various natural choices how to define the behavior of termstandard_query1.
  This is because intuitively, this should not happen, because this operation intended to be 
  applied to uncompressed oracles which are superpositions of total functions.
  Yet, due to errors introduced by projecting onto invariants, one can get situations where this is not perfectly the case,
  so the behavior on termNone matters. Here, we choose to let termstandard_query1 be the identity
  in that case.
›

definition standard_query1 :: ('y × 'y option) update where
  standard_query1 = classical_operator (Some o (λ(y,z). case z of None  (y,None) | Some z'  (y + z', z)))

text ‹The operation termstandard_query1' is defined like termstandard_query1
  (and the motivation and properties mentioned there also hold here),
  except that in the case termh=None (see discussion for termstandard_query1), instead of being the
  identify, termstandard_query1' returns the 0-vector (not termket 0!).
  In particular, this operation is not a unitary which can make some things more awkward.
  But on the plus side, we can achieve better bounds in some situations when using termstandard_query1'.›

definition standard_query1' :: ('y × 'y option) update where
  standard_query1' = classical_operator (λ(y,z). case z of None  None | Some z'  Some (y + z', z))

lemma standard_query1_Some[simp]: standard_query1 *V ket (y, Some z) = ket (y + z, Some z)
  by (simp add: standard_query1_def classical_operator_ket_finite)

lemma standard_query1_None[simp]: standard_query1 *V ket (y, None) = ket (y, None)
  by (simp add: standard_query1_def classical_operator_ket_finite)

lemma standard_query1'_Some[simp]: standard_query1' *V ket (y, Some z) = ket (y + z, Some z)
  by (simp add: standard_query1'_def classical_operator_ket_finite)

lemma standard_query1'_None[simp]: standard_query1' *V ket (y, None) = 0
  by (simp add: standard_query1'_def classical_operator_ket_finite)

lemma unitary_standard_query1[simp]: unitary standard_query1
  unfolding standard_query1_def
  apply (rule unitary_classical_operator)
  apply (rule o_bij[where g=λ(y,z). case z of None  (y,None) | Some z'  (y - z', z)])
  by (auto intro!: ext simp: case_prod_beta cong del: option.case_cong split!: option.split option.split_asm)

lemma norm_standard_query1'[simp]: norm standard_query1' = 1
proof (rule order.antisym)
  show norm standard_query1'  1
    unfolding standard_query1'_def
    apply (rule classical_operator_norm_inj)
    by (auto simp: inj_map_def split!: option.split_asm)
  show norm standard_query1'  1
    apply (rule cblinfun_norm_geqI[where x=ket (undefined, Some undefined)])
    by simp
qed


lemma standard_query1_selfinverse[simp]: standard_query1 oCL standard_query1 = id_cblinfun
proof -
  have *: (Some  (λ(y::'y, z). case z of None  (y, None) | Some z'  (y + z', z)) m
      (Some  (λ(y, z). case z of None  (y, None) | Some z'  (y + z', z)))) = Some
    by (auto intro!: ext, rename_tac a b, case_tac b, auto)
  show ?thesis
    by (auto simp: standard_query1_def classical_operator_mult *)
qed

subsection standard_query› - Operator for uncompressed query of the RO›

text ‹We can now define the operation of querying the (non-compressed) oracle,
  i.e., the operation $\lvert x,y,h⟩ ↦ \lvert x, y⊕h(x), h⟩$.
  Most of the work has already been done when defining conststandard_query1.
  We just need to apply conststandard_query1 onto the termY-register and the termx-output 
  of the termH-register, where termx is the content of the termX-register (in the computational basis).

  The various lemmas below (e.g., standard_query_ket›) show that this definition actually achieves this.

  That is: If one starts with a four-partite state termψ s ket 0 s ket 0 s superpos_total and keeps performing
  operations $U_i$ on the parts 1--3 of the state, interleaved with termstandard_query invocations on parts 2--4,
  this is a simulation of starting with state termψ s 0 and performing $U_i$ interleaved with
  invocations of the unitary $\ket{x, y} ↦ \ket{x, y ⊕ h(x)}$  on parts 2, 3 where $h$ is a function 
  chosen uniformly at random in the beginning.›

definition standard_query :: ('x × 'y × ('x  'y)) ell2 CL ('x × 'y × ('x  'y)) ell2 where
  standard_query = controlled_op (λx. (Fst; Snd o function_at x) standard_query1)

text ‹Analogous to conststandard_query but using the variant conststandard_query1'.›

definition standard_query' :: ('x × 'y × ('x  'y)) ell2 CL ('x × 'y × ('x  'y)) ell2 where
  standard_query' = controlled_op (λx. (Fst; Snd o function_at x) standard_query1')

lemma standard_query_ket: standard_query *V (ket x s ψ) = ket x s ((Fst; Snd o function_at x) standard_query1 *V ψ)
  by (auto simp: standard_query_def)

lemma standard_query_ket_full_Some: 
  assumes H x = Some z
  shows standard_query *V (ket (x,y,H)) = ket (x, y + z, H)
proof -
  obtain H' where pf_xH: puncture_function x H = (H x, H')
    by (metis fst_puncture_function prod.collapse)
  have standard_query *V (ket (x,y,H)) = ket x s sandwich (id_cblinfun o function_at_U x) ((id r Fst) standard_query1) *V ket y s ket H
    by (simp add: standard_query_ket function_at_def pair_o_tensor_right pair_Fst_Snd
        pair_o_tensor_right unitary_sandwich_register pair_o_tensor_right
        register_tensor_distrib_right id_tensor_sandwich
        flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (id r Fst) standard_query1 *V (ket y s ket (H x) s ket H')
            (is _ = _ s _ *V ?R standard_query1 *V _)
    by (simp add: sandwich_apply' tensor_op_adjoint tensor_op_ell2 pf_xH assms flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (ket (y + z) s ket (H x) s ket H')
    apply (subst asm_rl[of (id r Fst) = assoc o Fst])
    subgoal by (auto intro!: tensor_extensionality simp add: register_tensor_is_register Fst_def)
    apply (simp add: Fst_def assoc_ell2_sandwich sandwich_apply' assoc_ell2'_tensor tensor_op_ell2 assms)
    apply (simp add: tensor_ell2_ket del: function_at_U_ket)
    by (simp add: assoc_ell2_tensor tensor_op_ell2 flip: tensor_ell2_ket)
  also have  = ket x s ket (y + z) s ket H
    apply (simp add: tensor_op_ell2 flip: tensor_ell2_ket)
    by (simp flip: pf_xH add: tensor_ell2_ket)
  finally show ?thesis
    by (simp add: tensor_ell2_ket)
qed

lemma standard_query_ket_full_None: 
  assumes H x = None
  shows standard_query *V (ket (x,y,H)) = ket (x, y, H)
proof -
  obtain H' where pf_xH: puncture_function x H = (H x, H')
    by (metis fst_puncture_function prod.collapse)
  have standard_query *V (ket (x,y,H)) = ket x s sandwich (id_cblinfun o function_at_U x) ((id r Fst) standard_query1) *V ket y s ket H
    by (simp add: standard_query_ket function_at_def pair_o_tensor_right pair_Fst_Snd
        pair_o_tensor_right unitary_sandwich_register pair_o_tensor_right
        register_tensor_distrib_right id_tensor_sandwich
        flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (id r Fst) standard_query1 *V ket y s ket (H x) s ket H'
    by (simp add: sandwich_apply' tensor_op_adjoint tensor_op_ell2 pf_xH assms flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V ket y s ket (H x) s ket H'
    apply (subst asm_rl[of (id r Fst) = assoc o Fst])
    subgoal by (auto intro!: tensor_extensionality simp add: register_tensor_is_register Fst_def)
    apply (simp add: Fst_def assoc_ell2_sandwich sandwich_apply' assoc_ell2'_tensor tensor_op_ell2 assms)
    apply (simp add: tensor_ell2_ket del: function_at_U_ket)
    by (simp add: assoc_ell2_tensor tensor_op_ell2 flip: tensor_ell2_ket)
  also have  = ket x s ket y s ket H
    apply (simp add: tensor_op_ell2 flip: tensor_ell2_ket)
    by (simp flip: pf_xH add: tensor_ell2_ket)
  finally show ?thesis
    by (simp add: tensor_ell2_ket)
qed

lemma standard_query'_ket: standard_query' *V (ket x s ψ) = ket x s ((Fst; Snd o function_at x) standard_query1' *V ψ)
  by (auto simp: standard_query'_def)

lemma standard_query'_ket_full_Some: 
  assumes H x = Some z
  shows standard_query' *V (ket (x,y,H)) = ket (x, y + z, H)
proof -
  obtain H' where pf_xH: puncture_function x H = (H x, H')
    by (metis fst_puncture_function prod.collapse)
  have standard_query' *V (ket (x,y,H)) = ket x s sandwich (id_cblinfun o function_at_U x) ((id r Fst) standard_query1') *V ket y s ket H
    by (simp add: standard_query'_ket function_at_def pair_o_tensor_right pair_Fst_Snd
        pair_o_tensor_right unitary_sandwich_register pair_o_tensor_right
        register_tensor_distrib_right id_tensor_sandwich
        flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (id r Fst) standard_query1' *V (ket y s ket (H x) s ket H')
            (is _ = _ s _ *V ?R standard_query1' *V _)
    by (simp add: sandwich_apply' tensor_op_adjoint tensor_op_ell2 pf_xH assms flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (ket (y + z) s ket (H x) s ket H')
    apply (subst asm_rl[of (id r Fst) = assoc o Fst])
    subgoal by (auto intro!: tensor_extensionality simp add: register_tensor_is_register Fst_def)
    apply (simp add: Fst_def assoc_ell2_sandwich sandwich_apply' assoc_ell2'_tensor tensor_op_ell2 assms)
    apply (simp add: tensor_ell2_ket del: function_at_U_ket)
    by (simp add: assoc_ell2_tensor tensor_op_ell2 flip: tensor_ell2_ket)
  also have  = ket x s ket (y + z) s ket H
    apply (simp add: tensor_op_ell2 flip: tensor_ell2_ket)
    by (simp flip: pf_xH add: tensor_ell2_ket)
  finally show ?thesis
    by (simp add: tensor_ell2_ket)
qed

lemma standard_query'_ket_full_None: 
  assumes H x = None
  shows standard_query' *V (ket (x,y,H)) = 0
proof -
  obtain H' where pf_xH: puncture_function x H = (H x, H')
    by (metis fst_puncture_function prod.collapse)
  have standard_query' *V (ket (x,y,H)) = ket x s sandwich (id_cblinfun o function_at_U x) ((id r Fst) standard_query1') *V ket y s ket H
    by (simp add: standard_query'_ket function_at_def pair_o_tensor_right pair_Fst_Snd
        pair_o_tensor_right unitary_sandwich_register pair_o_tensor_right
        register_tensor_distrib_right id_tensor_sandwich
        flip: tensor_ell2_ket)
  also have  = ket x s (id_cblinfun o function_at_U x) *V (id r Fst) standard_query1' *V ket y s ket (H x) s ket H'
    by (simp add: sandwich_apply' tensor_op_adjoint tensor_op_ell2 pf_xH assms flip: tensor_ell2_ket)
  also have  = 0
    apply (subst asm_rl[of (id r Fst) = assoc o Fst])
    subgoal by (auto intro!: tensor_extensionality simp add: register_tensor_is_register Fst_def)
    apply (simp add: Fst_def assoc_ell2_sandwich sandwich_apply' assoc_ell2'_tensor tensor_op_ell2 assms)
    by (simp add: tensor_ell2_ket del: function_at_U_ket)
  finally show ?thesis
    by -
qed


lemma standard_query_selfinverse[simp]: standard_query oCL standard_query = id_cblinfun
  by (simp add: standard_query_def controlled_op_compose register_mult)

lemma unitary_standard_query[simp]: unitary standard_query
  by (auto simp: standard_query_def intro!: controlled_op_unitary register_unitary[of (_;_)])

lemma contracting_standard'_query[simp]: norm standard_query' = 1
proof (rule antisym)
  show norm standard_query'  1
    unfolding standard_query'_def
    apply (rule controlled_op_norm_leq)
    by (smt (verit) norm_standard_query1' norm_zero register_norm register_pair_def register_pair_is_register)
  show norm standard_query'  1
    apply (rule cblinfun_norm_geqI[where x=ket (undefined, undefined, λ_. Some undefined)])
    apply (subst standard_query'_ket_full_Some)
    by auto
qed

subsection query1› - Query the compressed oracle at a single output›

text ‹Before we formulate the compressed oracle itself, we define a scaled down version where
  the function in the oracle has only a single output (and there's no input register).
  Cf.~conststandard_query1. This is done by decompressing the oracle register,
  applying conststandard_query1, and then recompressing the oracle register.

  That is: If one starts with a three-partite state termψ s ket 0 s ket None and keeps performing
  operations $U_i$ on the parts 1, 2 of the state, interleaved with termquery1 invocations on parts 2, 3,
  this is a simulation of starting with state termψ s 0 and performing $U_i$ interleaved with
  invocations of the unitary $\ket{y} ↦ \ket{y ⊕ h}$  on part 2 where $h$ is chosen uniformly at random
  in the beginning.›

definition query1 where query1 = Snd compress1 oCL standard_query1 oCL Snd compress1

text ‹The operation termquery1' is defined like termquery1
  (and the motivation and properties mentioned there also hold here),
  except that it is based on termstandard_query1' instead of termstandard_query1.
  See the comment at termstandard_query1' for a discussion of the difference.›


definition query1' where query1' = Snd compress1 oCL standard_query1' oCL Snd compress1



lemma unitary_query1[simp]: unitary query1
  by (auto simp: query1_def register_unitary intro!: unitary_cblinfun_compose)

lemma norm_query1'[simp]: norm query1' = 1
  unfolding query1'_def
  apply (subst norm_isometry_compose')
   apply (simp add: Snd_def comp_tensor_op compress1_square isometry_def tensor_op_adjoint)
  apply (subst norm_isometry_compose)
   apply (simp add: Snd_def comp_tensor_op compress1_square isometry_def tensor_op_adjoint)
  by simp

text ‹The following lemmas give explicit formulas for the result of applying constquery1 and constquery1'
  to computational basis states (termket ). While the definitions of constquery1 and constquery1'
  are useful for showing structural properties of these operations (e.g., the fact that
  they actually simulate a random oracle), for doing computations in concrete cases (e.g.,
  the preservation of an invariant), the explicit formulas can be more useful.›

lemma query1_None: query1 *V ket (y,None) = 
                        α *C (dUNIV. ket (y + d, Some d))
                        - α^3 *C (y'UNIV. dUNIV. ket (y', Some d))
                        + α2 *C (dUNIV. ket (d, None)) (is _ = ?rhs)
proof -
  have [simp]: α * α = α2 α * α2 = α^3
    by (simp_all add: power2_eq_square numeral_2_eq_2 numeral_3_eq_3)

  have aux: a = a'  b = b'  c = c'  a - b + c = a' - b' + c' for a b c a' b' c' :: 'z::group_add
    by simp

  have Snd compress1 *V ket (y, None) = (dUNIV. α *C ket (y, Some d))
    by (simp add: query1_def tensor_ell2_scaleC2 tensor_ell2_sum_right flip: tensor_ell2_ket)
  also have standard_query1 *V  = (dUNIV. α *C ket (y + d, Some d))
    by (simp add: cblinfun.scaleC_right cblinfun.sum_right)
  also have Snd compress1 *V  = 
          α *C (dUNIV. (ket (y + d) s ket (Some d)))
        - α^3 *C (zUNIV. dUNIV. (ket (y + z) s ket (Some d)))
        + α2 *C (zUNIV. (ket (y + z) s ket None))
    by (simp add: tensor_ell2_diff2 tensor_ell2_add2 scaleC_add_right sum.distrib tensor_ell2_sum_right
        tensor_ell2_scaleC2 sum_subtractf scaleC_diff_right scaleC_sum_right cblinfun.scaleC_right
        cblinfun.sum_right
        flip: tensor_ell2_ket)
  also have  = ?rhs
    apply (rule aux)
    subgoal 
      by (simp add: tensor_ell2_ket)
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV]) 
      by (simp_all add: tensor_ell2_ket)
    subgoal
      apply simp
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by (simp_all add: tensor_ell2_ket)
    by -
  finally show ?thesis
    unfolding query1_def by simp
qed

lemma query1_Some: query1 *V ket (y, Some d) = 
        ket (y + d, Some d)
        + α *C ket (y + d, None)
        - α^3 *C (y'UNIV. ket (y', None))
        - α2 *C (d'UNIV. ket (y + d', Some d'))
        - α2 *C (d'UNIV. ket (y + d, Some d'))
        + α2 *C (d'UNIV. ket (y, Some d'))
        + α^4 *C (y'UNIV. d'UNIV. ket (y', Some d'))
    (is _ = ?rhs)
proof -
  have [simp]: α * α = α2 α2 * α = α^3
    by (simp_all add: power2_eq_square numeral_2_eq_2 numeral_3_eq_3)

  have aux: a=a'  b=b'  c=c'  d=d'  e=e'  f=f'  g=g'
          a' - e' + b' + g' - d' - c' + f' = a + b - c - d - e + f + g 
    for a b c d e f g a' b' c' d' e' f' g' :: 'z::ab_group_add
    by simp

  have Snd compress1 *V ket (y, Some d) =
              ket (y, Some d) - α2 *C (d'UNIV. ket (y, Some d')) + α *C ket (y, None)
    by (simp add: query1_def tensor_ell2_scaleC2 tensor_ell2_diff2 tensor_ell2_add2 tensor_ell2_sum_right
             flip: tensor_ell2_ket scaleC_sum_right)
  also have standard_query1 *V  = ket (y + d, Some d) - α2 *C (d'UNIV. ket (y + d', Some d')) + α *C ket (y, None)
    by (simp add: cblinfun.add_right cblinfun.diff_right cblinfun.scaleC_right cblinfun.sum_right)
  also have Snd compress1 *V  = 
                ket (y + d, Some d)
                - α2 *C (d'UNIV. ket (y + d, Some d'))
                + α *C ket (y + d, None)
                + α^4 *C (zUNIV. d'UNIV. ket (y + z, Some d'))
                - α2 *C (d'UNIV. ket (y + d', Some d'))
                - α^3 *C (zUNIV. ket (y + z, None))
                + α2 *C (d'UNIV. ket (y, Some d'))
    by (simp add: tensor_ell2_diff2 tensor_ell2_add2 scaleC_add_right sum.distrib
        tensor_ell2_scaleC2 sum_subtractf scaleC_diff_right scaleC_sum_right tensor_ell2_sum_right
        cblinfun.add_right cblinfun.diff_right diff_diff_eq2 cblinfun.scaleC_right cblinfun.sum_right
        flip: tensor_ell2_ket diff_diff_eq scaleC_sum_right)
  also have  = ?rhs
    apply (rule aux)
    subgoal by rule
    subgoal by rule
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by simp_all
    subgoal by rule
    subgoal by rule
    subgoal by rule
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by simp_all
    by -
  finally show ?thesis
    unfolding query1_def by simp
qed

lemma query1: 
  shows query1 *V (ket yd) = (case yd of
    (y, None)  
        α *C (dUNIV. ket (y + d, Some d))
        - α^3 *C (y'UNIV. dUNIV. ket (y', Some d))
        + α2 *C (dUNIV. ket (d, None))
    | (y, Some d) 
        ket (y + d, Some d)
        + α *C ket (y + d, None)
        - α^3 *C (y'UNIV. ket (y', None))
        - α2 *C (d'UNIV. ket (y + d', Some d'))
        - α2 *C (d'UNIV. ket (y + d, Some d'))
        + α2 *C (d'UNIV. ket (y, Some d'))
        + α^4 *C (y'UNIV. d'UNIV. ket (y', Some d')))
  apply (cases yd, rename_tac y d) apply (case_tac d)
   apply (simp_all add: )
   apply (subst query1_None)
   apply simp
  apply (subst query1_Some)
  by simp


lemma query1'_None: query1' *V ket (y,None) = 
                        α *C (dUNIV. ket (y + d, Some d))
                        - α^3 *C (y'UNIV. dUNIV. ket (y', Some d))
                        + α2 *C (dUNIV. ket (d, None)) (is _ = ?rhs)
proof -
  have [simp]: α * α = α2 α * α2 = α^3
    by (simp_all add: power2_eq_square numeral_2_eq_2 numeral_3_eq_3)

  have aux: a = a'  b = b'  c = c'  a - b + c = a' - b' + c' for a b c a' b' c' :: 'z::group_add
    by simp

  have Snd compress1 *V ket (y, None) = (dUNIV. α *C ket (y, Some d))
    by (simp add: query1_def tensor_ell2_scaleC2 tensor_ell2_sum_right flip: tensor_ell2_ket)
  also have standard_query1' *V  = (dUNIV. α *C ket (y + d, Some d))
    by (simp add: cblinfun.scaleC_right cblinfun.sum_right)
  also have Snd compress1 *V  = 
          α *C (dUNIV. (ket (y + d) s ket (Some d)))
        - α^3 *C (zUNIV. dUNIV. (ket (y + z) s ket (Some d)))
        + α2 *C (zUNIV. (ket (y + z) s ket None))
    by (simp add: tensor_ell2_diff2 tensor_ell2_add2 scaleC_add_right sum.distrib tensor_ell2_sum_right
        tensor_ell2_scaleC2 sum_subtractf scaleC_diff_right scaleC_sum_right cblinfun.scaleC_right
        cblinfun.sum_right
        flip: tensor_ell2_ket)
  also have  = ?rhs
    apply (rule aux)
    subgoal 
      by (simp add: tensor_ell2_ket)
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV]) 
      by (simp_all add: tensor_ell2_ket)
    subgoal
      apply simp
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by (simp_all add: tensor_ell2_ket)
    by -
  finally show ?thesis
    unfolding query1'_def by simp
qed

lemma query1'_Some: query1' *V ket (y, Some d) = 
        ket (y + d, Some d)
        + α *C ket (y + d, None)
        - α^3 *C (y'UNIV. ket (y', None))
        - α2 *C (d'UNIV. ket (y + d', Some d'))
        - α2 *C (d'UNIV. ket (y + d, Some d'))
        + α^4 *C (y'UNIV. d'UNIV. ket (y', Some d'))
    (is _ = ?rhs)
proof -
  have [simp]: α * α = α2 α2 * α = α^3
    by (simp_all add: power2_eq_square numeral_2_eq_2 numeral_3_eq_3)

  have aux: a=a'  b=b'  c=c'  d=d'  e=e'  g=g'
          a' - e' + b' + g' - d' - c' = a + b - c - d - e + g 
    for a b c d e f g a' b' c' d' e' f' g' :: 'z::ab_group_add
    by simp

  have Snd compress1 *V ket (y, Some d) =
              ket (y, Some d) - α2 *C (d'UNIV. ket (y, Some d')) + α *C ket (y, None)
    by (simp add: query1_def tensor_ell2_scaleC2 tensor_ell2_diff2 tensor_ell2_add2 tensor_ell2_sum_right
             flip: tensor_ell2_ket scaleC_sum_right)
  also have standard_query1' *V  = ket (y + d, Some d) - α2 *C (d'UNIV. ket (y + d', Some d'))
    by (simp add: cblinfun.add_right cblinfun.diff_right cblinfun.scaleC_right cblinfun.sum_right)
  also have Snd compress1 *V  = 
                ket (y + d, Some d)
                - α2 *C (d'UNIV. ket (y + d, Some d'))
                + α *C ket (y + d, None)
                + α^4 *C (zUNIV. d'UNIV. ket (y + z, Some d'))
                - α2 *C (d'UNIV. ket (y + d', Some d'))
                - α^3 *C (zUNIV. ket (y + z, None))
    by (simp add: tensor_ell2_diff2 tensor_ell2_add2 scaleC_add_right sum.distrib tensor_ell2_sum_right
        tensor_ell2_scaleC2 sum_subtractf scaleC_diff_right scaleC_sum_right cblinfun.sum_right
        cblinfun.add_right cblinfun.diff_right diff_diff_eq2 cblinfun.scaleC_right
        flip: tensor_ell2_ket diff_diff_eq scaleC_sum_right)
  also have  = ?rhs
    apply (rule aux)
    subgoal by rule
    subgoal by rule
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by simp_all
    subgoal by rule
    subgoal by rule
    subgoal
      apply (subst sum.reindex_bij_betw[where h=λd. y + d and T=UNIV])
      by simp_all
    by -
  finally show ?thesis
    unfolding query1'_def by simp
qed

lemma query1': 
  shows query1' *V (ket yd) = (case yd of
    (y, None)  
        α *C (dUNIV. ket (y + d, Some d))
        - α^3 *C (y'UNIV. dUNIV. ket (y', Some d))
        + α2 *C (dUNIV. ket (d, None))
    | (y, Some d) 
        ket (y + d, Some d)
        + α *C ket (y + d, None)
        - α^3 *C (y'UNIV. ket (y', None))
        - α2 *C (d'UNIV. ket (y + d', Some d'))
        - α2 *C (d'UNIV. ket (y + d, Some d'))
        + α^4 *C (y'UNIV. d'UNIV. ket (y', Some d')))
  apply (cases yd, rename_tac y d) apply (case_tac d)
   apply (simp_all add: )
   apply (subst query1'_None)
   apply simp
  apply (subst query1'_Some)
  by simp

subsection query› - Query the compressed oracle›


text ‹
  We define the compressed oracle itself.
  
  Analogous to the definition of constquery1 above (decompress, conststandard_query1, recompress), 
  the compressed oracle is defined by decompressing the oracle register (now a superposition of functions),
  applying conststandard_query, and recompressing.

  That is: If one starts with a four-partite state termψ s ket 0 s ket 0 s ket None and keeps performing
  operations $U_i$ on the parts 1--3 of the state, interleaved with termquery invocations on parts 2--4,
  this is a simulation of starting with state termψ s 0 and performing $U_i$ interleaved with
  invocations of the unitary $\ket{x, y} ↦ \ket{x, y ⊕ h(x)}$  on parts 2, 3 where $h$ is a function 
  chosen uniformly at random in the beginning.

  Note that there is an alternative way of defining the compressed oracle, namely by decompressing
  not the whole oracle register, but only the specific oracle output that we are querying.
  This is closer to an efficient implementation of the compressed oracle.
  We show that this definition is equivalent below (lemma query_local›).›

definition query where query = reg_3_3 compress oCL standard_query oCL reg_3_3 compress

text termquery' is defined like constquery, except that it's based on 
  conststandard_query1' instead of conststandard_query1.
  See the discussion of conststandard_query1' for the difference.›

definition query' where query' = reg_3_3 compress oCL standard_query' oCL reg_3_3 compress


lemma unitary_query[simp]: unitary query
  by (auto simp: query_def register_unitary intro!: unitary_cblinfun_compose)

lemma norm_query[simp]: norm query = 1
  using norm_isometry unitary_isometry unitary_query by blast

lemma norm_query'[simp]: norm query' = 1
  unfolding query'_def
  apply (subst norm_isometry_compose')
   apply (subst register_adjoint[OF register_3_3, symmetric])
   apply (rule register_isometry[OF register_3_3])
   apply simp
  apply (subst norm_isometry_compose)
   apply (rule register_isometry[OF register_3_3])
   apply simp
  by simp

lemma query_local_generic: 
  ― ‹A generalization of lemmas query_local› and query'_local› below.
      We prove this first because it avoids a duplication of the proof because query_local› and query'_local›
      have very similar proofs.›
  fixes query :: ('x × 'y × ('x  'y)) update and query1
    and standard_query and standard_query1
  assumes query_def: query = reg_3_3 compress oCL standard_query oCL reg_3_3 compress
  assumes query1_def: query1 = Snd compress1 oCL standard_query1 oCL Snd compress1
  assumes standard_query_ket: x ψ. standard_query *V (ket x s ψ) = ket x s ((Fst; Snd o function_at x) standard_query1 *V ψ)
  shows query = controlled_op (λx. (Fst; Snd o function_at x) query1)
proof -
  have query *V ket x s ψ = controlled_op (λx. (Fst;Snd  function_at x) query1) *V ket x s ψ for x ψ
  proof -
    have aux: (Snd ((Fst;Snd  function_at x) Q) oCL reg_3_3 (apply_every M R) :: ('x×'y×('x'y)) update)
              = reg_3_3 (apply_every M R) oCL Snd ((Fst;Snd  function_at x) Q) 
      if xM for M and Q :: ('y × 'y option) update and R
      using finite[of M] that
    proof induction
      case empty
      show ?case 
        by simp
    next
      case (insert y F)
      have (Snd ((Fst;Snd  function_at x) Q) oCL reg_3_3 (apply_every (insert y F) R) :: ('x×'y×('x'y)) update) =
             ((Snd o (Fst;Snd  function_at x)) Q oCL (reg_3_3 o function_at y) (R y)) oCL reg_3_3 (apply_every F R)
        by (simp add: apply_every_insert insert register_mult[of reg_3_3, symmetric] cblinfun_compose_assoc)
      also have   = (reg_3_3  function_at y) (R y) oCL ((Snd ((Fst;Snd  function_at x) Q)) oCL reg_3_3 (apply_every F R))
        apply (subst swap_registers[of Snd o _ reg_3_3 o _])
        using insert apply (simp add: reg_3_3_def add: comp_assoc)
        by (simp add: cblinfun_compose_assoc)
      also have  = ((reg_3_3  function_at y) (R y) oCL reg_3_3 (apply_every F R)) oCL Snd ((Fst;Snd  function_at x) Q)
        apply (subst insert.IH)
        using insert by (auto simp: cblinfun_compose_assoc)
      also have  = (reg_3_3 (apply_every (insert y F) R)) oCL Snd ((Fst;Snd  function_at x) Q)
        by (simp add: apply_every_insert insert register_mult[of reg_3_3, symmetric] cblinfun_compose_assoc)
      finally show ?case
        by -
    qed

    have query *V (ket x s ψ) = reg_3_3 compress *V standard_query *V reg_3_3 compress *V (ket x s ψ)
      by (simp add: query_def)
    also have  = reg_3_3 compress *V
            standard_query *V (ket x s Snd compress *V ψ)
      apply (rule arg_cong[where f=λx. _ *V _ *V x])
      by (auto simp: reg_3_3_def)
    also have  = reg_3_3 compress *V
            (ket x s (((Fst; Snd o function_at x) standard_query1 *V Snd compress *V ψ)))
      by (simp add: standard_query_ket)
    also have  = reg_3_3 compress *V
            (Snd ((Fst; Snd o function_at x) standard_query1)) *V (ket x s Snd compress *V ψ)
      by auto
    also have  = reg_3_3 compress *V
            (Snd ((Fst; Snd o function_at x) standard_query1)) *V reg_3_3 compress *V (ket x s ψ)
      apply (rule arg_cong[where f=λx. _ *V _ *V x])
      by (auto simp: reg_3_3_def)
    also have  = (reg_3_3 compress oCL (Snd ((Fst; Snd o function_at x) standard_query1)) oCL reg_3_3 compress) *V (ket x s ψ)
      by auto
    also have  = (reg_3_3 (function_at x compress1) oCL (Snd ((Fst; Snd o function_at x) standard_query1)) oCL reg_3_3 (function_at x compress1)) *V (ket x s ψ)
      (is ?lhs *V _ = ?rhs *V _)
    proof -
      have [simp]: insert x (- {x}) = UNIV for x :: 'x
        by auto
      have ?lhs = reg_3_3 (apply_every ({x}  -{x}) (λ_. compress1))
             oCL Snd ((Fst;Snd  function_at x) standard_query1)
             oCL reg_3_3 (apply_every (-{x}  {x}) (λ_. compress1))
        by (simp add: compress_def)
      also have  = reg_3_3 (function_at x compress1) oCL reg_3_3 (apply_every (- {x}) (λ_. compress1))
         oCL (  Snd ((Fst;Snd  function_at x) standard_query1) oCL reg_3_3 (apply_every (- {x}) (λ_. compress1))  )
         oCL reg_3_3 (function_at x compress1)
        apply (subst apply_every_split[symmetric], simp)
        apply (subst apply_every_split[symmetric], simp)
        by (simp add: register_mult cblinfun_compose_assoc)
      also have  = reg_3_3 (function_at x compress1)
          oCL (  reg_3_3 (apply_every (- {x}) (λ_. compress1)) oCL reg_3_3 (apply_every (- {x}) (λ_. compress1))  )
          oCL Snd ((Fst;Snd  function_at x) standard_query1)
          oCL reg_3_3 (function_at x compress1)
        apply (subst aux)
        by (auto simp add: cblinfun_compose_assoc)
      also have  = reg_3_3 (function_at x compress1)
          oCL (reg_3_3 (apply_every (- {x}) (λ_. compress1 oCL compress1)))
          oCL Snd ((Fst;Snd  function_at x) standard_query1)
          oCL reg_3_3 (function_at x compress1)
        by (simp add: register_mult[of reg_3_3] apply_every_mult)
      also have  = reg_3_3 (function_at x compress1)
          oCL Snd ((Fst;Snd  function_at x) standard_query1)
          oCL reg_3_3 (function_at x compress1)
        by (simp add: compress1_square)
      finally show ?thesis
        by auto
    qed
    also have  = ket x s ((Snd (function_at x compress1) oCL ((Fst; Snd o function_at x) standard_query1) oCL Snd (function_at x compress1)) *V ψ)
      by (simp add: reg_3_3_def)
    also have  = controlled_op (λx. Snd (function_at x compress1) oCL ((Fst; Snd o function_at x) standard_query1) oCL Snd (function_at x compress1)) *V
                  (ket x s ψ)
      by simp
    also have  = controlled_op (λx. (Fst; Snd o function_at x) query1) (ket x s ψ)
      by (auto simp: query1_def register_mult[symmetric] register_pair_Snd[unfolded o_def, THEN fun_cong])
    finally show ?thesis
      by -
  qed

  from this[of _ ket _]
  show ?thesis
    by (auto intro!: equal_ket simp: tensor_ell2_ket)
qed

text ‹We give an alternate (equivalent) definition of the compressed oracle constquery.
  Instead of decompressing the whole oracle, we decompress only the output we need.
  Specifically, this is implemented by -- if the query register contains termket x --
  performing constquery1 on the output register and on
  the register $H_x$ which is the part of the oracle register which corresponds to the 
  output for input $x$.

  And analogously for constquery1'.›

lemma query_local: query = controlled_op (λx. (Fst; Snd o function_at x) query1)
  using query_def query1_def standard_query_ket by (rule query_local_generic)

lemma query'_local: query' = controlled_op (λx. (Fst; Snd o function_at x) query1')
  using query'_def query1'_def standard_query'_ket by (rule query_local_generic)

lemma (in compressed_oracle) standard_query_compress: standard_query oCL reg_3_3 compress = reg_3_3 compress oCL query
  by (simp add: query_def register_mult compress_selfinverse flip: cblinfun_compose_assoc)

lemma (in compressed_oracle) standard_query'_compress: standard_query' oCL reg_3_3 compress = reg_3_3 compress oCL query'
  by (simp add: query'_def register_mult compress_selfinverse flip: cblinfun_compose_assoc)



end (* locale compressed_oracle *)

end