Theory Grover

section ‹Grover's algorithm›

theory Grover
  imports Partial_State Gates Quantum_Hoare
begin

subsection ‹Basic definitions›

locale grover_state =
  fixes n :: nat  (* number of qubits *)
    and f :: "nat  bool"  (* characteristic function, only need values in [0,N). *)
  assumes n: "n > 1"
    and dimM: "card {i. i < (2::nat) ^ n  f i} > 0"
              "card {i. i < (2::nat) ^ n  f i} < (2::nat) ^ n"
begin

definition N where
  "N = (2::nat) ^ n"

definition M where
  "M = card {i. i < N  f i}"

lemma N_ge_0 [simp]: "0 < N" by (simp add: N_def)

lemma M_ge_0 [simp]: "0 < M" by (simp add: M_def dimM N_def)

lemma M_neq_0 [simp]: "M  0" by simp

lemma M_le_N [simp]: "M < N" by (simp add: M_def dimM N_def)

lemma M_not_ge_N [simp]: "¬ M  N" using M_le_N by arith

definition ψ :: "complex vec" where
  "ψ = Matrix.vec N (λi. 1 / sqrt N)"

lemma ψ_dim [simp]:
  "ψ  carrier_vec N"
  "dim_vec ψ = N"
  by (simp add: ψ_def)+

lemma ψ_eval:
  "i < N  ψ $ i = 1 / sqrt N"
  by (simp add: ψ_def)

lemma ψ_inner:
  "inner_prod ψ ψ = 1"
  apply (simp add: ψ_eval scalar_prod_def)
  by (smt of_nat_less_0_iff of_real_mult of_real_of_nat_eq real_sqrt_mult_self)
 
lemma ψ_norm:
  "vec_norm ψ = 1"
  by (simp add: ψ_eval vec_norm_def scalar_prod_def)

definition α :: "complex vec" where
  "α = Matrix.vec N (λi. if f i then 0 else 1 / sqrt (N - M))"

lemma α_dim [simp]:
  "α  carrier_vec N"
  "dim_vec α = N"
  by (simp add: α_def)+

lemma α_eval:
  "i < N  α $ i = (if f i then 0 else 1 / sqrt (N - M))"
  by (simp add: α_def)

lemma α_inner:
  "inner_prod α α = 1"
  apply (simp add: scalar_prod_def α_eval)
  apply (subst sum.mono_neutral_cong_right[of "{0..<N}" "{0..<N}-{i. i < N  f i}"])
   apply auto
  apply (subgoal_tac "card ({0..<N} - {i. i < N  f i}) = N - M")
  subgoal by (metis of_nat_0_le_iff of_real_of_nat_eq of_real_power power2_eq_square real_sqrt_pow2)
  unfolding N_def M_def 
  by (metis (no_types, lifting) atLeastLessThan_iff card.infinite card_Diff_subset card_atLeastLessThan diff_zero dimM(1) mem_Collect_eq neq0_conv subsetI zero_order(1))

definition β :: "complex vec" where
  "β = Matrix.vec N (λi. if f i then 1 / sqrt M else 0)"

lemma β_dim [simp]:
  "β  carrier_vec N"
  "dim_vec β = N"
  by (simp add: β_def)+

lemma β_eval:
  "i < N  β $ i = (if f i then 1 / sqrt M else 0)"
  by (simp add: β_def)

lemma β_inner:
  "inner_prod β β = 1"  
  apply (simp add: scalar_prod_def β_eval)
  apply (subst sum.mono_neutral_cong_right[of "{0..<N}" "{i. i < N  f i}"])
   apply auto
  apply (fold M_def)
  by (metis of_nat_0_le_iff of_real_of_nat_eq of_real_power power2_eq_square real_sqrt_pow2)

lemma alpha_beta_orth:
  "inner_prod α β = 0"
  unfolding α_def β_def by (simp add: scalar_prod_def)

lemma beta_alpha_orth:
  "inner_prod β α = 0"
  unfolding α_def β_def by (simp add: scalar_prod_def)

definition θ :: real where
  "θ = 2 * arccos (sqrt ((N - M) / N))"

lemma cos_theta_div_2:
  "cos (θ / 2) = sqrt ((N - M) / N)"
proof -
  have "θ / 2 = arccos (sqrt ((N - M) / N))" using θ_def by simp
  then show "cos (θ / 2) = sqrt ((N - M) / N)" 
    by (simp add: cos_arccos_abs)
qed

lemma sin_theta_div_2:
  "sin (θ / 2) = sqrt (M / N)"
proof -
  have a: "θ / 2 = arccos (sqrt ((N - M) / N))" using θ_def by simp
  have N: "N > 0" using N_def by auto
  have M: "M < N" using M_def dimM N_def by auto
  then show "sin (θ / 2) = sqrt (M / N)"
    unfolding a
    apply (simp add: sin_arccos_abs)
  proof -
    have eq: "real (N - M) = real N - real M" using N M 
      using M_not_ge_N nat_le_linear of_nat_diff by blast
    have "1 - real (N - M) / real N = (real N - (real N - real M)) / real N" 
      unfolding eq using N 
      by (metis diff_divide_distrib divide_self_if eq gr_implies_not0 of_nat_0_eq_iff)
    then show "1 - real (N - M) / real N = real M / real N" by auto
  qed
qed

lemma θ_neq_0:
  "θ  0"
proof -
  {
  assume "θ = 0"
  then have "θ / 2 = 0" by auto
  then have "sin (θ / 2) = 0" by auto
  }
  note z = this
  have "sin (θ / 2) = sqrt (M / N)" using sin_theta_div_2 by auto
  moreover have "M > 0" unfolding M_def N_def using dimM by auto
  ultimately have "sin (θ / 2) > 0" by auto
  with z show ?thesis by auto
qed

abbreviation ccos where "ccos φ  complex_of_real (cos φ)"
abbreviation csin where "csin φ  complex_of_real (sin φ)"

lemma ψ_eq:
  "ψ = ccos (θ / 2) v α + csin (θ / 2) v β"
  apply (simp add: cos_theta_div_2 sin_theta_div_2)
  apply (rule eq_vecI)
  by (auto simp add: α_def β_def ψ_def real_sqrt_divide)

lemma psi_inner_alpha:
  "inner_prod ψ α = ccos (θ / 2)"
  unfolding ψ_eq
proof -
  have "inner_prod (ccos (θ / 2) v α) α = ccos (θ / 2)"
    apply (subst inner_prod_smult_right[of _ N])
    using α_dim α_inner by auto
  moreover have "inner_prod (csin (θ / 2) v β) α = 0"
    apply (subst inner_prod_smult_right[of _ N])
    using α_dim β_dim beta_alpha_orth by auto
  ultimately show "inner_prod (ccos (θ / 2) v α + csin (θ / 2) v β) α = ccos (θ / 2)"
    apply (subst inner_prod_distrib_left[of _ N])
    using α_dim β_dim by auto
qed

lemma psi_inner_beta:
  "inner_prod ψ β = csin (θ / 2)"
  unfolding ψ_eq
proof -
  have "inner_prod (ccos (θ / 2) v α) β = 0"
    apply (subst inner_prod_smult_right[of _ N])
    using α_dim β_dim alpha_beta_orth by auto
  moreover have "inner_prod (csin (θ / 2) v β) β = csin (θ / 2)"
    apply (subst inner_prod_smult_right[of _ N])
    using β_dim β_inner by auto
  ultimately show "inner_prod (ccos (θ / 2) v α + csin (θ / 2) v β) β = csin (θ / 2)"
    apply (subst inner_prod_distrib_left[of _ N])
    using α_dim β_dim by auto
qed

definition alpha_l :: "nat  complex" where
  "alpha_l l = ccos ((l + 1 / 2) * θ)"

lemma alpha_l_real:
  "alpha_l l  Reals"
  unfolding alpha_l_def by auto

lemma cnj_alpha_l:
  "conjugate (alpha_l l) = alpha_l l"
  using alpha_l_real Reals_cnj_iff by auto

definition beta_l :: "nat  complex" where
  "beta_l l = csin ((l + 1 / 2) * θ)"

lemma beta_l_real:
  "beta_l l  Reals"
  unfolding beta_l_def by auto

lemma cnj_beta_l:
  "conjugate (beta_l l) = beta_l l"
  using beta_l_real Reals_cnj_iff by auto

lemma csin_ccos_squared_add:
  "ccos (a::real) * ccos a + csin a * csin a = 1"
  by (smt cos_diff cos_zero of_real_add of_real_hom.hom_one of_real_mult)

lemma alpha_l_beta_l_add_norm:
  "alpha_l l * alpha_l l + beta_l l * beta_l l = 1"
  using alpha_l_def beta_l_def csin_ccos_squared_add by auto

definition psi_l where
  "psi_l l = (alpha_l l) v α + (beta_l l) v β"

lemma psi_l_dim:
  "psi_l l  carrier_vec N"
  unfolding psi_l_def α_def β_def by auto

lemma inner_psi_l:
  "inner_prod (psi_l l) (psi_l l) = 1"
proof -
  have eq0: "inner_prod (psi_l l) (psi_l l) 
    = inner_prod ((alpha_l l) v α) (psi_l l) + inner_prod ((beta_l l) v β) (psi_l l)"
    unfolding psi_l_def
    apply (subst inner_prod_distrib_left)
    using α_def β_def by auto
  have "inner_prod ((alpha_l l) v α) (psi_l l) 
    = inner_prod ((alpha_l l) v α) ((alpha_l l) v α) + inner_prod ((alpha_l l) v α) ((beta_l l) v β)"
    unfolding psi_l_def
    apply (subst inner_prod_distrib_right)
    using α_def β_def by auto
  also have " = (conjugate (alpha_l l)) * (alpha_l l) * inner_prod α α 
                + (conjugate (alpha_l l)) * (beta_l l) * inner_prod α β"
    apply (subst (1 2) inner_prod_smult_left_right) using α_def β_def by auto
  also have " = conjugate (alpha_l l) * (alpha_l l) "
    by (simp add: alpha_beta_orth α_inner)
  also have " = (alpha_l l) * (alpha_l l)" using cnj_alpha_l by simp
  finally have eq1: "inner_prod (alpha_l l v α) (psi_l l) = alpha_l l * alpha_l l".

  have "inner_prod ((beta_l l) v β) (psi_l l) 
    = inner_prod ((beta_l l) v β) ((alpha_l l) v α) + inner_prod ((beta_l l) v β) ((beta_l l) v β)"
    unfolding psi_l_def
    apply (subst inner_prod_distrib_right)
    using α_def β_def by auto
  also have " = (conjugate (beta_l l)) * (alpha_l l) * inner_prod β α 
                + (conjugate (beta_l l)) * (beta_l l) * inner_prod β β"
    apply (subst (1 2) inner_prod_smult_left_right) using α_def β_def by auto
  also have " = (conjugate (beta_l l)) * (beta_l l)"  using β_inner beta_alpha_orth by auto
  also have " = (beta_l l) * (beta_l l)" using cnj_beta_l by auto
  finally have eq2: "inner_prod (beta_l l v β) (psi_l l) = beta_l l * beta_l l".

  show ?thesis unfolding eq0 eq1 eq2 using alpha_l_beta_l_add_norm by auto
qed

abbreviation proj :: "complex vec  complex mat" where
  "proj v  outer_prod v v"

definition psi'_l where
  "psi'_l l = (alpha_l l) v α - (beta_l l) v β"

lemma psi'_l_dim:
  "psi'_l l  carrier_vec N"
  unfolding psi'_l_def α_def β_def by auto

definition proj_psi'_l where
  "proj_psi'_l l = proj (psi'_l l)"

lemma proj_psi'_dim:
  "proj_psi'_l l  carrier_mat N N"
  unfolding proj_psi'_l_def using psi'_l_dim by auto

lemma psi_inner_psi'_l:
  "inner_prod ψ (psi'_l l) = (alpha_l l * ccos (θ / 2) - beta_l l * csin (θ / 2))"
proof -
  have "inner_prod ψ (psi'_l l) = inner_prod ψ (alpha_l l v α) - inner_prod ψ (beta_l l v β)"
    unfolding psi'_l_def apply (subst inner_prod_minus_distrib_right[of _ N]) by auto
  also have " = alpha_l l * (inner_prod ψ α) - beta_l l * (inner_prod ψ β)"
    using ψ_dim α_dim β_dim by auto
  also have " = alpha_l l * (ccos (θ / 2)) - beta_l l * (csin (θ / 2))"
    using psi_inner_alpha psi_inner_beta by auto
  finally show ?thesis by auto
qed

lemma double_ccos_square:
  "2 * ccos (a::real) * ccos a = ccos (2 * a) + 1"
proof -
  have eq: "ccos (2 * a) = ccos a * ccos a - csin a * csin a"
    using cos_add[of a a] by auto
  have "csin a * csin a = 1 - ccos a * ccos a"
    using csin_ccos_squared_add[of a]
    by (metis add_diff_cancel_left')
  then have "ccos a * ccos a - csin a * csin a = 2 * ccos a * ccos a - 1"
    by simp
  with eq show ?thesis by simp 
qed

lemma double_csin_square:
  "2 * csin (a::real) * csin a = 1 - ccos (2 * a)"
proof -
  have eq: "ccos (2 * a) = ccos a * ccos a - csin a * csin a"
    using cos_add[of a a] by auto
  have "ccos a * ccos a = 1 - csin a * csin a"
    using csin_ccos_squared_add[of a]
      by (auto intro: add_implies_diff)
  then have "ccos a * ccos a - csin a * csin a = 1 - 2 * csin (a::real) * csin a"
    by simp
  with eq show ?thesis by simp
qed

lemma csin_double:
  "2 * csin (a::real) * ccos a = csin(2 * a)"
  using sin_add[of a a] by simp

lemma ccos_add:
  "ccos (x + y) = ccos x * ccos y - csin x * csin y"
  using cos_add[of x y] by simp

lemma alpha_l_Suc_l_derive:
  "2 * (alpha_l l * ccos (θ / 2) - beta_l l * csin (θ / 2)) * ccos (θ / 2) - alpha_l l = alpha_l (l + 1)"
  (is "?lhs = ?rhs")
proof -
  have "2 * ((alpha_l l) * ccos (θ / 2) - (beta_l l) * csin (θ / 2)) * ccos (θ / 2)
    = (alpha_l l) * (2 * ccos (θ / 2)* ccos (θ / 2)) - (beta_l l) * (2 * csin (θ / 2) * ccos (θ / 2))" 
    by (simp add: left_diff_distrib)

  also have " = (alpha_l l) * (ccos (θ) + 1) - (beta_l l) * csin θ"
    using double_ccos_square csin_double by auto
  finally have "2 * ((alpha_l l) * ccos (θ / 2) - (beta_l l) * csin (θ / 2)) * ccos (θ / 2) 
    = (alpha_l l) * (ccos (θ) + 1) - (beta_l l) * csin θ".
  then have "?lhs = (alpha_l l) * ccos (θ) - (beta_l l) * csin θ" by (simp add: algebra_simps)
  also have " = (alpha_l (l + 1))"
    unfolding alpha_l_def beta_l_def 
    apply (subst ccos_add[of "(real l + 1 / 2) * θ" "θ", symmetric])
    by (simp add: algebra_simps)
  finally show ?thesis by auto
qed

lemma csin_add:
  "csin (x + y) = ccos x * csin y + csin x * ccos y"
  using sin_add[of x y] by simp

lemma beta_l_Suc_l_derive:
  "2 * (alpha_l l * ccos (θ / 2) - (beta_l l) * csin (θ / 2)) * csin (θ / 2) + beta_l l = beta_l (l + 1)"
  (is "?lhs = ?rhs")
proof -
  have "2 * ((alpha_l l) * ccos (θ / 2) - (beta_l l) * csin (θ / 2)) * csin (θ / 2)
    = (alpha_l l) * (2 * csin (θ / 2)* ccos (θ / 2)) - (beta_l l) * (2 * csin (θ / 2) * csin (θ / 2))" 
    by (simp add: left_diff_distrib)
  also have " = (alpha_l l) * (csin θ) - (beta_l l) * (1 - ccos (θ))"
    using double_csin_square csin_double by auto
  finally have "2 * ((alpha_l l) * ccos (θ / 2) - (beta_l l) * csin (θ / 2)) * csin (θ / 2)
    = (alpha_l l) * (csin θ) - (beta_l l) * (1 - ccos (θ))".
  then have "?lhs = (alpha_l l) * (csin θ) + (beta_l l) * ccos θ" by (simp add: algebra_simps)
  also have " = (beta_l (l + 1))"
    unfolding alpha_l_def beta_l_def 
    apply (subst csin_add[of "(real l + 1 / 2) * θ" "θ", symmetric])
    by (simp add: algebra_simps)
  finally show ?thesis by auto
qed

lemma psi_l_Suc_l_derive:
  "2 * (alpha_l l * ccos (θ / 2) - beta_l l * csin (θ / 2)) v ψ - psi'_l l = psi_l (l + 1)"
  (is "?lhs = ?rhs")
proof -
  let ?l = "2 * ((alpha_l l) * ccos (θ / 2) - (beta_l l) * csin (θ / 2))"
  have "?l v ψ = ?l v (ccos (θ / 2) v α + csin (θ / 2) v β)" unfolding ψ_eq by auto
  also have " = ?l v (ccos (θ / 2) v α) + ?l v (csin (θ / 2) v β)" 
    apply (subst smult_add_distrib_vec[of _ N]) using α_dim β_dim by auto
  also have " = (?l * ccos (θ / 2)) v α + (?l * csin (θ / 2)) v β" by auto
  finally have "?l v ψ  = (?l * ccos (θ / 2)) v α + (?l * csin (θ / 2)) v β".
  then have "?l v ψ - (psi'_l l) = ((?l * ccos (θ / 2)) v α - (alpha_l l) v α) + ((?l * csin (θ / 2)) v β + (beta_l l) v β)"
    unfolding psi'_l_def by auto
  also have " = (?l * ccos (θ / 2) - alpha_l l) v α + (?l * csin (θ / 2) + beta_l l) v β"
    apply (subst minus_smult_vec_distrib) apply (subst add_smult_distrib_vec) by auto
  also have " = (alpha_l (l + 1)) v α + (beta_l (l + 1)) v β"
    using alpha_l_Suc_l_derive beta_l_Suc_l_derive by auto
  finally have "?l v ψ - (psi'_l l) = (alpha_l (l + 1)) v α + (beta_l (l + 1)) v β".
  then show ?thesis unfolding psi_l_def by auto
qed

subsection ‹Grover operator›

text ‹Oracle O›

definition proj_O :: "complex mat" where
  "proj_O = mat N N (λ(i, j). if i = j then (if f i then 1 else 0) else 0)"

lemma proj_O_dim:
  "proj_O  carrier_mat N N"
  unfolding proj_O_def by auto

lemma proj_O_mult_alpha:
  "proj_O *v α = zero_vec N"
  by (auto simp add: proj_O_def α_def scalar_prod_def)

lemma proj_O_mult_beta:
  "proj_O *v β = β"
  by (auto simp add: proj_O_def β_def scalar_prod_def sum_only_one_neq_0)

definition mat_O :: "complex mat" where
  "mat_O = mat N N (λ(i,j). if i = j then (if f i then -1 else 1) else 0)"

lemma mat_O_dim:
  "mat_O  carrier_mat N N"
  unfolding mat_O_def by auto

lemma mat_O_mult_alpha:
  "mat_O *v α = α"
  by (auto simp add: mat_O_def α_def scalar_prod_def sum_only_one_neq_0)

lemma mat_O_mult_beta:
  "mat_O *v β = - β"
  by (auto simp add: mat_O_def β_def scalar_prod_def sum_only_one_neq_0)

lemma hermitian_mat_O:
  "hermitian mat_O"
  by (auto simp add: hermitian_def mat_O_def adjoint_eval)

lemma unitary_mat_O:
  "unitary mat_O"
proof -
  have "mat_O  carrier_mat N N" unfolding mat_O_def by auto
  moreover have "mat_O * adjoint mat_O = mat_O * mat_O" using hermitian_mat_O unfolding hermitian_def by auto
  moreover have "mat_O * mat_O = 1m N"
    apply (rule eq_matI)
    unfolding mat_O_def
      apply (simp add: scalar_prod_def)
    subgoal for i j apply (rule)
      subgoal apply (subst sum_only_one_neq_0[of "{0..<N}" "j"]) by auto
        apply (subst sum_only_one_neq_0[of "{0..<N}" "j"]) by auto
    by auto
  ultimately show ?thesis unfolding unitary_def inverts_mat_def by auto
qed

definition mat_Ph :: "complex mat" where
  "mat_Ph = mat N N (λ(i,j). if i = j then if i = 0 then 1 else -1 else 0)"

lemma hermitian_mat_Ph:
  "hermitian mat_Ph"
  unfolding hermitian_def mat_Ph_def
  apply (rule eq_matI)
  by (auto simp add: adjoint_eval)

lemma unitary_mat_Ph:
  "unitary mat_Ph"
proof -
  have "mat_Ph  carrier_mat N N" unfolding mat_Ph_def by auto
  moreover have "mat_Ph * adjoint mat_Ph = mat_Ph * mat_Ph" using hermitian_mat_Ph unfolding hermitian_def by auto
  moreover have "mat_Ph * mat_Ph = 1m N"
    apply (rule eq_matI)
    unfolding mat_Ph_def
      apply (simp add: scalar_prod_def)
    subgoal for i j apply (rule)
      subgoal apply (subst sum_only_one_neq_0[of "{0..<N}" "0"]) by auto
        apply (subst sum_only_one_neq_0[of "{0..<N}" "j"]) by auto
    by auto
  ultimately show ?thesis unfolding unitary_def inverts_mat_def by auto
qed

definition mat_G' :: "complex mat" where
  "mat_G' = mat N N (λ(i,j). if i = j then 2 / N - 1 else 2 / N)"

text ‹Geometrically, the Grover operator G is a rotation›
definition mat_G :: "complex mat" where
  "mat_G = mat_G' * mat_O"

end

subsection ‹State of Grover's algorithm›

text ‹The dimensions are [2, 2, ..., 2, n]. We work with a very special
  case as in the paper›
locale grover_state_sig = grover_state + state_sig +
  fixes R :: nat
  fixes K :: nat
  assumes dims_def: "dims = replicate n 2 @ [K]"
  assumes R: "R = pi / (2 * θ) - 1 / 2"
  assumes K: "K > R"

begin

lemma K_gt_0:
  "K > 0"
  using K by auto

text ‹Bits q0 to q\_(n-1)›
definition vars1 :: "nat set" where
  "vars1 = {0 ..< n}"

text ‹Bit r›
definition vars2 :: "nat set" where
  "vars2 = {n}"

lemma length_dims:
  "length dims = n + 1"
  unfolding dims_def by auto

lemma dims_nth_lt_n:
  "l < n  nth dims l = 2" 
  unfolding dims_def by (simp add: nth_append)

lemma nths_Suc_n_dims:
  "nths dims {0..<(Suc n)} = dims" 
  using length_dims nths_upt_eq_take
  by (metis add_Suc_right add_Suc_shift lessThan_atLeast0 less_add_eq_less less_numeral_extra(4)
            not_less plus_1_eq_Suc take_all)

interpretation ps2_P: partial_state2 dims vars1 vars2
   apply unfold_locales unfolding vars1_def vars2_def by auto

interpretation ps_P: partial_state ps2_P.dims0 ps2_P.vars1'.

abbreviation tensor_P where
"tensor_P A B  ps2_P.ptensor_mat A B"

lemma tensor_P_dim:
  "tensor_P A B  carrier_mat d d"
proof -
  have "ps2_P.d0 = prod_list (nths dims ({0..<n}  {n}))" unfolding ps2_P.d0_def ps2_P.dims0_def ps2_P.vars0_def 
    by (simp add: vars1_def vars2_def)
  also have " = prod_list (nths dims ({0..<Suc n}))"
    apply (subgoal_tac "{0..<n}  {n} = {0..<(Suc n)}") by auto
  also have " = prod_list dims" using nths_Suc_n_dims by auto
  also have " = d" unfolding d_def by auto
  finally show ?thesis  using ps2_P.ptensor_mat_carrier by auto
qed

lemma dims_nths_le_n:
  assumes "l  n"
  shows "nths dims {0..<l} = replicate l 2"
proof (rule nth_equalityI, auto)
  have "l  n  (i < Suc n  i < l) = (i < l)" for i
    using less_trans by fastforce
  then show l: "length (nths dims {0..<l}) = l" using assms
    by (auto simp add: length_nths length_dims)

  have llt: "l < length dims" using length_dims assms by auto
  have v1: "i. i < l  {a. a < i  a  {0..<l}} = {0..<i}" unfolding vars1_def by auto
  then have "i. i < l  card {j. j < i  j  {0..<l}} = i" by auto 
  then have "nths dims {0..<l} ! i = dims ! i" if "i < l" for i
    using nth_nths_card[of i dims "{0..<l}"] that llt by auto
  moreover have "dims ! i = replicate n 2 ! i" if "i < n" for i unfolding dims_def 
    by (auto simp add: nth_append that)
  moreover have "replicate n 2 ! i = replicate l 2 ! i" if "i < l" for i using assms that by auto
  ultimately show "nths dims {0..<l} ! i = replicate l 2 ! i" if "i < length (nths dims {0..<l})" for i
    using l that assms by auto 
qed

lemma dims_nths_one_lt_n: 
  assumes "l < n"
  shows "nths dims {l} = [2]"
proof -
  have "{i. i < length dims  i  {l}} = {l}" using assms length_dims by auto
  then have "nths dims {l} = [dims ! l]" using nths_only_one[of dims "{l}" l] by auto
  moreover have "dims ! l = 2" unfolding dims_def using assms by (simp add: nth_append)
  ultimately show ?thesis by auto
qed

lemma dims_vars1:
  "nths dims vars1 = replicate n 2"
proof (rule nth_equalityI, auto)
  show l: "length (nths dims vars1) = n"
    apply (auto simp add: length_nths vars1_def length_dims)
    by (metis (no_types, lifting) Collect_cong Suc_lessD card_Collect_less_nat not_less_eq)

  have v1: "i. i < n  {a. a < i  a  vars1} = {0..<i}" unfolding vars1_def by auto
  then have "i. i < n  card {j. j < i  j  vars1} = i" by auto 
  then have "nths dims vars1 ! i = dims ! i" if "i < n" for i
    using nth_nths_card[of i dims vars1] that length_dims vars1_def by auto
  moreover have "dims ! i = replicate n 2 ! i" if "i < n" for i unfolding dims_def 
    by (simp add: nth_append that)
  ultimately show "nths dims vars1 ! i = replicate n 2 ! i" if "i < length (nths dims vars1)" for i
    using l that by auto 
qed

lemma nths_rep_2_n:
  "nths (replicate n 2) {n} = []"
  by (metis (no_types, lifting) Collect_empty_eq card.empty length_0_conv length_replicate less_Suc_eq not_less_eq nths_replicate singletonD)

lemma dims_vars2:
  "nths dims vars2 = [K]"
  unfolding dims_def vars2_def
  apply (subst nths_append)
  apply (subst nths_rep_2_n)
  by simp

lemma d_vars1:
  "prod_list (nths dims vars1) = N"
proof -
  have eq: "{0..<n} = {..<n}"  by auto
  have "nths (replicate n 2 @ [K]) {0..<n} = (replicate n 2)"
    apply (subst eq)
    using nths_upt_eq_take by simp
  then show ?thesis unfolding dims_def vars1_def N_def by auto
qed

lemma ps2_P_dims0:
  "ps2_P.dims0 = dims"
proof -
  have "vars1  vars2 = {0..<Suc n}" unfolding vars1_def vars2_def by auto
  then have dims: "nths dims (vars1  vars2) = dims" unfolding vars1_def vars2_def using nths_Suc_n_dims by auto
  then show ?thesis unfolding ps2_P.dims0_def ps2_P.vars0_def apply (subst dims) by auto
qed

lemma ps2_P_vars1':
  "ps2_P.vars1' = vars1"
  unfolding ps2_P.vars1'_def ps2_P.vars0_def  
proof -
  have eq: "vars1  vars2 = {0..<(Suc n)}" unfolding vars1_def vars2_def by auto
  have "x < Suc n  {i  {0..<Suc n}. i < x} = {i. i < x}" for x by auto
  then have "x < Suc n  ind_in_set {0..<(Suc n)} x = x" for x unfolding ind_in_set_def by auto
  then have "x  vars1  ind_in_set {0..<(Suc n)} x = x" for x unfolding vars1_def by auto
  then have "ind_in_set {0..<(Suc n)} ` vars1 = vars1" by force
  with eq show "ind_in_set (vars1  vars2) ` vars1 = vars1" by auto
qed

lemma ps2_P_d0:
  "ps2_P.d0 = d"
  unfolding ps2_P.d0_def using ps2_P_dims0 d_def by auto

lemma ps2_P_d1:
  "ps2_P.d1 = N"
  unfolding ps2_P.d1_def ps2_P.dims1_def by (simp add: dims_vars1 N_def)

lemma ps2_P_d2:
  "ps2_P.d2 = K"
  unfolding ps2_P.d2_def ps2_P.dims2_def by (simp add: dims_vars2)

lemma ps_P_d:
  "ps_P.d = d"
  unfolding ps_P.d_def ps2_P_dims0 by auto

lemma ps_P_d1:
  "ps_P.d1 = N"
  unfolding ps_P.d1_def ps_P.dims1_def ps2_P.nths_vars1' using ps2_P_d1 unfolding ps2_P.d1_def by auto

lemma ps_P_d2:
  "ps_P.d2 = K"
  unfolding ps_P.d2_def ps_P.dims2_def ps2_P.nths_vars2' using ps2_P_d2 unfolding ps2_P.d2_def by auto

lemma nths_uminus_vars1:
  "nths dims (- vars1) = nths dims vars2"
  using ps2_P.nths_vars2' unfolding ps2_P_dims0 ps2_P_vars1' ps2_P.dims2_def by auto

lemma tensor_P_mult:
  assumes "m1  carrier_mat (2^n) (2^n)"
    and "m2  carrier_mat (2^n) (2^n)"
    and "m3  carrier_mat K K"
    and "m4  carrier_mat K K"
  shows "(tensor_P m1 m3) * (tensor_P m2 m4) = tensor_P (m1 * m2) (m3 * m4)"
proof -
  have eq:"{0..<n} = {..<n}" by auto
  have "(nths dims vars1) = replicate n 2"
    unfolding dims_def vars1_def apply (subst eq)
    by (simp add: nths_upt_eq_take[of "(replicate n 2 @ [K])" n]) 

  have "ps2_P.d1 = 2^n" unfolding ps2_P.d1_def ps2_P.dims1_def using d_vars1 N_def by auto
  moreover have "ps2_P.d2 = K" unfolding ps2_P.d2_def ps2_P.dims2_def using dims_vars2 by auto

  ultimately show ?thesis apply (subst ps2_P.ptensor_mat_mult) using assms by auto
qed

lemma mat_ext_vars1:
  shows "mat_extension dims vars1 A = tensor_P A (1m K)"
  unfolding Utrans_P_def ps2_P.ptensor_mat_def partial_state.mat_extension_def
    partial_state.d2_def partial_state.dims2_def ps2_P.nths_vars2'[simplified ps2_P_dims0 ps2_P_vars1'] 
  using ps2_P_d2 unfolding ps2_P.d2_def using ps2_P_dims0 ps2_P_vars1' by auto

lemma Utrans_P_is_tensor_P1:
  "Utrans_P vars1 A = Utrans (tensor_P A (1m K))"
  unfolding Utrans_P_def ps2_P.ptensor_mat_def partial_state.mat_extension_def
    partial_state.d2_def partial_state.dims2_def ps2_P.nths_vars2'[simplified ps2_P_dims0 ps2_P_vars1'] 
  using ps2_P_d2 unfolding ps2_P.d2_def using ps2_P_dims0 ps2_P_vars1' by auto

lemma nths_dims_uminus_vars2:
  "nths dims (-vars2) = nths dims vars1"
proof -
  have "nths dims (-vars2) = nths dims ({0..<length dims} - vars2)"
    using nths_minus_eq by auto
  also have " = nths dims vars1" unfolding vars1_def vars2_def length_dims
    apply (subgoal_tac "{0..<n + 1} - {n} = {0..<n}") by auto
  finally show ?thesis by auto
qed

lemma mat_ext_vars2:
  assumes "A  carrier_mat K K"
  shows "mat_extension dims vars2 A = tensor_P (1m N) A"
proof -
  have "mat_extension dims vars2 A = tensor_mat dims vars2 A (1m N)"
    unfolding Utrans_P_def partial_state.mat_extension_def
      partial_state.d2_def partial_state.dims2_def
      nths_dims_uminus_vars2 dims_vars1 N_def by auto
  also have " = tensor_mat dims vars1 (1m N) A" 
    apply (subst tensor_mat_comm[of vars1 vars2])
    subgoal unfolding vars1_def vars2_def by auto
    subgoal unfolding length_dims vars1_def vars2_def by auto
    subgoal unfolding dims_vars1 N_def by auto
    unfolding dims_vars2 using assms by auto
  finally show "mat_extension dims vars2 A = tensor_P (1m N) A"
    unfolding ps2_P.ptensor_mat_def ps2_P_dims0 ps2_P_vars1' by auto
qed

lemma Utrans_P_is_tensor_P2:
  assumes "A  carrier_mat K K"
  shows "Utrans_P vars2 A = Utrans (tensor_P (1m N) A)"
  unfolding Utrans_P_def using mat_ext_vars2 assms by auto


subsection ‹Grover's algorithm›

text ‹Apply hadamard operator to first n variables›
definition hadamard_on_i :: "nat  complex mat" where
  "hadamard_on_i i = pmat_extension dims {i} (vars1 - {i}) hadamard"
declare hadamard_on_i_def [simp]

fun hadamard_n :: "nat  com" where
  "hadamard_n 0 = SKIP"
| "hadamard_n (Suc i) = hadamard_n i ;; Utrans (tensor_P (hadamard_on_i i) (1m K))"

text ‹Body of the loop›
definition D :: com where
  "D = Utrans_P vars1 mat_O ;;
       hadamard_n n ;;
       Utrans_P vars1 mat_Ph ;;
       hadamard_n n ;;
       Utrans_P vars2 (mat_incr K)"

lemma unitary_ex_mat_O:
  "unitary (tensor_P mat_O (1m K))"
  unfolding ps2_P.ptensor_mat_def
  apply (subst ps_P.tensor_mat_unitary)
  subgoal using ps_P_d1 mat_O_def by auto
  subgoal using ps_P_d2 by auto
  subgoal using unitary_mat_O by auto
  using unitary_one by auto

lemma unitary_ex_mat_Ph:
  "unitary (tensor_P mat_Ph (1m K))"
  unfolding ps2_P.ptensor_mat_def
  apply (subst ps_P.tensor_mat_unitary)
  subgoal using ps_P_d1 mat_Ph_def by auto
  subgoal using ps_P_d2 by auto
  subgoal using unitary_mat_Ph by auto
  using unitary_one by auto

lemma unitary_hadamard_on_i:
  assumes "k < n"
  shows "unitary (hadamard_on_i k)"
proof -
  interpret st2: partial_state2 dims "{k}" "vars1 - {k}"
    apply unfold_locales by auto
  show ?thesis unfolding hadamard_on_i_def st2.pmat_extension_def st2.ptensor_mat_def
    apply (rule partial_state.tensor_mat_unitary)
    subgoal unfolding partial_state.d1_def partial_state.dims1_def st2.nths_vars1' st2.dims1_def
      using dims_nths_one_lt_n assms hadamard_dim by auto
    subgoal unfolding st2.d2_def st2.dims2_def partial_state.d2_def partial_state.dims2_def st2.nths_vars2' st2.dims1_def
      by auto
    subgoal using unitary_hadamard by auto
    subgoal using unitary_one by auto
    done
qed

lemma unitary_exhadamard_on_i:
  assumes "k < n"
  shows "unitary (tensor_P (hadamard_on_i k) (1m K))"
proof -
  interpret st2: partial_state2 dims "{k}" "vars1 - {k}"
    apply unfold_locales by auto
  have d1: "st2.d0 = partial_state.d1 ps2_P.dims0 ps2_P.vars1'"
    unfolding partial_state.d1_def partial_state.dims1_def ps2_P.nths_vars1' ps2_P.dims1_def
      st2.d0_def st2.dims0_def st2.vars0_def using assms
    apply (subgoal_tac "{k}  (vars1 - {k}) = vars1") apply simp
    unfolding vars1_def by auto
  show ?thesis
  unfolding ps2_P.ptensor_mat_def
  apply (rule partial_state.tensor_mat_unitary)
  subgoal unfolding hadamard_on_i_def st2.pmat_extension_def 
    using st2.ptensor_mat_carrier[of hadamard "1m st2.d2"]
    using d1 by auto
  subgoal unfolding partial_state.d2_def partial_state.dims2_def ps2_P.nths_vars2' ps2_P.dims2_def dims_vars2 by auto
  using unitary_hadamard_on_i unitary_one assms by auto
qed

lemma hadamard_on_i_dim:
  assumes "k < n"
  shows "hadamard_on_i k  carrier_mat N N"
proof -
  interpret st: partial_state2 dims "{k}" "(vars1 - {k})"
    apply unfold_locales by auto
  have vars1: "{k}  (vars1 - {k}) = vars1" unfolding vars1_def using assms by auto
  show ?thesis unfolding hadamard_on_i_def N_def using st.pmat_extension_carrier unfolding st.d0_def st.dims0_def st.vars0_def
    using vars1 dims_vars1 by auto
qed

lemma well_com_hadamard_k:
  "k  n  well_com (hadamard_n k)"
proof (induct k)
  case 0
  then show ?case by auto
next
  case (Suc n)
  then have "well_com (hadamard_n n)" by auto
  then show ?case unfolding hadamard_n.simps well_com.simps using tensor_P_dim unitary_exhadamard_on_i Suc by auto
qed

lemma well_com_hadamard_n:
  "well_com (hadamard_n n)"
  using well_com_hadamard_k by auto

lemma well_com_mat_O:
  "well_com (Utrans_P vars1 mat_O)"
  apply (subst Utrans_P_is_tensor_P1)
  apply simp using tensor_P_dim unitary_ex_mat_O by auto

lemma well_com_mat_Ph:
  "well_com (Utrans_P vars1 mat_Ph)"
  apply (subst Utrans_P_is_tensor_P1)
  apply simp using tensor_P_dim unitary_ex_mat_Ph by auto

lemma unitary_exmat_incr:
  "unitary (tensor_P (1m N) (mat_incr K))"
  unfolding ps2_P.ptensor_mat_def
  apply (subst ps_P.tensor_mat_unitary)
  using  unitary_mat_incr K unitary_one by (auto simp add: ps_P_d1 ps_P_d2 mat_incr_def)

lemma well_com_mat_incr:
  "well_com (Utrans_P vars2 (mat_incr K))"
  apply (subst Utrans_P_is_tensor_P2)
  apply (simp add: mat_incr_def) using tensor_P_dim unitary_exmat_incr by auto

lemma well_com_D: "well_com D"
  unfolding D_def apply auto
  using well_com_hadamard_n well_com_mat_incr well_com_mat_O well_com_mat_Ph 
  by auto

text ‹Test at while loop›

definition M0 :: "complex mat" where
  "M0 = mat K K (λ(i,j). if i = j  i  R then 1 else 0)"

lemma hermitian_M0:
  "hermitian M0"
  by (auto simp add: hermitian_def M0_def adjoint_eval)

lemma M0_dim:
  "M0  carrier_mat K K"
  unfolding M0_def by auto

lemma M0_mult_M0:
  "M0 * M0 = M0"
  by (auto simp add: M0_def scalar_prod_def sum_only_one_neq_0)

definition M1 :: "complex mat" where
  "M1 = mat K K (λ(i,j). if i = j  i < R then 1 else 0)"

lemma M1_dim:
  "M1  carrier_mat K K"
  unfolding M1_def by auto

lemma hermitian_M1:
  "hermitian M1"
  by (auto simp add: hermitian_def M1_def adjoint_eval)

lemma M1_mult_M1:
  "M1 * M1 = M1"
  by (auto simp add: M1_def scalar_prod_def sum_only_one_neq_0)

lemma M1_add_M0:
  "M1 + M0 = 1m K"
  unfolding M0_def M1_def by auto

text ‹Test at the end›

definition testN :: "nat  complex mat" where
  "testN k = mat N N (λ(i,j). if i = k  j = k then 1 else 0)"

lemma hermitian_testN:
  "hermitian (testN k)"
  unfolding hermitian_def testN_def
  by (auto simp add: scalar_prod_def adjoint_eval)

lemma testN_mult_testN:
  "testN k * testN k = testN k"
  unfolding testN_def
  by (auto simp add: scalar_prod_def sum_only_one_neq_0)

lemma testN_dim:
  "testN k  carrier_mat N N"
  unfolding testN_def by auto

definition test_fst_k :: "nat  complex mat" where
  "test_fst_k k = mat N N (λ(i, j). if (i = j  i < k) then 1 else 0)"

lemma sum_test_k:
  assumes "m  N"
  shows "matrix_sum N (λk. testN k) m = test_fst_k m"
proof -
  have "m  N  matrix_sum N (λk. testN k) m = mat N N (λ(i, j). if (i = j  i < m) then 1 else 0)" for m
  proof (induct m)
    case 0
    then show ?case apply simp apply (rule eq_matI) by auto
  next
    case (Suc m)
    then have m: "m < N" by auto
    then have m': "m  N" by auto
    have "matrix_sum N testN (Suc m) = testN m + matrix_sum N testN m" by simp
    also have " = mat N N (λ(i, j). if (i = j  i < (Suc m)) then 1 else 0)"
      unfolding testN_def Suc(1)[OF m'] apply (rule eq_matI) by auto
    finally show ?case by auto
  qed
  then show ?thesis unfolding test_fst_k_def using assms by auto
qed

lemma test_fst_kN:
  "test_fst_k N = 1m N"
  apply (rule eq_matI)
  unfolding test_fst_k_def by auto

lemma matrix_sum_tensor_P1:
  "(k. k < m  g k  carrier_mat N N)  (A  carrier_mat K K) 
   matrix_sum d (λk. tensor_P (g k) A) m = tensor_P (matrix_sum N g m) A"
proof (induct m)
  case 0
  show ?case apply (simp) unfolding ps2_P.ptensor_mat_def 
    using ps_P.tensor_mat_zero1[simplified ps_P_d ps_P_d1, of A] by auto
next
  case (Suc m)
  then have ind: "matrix_sum d (λk. tensor_P (g k) A) m = tensor_P (matrix_sum N g m) A" 
    and dk: "k. k < m  g k  carrier_mat N N" and "A  carrier_mat K K" by auto
  have ds: "matrix_sum N g m  carrier_mat N N" apply (subst matrix_sum_dim)
    using dk by auto
  show ?case apply simp
    apply (subst ind)
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_add1)
    unfolding ps_P_d1 ps_P_d2 using Suc ds by auto
qed

text ‹Grover's algorithm. Assume we start in the zero state›
definition Grover :: com where
  "Grover = hadamard_n n ;;
            While_P vars2 M0 M1 D ;;
            Measure_P vars1 N testN (replicate N SKIP)"

lemma well_com_if:
  "well_com (Measure_P vars1 N testN (replicate N SKIP))"
  unfolding Measure_P_def apply auto
proof -
  have eq0: "n. mat_extension dims vars1 (testN n) = tensor_P (testN n) (1m K)"
    unfolding mat_ext_vars1 by auto 
  have eq1: "adjoint (tensor_P (testN j) (1m K)) * tensor_P (testN j) (1m K) = tensor_P (testN j) (1m K)" for j
    unfolding ps2_P.ptensor_mat_def
    apply (subst ps_P.tensor_mat_adjoint)
      apply (auto simp add: ps_P_d1 ps_P_d2 testN_dim hermitian_testN[unfolded hermitian_def] hermitian_one[unfolded hermitian_def])
    apply (subst ps_P.tensor_mat_mult[symmetric])
    by (auto simp add: ps_P_d1 ps_P_d2 testN_dim testN_mult_testN)
  have "measurement d N (λn. tensor_P (testN n) (1m K))"
    unfolding measurement_def
    apply (simp add: tensor_P_dim)
    apply (subst eq1)
    apply (subst matrix_sum_tensor_P1)
      apply (auto simp add: testN_dim)
    apply (subst sum_test_k, simp)
    apply (subst test_fst_kN)
    unfolding ps2_P.ptensor_mat_def
    using ps_P.tensor_mat_id ps_P_d ps_P_d1 ps_P_d2 by auto
  then show "measurement d N (λn. mat_extension dims vars1 (testN n))" using eq0 by auto

  show "list_all well_com (replicate N SKIP)" 
    apply (subst list_all_length) by simp
qed

lemma well_com_while:
  "well_com (While_P vars2 M0 M1 D)"
  unfolding While_P_def apply auto
   apply (subst (1 2) mat_ext_vars2)
  apply (auto simp add: M1_dim M0_dim)
proof -
  have 2: "2 = Suc (Suc 0)" by auto
  have ad0: "adjoint (tensor_P (1m N) M0) = (tensor_P (1m N) M0)"
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_adjoint)
    unfolding ps_P_d1 ps_P_d2 by (auto simp add: M0_dim adjoint_one hermitian_M0[unfolded hermitian_def])
  have ad1: "adjoint (tensor_P (1m N) M1) = (tensor_P (1m N) M1)"
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_adjoint)
    unfolding ps_P_d1 ps_P_d2 by (auto simp add: M1_dim adjoint_one hermitian_M1[unfolded hermitian_def])
  have m0: "tensor_P (1m N) M0 * tensor_P (1m N) M0 = tensor_P (1m N) M0"
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_mult[symmetric])
    unfolding ps_P_d1 ps_P_d2 using M0_dim M0_mult_M0 by auto
  have m1: "tensor_P (1m N) M1 * tensor_P (1m N) M1 = tensor_P (1m N) M1"
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_mult[symmetric])
    unfolding ps_P_d1 ps_P_d2 using M1_dim M1_mult_M1 by auto
  have s: "tensor_P (1m N) M1 + tensor_P (1m N) M0 = 1m d"
    unfolding ps2_P.ptensor_mat_def apply (subst ps_P.tensor_mat_add2[symmetric])
    unfolding ps_P_d1 ps_P_d2 
    by (auto simp add: M1_dim M0_dim M1_add_M0 ps_P.tensor_mat_id[simplified ps_P_d1 ps_P_d2 ps_P_d])
  show "measurement d 2 (λn. if n = 0 then tensor_P (1m N) M0 else if n = 1 then tensor_P (1m N) M1 else undefined)"
    unfolding measurement_def apply (auto simp add: tensor_P_dim) apply (subst 2)
    apply (simp add: ad0 ad1 m0 m1)
    apply (subst assoc_add_mat[symmetric, of _ d d]) using tensor_P_dim s by auto
  show "well_com D" using well_com_D by auto
qed

lemma well_com_Grover:
  "well_com Grover"
  unfolding Grover_def apply auto
  using well_com_hadamard_n well_com_if well_com_while by auto

subsection ‹Correctness›

text ‹Pre-condition: assume in the zero state›

definition ket_pre :: "complex vec" where
  "ket_pre = Matrix.vec N (λk. if k = 0 then 1 else 0)"

lemma ket_pre_dim:
  "ket_pre  carrier_vec N" using ket_pre_def by auto

definition pre :: "complex mat" where
  "pre = proj ket_pre"

lemma pre_dim:
  "pre  carrier_mat N N"
  using pre_def ket_pre_def by auto

lemma norm_pre:
  "inner_prod ket_pre ket_pre = 1"
  unfolding ket_pre_def scalar_prod_def
  using sum_only_one_neq_0[of "{0..<N}" 0 "λi. (if i = 0 then 1 else 0) * cnj (if i = 0 then 1 else 0)"] by auto

lemma pre_trace:
  "trace pre = 1"
  unfolding pre_def
  apply (subst trace_outer_prod[of _ N])
  subgoal unfolding ket_pre_def by auto using norm_pre by auto

lemma positive_pre:
  "positive pre"
  using positive_same_outer_prod unfolding pre_def ket_pre_def by auto

lemma pre_le_one:
  "pre L 1m N"
  unfolding pre_def using outer_prod_le_one norm_pre ket_pre_def by auto

text ‹Post-condition: should be in a state i with f i = 1›

definition post :: "complex mat" where
  "post = mat N N (λ(i, j). if (i = j  f i) then 1 else 0)"

lemma post_dim:
  "post  carrier_mat N N"
  unfolding post_def by auto

lemma hermitian_post:
  "hermitian post"
  unfolding hermitian_def post_def
  by (auto simp add: adjoint_eval)

text ‹Hoare triples of initialization›

definition ket_zero :: "complex vec" where
  "ket_zero = Matrix.vec 2 (λk. if k = 0 then 1 else 0)"

lemma ket_zero_dim:
  "ket_zero  carrier_vec 2" unfolding ket_zero_def by auto

definition proj_zero where
  "proj_zero = proj ket_zero"

definition ket_one where
  "ket_one = Matrix.vec 2 (λk. if k = 1 then 1 else 0)"

definition proj_one where
  "proj_one = proj ket_one"

definition ket_plus where
  "ket_plus = Matrix.vec 2 (λk.1 / csqrt 2) "

lemma ket_plus_dim:
  "ket_plus  carrier_vec 2" unfolding ket_plus_def by auto

lemma ket_plus_eval [simp]:
  "i < 2  ket_plus $ i = 1 / csqrt 2"
  apply (simp only: ket_plus_def)
  using index_vec less_2_cases by force

lemma csqrt_2_sq [simp]:
  "complex_of_real (sqrt 2) * complex_of_real (sqrt 2) = 2"
  by (smt of_real_add of_real_hom.hom_one of_real_power one_add_one power2_eq_square real_sqrt_pow2)

lemma ket_plus_tensor_n:
  "partial_state.tensor_vec [2, 2] {0} ket_plus ket_plus = Matrix.vec 4 (λk. 1 / 2)"
  unfolding partial_state.tensor_vec_def state_sig.d_def
proof (rule eq_vecI, auto)
  fix i :: nat assume i: "i < 4"
  interpret st: partial_state "[2, 2]" "{0}" .
  have d1_eq: "st.d1 = 2"
    by (simp add: st.d1_def st.dims1_def nths_def)
  have "st.encode1 i < st.d1"
    by (simp add: st.d_def i)
  then have i1_lt: "st.encode1 i < 2"
    using d1_eq by auto
  have d2_eq: "st.d2 = 2"
    by (simp add: st.d2_def st.dims2_def nths_def)
  have "st.encode2 i < st.d2"
    by (simp add: st.d_def i)
  then have i2_lt: "st.encode2 i < 2"
    using d2_eq by auto
  show "ket_plus $ st.encode1 i * ket_plus $ st.encode2 i * 2 = 1"
    by (auto simp add: i1_lt i2_lt)
qed

definition proj_plus where
  "proj_plus = proj ket_plus"

lemma hadamard_on_zero:
  "hadamard *v ket_zero = ket_plus"
  unfolding hadamard_def ket_zero_def ket_plus_def mat_of_rows_list_def  
  apply (rule eq_vecI, auto simp add: scalar_prod_def)
  subgoal for i
    apply (drule less_2_cases)
    apply (drule disjE, auto)
    by (subst sum_le_2, auto)+.

fun exH_k :: "nat  complex mat" where
  "exH_k 0 = hadamard_on_i 0"
| "exH_k (Suc k) = exH_k k * hadamard_on_i (Suc k)"

fun H_k :: "nat  complex mat" where
  "H_k 0 = hadamard"
| "H_k (Suc k) = ptensor_mat dims {0..<Suc k} {Suc k} (H_k k) hadamard"

lemma H_k_dim:
  "k < n  H_k k  carrier_mat (2^(Suc k)) (2^(Suc k))"
proof (induct k)
  case 0
  then show ?case using hadamard_dim by auto
next
  case (Suc k)
  interpret st: partial_state2 dims "{0..<(Suc k)}" "{Suc k}"
    apply unfold_locales by auto
  have "Suc (Suc k)  n" using Suc by auto
  then have "nths dims ({0..<Suc (Suc k)}) = replicate (Suc (Suc k)) 2" using dims_nths_le_n by auto
  moreover have "prod_list (replicate l 2) = 2^l" for l by simp
  moreover have "{0..<Suc k}  {Suc k} = {0..<(Suc (Suc k))}" by auto
  ultimately have plssk: "prod_list (nths dims ({0..<Suc k}  {Suc k})) = 2^(Suc (Suc k))" by auto
  have "dim_col (H_k (Suc k)) = 2^(Suc (Suc k))" using st.ptensor_mat_dim_col unfolding st.d0_def st.dims0_def st.vars0_def using plssk by auto
  moreover have "dim_row (H_k (Suc k)) = 2^(Suc (Suc k))" using st.ptensor_mat_dim_row unfolding st.d0_def st.dims0_def st.vars0_def using plssk by auto
  ultimately show ?case by auto
qed

lemma exH_k_eq_H_k:
  "k < n  exH_k k = pmat_extension dims {0..<(Suc k)} {(Suc k)..<n} (H_k k)"
proof(induct k)
  case 0
  have "{(Suc 0)..<n} = vars1 - {0..<(Suc 0)}" using vars1_def by fastforce
  then show ?case unfolding exH_k.simps using vars1_def by auto
next
  case (Suc k)
  interpret st: partial_state2 dims "{0..<Suc k}" "{(Suc k)..<n}"
    apply unfold_locales by auto
  interpret st1: partial_state2 dims "{Suc k}" "{(Suc (Suc k))..<n}"
    apply unfold_locales by auto
  interpret st2: partial_state2 dims "{Suc k}" "vars1 - {Suc k}"
    apply unfold_locales by auto
  interpret st3: partial_state2 dims "{0..<Suc k}" "{Suc (Suc k)..<n}"
    apply unfold_locales by auto
  interpret st4: partial_state2 dims "{0..<Suc (Suc k)}" "{Suc (Suc k)..<n}"
    apply unfold_locales by auto

  from Suc have eq0: "exH_k (Suc k) 
    = (st.pmat_extension (H_k k)) * (st2.pmat_extension hadamard)" by auto
  have "vars1 - {0..<Suc k} = {(Suc k)..<n}" using vars1_def by auto

  then have eql1: "st.pmat_extension (H_k k) = st.ptensor_mat (H_k k) (1m st.d2)"
    using st.pmat_extension_def by auto

  from dims_nths_one_lt_n[OF Suc(2)] have st1d1: "st1.d1 = 2" unfolding st1.d1_def st1.dims1_def by fastforce
  have "{Suc k}  {Suc (Suc k)..<n} = {Suc k..<n}" using Suc by auto
  then have "st1.d0 = st.d2" unfolding st1.d0_def st1.dims0_def st1.vars0_def st.d2_def st.dims2_def by fastforce
  then have eql2: "st1.ptensor_mat (1m 2) (1m st1.d2) = 1m st.d2"
    using st1.ptensor_mat_id st1d1 by auto
  have eql3: "st.ptensor_mat (H_k k) (1m st.d2) = st.ptensor_mat (H_k k) (st1.ptensor_mat (1m 2) (1m st1.d2))"
    apply (subst eql2[symmetric]) by auto

  have eqr1: "(st2.pmat_extension hadamard) = st2.ptensor_mat hadamard (1m st2.d2)" using st2.pmat_extension_def by auto
  have splitset: "{0..<Suc k}  {Suc (Suc k)..<n} = vars1 - {Suc k}" unfolding vars1_def using Suc(2) by auto

  have Sksplit: "{Suc k}  {Suc (Suc k)..<n} = {Suc k..<n}" using Suc(2) by auto
  have Sksplit1: "{0..<Suc k}{Suc k} = {0..<Suc (Suc k)}" by auto
  have "st.ptensor_mat (H_k k) (st1.ptensor_mat (1m 2) (1m st1.d2)) 
    = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (ptensor_mat dims {0..<Suc k} {Suc k} (H_k k) (1m 2)) (1m st1.d2)"
    apply (subst ptensor_mat_assoc[symmetric, of "{0..<Suc k}" "{Suc k}" "{Suc (Suc k)..<n}" "H_k k" "1m 2" "1m st1.d2", simplified Sksplit])
    using Suc length_dims by auto
  also have " = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (ptensor_mat dims {Suc k} {0..<Suc k} (1m 2) (H_k k)) (1m st1.d2)"
    using ptensor_mat_comm[of "{0..<Suc k}" "{Suc k}"] by auto
  also have " = ptensor_mat dims {Suc k} ({0..<Suc k}  {Suc (Suc k)..<n})
                  (1m 2) 
                  (ptensor_mat dims {0..<Suc k} {Suc (Suc k)..<n} (H_k k) (1m st1.d2))"
    apply (subst sup_commute)
    apply (subst ptensor_mat_assoc[of "{Suc k}" "{0..<Suc k}" "{Suc (Suc k)..<n}" "(1m 2)" "H_k k" "1m st1.d2"])
    using Suc length_dims by auto
  finally have eql4: "st.pmat_extension (H_k k) 
    = st2.ptensor_mat (1m 2) (st3.ptensor_mat (H_k k) (1m st3.d2))" using eql1 eql3 splitset by auto

  have "st2.ptensor_mat (1m 2) (st3.ptensor_mat (H_k k) (1m st3.d2)) * st2.ptensor_mat hadamard (1m st2.d2)
        = st2.ptensor_mat ((1m 2)*hadamard) ((st3.ptensor_mat (H_k k) (1m st3.d2))*(1m st2.d2))"
    apply (rule st2.ptensor_mat_mult[symmetric, of "1m 2" "hadamard" "(st3.ptensor_mat (H_k k) (1m st3.d2))" "(1m st2.d2)"])
    subgoal unfolding st2.d1_def st2.dims1_def
      by (simp add: dims_nths_one_lt_n Suc(2))
    subgoal unfolding st2.d1_def st2.dims1_def
      apply (simp add: dims_nths_one_lt_n Suc(2)) using hadamard_dim by auto
    subgoal unfolding st2.d2_def[unfolded st2.dims2_def]
      using st3.ptensor_mat_dim_col[unfolded st3.d0_def st3.dims0_def st3.vars0_def, simplified splitset]
        st3.ptensor_mat_dim_row[unfolded st3.d0_def st3.dims0_def st3.vars0_def, simplified splitset] by auto
    by auto
  also have " = st2.ptensor_mat (hadamard) (st3.ptensor_mat (H_k k) (1m st3.d2))"
    unfolding st2.d2_def[unfolded st2.dims2_def]
    using hadamard_dim st3.ptensor_mat_dim_col[unfolded st3.d0_def st3.dims0_def st3.vars0_def, simplified splitset]
        st3.ptensor_mat_dim_row[unfolded st3.d0_def st3.dims0_def st3.vars0_def, simplified splitset] by auto
  also have " = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (ptensor_mat dims {Suc k} {0..<Suc k} hadamard (H_k k)) (1m st3.d2)"
    apply (subst ptensor_mat_assoc[symmetric, of "{Suc k}" "{0..<Suc k}" "{Suc (Suc k)..<n}" "hadamard" "H_k k" "1m st3.d2", simplified splitset]) 
    using Suc length_dims by auto
  also have " = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (H_k (Suc k)) (1m st3.d2)"
    using ptensor_mat_comm[of "{Suc k}"] Sksplit1 by auto
  also have " = ptensor_mat dims ({0..<Suc (Suc k)}) {Suc (Suc k)..<n} (H_k (Suc k)) (1m st3.d2)" using Sksplit1 by auto
  also have " = pmat_extension dims {0..<Suc (Suc k)} {Suc (Suc k)..<n} (H_k (Suc k))" 
    unfolding st4.pmat_extension_def by auto
  finally show ?case using eq0 eql4 eqr1 by auto
qed

lemma mult_exH_k_left:
  assumes "Suc k < n"
  shows "hadamard_on_i (Suc k) * exH_k k = exH_k (Suc k)"
proof -
  interpret st: partial_state2 dims "{0..<Suc k}" "{(Suc k)..<n}"
    apply unfold_locales by auto
  interpret st1: partial_state2 dims "{Suc k}" "{(Suc (Suc k))..<n}"
    apply unfold_locales by auto
  interpret st2: partial_state2 dims "{Suc k}" "vars1 - {Suc k}"
    apply unfold_locales by auto
  interpret st3: partial_state2 dims "{0..<Suc k}" "{Suc (Suc k)..<n}"
    apply unfold_locales by auto
  interpret st4: partial_state2 dims "{0..<Suc (Suc k)}" "{Suc (Suc k)..<n}"
    apply unfold_locales by auto

  from exH_k_eq_H_k assms have eq0: "exH_k (Suc k) 
    = (st.pmat_extension (H_k k)) * (st2.pmat_extension hadamard)" by auto
  have "vars1 - {0..<Suc k} = {(Suc k)..<n}" using vars1_def by auto

  then have eql1: "st.pmat_extension (H_k k) = st.ptensor_mat (H_k k) (1m st.d2)"
    using st.pmat_extension_def by auto

  from dims_nths_one_lt_n[OF assms] have st1d1: "st1.d1 = 2" unfolding st1.d1_def st1.dims1_def by fastforce
  have "{Suc k}  {Suc (Suc k)..<n} = {Suc k..<n}" using assms by auto
  then have "st1.d0 = st.d2" unfolding st1.d0_def st1.dims0_def st1.vars0_def st.d2_def st.dims2_def by fastforce
  then have eql2: "st1.ptensor_mat (1m 2) (1m st1.d2) = 1m st.d2"
    using st1.ptensor_mat_id st1d1 by auto
  have eql3: "st.ptensor_mat (H_k k) (1m st.d2) = st.ptensor_mat (H_k k) (st1.ptensor_mat (1m 2) (1m st1.d2))"
    apply (subst eql2[symmetric]) by auto

  have eqr1: "(st2.pmat_extension hadamard) = st2.ptensor_mat hadamard (1m st2.d2)" using st2.pmat_extension_def by auto
  have splitset: "{0..<Suc k}  {Suc (Suc k)..<n} = vars1 - {Suc k}" unfolding vars1_def using assms by auto

  have Sksplit: "{Suc k}  {Suc (Suc k)..<n} = {Suc k..<n}" using assms by auto
  have Sksplit1: "{0..<Suc k}{Suc k} = {0..<Suc (Suc k)}" by auto
  have "st.ptensor_mat (H_k k) (st1.ptensor_mat (1m 2) (1m st1.d2)) 
    = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (ptensor_mat dims {0..<Suc k} {Suc k} (H_k k) (1m 2)) (1m st1.d2)"
    apply (subst ptensor_mat_assoc[symmetric, of "{0..<Suc k}" "{Suc k}" "{Suc (Suc k)..<n}" "H_k k" "1m 2" "1m st1.d2", simplified Sksplit])
    using assms length_dims by auto
  also have " = ptensor_mat dims ({0..<Suc k}{Suc k}) {Suc (Suc k)..<n} (ptensor_mat dims {Suc k} {0..<Suc k} (1m 2) (H_k k)) (1m st1.d2