Theory Isabelle_Marries_Dirac.Quantum
section ‹Qubits and Quantum Gates›
theory Quantum
imports
  Jordan_Normal_Form.Matrix
  "HOL-Library.Nonpos_Ints"
  Basics
  Binary_Nat
begin
subsection ‹Qubits›
text‹In this theory @{text cpx} stands for @{text complex}.›
definition cpx_vec_length :: "complex vec ⇒ real" (‹∥_∥›) where
"cpx_vec_length v ≡ sqrt(∑i<dim_vec v. (cmod (v $ i))⇧2)"
lemma cpx_length_of_vec_of_list [simp]:
  "∥vec_of_list l∥ = sqrt(∑i<length l. (cmod (l ! i))⇧2)"
  by (auto simp: cpx_vec_length_def vec_of_list_def vec_of_list_index)
    (metis (no_types, lifting) dim_vec_of_list sum.cong vec_of_list.abs_eq vec_of_list_index)
lemma norm_vec_index_unit_vec_is_0 [simp]:
  assumes "j < n" and "j ≠ i"
  shows "cmod ((unit_vec n i) $ j) = 0"
  using assms by (simp add: unit_vec_def)
lemma norm_vec_index_unit_vec_is_1 [simp]:
  assumes "j < n" and "j = i"
  shows "cmod ((unit_vec n i) $ j) = 1"
proof -
  have f:"(unit_vec n i) $ j = 1"
    using assms by simp
  thus ?thesis
    by (simp add: f cmod_def) 
qed
lemma unit_cpx_vec_length [simp]:
  assumes "i < n"
  shows "∥unit_vec n i∥ = 1"
proof -
  have "(∑j<n. (cmod((unit_vec n i) $ j))⇧2) = (∑j<n. if j = i then 1 else 0)"
    using norm_vec_index_unit_vec_is_0 norm_vec_index_unit_vec_is_1
    by (smt (verit) lessThan_iff one_power2 sum.cong zero_power2) 
  also have "… = 1"
    using assms by simp
  finally have "sqrt (∑j<n. (cmod((unit_vec n i) $ j))⇧2) = 1" 
    by simp
  thus ?thesis
    using cpx_vec_length_def by simp
qed
lemma smult_vec_length [simp]:
  assumes "x ≥ 0"
  shows "∥complex_of_real(x) ⋅⇩v v∥ = x * ∥v∥"
proof-
  have "(λi::nat.(cmod (complex_of_real x * v $ i))⇧2) = (λi::nat. (cmod (v $ i))⇧2 * x⇧2)" 
    by (auto simp: norm_mult power_mult_distrib)
  then have "(∑i<dim_vec v. (cmod (complex_of_real x * v $ i))⇧2) = 
             (∑i<dim_vec v. (cmod (v $ i))⇧2 * x⇧2)" by meson
  moreover have "(∑i<dim_vec v. (cmod (v $ i))⇧2 * x⇧2) = x⇧2 * (∑i<dim_vec v. (cmod (v $ i))⇧2)"
    by (metis (no_types) mult.commute sum_distrib_right)
  moreover have "sqrt(x⇧2 * (∑i<dim_vec v. (cmod (v $ i))⇧2)) = 
                 sqrt(x⇧2) * sqrt (∑i<dim_vec v. (cmod (v $ i))⇧2)" 
    using real_sqrt_mult by blast
  ultimately show ?thesis
    by(simp add: cpx_vec_length_def assms)
qed
locale state =
  fixes n:: nat and v:: "complex mat"
  assumes is_column [simp]: "dim_col v = 1"
    and dim_row [simp]: "dim_row v = 2^n"
    and is_normal [simp]: "∥col v 0∥ = 1"
text‹ 
Below the natural number n codes for the dimension of the complex vector space whose elements of norm
1 we call states. 
›
lemma unit_vec_of_right_length_is_state [simp]:
  assumes "i < 2^n"
  shows "unit_vec (2^n) i ∈ {v| n v::complex vec. dim_vec v = 2^n ∧ ∥v∥ = 1}"
proof-
  have "dim_vec (unit_vec (2^n) i) = 2^n" 
    by simp
  moreover have "∥unit_vec (2^n) i∥ = 1"
    using assms by simp
  ultimately show ?thesis 
    by simp
qed
definition state_qbit :: "nat ⇒ complex vec set" where
"state_qbit n ≡ {v| v:: complex vec. dim_vec v = 2^n ∧ ∥v∥ = 1}"
lemma (in state) state_to_state_qbit [simp]:
  shows "col v 0 ∈ state_qbit n"
  using state_def state_qbit_def by simp
subsection "The Hermitian Conjugation"
text ‹The Hermitian conjugate of a complex matrix is the complex conjugate of its transpose. ›
definition dagger :: "complex mat ⇒ complex mat" (‹_⇧†›) where
  "M⇧† ≡ mat (dim_col M) (dim_row M) (λ(i,j). cnj(M $$ (j,i)))"
text ‹We introduce the type of complex square matrices.›
typedef cpx_sqr_mat = "{M | M::complex mat. square_mat M}"
proof-
  have "square_mat (1⇩m n)" for n
    using one_mat_def by simp
  thus ?thesis by blast
qed
definition cpx_sqr_mat_to_cpx_mat :: "cpx_sqr_mat => complex mat" where
"cpx_sqr_mat_to_cpx_mat M ≡ Rep_cpx_sqr_mat M"
text ‹
We introduce a coercion from the type of complex square matrices to the type of complex 
matrices.
›
declare [[coercion cpx_sqr_mat_to_cpx_mat]]
lemma dim_row_of_dagger [simp]:
  "dim_row (M⇧†) = dim_col M"
  using dagger_def by simp
lemma dim_col_of_dagger [simp]:
  "dim_col (M⇧†) = dim_row M"
  using dagger_def by simp
lemma col_of_dagger [simp]:
  assumes "j < dim_row M"
  shows "col (M⇧†) j = vec (dim_col M) (λi. cnj (M $$ (j,i)))"
  using assms col_def dagger_def by simp
lemma row_of_dagger [simp]:
  assumes "i < dim_col M"
  shows "row (M⇧†) i = vec (dim_row M) (λj. cnj (M $$ (j,i)))"
  using assms row_def dagger_def by simp
lemma dagger_of_dagger_is_id:
  fixes M :: "complex Matrix.mat"
  shows "(M⇧†)⇧† = M"
proof
  show "dim_row ((M⇧†)⇧†) = dim_row M" by simp
  show "dim_col ((M⇧†)⇧†) = dim_col M" by simp
  fix i j assume a0:"i < dim_row M" and a1:"j < dim_col M"
  then show "(M⇧†)⇧† $$ (i,j) = M $$ (i,j)"
  proof-
    show ?thesis
      using dagger_def a0 a1 by auto
  qed
qed
lemma dagger_of_sqr_is_sqr [simp]:
  "square_mat ((M::cpx_sqr_mat)⇧†)"
proof-
  have "square_mat M"
    using cpx_sqr_mat_to_cpx_mat_def Rep_cpx_sqr_mat by simp
  then have "dim_row M = dim_col M" by simp
  then have "dim_col (M⇧†) = dim_row (M⇧†)" by simp
  thus "square_mat (M⇧†)" by simp
qed
lemma dagger_of_id_is_id [simp]:
  "(1⇩m n)⇧† = 1⇩m n"
  using dagger_def one_mat_def by auto
subsection "Unitary Matrices and Quantum Gates"
definition unitary :: "complex mat ⇒ bool" where
"unitary M ≡ (M⇧†) * M = 1⇩m (dim_col M) ∧ M * (M⇧†) = 1⇩m (dim_row M)"
lemma id_is_unitary [simp]:
  "unitary (1⇩m n)"
  by (simp add: unitary_def)
locale gate =
  fixes n:: nat and A:: "complex mat"
  assumes dim_row [simp]: "dim_row A = 2^n"
    and square_mat [simp]: "square_mat A"
    and unitary [simp]: "unitary A"
text ‹
We prove that a quantum gate is invertible and its inverse is given by its Hermitian conjugate.
›
lemma mat_unitary_mat [intro]:
  assumes "unitary M"
  shows "inverts_mat M (M⇧†)"
  using assms by (simp add: unitary_def inverts_mat_def)
lemma unitary_mat_mat [intro]:
  assumes "unitary M"
  shows "inverts_mat (M⇧†) M"
  using assms by (simp add: unitary_def inverts_mat_def)
lemma (in gate) gate_is_inv:
  "invertible_mat A"
  using square_mat unitary invertible_mat_def by blast
subsection "Relations Between Complex Conjugation, Hermitian Conjugation, Transposition and Unitarity"
notation transpose_mat (‹(_⇧t)›)
lemma col_tranpose [simp]:
  assumes "dim_row M = n" and "i < n"
  shows "col (M⇧t) i = row M i"
proof
  show "dim_vec (col (M⇧t) i) = dim_vec (row M i)"
    by (simp add: row_def col_def transpose_mat_def)
next
  show "⋀j. j < dim_vec (row M i) ⟹ col M⇧t i $ j = row M i $ j"
    using assms by (simp add: transpose_mat_def)
qed
lemma row_transpose [simp]:
  assumes "dim_col M = n" and "i < n"
  shows "row (M⇧t) i = col M i"
  using assms by simp
definition cpx_mat_cnj :: "complex mat ⇒ complex mat" (‹(_⇧⋆)›) where
"cpx_mat_cnj M ≡ mat (dim_row M) (dim_col M) (λ(i,j). cnj (M $$ (i,j)))"
lemma cpx_mat_cnj_id [simp]:
  "(1⇩m n)⇧⋆ = 1⇩m n" 
  by (auto simp: cpx_mat_cnj_def)
lemma cpx_mat_cnj_cnj [simp]:
  "(M⇧⋆)⇧⋆ = M"
  by (auto simp: cpx_mat_cnj_def)
lemma dim_row_of_cjn_prod [simp]: 
  "dim_row ((M⇧⋆) * (N⇧⋆)) = dim_row M"
  by (simp add: cpx_mat_cnj_def)
lemma dim_col_of_cjn_prod [simp]: 
  "dim_col ((M⇧⋆) * (N⇧⋆)) = dim_col N"
  by (simp add: cpx_mat_cnj_def)
lemma cpx_mat_cnj_prod:
  assumes "dim_col M = dim_row N"
  shows "(M * N)⇧⋆ = (M⇧⋆) * (N⇧⋆)"
proof
  show "dim_row (M * N)⇧⋆ = dim_row ((M⇧⋆) * (N⇧⋆))" 
    by (simp add: cpx_mat_cnj_def)
next
  show "dim_col ((M * N)⇧⋆) = dim_col ((M⇧⋆) * (N⇧⋆))" 
    by (simp add: cpx_mat_cnj_def)
next 
  fix i j::nat
  assume a1:"i < dim_row ((M⇧⋆) * (N⇧⋆))" and a2:"j < dim_col ((M⇧⋆) * (N⇧⋆))"
  then have "(M * N)⇧⋆ $$ (i,j) = cnj (∑k<(dim_row N). M $$ (i,k) * N $$ (k,j))"
    using assms cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def 
dim_row_of_cjn_prod dim_col_of_cjn_prod
    by simp
  also have "… = (∑k<(dim_row N). cnj(M $$ (i,k)) * cnj(N $$ (k,j)))" by simp
  also have "((M⇧⋆) * (N⇧⋆)) $$ (i,j) = 
    (∑k<(dim_row N). cnj(M $$ (i,k)) * cnj(N $$ (k,j)))"
    using assms a1 a2 cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def
    by simp
  finally show "(M * N)⇧⋆ $$ (i, j) = ((M⇧⋆) * (N⇧⋆)) $$ (i, j)" by simp
qed
lemma transpose_of_prod:
  fixes M N::"complex Matrix.mat"
  assumes "dim_col M = dim_row N"
  shows "(M * N)⇧t = N⇧t * (M⇧t)"
proof
  fix i j::nat
  assume a0: "i < dim_row (N⇧t * (M⇧t))" and a1: "j < dim_col (N⇧t * (M⇧t))"  
  then have "(M * N)⇧t $$ (i,j) = (M * N) $$ (j,i)" by auto
  also have "... = (∑k<dim_row M⇧t.  M $$ (j,k) * N $$ (k,i))"
    using assms a0 a1 by auto
  also have "... = (∑k<dim_row M⇧t. N $$ (k,i) * M $$ (j,k))"
   by (simp add: semiring_normalization_rules(7))
  also have "... = (∑k<dim_row M⇧t. ((N⇧t) $$ (i,k)) * (M⇧t) $$ (k,j))" 
    using assms a0 a1 by auto
  finally show "((M * N)⇧t) $$ (i,j) = (N⇧t * (M⇧t)) $$ (i,j)" 
    using assms a0 a1 by auto
next
  show "dim_row ((M * N)⇧t) = dim_row (N⇧t * (M⇧t))" by auto
next
  show "dim_col ((M * N)⇧t) = dim_col (N⇧t * (M⇧t))" by auto
qed
lemma transpose_cnj_is_dagger [simp]:
  "(M⇧t)⇧⋆ = (M⇧†)"
proof
  show f1:"dim_row ((M⇧t)⇧⋆) = dim_row (M⇧†)"
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
  show f2:"dim_col ((M⇧t)⇧⋆) = dim_col (M⇧†)" 
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
  fix i j::nat
  assume "i < dim_row M⇧†" and "j < dim_col M⇧†"
  then show "M⇧t⇧⋆ $$ (i, j) = M⇧† $$ (i, j)" 
    by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
qed
lemma cnj_transpose_is_dagger [simp]:
  "(M⇧⋆)⇧t = (M⇧†)"
proof
  show "dim_row ((M⇧⋆)⇧t) = dim_row (M⇧†)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
  show "dim_col ((M⇧⋆)⇧t) = dim_col (M⇧†)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
  fix i j::nat
  assume "i < dim_row M⇧†" and "j < dim_col M⇧†"
  then show "M⇧⋆⇧t $$ (i, j) = M⇧† $$ (i, j)" 
    by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
qed
lemma dagger_of_transpose_is_cnj [simp]:
  "(M⇧t)⇧† = (M⇧⋆)"
  by (metis transpose_transpose transpose_cnj_is_dagger)
lemma dagger_of_prod:
  fixes M N::"complex Matrix.mat"
  assumes "dim_col M = dim_row N"
  shows "(M * N)⇧† = N⇧† * (M⇧†)"
proof-
  have "(M * N)⇧† = ((M * N)⇧⋆)⇧t" by auto
  also have "... = ((M⇧⋆) * (N⇧⋆))⇧t" using assms cpx_mat_cnj_prod by auto
  also have "... = (N⇧⋆)⇧t * ((M⇧⋆)⇧t)" using assms transpose_of_prod 
    by (metis cnj_transpose_is_dagger dim_col_of_dagger dim_row_of_dagger index_transpose_mat(2) index_transpose_mat(3))
  finally show "(M * N)⇧† = N⇧† * (M⇧†)" by auto
qed
text ‹The product of two quantum gates is a quantum gate.›
lemma prod_of_gate_is_gate: 
  assumes "gate n G1" and "gate n G2"
  shows "gate n (G1 * G2)"
proof
  show "dim_row (G1 * G2) = 2^n" using assms by (simp add: gate_def)
next
  show "square_mat (G1 * G2)" 
    using assms gate.dim_row gate.square_mat by simp
next
  show "unitary (G1 * G2)" 
  proof-
    have "((G1 * G2)⇧†) * (G1 * G2) = 1⇩m (dim_col (G1 * G2))" 
    proof-
      have f0: "G1 ∈ carrier_mat (2^n) (2^n) ∧ G2 ∈ carrier_mat (2^n) (2^n)
              ∧ G1⇧† ∈ carrier_mat (2^n) (2^n) ∧ G2⇧† ∈ carrier_mat (2^n) (2^n)
              ∧ G1 * G2 ∈ carrier_mat (2^n) (2^n)" 
        using assms gate.dim_row gate.square_mat by auto
      have "((G1 * G2)⇧†) * (G1 * G2) = ((G2⇧†) * (G1⇧†)) * (G1 * G2)" 
        using assms dagger_of_prod gate.dim_row gate.square_mat by simp
      also have "... = (G2⇧†) * ((G1⇧†) * (G1 * G2))" 
        using assms f0 by auto
      also have "... = (G2⇧†) * (((G1⇧†) * G1) * G2)" 
        using assms f0 f0 by auto
      also have "... = (G2⇧†) * ((1⇩m (dim_col G1)) * G2)" 
        using gate.unitary[of n G1] assms unitary_def[of G1] by simp
      also have "... = (G2⇧†) * ((1⇩m (dim_col G2)) * G2)" 
        using assms f0 by (metis carrier_matD(2))
      also have "... = (G2⇧†) * G2" 
        using f0 by (metis carrier_matD(2) left_mult_one_mat)
      finally show "((G1 * G2)⇧†) * (G1 * G2) = 1⇩m (dim_col (G1 * G2))" 
        using assms gate.unitary unitary_def by simp
    qed
    moreover have "(G1 * G2) * ((G1 * G2)⇧†) = 1⇩m (dim_row (G1 * G2))"
      using assms calculation
      by (smt (verit) carrier_matI dim_col_of_dagger dim_row_of_dagger gate.dim_row gate.square_mat index_mult_mat(2) index_mult_mat(3) 
          mat_mult_left_right_inverse square_mat.elims(2))
    ultimately show ?thesis using unitary_def by simp
  qed
qed
lemma left_inv_of_unitary_transpose [simp]:
  assumes "unitary U"
  shows "(U⇧t)⇧† * (U⇧t) =  1⇩m(dim_row U)"
proof -
  have "dim_col U = dim_row ((U⇧t)⇧⋆)" by simp
  then have "(U * ((U⇧t)⇧⋆))⇧⋆ = (U⇧⋆) * (U⇧t)"
    using cpx_mat_cnj_prod cpx_mat_cnj_cnj by presburger
  also have "… = (U⇧t)⇧† * (U⇧t)" by simp
  finally show ?thesis 
    using assms by (metis transpose_cnj_is_dagger cpx_mat_cnj_id unitary_def)
qed
lemma right_inv_of_unitary_transpose [simp]:
  assumes "unitary U"
  shows "U⇧t * ((U⇧t)⇧†) = 1⇩m(dim_col U)"
proof -
  have "dim_col ((U⇧t)⇧⋆) = dim_row U" by simp
  then have "U⇧t * ((U⇧t)⇧†) = (((U⇧t)⇧⋆ * U)⇧⋆)"
    using cpx_mat_cnj_cnj cpx_mat_cnj_prod dagger_of_transpose_is_cnj by presburger
  also have "… = (U⇧† * U)⇧⋆" by simp
  finally show ?thesis
    using assms by (metis cpx_mat_cnj_id unitary_def)
qed
lemma transpose_of_unitary_is_unitary [simp]:
  assumes "unitary U"
  shows "unitary (U⇧t)" 
  using unitary_def assms left_inv_of_unitary_transpose right_inv_of_unitary_transpose by simp
subsection "The Inner Product"
text ‹We introduce a coercion between complex vectors and (column) complex matrices.›
definition ket_vec :: "complex vec ⇒ complex mat" (‹|_⟩›) where
"|v⟩ ≡ mat (dim_vec v) 1 (λ(i,j). v $ i)"
lemma ket_vec_index [simp]:
  assumes "i < dim_vec v"
  shows "|v⟩ $$ (i,0) = v $ i"
  using assms ket_vec_def by simp
lemma ket_vec_col [simp]:
  "col |v⟩ 0 = v"
  by (auto simp: col_def ket_vec_def)
lemma smult_ket_vec [simp]:
  "|x ⋅⇩v v⟩ = x ⋅⇩m |v⟩"
  by (auto simp: ket_vec_def)
lemma smult_vec_length_bis [simp]:
  assumes "x ≥ 0"
  shows "∥col (complex_of_real(x) ⋅⇩m |v⟩) 0∥ = x * ∥v∥"
  using assms smult_ket_vec smult_vec_length ket_vec_col by metis
declare [[coercion ket_vec]]
definition row_vec :: "complex vec ⇒ complex mat" where
"row_vec v ≡ mat 1 (dim_vec v) (λ(i,j). v $ j)" 
definition bra_vec :: "complex vec ⇒ complex mat" where
"bra_vec v ≡ (row_vec v)⇧⋆"
lemma row_bra_vec [simp]:
  "row (bra_vec v) 0 = vec (dim_vec v) (λi. cnj(v $ i))"
  by (auto simp: row_def bra_vec_def cpx_mat_cnj_def row_vec_def)
text ‹We introduce a definition called @{term "bra"} to see a vector as a column matrix.›
definition bra :: "complex mat ⇒ complex mat" (‹⟨_|›) where
"⟨v| ≡ mat 1 (dim_row v) (λ(i,j). cnj(v $$ (j,i)))"
text ‹The relation between @{term "bra"}, @{term "bra_vec"} and @{term "ket_vec"} is given as follows.›
lemma bra_bra_vec [simp]:
  "bra (ket_vec v) = bra_vec v"
  by (auto simp: bra_def ket_vec_def bra_vec_def cpx_mat_cnj_def row_vec_def)
lemma row_bra [simp]:
  fixes v::"complex vec"
  shows "row ⟨v| 0 = vec (dim_vec v) (λi. cnj (v $ i))" by simp
text ‹We introduce the inner product of two complex vectors in @{text "ℂ⇧n"}.›
definition inner_prod :: "complex vec ⇒ complex vec ⇒ complex" (‹⟨_|_⟩›) where
"inner_prod u v ≡ ∑ i ∈ {0..< dim_vec v}. cnj(u $ i) * (v $ i)"
lemma inner_prod_with_row_bra_vec [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "⟨u|v⟩ = row (bra_vec u) 0 ∙ v"
  using assms inner_prod_def scalar_prod_def row_bra_vec index_vec
  by (smt (verit) lessThan_atLeast0 lessThan_iff sum.cong)
lemma inner_prod_with_row_bra_vec_col_ket_vec [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "⟨u|v⟩ = (row ⟨u| 0) ∙ (col |v⟩ 0)"
  using assms by (simp add: inner_prod_def scalar_prod_def)
lemma inner_prod_with_times_mat [simp]:
  assumes "dim_vec u = dim_vec v"
  shows "⟨u|v⟩ = (⟨u| * |v⟩) $$ (0,0)"
  using assms inner_prod_with_row_bra_vec_col_ket_vec 
  by (simp add: inner_prod_def times_mat_def ket_vec_def bra_def)
lemma orthogonal_unit_vec [simp]:
  assumes "i < n" and "j < n" and "i ≠ j"
  shows "⟨unit_vec n i|unit_vec n j⟩ = 0"
proof-
  have "⟨unit_vec n i|unit_vec n j⟩ = unit_vec n i ∙ unit_vec n j"
    using assms unit_vec_def inner_prod_def scalar_prod_def
    by (smt (verit) complex_cnj_zero index_unit_vec(3) index_vec inner_prod_with_row_bra_vec row_bra_vec 
        scalar_prod_right_unit)
  thus ?thesis
    using assms scalar_prod_def unit_vec_def by simp 
qed
text ‹We prove that our inner product is linear in its second argument.›
lemma vec_index_is_linear [simp]:
  assumes "dim_vec u = dim_vec v" and "j < dim_vec u"
  shows "(k ⋅⇩v u + l ⋅⇩v v) $ j = k * (u $ j) + l * (v $ j)"
  using assms vec_index_def smult_vec_def plus_vec_def by simp
lemma inner_prod_is_linear [simp]:
  fixes u::"complex vec" and v::"nat ⇒ complex vec" and l::"nat ⇒ complex"
  assumes "∀i∈{0, 1}. dim_vec u = dim_vec (v i)"
  shows "⟨u|l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1⟩ = (∑i≤1. l i * ⟨u|v i⟩)"
proof -
  have f1:"dim_vec (l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1) = dim_vec u"
    using assms by simp
  then have "⟨u|l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1⟩ = (∑i∈{0 ..< dim_vec u}. cnj (u $ i) * ((l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1) $ i))"
    by (simp add: inner_prod_def)
  also have "… = (∑i∈{0 ..< dim_vec u}. cnj (u $ i) * (l 0 * v 0 $ i + l 1 * v 1 $ i))"
    using assms by simp
  also have "… = l 0 * (∑i∈{0 ..< dim_vec u}. cnj(u $ i) * (v 0 $ i)) + l 1 * (∑i∈{0 ..< dim_vec u}. cnj(u $ i) * (v 1 $ i))"
    by (auto simp: algebra_simps)
      (simp add: sum.distrib sum_distrib_left)
  also have "… = l 0 * ⟨u|v 0⟩ + l 1 * ⟨u|v 1⟩"
    using assms inner_prod_def by auto
  finally show ?thesis by simp
qed
lemma inner_prod_cnj:
  assumes "dim_vec u = dim_vec v"
  shows "⟨v|u⟩ = cnj (⟨u|v⟩)"
  by (simp add: assms inner_prod_def algebra_simps)
lemma inner_prod_with_itself_Im [simp]:
  "Im (⟨u|u⟩) = 0"
  using inner_prod_cnj by (metis Reals_cnj_iff complex_is_Real_iff)
lemma inner_prod_with_itself_real [simp]:
  "⟨u|u⟩ ∈ ℝ"
  using inner_prod_with_itself_Im by (simp add: complex_is_Real_iff)
lemma inner_prod_with_itself_eq0 [simp]:
  assumes "u = 0⇩v (dim_vec u)"
  shows "⟨u|u⟩ = 0"
  using assms inner_prod_def zero_vec_def
  by (smt (verit) atLeastLessThan_iff complex_cnj_zero index_zero_vec(1) mult_zero_left sum.neutral)
lemma inner_prod_with_itself_Re:
  "Re (⟨u|u⟩) ≥ 0"
proof -
  have "Re (⟨u|u⟩) = (∑i<dim_vec u. Re (cnj(u $ i) * (u $ i)))"
    by (simp add: inner_prod_def lessThan_atLeast0)
  moreover have "… = (∑i<dim_vec u. (Re (u $ i))⇧2 + (Im (u $ i))⇧2)"
    using complex_mult_cnj
    by (metis (no_types, lifting) Re_complex_of_real semiring_normalization_rules(7))
  ultimately show "Re (⟨u|u⟩) ≥ 0" by (simp add: sum_nonneg)
qed
lemma inner_prod_with_itself_nonneg_reals:
  fixes u::"complex vec"
  shows "⟨u|u⟩ ∈ nonneg_Reals"
  using inner_prod_with_itself_real inner_prod_with_itself_Re complex_nonneg_Reals_iff 
inner_prod_with_itself_Im by auto
lemma inner_prod_with_itself_Re_non0:
  assumes "u ≠ 0⇩v (dim_vec u)"
  shows "Re (⟨u|u⟩) > 0"
proof -
  obtain i where a1:"i < dim_vec u" and "u $ i ≠ 0"
    using assms zero_vec_def by (metis dim_vec eq_vecI index_zero_vec(1))
  then have f1:"Re (cnj (u $ i) * (u $ i)) > 0"
    by (metis Re_complex_of_real complex_mult_cnj complex_neq_0 mult.commute)
  moreover have f2:"Re (⟨u|u⟩) = (∑i<dim_vec u. Re (cnj(u $ i) * (u $ i)))"
    using inner_prod_def by (simp add: lessThan_atLeast0)
  moreover have f3:"∀i<dim_vec u. Re (cnj(u $ i) * (u $ i)) ≥ 0"
    using complex_mult_cnj by simp
  ultimately show ?thesis
    using a1 inner_prod_def lessThan_iff
    by (metis (no_types, lifting) finite_lessThan sum_pos2)
qed
lemma inner_prod_with_itself_nonneg_reals_non0:
  assumes "u ≠ 0⇩v (dim_vec u)"
  shows "⟨u|u⟩ ≠ 0"
  using assms inner_prod_with_itself_Re_non0 by fastforce
lemma cpx_vec_length_inner_prod [simp]:
  "∥v∥⇧2 = ⟨v|v⟩"
proof -
  have "∥v∥⇧2 = (∑i<dim_vec v. (cmod (v $ i))⇧2)"
    using cpx_vec_length_def complex_of_real_def
    by (metis (no_types, lifting) real_sqrt_power real_sqrt_unique sum_nonneg zero_le_power2)
  also have "… = (∑i<dim_vec v. cnj (v $ i) * (v $ i))"
    using complex_norm_square mult.commute by (smt (verit) of_real_sum sum.cong)
  finally show ?thesis
    using inner_prod_def by (simp add: lessThan_atLeast0)
qed
lemma inner_prod_csqrt [simp]:
  "csqrt ⟨v|v⟩ = ∥v∥"
  using inner_prod_with_itself_Re inner_prod_with_itself_Im csqrt_of_real_nonneg cpx_vec_length_def
  by (metis (no_types, lifting) Re_complex_of_real cpx_vec_length_inner_prod real_sqrt_ge_0_iff 
      real_sqrt_unique sum_nonneg zero_le_power2)
subsection "Unitary Matrices and Length-Preservation"
subsubsection "Unitary Matrices are Length-Preserving"
text ‹The bra-vector @{text "⟨A * v|"} is given by @{text "⟨v| * A⇧†"}›
lemma dagger_of_ket_is_bra:
  fixes v:: "complex vec"
  shows "( |v⟩ )⇧† = ⟨v|"
  by (simp add: bra_def dagger_def ket_vec_def)
lemma bra_mat_on_vec:
  fixes v::"complex vec" and A::"complex mat"
  assumes "dim_col A = dim_vec v"
  shows "⟨A * v| = ⟨v| * (A⇧†)"
proof
  show "dim_row ⟨A * v| = dim_row (⟨v| * (A⇧†))"
    by (simp add: bra_def times_mat_def)
next
  show "dim_col ⟨A * v| = dim_col (⟨v| * (A⇧†))"
    by (simp add: bra_def times_mat_def)
next
  fix i j::nat
  assume a1:"i < dim_row (⟨v| * (A⇧†))" and a2:"j < dim_col (⟨v| * (A⇧†))" 
  then have "cnj((A * v) $$ (j,0)) = cnj (row A j ∙ v)"
    using bra_def times_mat_def ket_vec_col ket_vec_def by simp
  also have f7:"…= (∑i∈{0 ..< dim_vec v}. cnj(v $ i) * cnj(A $$ (j,i)))"
    by (simp add: row_def scalar_prod_def mult.commute) (use assms in auto)
  moreover have f8:"(row ⟨v| 0) ∙ (col (A⇧†) j) = 
    vec (dim_vec v) (λi. cnj (v $ i)) ∙ vec (dim_col A) (λi. cnj (A $$ (j,i)))"
    using a2 by simp 
  ultimately have "cnj((A * v) $$ (j,0)) = (row ⟨v| 0) ∙ (col (A⇧†) j)"
    using assms scalar_prod_def
    by (smt (verit) dim_vec index_vec lessThan_atLeast0 lessThan_iff sum.cong)
  then have "⟨A * v| $$ (0,j) = (⟨v| * (A⇧†)) $$ (0,j)"
    using bra_def times_mat_def a2 by simp
  thus "⟨A * |v⟩| $$ (i, j) = (⟨v| * (A⇧†)) $$ (i, j)" 
    using a1 by (simp add: times_mat_def bra_def)
qed
lemma mat_on_ket:
  fixes v:: "complex vec" and A:: "complex mat"
  assumes "dim_col A = dim_vec v"
  shows "A * |v⟩ = |col (A * v) 0⟩"
  using assms ket_vec_def by auto
lemma dagger_of_mat_on_ket:
  fixes v:: "complex vec" and A :: "complex mat"
  assumes "dim_col A = dim_vec v"
  shows "(A * |v⟩ )⇧† = ⟨v| * (A⇧†)"
  using assms by (metis bra_mat_on_vec dagger_of_ket_is_bra mat_on_ket)
definition col_fst :: "'a mat ⇒ 'a vec" where 
  "col_fst A = vec (dim_row A) (λ i. A $$ (i,0))"
lemma col_fst_is_col [simp]:
  "col_fst M = col M 0"
  by (simp add: col_def col_fst_def)
text ‹
We need to declare @{term "col_fst"} as a coercion from matrices to vectors in order to see a column 
matrix as a vector. 
›
declare 
  [[coercion_delete ket_vec]]
  [[coercion col_fst]]
lemma unit_vec_to_col:
  assumes "dim_col A = n" and "i < n"
  shows "col A i = A * |unit_vec n i⟩"
proof
  show "dim_vec (col A i) = dim_vec (A * |unit_vec n i⟩)"
    using col_def times_mat_def by simp
next
  fix j::nat
  assume "j < dim_vec (col_fst (A * |unit_vec n i⟩))"
  then show "col A i $ j = (A * |unit_vec n i⟩) $ j"
    using assms times_mat_def ket_vec_def
    by (smt (verit) col_fst_is_col dim_col dim_col_mat(1) index_col index_mult_mat(1) index_mult_mat(2) 
index_row(1) ket_vec_col less_numeral_extra(1) scalar_prod_right_unit)
qed
lemma mult_ket_vec_is_ket_vec_of_mult:
  fixes A::"complex mat" and v::"complex vec"
  assumes "dim_col A = dim_vec v"
  shows "|A * |v⟩ ⟩ = A * |v⟩"
  using assms ket_vec_def
  by (metis One_nat_def col_fst_is_col dim_col dim_col_mat(1) index_mult_mat(3) ket_vec_col less_Suc0 
mat_col_eqI)
lemma unitary_is_sq_length_preserving [simp]:
  assumes "unitary U" and "dim_vec v = dim_col U"
  shows "∥U * |v⟩∥⇧2 = ∥v∥⇧2"
proof -
  have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * (U⇧†) * |U * |v⟩⟩) $$ (0,0)"
    using assms(2) bra_mat_on_vec
    by (metis inner_prod_with_times_mat mult_ket_vec_is_ket_vec_of_mult)
  then have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * (U⇧†) * (U * |v⟩)) $$ (0,0)"
    using assms(2) mult_ket_vec_is_ket_vec_of_mult by simp
  moreover have f1:"dim_col ⟨|v⟩| = dim_vec v"
    using ket_vec_def bra_def by simp
  moreover have "dim_row (U⇧†) = dim_vec v"
    using assms(2) by simp
  ultimately have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * ((U⇧†) * U) * |v⟩) $$ (0,0)"
    using assoc_mult_mat
    by(smt (verit, ccfv_threshold) carrier_mat_triv dim_row_mat(1) dagger_def ket_vec_def mat_carrier times_mat_def)
  then have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * |v⟩) $$ (0,0)"
    using assms f1 unitary_def by simp
  thus ?thesis
    using cpx_vec_length_inner_prod by(metis Re_complex_of_real inner_prod_with_times_mat)
qed
lemma col_ket_vec [simp]:
  assumes "dim_col M = 1"
  shows "|col M 0⟩ = M"
  using eq_matI assms ket_vec_def by auto
lemma state_col_ket_vec:
  assumes "state 1 v"
  shows "state 1 |col v 0⟩"
  using assms by (simp add: state_def)
lemma col_ket_vec_index [simp]:
  assumes "i < dim_row v"
  shows "|col v 0⟩ $$ (i,0) = v $$ (i,0)"
  using assms ket_vec_def by (simp add: col_def)
lemma col_index_of_mat_col [simp]:
  assumes "dim_col v = 1" and "i < dim_row v"
  shows "col v 0 $ i = v $$ (i,0)"
  using assms by simp
lemma unitary_is_sq_length_preserving_bis [simp]:
  assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
  shows "∥col (U * v) 0∥⇧2 = ∥col v 0∥⇧2"
proof -
  have "dim_vec (col v 0) = dim_col U"
    using assms(2) by simp
  then have "∥col_fst (U * |col v 0⟩)∥⇧2 = ∥col v 0∥⇧2"
    using unitary_is_sq_length_preserving[of "U" "col v 0"] assms(1) by simp
  thus ?thesis
    using assms(3) by simp
qed
text ‹ 
A unitary matrix is length-preserving, i.e. it acts on a vector to produce another vector of the 
same length. 
›
lemma unitary_is_length_preserving_bis [simp]:
  fixes U::"complex mat" and v::"complex mat"
  assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
  shows "∥col (U * v) 0∥ = ∥col v 0∥"
  using assms unitary_is_sq_length_preserving_bis
  by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)
lemma unitary_is_length_preserving [simp]:
  fixes U:: "complex mat" and v:: "complex vec"
  assumes "unitary U" and "dim_vec v = dim_col U"
  shows "∥U * |v⟩∥ = ∥v∥"
  using assms unitary_is_sq_length_preserving
  by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)
subsubsection "Length-Preserving Matrices are Unitary"
lemma inverts_mat_sym:
  fixes A B:: "complex mat"
  assumes "inverts_mat A B" and "dim_row B = dim_col A" and "square_mat B"
  shows "inverts_mat B A"
proof-
  define n where d0:"n = dim_row B"
  have "A * B = 1⇩m (dim_row A)" using assms(1) inverts_mat_def by auto
  moreover have "dim_col B = dim_col (A * B)" using times_mat_def by simp
  ultimately have "dim_col B = dim_row A" by simp
  then have c0:"A ∈ carrier_mat n n" using assms(2,3) d0 by auto
  have c1:"B ∈ carrier_mat n n" using assms(3) d0 by auto
  have f0:"A * B = 1⇩m n" using inverts_mat_def c0 c1 assms(1) by auto
  have f1:"det B ≠ 0"
  proof
    assume "det B = 0"
    then have "∃v. v ∈ carrier_vec n ∧ v ≠ 0⇩v n ∧ B *⇩v v = 0⇩v n"
      using det_0_iff_vec_prod_zero assms(3) c1 by blast
    then obtain v where d1:"v ∈ carrier_vec n ∧ v ≠ 0⇩v n ∧ B *⇩v v = 0⇩v n" by auto
    then have d2:"dim_vec v = n" by simp
    have "B * |v⟩ = |0⇩v n⟩"
    proof
      show "dim_row (B * |v⟩) = dim_row |0⇩v n⟩" using ket_vec_def d0 by simp
    next
      show "dim_col (B * |v⟩) = dim_col |0⇩v n⟩" using ket_vec_def d0 by simp
    next
      fix i j assume "i < dim_row |0⇩v n⟩" and "j < dim_col |0⇩v n⟩"
      then have f2:"i < n ∧ j = 0" using ket_vec_def by simp
      moreover have "vec (dim_row B) (($) v) = v" using d0 d1 by auto
      moreover have "(B *⇩v v) $ i = (∑ia = 0..<dim_row B. row B i $ ia * v $ ia)"
        using d0 d2 f2 by (auto simp add: scalar_prod_def)
      ultimately show "(B * |v⟩) $$ (i, j) = |0⇩v n⟩ $$ (i, j)"
        using ket_vec_def d0 d1 times_mat_def mult_mat_vec_def by (auto simp add: scalar_prod_def)
    qed
    moreover have "|v⟩ ∈ carrier_mat n 1" using d2 ket_vec_def by simp
    ultimately have "(A * B) * |v⟩ = A * |0⇩v n⟩" using c0 c1 by simp
    then have f3:"|v⟩ = A * |0⇩v n⟩" using d2 f0 ket_vec_def by auto
    have "v = 0⇩v n"
    proof
      show "dim_vec v = dim_vec (0⇩v n)" using d2 by simp
    next
      fix i assume f4:"i < dim_vec (0⇩v n)"
      then have "|v⟩ $$ (i,0) = v $ i" using d2 ket_vec_def by simp
      moreover have "(A * |0⇩v n⟩) $$ (i, 0) = 0"
        using ket_vec_def times_mat_def scalar_prod_def f4 c0 by auto
      ultimately show "v $ i = 0⇩v n $ i" using f3 f4 by simp
    qed
    then show False using d1 by simp
  qed
  have f5:"adj_mat B ∈ carrier_mat n n ∧ B * adj_mat B = det B ⋅⇩m 1⇩m n" using c1 adj_mat by auto
  then have c2:"((1/det B) ⋅⇩m adj_mat B) ∈ carrier_mat n n" by simp
  have f6:"B * ((1/det B) ⋅⇩m adj_mat B) = 1⇩m n" using c1 f1 f5 mult_smult_distrib[of "B"] by auto
  then have "A = (A * B) * ((1/det B) ⋅⇩m adj_mat B)" using c0 c1 c2 by simp
  then have "A = (1/det B) ⋅⇩m adj_mat B" using f0 c2 by auto
  then show ?thesis using c0 c1 f6 inverts_mat_def by auto
qed
lemma sum_of_unit_vec_length:
  fixes i j n:: nat and c:: complex
  assumes "i < n" and "j < n" and "i ≠ j"
  shows "∥unit_vec n i + c ⋅⇩v unit_vec n j∥⇧2 = 1 + cnj(c) * c"
proof-
  define v where d0:"v = unit_vec n i + c ⋅⇩v unit_vec n j"
  have "∀k<n. v $ k = (if k = i then 1 else (if k = j then c else 0))"
    using d0 assms(1,2,3) by auto
  then have "∀k<n. cnj (v $ k) * v $ k = (if k = i then 1 else 0) + (if k = j then cnj(c) * c else 0)"
    using assms(3) by auto
  moreover have "∥v∥⇧2 = (∑k = 0..<n. cnj (v $ k) * v $ k)"
    using d0 assms cpx_vec_length_inner_prod inner_prod_def by simp
  ultimately show ?thesis
    using d0 assms by (auto simp add: sum.distrib)
qed
lemma sum_of_unit_vec_to_col:
  assumes "dim_col A = n" and "i < n" and "j < n"
  shows "col A i + c ⋅⇩v col A j = A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩"
proof
  show "dim_vec (col A i + c ⋅⇩v col A j) = dim_vec (col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩))"
    using assms(1) by auto
next
  fix k assume "k < dim_vec (col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩))"
  then have f0:"k < dim_row A" using assms(1) by auto
  have "(col A i + c ⋅⇩v col A j) $ k = A $$ (k, i) + c * A $$ (k, j)"
    using f0 assms(1-3) by auto
  moreover have "(∑x<n. A $$ (k, x) * ((if x = i then 1 else 0) + c * (if x = j then 1 else 0))) = 
                 (∑x<n. A $$ (k, x) * (if x = i then 1 else 0)) + 
                 (∑x<n. A $$ (k, x) * c * (if x = j then 1 else 0))"
    by (auto simp add: sum.distrib algebra_simps)
  moreover have "∀x<n. A $$ (k, x) * (if x = i then 1 else 0) = (if x = i then A $$ (k, x) else 0)"
    by simp
  moreover have "∀x<n. A $$ (k, x) * c * (if x = j then 1 else 0) = (if x = j then A $$ (k, x) * c else 0)"
    by simp
  ultimately show "(col A i + c ⋅⇩v col A j) $ k = col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩) $ k"
    using f0 assms(1-3) times_mat_def scalar_prod_def ket_vec_def by auto
qed
lemma inner_prod_is_sesquilinear:
  fixes u1 u2 v1 v2:: "complex vec" and c1 c2 c3 c4:: complex and n:: nat
  assumes "dim_vec u1 = n" and "dim_vec u2 = n" and "dim_vec v1 = n" and "dim_vec v2 = n"
  shows "⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|c3 ⋅⇩v v1 + c4 ⋅⇩v v2⟩ = cnj (c1) * c3 * ⟨u1|v1⟩ + cnj (c2) * c3 * ⟨u2|v1⟩ + 
                                                 cnj (c1) * c4 * ⟨u1|v2⟩ + cnj (c2) * c4 * ⟨u2|v2⟩"
proof-
  have "⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|c3 ⋅⇩v v1 + c4 ⋅⇩v v2⟩ = c3 * ⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|v1⟩ + c4 * ⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|v2⟩"
    using inner_prod_is_linear[of "c1 ⋅⇩v u1 + c2 ⋅⇩v u2" "λi. if i = 0 then v1 else v2" 
                                  "λi. if i = 0 then c3 else c4"] assms
    by simp
  also have "... = c3 * cnj(⟨v1|c1 ⋅⇩v u1 + c2 ⋅⇩v u2⟩) + c4 * cnj(⟨v2|c1 ⋅⇩v u1 + c2 ⋅⇩v u2⟩)"
    using assms inner_prod_cnj[of "v1" "c1 ⋅⇩v u1 + c2 ⋅⇩v u2"] inner_prod_cnj[of "v2" "c1 ⋅⇩v u1 + c2 ⋅⇩v u2"] 
    by simp
  also have "... = c3 * cnj(c1 * ⟨v1|u1⟩ + c2 * ⟨v1|u2⟩) + c4 * cnj(c1 * ⟨v2|u1⟩ + c2 * ⟨v2|u2⟩)"
    using inner_prod_is_linear[of "v1" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"] 
          inner_prod_is_linear[of "v2" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"] assms
    by simp
  also have "... = c3 * (cnj(c1) * ⟨u1|v1⟩ + cnj(c2) * ⟨u2|v1⟩) + 
                   c4 * (cnj(c1) * ⟨u1|v2⟩ + cnj(c2) * ⟨u2|v2⟩)"
    using inner_prod_cnj[of "v1" "u1"] inner_prod_cnj[of "v1" "u2"] 
          inner_prod_cnj[of "v2" "u1"] inner_prod_cnj[of "v2" "u2"] assms
    by simp
  finally show ?thesis
    by (auto simp add: algebra_simps)
qed
text ‹
A length-preserving matrix is unitary. So, unitary matrices are exactly the length-preserving
matrices.
›
lemma length_preserving_is_unitary:
  fixes U:: "complex mat"
  assumes "square_mat U" and "∀v::complex vec. dim_vec v = dim_col U ⟶ ∥U * |v⟩∥ = ∥v∥"
  shows "unitary U"
proof-
  define n where "n = dim_col U"
  then have c0:"U ∈ carrier_mat n n" using assms(1) by auto
  then have c1:"U⇧† ∈ carrier_mat n n" using assms(1) dagger_def by auto
  have f0:"(U⇧†) * U = 1⇩m (dim_col U)"
  proof
    show "dim_row (U⇧† * U) = dim_row (1⇩m (dim_col U))" using c0 by simp
  next
    show "dim_col (U⇧† * U) = dim_col (1⇩m (dim_col U))" using c0 by simp
  next
    fix i j assume "i < dim_row (1⇩m (dim_col U))" and "j < dim_col (1⇩m (dim_col U))"
    then have a0:"i < n ∧ j < n" using c0 by simp
    have f1:"⋀l. l<n ⟶ (∑k<n. cnj (U $$ (k, l)) * U $$ (k, l)) = 1"
    proof
      fix l assume a1:"l<n"
      define v::"complex vec" where d1:"v = unit_vec n l"
      have "∥col U l∥⇧2 = (∑k<n. cnj (U $$ (k, l)) * U $$ (k, l))"
        using c0 a1 cpx_vec_length_inner_prod inner_prod_def lessThan_atLeast0 by simp
      moreover have "∥col U l∥⇧2 = ∥v∥⇧2" using c0 d1 a1 assms(2) unit_vec_to_col by simp
      moreover have "∥v∥⇧2 = 1" using d1 a1 cpx_vec_length_inner_prod by simp
      ultimately show "(∑k<n. cnj (U $$ (k, l)) * U $$ (k, l)) = 1" by simp
    qed
    moreover have "i ≠ j ⟶ (∑k<n. cnj (U $$ (k, i)) * U $$ (k, j)) = 0"
    proof
      assume a2:"i ≠ j"
      define v1::"complex vec" where d1:"v1 = unit_vec n i + 1 ⋅⇩v unit_vec n j"
      define v2::"complex vec" where d2:"v2 = unit_vec n i + 𝗂 ⋅⇩v unit_vec n j"
      have "∥v1∥⇧2 = 1 + cnj 1 * 1" using d1 a0 a2 sum_of_unit_vec_length by blast
      then have "∥v1∥⇧2 = 2"
        by (metis complex_cnj_one cpx_vec_length_inner_prod mult.left_neutral of_real_eq_iff 
            of_real_numeral one_add_one)
      then have "∥U * |v1⟩∥⇧2 = 2" using c0 d1 assms(2) unit_vec_to_col by simp
      moreover have "col U i + 1 ⋅⇩v col U j = U * |v1⟩"
        using c0 d1 a0 sum_of_unit_vec_to_col by blast
      moreover have "col U i + 1 ⋅⇩v col U j = col U i + col U j" by simp
      ultimately have "⟨col U i + col U j|col U i + col U j⟩ = 2"
        using cpx_vec_length_inner_prod by (metis of_real_numeral)
      moreover have "⟨col U i + col U j|col U i + col U j⟩ = 
               ⟨col U i|col U i⟩ + ⟨col U j|col U i⟩ + ⟨col U i|col U j⟩ + ⟨col U j|col U j⟩"
        using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "1" "1" "1"]
        by simp
      ultimately have f2:"⟨col U j|col U i⟩ + ⟨col U i|col U j⟩ = 0"
        using c0 a0 f1 inner_prod_def lessThan_atLeast0 by simp
      have "∥v2∥⇧2 = 1 + cnj 𝗂 * 𝗂" using a0 a2 d2 sum_of_unit_vec_length by simp
      then have "∥v2∥⇧2 = 2"
        by (metis Re_complex_of_real complex_norm_square mult.commute norm_ii numeral_Bit0 
            numeral_One numeral_eq_one_iff of_real_numeral one_power2)
      moreover have "∥U * |v2⟩∥⇧2 = ∥v2∥⇧2" using c0 d2 assms(2) unit_vec_to_col by simp
      moreover have "⟨col U i + 𝗂 ⋅⇩v col U j|col U i + 𝗂 ⋅⇩v col U j⟩ = ∥U * |v2⟩∥⇧2"
        using c0 a0 d2 sum_of_unit_vec_to_col cpx_vec_length_inner_prod by auto
      moreover have "⟨col U i + 𝗂 ⋅⇩v col U j|col U i + 𝗂 ⋅⇩v col U j⟩ = 
                     ⟨col U i|col U i⟩ + (-𝗂) * ⟨col U j|col U i⟩ + 𝗂 * ⟨col U i|col U j⟩ + ⟨col U j|col U j⟩"
        using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "𝗂" "1" "𝗂"]
        by simp
      ultimately have "⟨col U j|col U i⟩ - ⟨col U i|col U j⟩ = 0"
        using c0 a0 f1 inner_prod_def lessThan_atLeast0 by auto
      then show "(∑k<n. cnj (U $$ (k, i)) * U $$ (k, j)) = 0"
        using c0 a0 f2 lessThan_atLeast0 inner_prod_def by auto
    qed
    ultimately show "(U⇧† * U) $$ (i, j) = 1⇩m (dim_col U) $$ (i, j)"
      using c0 assms(1) a0 one_mat_def dagger_def by auto
qed
  then have "(U⇧†) * U = 1⇩m n" using c0 by simp
  then have "inverts_mat (U⇧†) U" using c1 inverts_mat_def by auto
  then have "inverts_mat U (U⇧†)" using c0 c1 inverts_mat_sym by simp
  then have "U * (U⇧†) = 1⇩m (dim_row U)" using c0 inverts_mat_def by auto
  then show ?thesis using f0 unitary_def by simp
qed
lemma inner_prod_with_unitary_mat [simp]:
  assumes "unitary U" and "dim_vec u = dim_col U" and "dim_vec v = dim_col U"
  shows "⟨U * |u⟩|U * |v⟩⟩ = ⟨u|v⟩"
proof -
  have f1:"⟨U * |u⟩|U * |v⟩⟩ = (⟨|u⟩| * (U⇧†) * U * |v⟩) $$ (0,0)"
    using assms(2-3) bra_mat_on_vec mult_ket_vec_is_ket_vec_of_mult
    by (smt (verit, ccfv_threshold) assoc_mult_mat carrier_mat_triv col_fst_def dim_vec dim_col_of_dagger index_mult_mat(2) 
        index_mult_mat(3) inner_prod_with_times_mat ket_vec_def mat_carrier)
  moreover have f2:"⟨|u⟩| ∈ carrier_mat 1 (dim_vec v)"
    using bra_def ket_vec_def assms(2-3) by simp
  moreover have f3:"U⇧† ∈ carrier_mat (dim_col U) (dim_row U)"
    using dagger_def by simp
  ultimately have "⟨U * |u⟩|U * |v⟩⟩ = (⟨|u⟩| * (U⇧† * U) * |v⟩) $$ (0,0)"
    using assms(3) assoc_mult_mat by (metis carrier_mat_triv)
  also have "… = (⟨|u⟩| * |v⟩) $$ (0,0)"
    using assms(1) unitary_def
    by (simp add: assms(2) bra_def ket_vec_def)
  finally show ?thesis
    using assms(2-3) inner_prod_with_times_mat by presburger
qed
text ‹As a consequence we prove that columns and rows of a unitary matrix are orthonormal vectors.›
lemma unitary_unit_col [simp]:
  assumes "unitary U" and "dim_col U = n" and "i < n"
  shows "∥col U i∥ = 1"
  using assms unit_vec_to_col unitary_is_length_preserving by simp
lemma unitary_unit_row [simp]:
  assumes "unitary U" and "dim_row U = n" and "i < n"
  shows "∥row U i∥ = 1"
proof -
  have "row U i = col (U⇧t) i"
    using  assms(2-3) by simp
  thus ?thesis
    using assms transpose_of_unitary_is_unitary unitary_unit_col
    by (metis index_transpose_mat(3))
qed
lemma orthogonal_col_of_unitary [simp]:
  assumes "unitary U" and "dim_col U = n" and "i < n" and "j < n" and "i ≠ j"
  shows "⟨col U i|col U j⟩ = 0"
proof -
  have "⟨col U i|col U j⟩ = ⟨U * |unit_vec n i⟩| U * |unit_vec n j⟩⟩"
    using assms(2-4) unit_vec_to_col by simp
  also have "… = ⟨unit_vec n i |unit_vec n j⟩"
    using assms(1-2) inner_prod_with_unitary_mat index_unit_vec(3) by simp
  finally show ?thesis
    using assms(3-5) by simp
qed
lemma orthogonal_row_of_unitary [simp]:
  fixes U::"complex mat"
  assumes "unitary U" and "dim_row U = n" and "i < n" and "j < n" and "i ≠ j"
  shows "⟨row U i|row U j⟩ = 0"
  using assms orthogonal_col_of_unitary transpose_of_unitary_is_unitary col_transpose
  by (metis index_transpose_mat(3))
text‹
As a consequence, we prove that a quantum gate acting on a state of a system of n qubits give 
another state of that same system.
›
lemma gate_on_state_is_state [intro, simp]:
  assumes a1:"gate n A" and a2:"state n v"
  shows "state n (A * v)"
proof
  show "dim_row (A * v) = 2^n"
    using gate_def state_def a1 by simp
next
  show "dim_col (A * v) = 1"
    using state_def a2 by simp
next
  have "square_mat A"
    using a1 gate_def by simp
  then have "dim_col A = 2^n"
    using a1 gate.dim_row by simp
  then have "dim_col A = dim_row v"
    using a2 state.dim_row by simp
  then have "∥col (A * v) 0∥ = ∥col v 0∥"
    using unitary_is_length_preserving_bis assms gate_def state_def by simp
  thus"∥col (A * v) 0∥ = 1"
    using a2 state.is_normal by simp
qed
subsection ‹A Few Well-known Quantum Gates›
text ‹
Any unitary operation on n qubits can be implemented exactly by composing single qubits and
CNOT-gates (controlled-NOT gates). However, no straightforward method is known to implement these 
gates in a fashion which is resistant to errors. But, the Hadamard gate, the phase gate, the 
CNOT-gate and the @{text "π/8"} gate are also universal for quantum computations, i.e. any quantum circuit on 
n qubits can be approximated to an arbitrary accuracy by using only these gates, and these gates can 
be implemented in a fault-tolerant way. 
›
text ‹We introduce a coercion from real matrices to complex matrices.›
definition real_to_cpx_mat:: "real mat ⇒ complex mat" where
"real_to_cpx_mat A ≡ mat (dim_row A) (dim_col A) (λ(i,j). A $$ (i,j))"
text ‹Our first quantum gate: the identity matrix! Arguably, not a very interesting one though!›
definition Id :: "nat ⇒ complex mat" where
"Id n ≡ 1⇩m (2^n)"
lemma id_is_gate [simp]:
  "gate n (Id n)"
proof
  show "dim_row (Id n) = 2^n"
    using Id_def by simp
next
  show "square_mat (Id n)"
    using Id_def by simp
next
  show "unitary (Id n)" 
    by (simp add: Id_def)
qed
text ‹More interesting: the Pauli matrices.›
definition X ::"complex mat" where
"X ≡ mat 2 2 (λ(i,j). if i=j then 0 else 1)"
text‹ 
Be aware that @{text "gate n A"} means that the matrix A has dimension @{text "2^n * 2^n"}. 
For instance, with this convention a 2 X 2 matrix A which is unitary satisfies @{text "gate 1 A"}
 but not @{text "gate 2 A"} as one might have been expected.
›
lemma dagger_of_X [simp]:
  "X⇧† = X"
  using dagger_def by (simp add: X_def cong_mat)
lemma X_inv [simp]:
  "X * X = 1⇩m 2"
  apply(simp add: X_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)
lemma X_is_gate [simp]:
  "gate 1 X"
  by (simp add: gate_def unitary_def)
    (simp add: X_def)
definition Y ::"complex mat" where
"Y ≡ mat 2 2 (λ(i,j). if i=j then 0 else (if i=0 then -𝗂 else 𝗂))"
lemma dagger_of_Y [simp]:
  "Y⇧† = Y"
  using dagger_def by (simp add: Y_def cong_mat)
lemma Y_inv [simp]:
  "Y * Y = 1⇩m 2"
  apply(simp add: Y_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)
lemma Y_is_gate [simp]:
  "gate 1 Y"
  by (simp add: gate_def unitary_def)
    (simp add: Y_def)
definition Z ::"complex mat" where
"Z ≡ mat 2 2 (λ(i,j). if i≠j then 0 else (if i=0 then 1 else -1))"
lemma dagger_of_Z [simp]:
  "Z⇧† = Z"
  using dagger_def by (simp add: Z_def cong_mat)
lemma Z_inv [simp]:
  "Z * Z = 1⇩m 2"
  apply(simp add: Z_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)
lemma Z_is_gate [simp]:
  "gate 1 Z"
  by (simp add: gate_def unitary_def)
    (simp add: Z_def)
text ‹The Hadamard gate›
definition H ::"complex mat" where
"H ≡ 1/sqrt(2) ⋅⇩m (mat 2 2 (λ(i,j). if i≠j then 1 else (if i=0 then 1 else -1)))"
lemma H_without_scalar_prod:
  "H = mat 2 2 (λ(i,j). if i≠j then 1/sqrt(2) else (if i=0 then 1/sqrt(2) else -(1/sqrt(2))))"
  using cong_mat by (auto simp: H_def)
lemma dagger_of_H [simp]:
  "H⇧† = H"
  using dagger_def by (auto simp: H_def cong_mat)
lemma H_inv [simp]:
  "H * H = 1⇩m 2"
  apply(simp add: H_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def complex_eqI)
lemma H_is_gate [simp]:
  "gate 1 H"
  by (simp add: gate_def unitary_def)
    (simp add: H_def)
lemma H_values:
  fixes i j:: nat
  assumes "i < dim_row H" and "j < dim_col H" and "i ≠ 1 ∨ j ≠ 1" 
  shows "H $$ (i,j) = 1/sqrt 2"
proof-
  have "i < 2"
    using assms(1) by (simp add: H_without_scalar_prod less_2_cases)
  moreover have "j < 2"
    using assms(2) by (simp add: H_without_scalar_prod less_2_cases)
  ultimately show ?thesis 
    using assms(3) H_without_scalar_prod by (smt (verit) One_nat_def index_mat(1) less_2_cases old.prod.case)
qed
lemma H_values_right_bottom:
  fixes i j:: nat
  assumes "i = 1 ∧ j = 1"
  shows "H $$ (i,j) = - 1/sqrt 2"     
  using assms by (simp add: H_without_scalar_prod)
text ‹The controlled-NOT gate›
definition CNOT ::"complex mat" where
"CNOT ≡ mat 4 4 
  (λ(i,j). if i=0 ∧ j=0 then 1 else 
    (if i=1 ∧ j=1 then 1 else 
      (if i=2 ∧ j=3 then 1 else 
        (if i=3 ∧ j=2 then 1 else 0))))"
lemma dagger_of_CNOT [simp]:
  "CNOT⇧† = CNOT"
  using dagger_def by (simp add: CNOT_def cong_mat)
lemma CNOT_inv [simp]:
  "CNOT * CNOT = 1⇩m 4"
  apply(simp add: CNOT_def times_mat_def one_mat_def)
  apply(rule cong_mat)
  by(auto simp: scalar_prod_def)
lemma CNOT_is_gate [simp]:
  "gate 2 CNOT"
  by (simp add: gate_def unitary_def)
    (simp add: CNOT_def)
text ‹The phase gate, also known as the S-gate›
definition S ::"complex mat" where
"S ≡ mat 2 2 (λ(i,j). if i=0 ∧ j=0 then 1 else (if i=1 ∧ j=1 then 𝗂 else 0))"
text ‹The @{text "π/8"} gate, also known as the T-gate›
definition T ::"complex mat" where
"T ≡ mat 2 2 (λ(i,j). if i=0 ∧ j=0 then 1 else (if i=1 ∧ j=1 then exp(𝗂*(pi/4)) else 0))"
text ‹A few relations between the Hadamard gate and the Pauli matrices›
lemma HXH_is_Z [simp]:
  "H * X * H = Z" 
  apply(simp add: X_def Z_def H_def times_mat_def)
  apply(rule cong_mat)
  by(auto simp add: scalar_prod_def complex_eqI)
lemma HYH_is_minusY [simp]:
  "H * Y * H = - Y" 
  apply(simp add: Y_def H_def times_mat_def)
  apply(rule eq_matI)
  by(auto simp add: scalar_prod_def complex_eqI)
lemma HZH_is_X [simp]:
  shows "H * Z * H = X"  
  apply(simp add: X_def Z_def H_def times_mat_def)
  apply(rule cong_mat)
  by(auto simp add: scalar_prod_def complex_eqI)
subsection ‹The Bell States›
text ‹
We introduce below the so-called Bell states, also known as EPR pairs (EPR stands for Einstein,
Podolsky and Rosen).
›
definition bell00 ::"complex mat" (‹|β⇩0⇩0⟩›) where
"bell00 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=0 ∨ i=3 then 1 else 0)⟩"
definition bell01 ::"complex mat" (‹|β⇩0⇩1⟩›) where
"bell01 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=1 ∨ i=2 then 1 else 0)⟩"
definition bell10 ::"complex mat" (‹|β⇩1⇩0⟩›) where
"bell10 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=0 then 1 else if i=3 then -1 else 0)⟩"
definition bell11 ::"complex mat" (‹|β⇩1⇩1⟩›) where
"bell11 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=1 then 1 else if i=2 then -1 else 0)⟩"
lemma
  shows bell00_is_state [simp]:"state 2 |β⇩0⇩0⟩" and bell01_is_state [simp]:"state 2 |β⇩0⇩1⟩" and 
    bell10_is_state [simp]:"state 2 |β⇩1⇩0⟩" and bell11_is_state [simp]:"state 2 |β⇩1⇩1⟩"
  by (auto simp: state_def bell00_def bell01_def bell10_def bell11_def ket_vec_def)
    (auto simp: cpx_vec_length_def Set_Interval.lessThan_atLeast0 cmod_def power2_eq_square) 
lemma bell00_index [simp]:
  shows "|β⇩0⇩0⟩ $$ (0,0) = 1/sqrt 2" and "|β⇩0⇩0⟩ $$ (1,0) = 0" and "|β⇩0⇩0⟩ $$ (2,0) = 0" and 
    "|β⇩0⇩0⟩ $$ (3,0) = 1/sqrt 2"
  by (auto simp: bell00_def ket_vec_def)
lemma bell01_index [simp]:
  shows "|β⇩0⇩1⟩ $$ (0,0) = 0" and "|β⇩0⇩1⟩ $$ (1,0) = 1/sqrt 2" and "|β⇩0⇩1⟩ $$ (2,0) = 1/sqrt 2" and 
    "|β⇩0⇩1⟩ $$ (3,0) = 0"
  by (auto simp: bell01_def ket_vec_def)
lemma bell10_index [simp]:
  shows "|β⇩1⇩0⟩ $$ (0,0) = 1/sqrt 2" and "|β⇩1⇩0⟩ $$ (1,0) = 0" and "|β⇩1⇩0⟩ $$ (2,0) = 0" and 
    "|β⇩1⇩0⟩ $$ (3,0) = - 1/sqrt 2"
  by (auto simp: bell10_def ket_vec_def)
lemma bell_11_index [simp]:
  shows "|β⇩1⇩1⟩ $$ (0,0) = 0" and "|β⇩1⇩1⟩ $$ (1,0) = 1/sqrt 2" and "|β⇩1⇩1⟩ $$ (2,0) = - 1/sqrt 2" and 
    "|β⇩1⇩1⟩ $$ (3,0) = 0"
  by (auto simp: bell11_def ket_vec_def)
subsection ‹The Bitwise Inner Product›
definition bitwise_inner_prod:: "nat ⇒ nat ⇒ nat ⇒ nat" where 
"bitwise_inner_prod n i j = (∑k∈{0..<n}. (bin_rep n i) ! k * (bin_rep n j) ! k)"
abbreviation bip:: "nat ⇒ nat ⇒ nat ⇒ nat" (‹_ ⋅⇘_⇙  _›) where
"bip i n j ≡ bitwise_inner_prod n i j"
lemma bitwise_inner_prod_fst_el_0: 
  assumes "i < 2^n ∨ j < 2^n" 
  shows "(i ⋅⇘Suc n⇙ j) = (i mod 2^n) ⋅⇘n⇙ (j mod 2^n)" 
proof-
  have "bip i (Suc n) j = (∑k∈{0..<(Suc n)}. (bin_rep (Suc n) i) ! k * (bin_rep (Suc n) j) ! k)" 
    using bitwise_inner_prod_def by simp
  also have "... = bin_rep (Suc n) i ! 0 * bin_rep (Suc n) j ! 0 + 
             (∑k∈{1..<(Suc n)}. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k)"
    by (simp add: sum.atLeast_Suc_lessThan)
  also have "... = (∑k∈{1..<(Suc n)}. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k)"
    using bin_rep_index_0[of i n] bin_rep_index_0[of j n] assms by auto
  also have "... = (∑k∈{0..<n}. bin_rep (Suc n) i !(k+1) * bin_rep (Suc n) j ! (k+1))" 
     using sum.shift_bounds_Suc_ivl[of "λk. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k" "0" "n"] 
     by (metis (no_types, lifting) One_nat_def add.commute plus_1_eq_Suc sum.cong)
  finally have "bip i (Suc n) j = (∑k∈{0..<n}. bin_rep (Suc n) i ! (k+1) * bin_rep (Suc n) j ! (k+1))" 
    by simp
  moreover have "k∈{0..n} ⟶ bin_rep (Suc n) i ! (k+1) = bin_rep n (i mod 2^n) ! k" for k
    using bin_rep_def by (simp add: bin_rep_aux_neq_nil)
  moreover have "k∈{0..n} ⟶ bin_rep (Suc n) j !(k+1) = bin_rep n (j mod 2^n) ! k" for k 
    using bin_rep_def by (simp add: bin_rep_aux_neq_nil)
  ultimately show ?thesis
    using assms bin_rep_index_0 bitwise_inner_prod_def by simp
qed
lemma bitwise_inner_prod_fst_el_is_1:
  fixes n i j:: nat
  assumes "i ≥ 2^n ∧ j ≥ 2^n" and "i < 2^(n+1) ∧ j < 2^(n+1)"
  shows "(i ⋅⇘(n+1)⇙ j) = 1 + ((i mod 2^n) ⋅⇘n⇙ (j mod 2^n))" 
proof-
  have "bip i (Suc n) j = (∑k∈{0..<(Suc n)}. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k)" 
    using bitwise_inner_prod_def by simp
  also have "... = bin_rep (Suc n) i ! 0 * bin_rep (Suc n) j ! 0 + 
            (∑k∈{1..<(Suc n)}. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k)"
    by (simp add: sum.atLeast_Suc_lessThan)
  also have "... = 1 + (∑k∈{1..<(Suc n)}. bin_rep (Suc n) i ! k * bin_rep (Suc n) j ! k)"
    using bin_rep_index_0_geq[of n i] bin_rep_index_0_geq[of n j] assms by simp
  also have "... = 1 + (∑k ∈ {0..<n}. bin_rep (Suc n) i ! (k+1) * bin_rep (Suc n) j ! (k+1))" 
    using sum.shift_bounds_Suc_ivl[of "λk. (bin_rep (Suc n) i)!k * (bin_rep (Suc n) j)!k" "0" "n"] 
    by (metis (no_types, lifting) One_nat_def Suc_eq_plus1 sum.cong)
  finally have f0:"bip i (Suc n) j = 1 + (∑k∈{0..<n}. bin_rep (Suc n) i ! (k+1) * bin_rep (Suc n) j ! (k+1))"
    by simp
  moreover have "k∈{0..n} ⟶ bin_rep (Suc n) i ! (k+1) = bin_rep n (i mod 2^n) ! k
∧ bin_rep (Suc n) j ! (k+1) = bin_rep n (j mod 2^n) ! k" for k
    using bin_rep_def by(metis Suc_eq_plus1 bin_rep_aux.simps(2) bin_rep_aux_neq_nil butlast.simps(2) nth_Cons_Suc)
  ultimately show ?thesis
    using bitwise_inner_prod_def by simp
qed
lemma bitwise_inner_prod_with_zero:
  assumes "m < 2^n"
  shows "(0 ⋅⇘n⇙  m) = 0" 
proof-
  have "(0 ⋅⇘n⇙  m) = (∑j∈{0..<n}. bin_rep n 0 ! j * bin_rep n m ! j)" 
    using bitwise_inner_prod_def by simp
  moreover have "(∑j∈{0..<n}. bin_rep n 0 ! j * bin_rep n m ! j) 
               = (∑j∈{0..<n}. 0 * (bin_rep n m) ! j)"
    by (simp add: bin_rep_index)
  ultimately show "?thesis" 
    by simp
qed
end