Theory Function_At

section Function_At› -- Function values as individual registers›

theory Function_At
imports Registers.Quantum_Extra Misc_Compressed_Oracle
begin

unbundle no m_inv_syntax

typedef ('a,'b) punctured_function = extensional (-{undefined}) :: ('a'b) set
  by auto
setup_lifting type_definition_punctured_function
instance punctured_function :: (finite, finite) finite
  apply standard apply (rule finite_imageD[where f=Rep_punctured_function])
  by (auto simp add: Rep_punctured_function_inject inj_on_def)

lift_definition fix_punctured_function :: 'a  ('b × ('a,'b) punctured_function)  ('a'b) is
  λx (y, f). (Fun.swap x undefined f) (x := y).

lift_definition puncture_function :: 'a  ('a'b)  'b × ('a,'b) punctured_function is
  λx f. (f x, (Fun.swap x undefined f) (undefined := undefined))
  by (simp add: Compl_eq_Diff_UNIV) 

lemma puncture_function_recombine:
  (y, snd (puncture_function x f)) = puncture_function x (f(x:=y))
  apply transfer
  by (auto intro!: ext simp: Transposition.transpose_def)
  
lemma snd_puncture_function_upd: snd (puncture_function x (f(x:=y))) = snd (puncture_function x f)
  apply transfer
  by (auto intro!: ext simp: Transposition.transpose_def)

lemma puncture_function_split: puncture_function x f = (f x, snd (puncture_function x f))
  using puncture_function_recombine[where x=x and f=f and y=f x]
  by simp

lemma puncture_function_inverse[simp]: fix_punctured_function x (puncture_function x f) = f
  apply transfer by (auto intro!: ext simp: Transposition.transpose_def)

lemma fix_punctured_function_inverse[simp]: puncture_function x (fix_punctured_function x yf) = yf
  apply transfer 
  by (auto intro!: ext simp: Transposition.transpose_def extensional_def)

lemma bij_fix_punctured_function[simp]: bij (fix_punctured_function x)
  by (metis bijI' fix_punctured_function_inverse puncture_function_inverse)

lemma inj_fix_punctured_function[simp]: inj (fix_punctured_function x)
  by (simp add: bij_is_inj)

lemma surj_fix_punctured_function[simp]: surj (fix_punctured_function x)
  by (simp add: bij_is_surj)

text ‹The following termfunction_at_U x provides an unitary isomorphism between typ('a  'b) ell2 (superposition of functions)
and typ('b × ('a, 'b) punctured_function) ell2 (superposition of pairs of the value of 
the function at termx and the rest of the function).
This allows to then apply a some operation to the first part of that pair and thus lifting it to an application to the whole function.
(The "rest of the function" part is to be considered opaque.)›

definition function_at_U :: 'a  ('b × ('a, 'b) punctured_function) ell2 CL ('a  'b) ell2 where
 function_at_U x = classical_operator (Some o fix_punctured_function x)

lemma unitary_function_at_U[simp]: unitary (function_at_U x)
  by (auto simp: function_at_U_def intro!: unitary_classical_operator)

lemma function_at_U_ket[simp]: function_at_U x *V ket y = ket (fix_punctured_function x y)
  by (simp add: function_at_U_def classical_operator_ket classical_operator_exists_inj)

lemma function_at_U_adj_ket[simp]: (function_at_U x)* *V ket y = ket (puncture_function x y)
  apply (simp add: function_at_U_def inv_map_total classical_operator_ket classical_operator_exists_inj)
  by (metis (no_types, lifting) bij_betw_inv_into bij_def bij_fix_punctured_function classical_operator_exists_inj classical_operator_ket inj_map_total inv_f_f o_def option.case(2) puncture_function_inverse)

text ‹The reference function_at x› lifts an operation termU on typ'a ell2 to an operation on typ('a  'b) ell2 (superposition of functions).
The resulting operation applies termU only to the termx-output of the function.›

definition function_at :: 'a  ('b update  ('a'b) update) where
 function_at x = sandwich (function_at_U x) o Fst

lemma Rep_ell2_function_at_ket:
  Rep_ell2 (function_at x U *V ket f) g = 
      of_bool (snd (puncture_function x f) = snd (puncture_function x g)) * Rep_ell2 (U *V ket (f x)) (g x)
proof -
  have Rep_ell2 (function_at x U *V ket f) g = Rep_ell2 (function_at_U x *V (U o id_cblinfun) *V ket (puncture_function x f)) g
    by (simp add: function_at_def Fst_def sandwich_apply)
  also have  = (function_at_U x* *V ket g) C ((U o id_cblinfun) *V ket (puncture_function x f))
    by (metis cinner_adj_left cinner_ket_left)
  also have  = (ket (puncture_function x g)) C ((U o id_cblinfun) *V ket (puncture_function x f))
    by (simp add: function_at_def)
  also have  = (ket (g x, snd (puncture_function x g))) C ((U o id_cblinfun) *V ket (f x, snd (puncture_function x f)))
    by (simp flip: puncture_function_split)
  also have  = of_bool (snd (puncture_function x f) = snd (puncture_function x g)) * (ket (g x) C (U *V ket (f x)))
    by (simp add: tensor_op_ell2 cinner_ket flip: tensor_ell2_ket)
  also have  = of_bool (snd (puncture_function x f) = snd (puncture_function x g)) * Rep_ell2 (U *V ket (f x)) (g x)
    by (simp add: cinner_ket_left)
  finally show ?thesis
    by -
qed


lemma function_at_ket:
  shows function_at x U *V ket f = (yUNIV. Rep_ell2 (U *V ket (f x)) y *C ket (f (x := y)))
proof -
  have function_at x U *V ket f = function_at_U x *V (U o id_cblinfun) *V ket (puncture_function x f)
    by (simp add: function_at_def Fst_def sandwich_apply)
  also have  = function_at_U x *V (U o id_cblinfun) *V ket (f x, snd (puncture_function x f))
    by (metis puncture_function_split)
  also have  = function_at_U x *V ((U *V ket (f x)) s ket (snd (puncture_function x f)))
    by (simp add: tensor_op_ket)
  also have  = function_at_U x *V ((yUNIV. Rep_ell2 (U *V ket (f x)) y *C ket y) s ket (snd (puncture_function x f)))
    by (simp flip: ell2_decompose_infsum)
  also have  = (yUNIV. Rep_ell2 (U *V ket (f x)) y *C (function_at_U x *V (ket y s ket (snd (puncture_function x f)))))
    by (simp del: function_at_U_ket 
        add: tensor_ell2_scaleC1 invertible_cblinfun_isometry infsum_cblinfun_apply_invertible infsum_tensor_ell2_left
        flip: cblinfun.scaleC_right)
  also have  = (yUNIV. Rep_ell2 (U *V ket (f x)) y *C ket (f (x := y)))
    by (simp add: puncture_function_recombine tensor_ell2_ket)
  finally show ?thesis
    by -
qed

lemma register_function_at[simp, register]: register (function_at x :: 'b update  ('a'b) update) for x :: 'a
  by (auto simp add: function_at_def unitary_sandwich_register)

lemma function_at_comm:
  fixes U V :: 'b ell2 CL 'b ell2 and x y :: 'a
  assumes x  y
  shows function_at x U oCL function_at y V = function_at y V oCL function_at x U
proof -
  define reorder where reorder = classical_operator (Some o (λ(f :: 'a  'b, a, b). (f(x:=a, y:=b), f x, f y)))

  have selfinv: (λ(f, a, b). (f(x := a, y := b), f x, f y)) o (λ(f, a, b). (f(x := a, y := b), f x, f y)) = id
    using assms by (auto intro!: ext)
  have bij: bij (λ(f, a, b). (f(x := a, y := b), f x, f y))
    using o_bij selfinv by blast
  have inv: inv (λ(f, a, b). (f(x := a, y := b), f x, f y)) = (λ(f, a, b). (f(x := a, y := b), f x, f y))
  using inv_unique_comp selfinv by blast
  have inj_map: inj_map (Some o (λ(f, a, b). (f(x := a, y := b), f x, f y)))
    by (simp add: inj_map_total bij_is_inj[OF bij])
  have inv: inv_map (Some o (λ(f, a, b). (f(x := a, y := b), f x, f y))) = (Some o (λ(f, a, b). (f(x := a, y := b), f x, f y)))
    by (simp add: inv_map_total bij_is_surj bij inv)
  have reorder_exists: classical_operator_exists (Some o (λ(f, a, b). (f(x := a, y := b), f x, f y)))
    using inj_map by (rule classical_operator_exists_inj)

  have [simp]: reorder* = reorder
    by (simp add: reorder_def classical_operator_adjoint[OF inj_map] inv)
  have [simp]: reorder (ket f s ket a s ket b) = ket (f(x:=a, y:=b), f x, f y) for f a b
    by (simp add: reorder_def tensor_ell2_ket classical_operator_ket[OF reorder_exists])
  have [simp]: isometry reorder
    using inj_map_total isometry_classical_operator inj_map reorder_def by blast

  have sandwichU: sandwich reorder (function_at x U o id_cblinfun) = id_cblinfun o (U o id_cblinfun)
  proof  (rule equal_ket, rule cinner_ket_eqI, rename_tac fab gcd)
    fix fab gcd :: ('a  'b) × 'b × 'b
    obtain f a b where [simp]: fab = (f,a,b)
      by (auto simp: prod_eq_iff)
    obtain g c d where [simp]: gcd = (g,c,d)
      by (auto simp: prod_eq_iff)
    have fg_rewrite: f = g  b = d  
        snd (puncture_function x (f(x := a, y := b))) = snd (puncture_function x (g(x := c, y := d)))  f x = g x  f y = g y
      using assms
      by (smt (verit, del_insts) array_rules(3) fun_upd_idem fun_upd_twist puncture_function_inverse puncture_function_recombine snd_puncture_function_upd)
    have ket gcd C ((sandwich reorder *V function_at x U o id_cblinfun) *V ket fab)
      = ket (g(x:=c, y:=d), g x, g y) C ((function_at x U o id_cblinfun) *V ket (f(x:=a, y:=b), f x, f y))
      by (simp add: sandwich_apply flip: cinner_adj_left tensor_ell2_ket)
    also have  = (ket (g(x:=c, y:=d)) C (function_at x U *V ket (f(x:=a, y:=b))))
                     * of_bool (f x = g x  f y = g y)
      by (auto simp add: tensor_op_ell2 simp flip: tensor_ell2_ket)
    also have  = Rep_ell2 (U *V ket a) c * of_bool (f = g  b = d)
      using assms by (auto simp add: cinner_ket_left Rep_ell2_function_at_ket fg_rewrite)
    also have  = ket gcd C ((id_cblinfun o U o id_cblinfun) *V ket fab)
      by (auto simp add: tensor_op_ell2 cinner_ket_left[of c] simp flip: tensor_ell2_ket)
    finally show ket gcd C ((sandwich reorder *V function_at x U o id_cblinfun) *V ket fab) =
                ket gcd C ((id_cblinfun o U o id_cblinfun) *V ket fab)
      by -
  qed

  have sandwichV: sandwich reorder (function_at y V o id_cblinfun) = id_cblinfun o (id_cblinfun o V)
  proof  (rule equal_ket, rule cinner_ket_eqI, rename_tac fab gcd)
    fix fab gcd :: ('a  'b) × 'b × 'b
    obtain f a b where [simp]: fab = (f,a,b)
      by (auto simp: prod_eq_iff)
    obtain g c d where [simp]: gcd = (g,c,d)
      by (auto simp: prod_eq_iff)
    have fg_rewrite: f = g  a = c  
        snd (puncture_function y (f(x := a, y := b))) = snd (puncture_function y (g(x := c, y := d)))  f x = g x  f y = g y
      using assms
      by (metis array_rules(3) fun_upd_idem fun_upd_twist puncture_function_inverse puncture_function_recombine snd_puncture_function_upd)
    have ket gcd C ((sandwich reorder *V function_at y V o id_cblinfun) *V ket fab)
      = ket (g(x:=c, y:=d), g x, g y) C ((function_at y V o id_cblinfun) *V ket (f(x:=a, y:=b), f x, f y))
      by (simp add: sandwich_apply flip: cinner_adj_left tensor_ell2_ket)
    also have  = (ket (g(x:=c, y:=d)) C (function_at y V *V ket (f(x:=a, y:=b))))
                     * of_bool (f x = g x  f y = g y)
      by (auto simp add: tensor_op_ell2 simp flip: tensor_ell2_ket)
    also have  = Rep_ell2 (V *V ket b) d * of_bool (f = g  a = c)
      using assms by (auto simp add: cinner_ket_left Rep_ell2_function_at_ket fg_rewrite)
    also have  = ket gcd C ((id_cblinfun o id_cblinfun o V) *V ket fab)
      by (auto simp add: tensor_op_ell2 cinner_ket_left[of d] simp flip: tensor_ell2_ket)
    finally show ket gcd C ((sandwich reorder *V function_at y V o id_cblinfun) *V ket fab) =
                ket gcd C ((id_cblinfun o id_cblinfun o V) *V ket fab)
      by -
  qed

  have sandwich reorder ((function_at x U o id_cblinfun) oCL (function_at y V o id_cblinfun))
      = sandwich reorder ((function_at y V o id_cblinfun) oCL (function_at x U o id_cblinfun))
    apply (simp add: sandwichU sandwichV flip: sandwich_arg_compose)
    by (simp add: comp_tensor_op)
  then have (function_at x U o (id_cblinfun :: ('b × 'b) ell2 CL ('b × 'b) ell2)) oCL (function_at y V o id_cblinfun) = (function_at y V o id_cblinfun) oCL (function_at x U o id_cblinfun)
    by (smt (verit, best) isometry reorder cblinfun_compose_id_left cblinfun_compose_id_right compatible_ac_rules(2) isometryD sandwich_apply)
  then have (function_at x U oCL function_at y V) o (id_cblinfun :: ('b × 'b) ell2 CL ('b × 'b) ell2) = (function_at y V oCL function_at x U) o id_cblinfun
    by (simp add: comp_tensor_op)
  then show function_at x U oCL function_at y V = function_at y V oCL function_at x U
    apply (rule injD[OF inj_tensor_left, rotated])
    by simp
qed

  

lemma compatible_function_at[simp]: 
  assumes x  y
  shows compatible (function_at x) (function_at y)
proof (rule compatibleI)
  show register (function_at x)
    by simp
  show register (function_at y)
    by simp
  fix a b :: 'b update
  show function_at x a oCL function_at y b = function_at y b oCL function_at x a
    using assms by (rule function_at_comm)
qed

lemma inv_fix_punctured_function[simp]: inv (fix_punctured_function x) = puncture_function x
  by (simp add: inv_equality)

lemma bij_puncture_function[simp]: bij (puncture_function x)
  by (metis bij_betw_inv_into bij_fix_punctured_function inv_fix_punctured_function)

lemma fst_puncture_function[simp]: fst (puncture_function x H) = H x
  apply transfer by simp

subsection apply_every›

text ‹Analogue to classical termλ(M::'a set) (u::'a  'b  'b) (f::'a  'b) (x::'a). if xM then u x (f x) else f x.

Note that the definition only makes sense when termM is finite.
In fact, a definition that works for infinite termM is impossible as the following example shows:
Let termH denote the Hadamard matrix. Let termM=(UNIV :: nat set).
Then, by symmetry, a meaningful definition of termapply_every would have that termapply_every M H (ket (λ_. 0))
would be a vector in typ(nat => bit) ell2 with all coefficients equal.
But the only such vector is term0. But a meaningful definition should not map termket (λ_. 0) to term0.›

definition apply_every where apply_every M U = (if finite M then Finite_Set.fold (λx a. function_at x (U x) oCL a) id_cblinfun M else 0)

lemma apply_every_empty[simp]: apply_every {} U = id_cblinfun
  by (simp add: apply_every_def)

interpretation apply_every_aux: comp_fun_commute (λx. (oCL) (function_at x (U x)))
  apply standard
  apply (rule ext)
  apply (case_tac x=y)
  by (auto simp flip: cblinfun_compose_assoc swap_registers_left)

lemma apply_every_unitary: unitary (apply_every M U) if  finite M and [simp]: x. xM  unitary (U x)
proof -
  show ?thesis
    using that
  proof induction
    case empty
    then show ?case 
      by simp
  next
    case (insert x F)
    then have *: apply_every (insert x F) U = function_at x (U x) oCL apply_every F U
      by (simp add: apply_every_def)
    show ?case
      by (simp add: * register_unitary insert)
  qed
qed

lemma apply_every_comm: apply_every M U oCL V = V oCL apply_every M U
  if finite M and x. xM  function_at x (U x) oCL V = V oCL function_at x (U x)
  unfolding apply_every_def using that
proof induction
  case empty
  show ?case
    by simp
next
  case (insert x F)
  then show ?case
    apply (simp add: insert cblinfun_compose_assoc)
    by (simp flip: cblinfun_compose_assoc  insert.prems)
qed

lemma apply_every_infinite: apply_every M U = 0 if infinite M
  using that by (simp add: apply_every_def)


lemma apply_every_split: apply_every M U oCL apply_every N U = apply_every (M  N) U if M  N = {} for M N U
proof -
  wlog finiteM: finite M
    using negation
    by (simp add: apply_every_infinite)
  wlog finiteN: finite N keeping finiteM
    using negation
    by (simp add: apply_every_infinite)
  define f :: 'a  ('a  'b) update  ('a  'b) update where f x = (oCL) (function_at x (U x)) for x
  define fM fN where fM = Finite_Set.fold f id_cblinfun M and fN = Finite_Set.fold f id_cblinfun N
  have apply_every (M  N) U = apply_every (N  M) U
    by (simp add: Un_commute)
  also have  = Finite_Set.fold f (Finite_Set.fold f id_cblinfun N) M
    unfolding apply_every_def
    apply (subst apply_every_aux.fold_set_union_disj)
    using finiteM finiteN that by (auto simp add: f_def[abs_def])
  also have  = fM oCL fN
    unfolding fM_def fN_def[symmetric]
    using finiteM
    apply (induction M)
    by (auto simp add: f_def[abs_def] cblinfun_compose_assoc)
  also have  = apply_every M U oCL apply_every N U
    by (simp add: apply_every_def fN_def fM_def f_def[abs_def] finiteN finiteM)
  finally show ?thesis
    by simp
qed

lemma apply_every_single[simp]: apply_every {x} U = function_at x (U x)
  by (simp add: apply_every_def)

lemma apply_every_insert: apply_every (insert x M) U = function_at x (U x) oCL apply_every M U if x  M and finite M
  using that by (simp add: apply_every_def)

lemma apply_every_mult: apply_every M U oCL apply_every M V = apply_every M (λx. U x oCL V x)
proof (induction rule:infinite_finite_induct)
  case (infinite M)
  then show ?case
    by (simp add: apply_every_infinite)
next
  case empty
  show ?case 
    by simp
next
  case (insert x F)
  have apply_every (insert x F) U oCL apply_every (insert x F) V
      = function_at x (U x) oCL (apply_every F U oCL function_at x (V x)) oCL apply_every F V
    using insert by (simp add: apply_every_insert cblinfun_compose_assoc)
  also have  = (function_at x (U x) oCL function_at x (V x)) oCL (apply_every F U oCL apply_every F V)
    apply (subst apply_every_comm)
      apply (fact insert)
    using insert apply (metis (no_types, lifting) compatible_function_at swap_registers)
    by (simp add: cblinfun_compose_assoc)
  also have  = (function_at x (U x oCL V x)) oCL (apply_every F U oCL apply_every F V)
    by (simp add: register_mult)
  also have  = (function_at x (U x oCL V x)) oCL (apply_every F (λx. U x oCL V x))
    using insert.IH by presburger
  also have  = (apply_every (insert x F) (λx. U x oCL V x))
    using insert.hyps by (simp add: apply_every_insert)
  finally show ?case
    by -
qed

lemma apply_every_id[simp]: apply_every M (λ_. id_cblinfun) = id_cblinfun if finite M
  using that apply induction
  by (auto simp: apply_every_insert)

lemma apply_every_function_at_comm:
  assumes x  M
  shows function_at x U oCL apply_every M f = apply_every M f oCL function_at x U
  using assms apply (induction rule: infinite_finite_induct)
    apply (simp add: apply_every_infinite)
   apply simp
  apply (simp add: apply_every_insert function_at_comm[where x=x] 
      flip: cblinfun_compose_assoc)
  by (simp add: cblinfun_compose_assoc)

lemma apply_every_adj: (apply_every M f)* = apply_every M (λi. (f i)*)
  apply (induction rule: infinite_finite_induct)
    apply (simp add: apply_every_infinite)
   apply simp
  by (simp add: apply_every_insert apply_every_function_at_comm register_adjoint)

end