# Theory Quantum

```(*
Authors:

Anthony Bordg, University of Cambridge, apdb3@cam.ac.uk
Yijun He, University of Cambridge, yh403@cam.ac.uk
with contributions by Hanna Lachnitt
*)

section ‹Qubits and Quantum Gates›

theory Quantum
imports
Jordan_Normal_Form.Matrix
"HOL-Library.Nonpos_Ints"
Basics
Binary_Nat
begin

subsection ‹Qubits›

text‹In this theory @{text cpx} stands for @{text complex}.›

definition cpx_vec_length :: "complex vec ⇒ real" ("∥_∥") where
"cpx_vec_length v ≡ sqrt(∑i<dim_vec v. (cmod (v \$ i))⇧2)"

lemma cpx_length_of_vec_of_list [simp]:
"∥vec_of_list l∥ = sqrt(∑i<length l. (cmod (l ! i))⇧2)"
by (auto simp: cpx_vec_length_def vec_of_list_def vec_of_list_index)
(metis (no_types, lifting) dim_vec_of_list sum.cong vec_of_list.abs_eq vec_of_list_index)

lemma norm_vec_index_unit_vec_is_0 [simp]:
assumes "j < n" and "j ≠ i"
shows "cmod ((unit_vec n i) \$ j) = 0"
using assms by (simp add: unit_vec_def)

lemma norm_vec_index_unit_vec_is_1 [simp]:
assumes "j < n" and "j = i"
shows "cmod ((unit_vec n i) \$ j) = 1"
proof -
have f:"(unit_vec n i) \$ j = 1"
using assms by simp
thus ?thesis
qed

lemma unit_cpx_vec_length [simp]:
assumes "i < n"
shows "∥unit_vec n i∥ = 1"
proof -
have "(∑j<n. (cmod((unit_vec n i) \$ j))⇧2) = (∑j<n. if j = i then 1 else 0)"
using norm_vec_index_unit_vec_is_0 norm_vec_index_unit_vec_is_1
by (smt lessThan_iff one_power2 sum.cong zero_power2)
also have "… = 1"
using assms by simp
finally have "sqrt (∑j<n. (cmod((unit_vec n i) \$ j))⇧2) = 1"
by simp
thus ?thesis
using cpx_vec_length_def by simp
qed

lemma smult_vec_length [simp]:
assumes "x ≥ 0"
shows "∥complex_of_real(x) ⋅⇩v v∥ = x * ∥v∥"
proof-
have "(λi::nat.(cmod (complex_of_real x * v \$ i))⇧2) = (λi::nat. (cmod (v \$ i))⇧2 * x⇧2)"
by (auto simp: norm_mult power_mult_distrib)
then have "(∑i<dim_vec v. (cmod (complex_of_real x * v \$ i))⇧2) =
(∑i<dim_vec v. (cmod (v \$ i))⇧2 * x⇧2)" by meson
moreover have "(∑i<dim_vec v. (cmod (v \$ i))⇧2 * x⇧2) = x⇧2 * (∑i<dim_vec v. (cmod (v \$ i))⇧2)"
by (metis (no_types) mult.commute sum_distrib_right)
moreover have "sqrt(x⇧2 * (∑i<dim_vec v. (cmod (v \$ i))⇧2)) =
sqrt(x⇧2) * sqrt (∑i<dim_vec v. (cmod (v \$ i))⇧2)"
using real_sqrt_mult by blast
ultimately show ?thesis
qed

locale state =
fixes n:: nat and v:: "complex mat"
assumes is_column [simp]: "dim_col v = 1"
and dim_row [simp]: "dim_row v = 2^n"
and is_normal [simp]: "∥col v 0∥ = 1"

text‹
Below the natural number n codes for the dimension of the complex vector space whose elements of norm
1 we call states.
›

lemma unit_vec_of_right_length_is_state [simp]:
assumes "i < 2^n"
shows "unit_vec (2^n) i ∈ {v| n v::complex vec. dim_vec v = 2^n ∧ ∥v∥ = 1}"
proof-
have "dim_vec (unit_vec (2^n) i) = 2^n"
by simp
moreover have "∥unit_vec (2^n) i∥ = 1"
using assms by simp
ultimately show ?thesis
by simp
qed

definition state_qbit :: "nat ⇒ complex vec set" where
"state_qbit n ≡ {v| v:: complex vec. dim_vec v = 2^n ∧ ∥v∥ = 1}"

lemma (in state) state_to_state_qbit [simp]:
shows "col v 0 ∈ state_qbit n"
using state_def state_qbit_def by simp

subsection "The Hermitian Conjugation"

text ‹The Hermitian conjugate of a complex matrix is the complex conjugate of its transpose. ›

definition dagger :: "complex mat ⇒ complex mat" ("_⇧†") where
"M⇧† ≡ mat (dim_col M) (dim_row M) (λ(i,j). cnj(M \$\$ (j,i)))"

text ‹We introduce the type of complex square matrices.›

typedef cpx_sqr_mat = "{M | M::complex mat. square_mat M}"
proof-
have "square_mat (1⇩m n)" for n
using one_mat_def by simp
thus ?thesis by blast
qed

definition cpx_sqr_mat_to_cpx_mat :: "cpx_sqr_mat => complex mat" where
"cpx_sqr_mat_to_cpx_mat M ≡ Rep_cpx_sqr_mat M"

text ‹
We introduce a coercion from the type of complex square matrices to the type of complex
matrices.
›

declare [[coercion cpx_sqr_mat_to_cpx_mat]]

lemma dim_row_of_dagger [simp]:
"dim_row (M⇧†) = dim_col M"
using dagger_def by simp

lemma dim_col_of_dagger [simp]:
"dim_col (M⇧†) = dim_row M"
using dagger_def by simp

lemma col_of_dagger [simp]:
assumes "j < dim_row M"
shows "col (M⇧†) j = vec (dim_col M) (λi. cnj (M \$\$ (j,i)))"
using assms col_def dagger_def by simp

lemma row_of_dagger [simp]:
assumes "i < dim_col M"
shows "row (M⇧†) i = vec (dim_row M) (λj. cnj (M \$\$ (j,i)))"
using assms row_def dagger_def by simp

lemma dagger_of_dagger_is_id:
fixes M :: "complex Matrix.mat"
shows "(M⇧†)⇧† = M"
proof
show "dim_row ((M⇧†)⇧†) = dim_row M" by simp
show "dim_col ((M⇧†)⇧†) = dim_col M" by simp
fix i j assume a0:"i < dim_row M" and a1:"j < dim_col M"
then show "(M⇧†)⇧† \$\$ (i,j) = M \$\$ (i,j)"
proof-
show ?thesis
using dagger_def a0 a1 by auto
qed
qed

lemma dagger_of_sqr_is_sqr [simp]:
"square_mat ((M::cpx_sqr_mat)⇧†)"
proof-
have "square_mat M"
using cpx_sqr_mat_to_cpx_mat_def Rep_cpx_sqr_mat by simp
then have "dim_row M = dim_col M" by simp
then have "dim_col (M⇧†) = dim_row (M⇧†)" by simp
thus "square_mat (M⇧†)" by simp
qed

lemma dagger_of_id_is_id [simp]:
"(1⇩m n)⇧† = 1⇩m n"
using dagger_def one_mat_def by auto

subsection "Unitary Matrices and Quantum Gates"

definition unitary :: "complex mat ⇒ bool" where
"unitary M ≡ (M⇧†) * M = 1⇩m (dim_col M) ∧ M * (M⇧†) = 1⇩m (dim_row M)"

lemma id_is_unitary [simp]:
"unitary (1⇩m n)"

locale gate =
fixes n:: nat and A:: "complex mat"
assumes dim_row [simp]: "dim_row A = 2^n"
and square_mat [simp]: "square_mat A"
and unitary [simp]: "unitary A"

text ‹
We prove that a quantum gate is invertible and its inverse is given by its Hermitian conjugate.
›

lemma mat_unitary_mat [intro]:
assumes "unitary M"
shows "inverts_mat M (M⇧†)"
using assms by (simp add: unitary_def inverts_mat_def)

lemma unitary_mat_mat [intro]:
assumes "unitary M"
shows "inverts_mat (M⇧†) M"
using assms by (simp add: unitary_def inverts_mat_def)

lemma (in gate) gate_is_inv:
"invertible_mat A"
using square_mat unitary invertible_mat_def by blast

subsection "Relations Between Complex Conjugation, Hermitian Conjugation, Transposition and Unitarity"

notation transpose_mat ("(_⇧t)")

lemma col_tranpose [simp]:
assumes "dim_row M = n" and "i < n"
shows "col (M⇧t) i = row M i"
proof
show "dim_vec (col (M⇧t) i) = dim_vec (row M i)"
by (simp add: row_def col_def transpose_mat_def)
next
show "⋀j. j < dim_vec (row M i) ⟹ col M⇧t i \$ j = row M i \$ j"
using assms by (simp add: transpose_mat_def)
qed

lemma row_transpose [simp]:
assumes "dim_col M = n" and "i < n"
shows "row (M⇧t) i = col M i"
using assms by simp

definition cpx_mat_cnj :: "complex mat ⇒ complex mat" ("(_⇧⋆)") where
"cpx_mat_cnj M ≡ mat (dim_row M) (dim_col M) (λ(i,j). cnj (M \$\$ (i,j)))"

lemma cpx_mat_cnj_id [simp]:
"(1⇩m n)⇧⋆ = 1⇩m n"
by (auto simp: cpx_mat_cnj_def)

lemma cpx_mat_cnj_cnj [simp]:
"(M⇧⋆)⇧⋆ = M"
by (auto simp: cpx_mat_cnj_def)

lemma dim_row_of_cjn_prod [simp]:
"dim_row ((M⇧⋆) * (N⇧⋆)) = dim_row M"

lemma dim_col_of_cjn_prod [simp]:
"dim_col ((M⇧⋆) * (N⇧⋆)) = dim_col N"

lemma cpx_mat_cnj_prod:
assumes "dim_col M = dim_row N"
shows "(M * N)⇧⋆ = (M⇧⋆) * (N⇧⋆)"
proof
show "dim_row (M * N)⇧⋆ = dim_row ((M⇧⋆) * (N⇧⋆))"
next
show "dim_col ((M * N)⇧⋆) = dim_col ((M⇧⋆) * (N⇧⋆))"
next
fix i j::nat
assume a1:"i < dim_row ((M⇧⋆) * (N⇧⋆))" and a2:"j < dim_col ((M⇧⋆) * (N⇧⋆))"
then have "(M * N)⇧⋆ \$\$ (i,j) = cnj (∑k<(dim_row N). M \$\$ (i,k) * N \$\$ (k,j))"
using assms cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def
dim_row_of_cjn_prod dim_col_of_cjn_prod
by (smt case_prod_conv dim_col index_mult_mat(2) index_mult_mat(3) index_vec lessThan_atLeast0
lessThan_iff sum.cong)
also have "… = (∑k<(dim_row N). cnj(M \$\$ (i,k)) * cnj(N \$\$ (k,j)))" by simp
also have "((M⇧⋆) * (N⇧⋆)) \$\$ (i,j) =
(∑k<(dim_row N). cnj(M \$\$ (i,k)) * cnj(N \$\$ (k,j)))"
using assms a1 a2 cpx_mat_cnj_def index_mat times_mat_def scalar_prod_def row_def col_def
by (smt case_prod_conv dim_col dim_col_mat(1) dim_row_mat(1) index_vec lessThan_atLeast0
lessThan_iff sum.cong)
finally show "(M * N)⇧⋆ \$\$ (i, j) = ((M⇧⋆) * (N⇧⋆)) \$\$ (i, j)" by simp
qed

lemma transpose_of_prod:
fixes M N::"complex Matrix.mat"
assumes "dim_col M = dim_row N"
shows "(M * N)⇧t = N⇧t * (M⇧t)"
proof
fix i j::nat
assume a0: "i < dim_row (N⇧t * (M⇧t))" and a1: "j < dim_col (N⇧t * (M⇧t))"
then have "(M * N)⇧t \$\$ (i,j) = (M * N) \$\$ (j,i)" by auto
also have "... = (∑k<dim_row M⇧t.  M \$\$ (j,k) * N \$\$ (k,i))"
using assms a0 a1 by auto
also have "... = (∑k<dim_row M⇧t. N \$\$ (k,i) * M \$\$ (j,k))"
also have "... = (∑k<dim_row M⇧t. ((N⇧t) \$\$ (i,k)) * (M⇧t) \$\$ (k,j))"
using assms a0 a1 by auto
finally show "((M * N)⇧t) \$\$ (i,j) = (N⇧t * (M⇧t)) \$\$ (i,j)"
using assms a0 a1 by auto
next
show "dim_row ((M * N)⇧t) = dim_row (N⇧t * (M⇧t))" by auto
next
show "dim_col ((M * N)⇧t) = dim_col (N⇧t * (M⇧t))" by auto
qed

lemma transpose_cnj_is_dagger [simp]:
"(M⇧t)⇧⋆ = (M⇧†)"
proof
show f1:"dim_row ((M⇧t)⇧⋆) = dim_row (M⇧†)"
by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
show f2:"dim_col ((M⇧t)⇧⋆) = dim_col (M⇧†)"
by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
next
fix i j::nat
assume "i < dim_row M⇧†" and "j < dim_col M⇧†"
then show "M⇧t⇧⋆ \$\$ (i, j) = M⇧† \$\$ (i, j)"
by (simp add: cpx_mat_cnj_def transpose_mat_def dagger_def)
qed

lemma cnj_transpose_is_dagger [simp]:
"(M⇧⋆)⇧t = (M⇧†)"
proof
show "dim_row ((M⇧⋆)⇧t) = dim_row (M⇧†)"
by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
show "dim_col ((M⇧⋆)⇧t) = dim_col (M⇧†)"
by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
next
fix i j::nat
assume "i < dim_row M⇧†" and "j < dim_col M⇧†"
then show "M⇧⋆⇧t \$\$ (i, j) = M⇧† \$\$ (i, j)"
by (simp add: transpose_mat_def cpx_mat_cnj_def dagger_def)
qed

lemma dagger_of_transpose_is_cnj [simp]:
"(M⇧t)⇧† = (M⇧⋆)"
by (metis transpose_transpose transpose_cnj_is_dagger)

lemma dagger_of_prod:
fixes M N::"complex Matrix.mat"
assumes "dim_col M = dim_row N"
shows "(M * N)⇧† = N⇧† * (M⇧†)"
proof-
have "(M * N)⇧† = ((M * N)⇧⋆)⇧t" by auto
also have "... = ((M⇧⋆) * (N⇧⋆))⇧t" using assms cpx_mat_cnj_prod by auto
also have "... = (N⇧⋆)⇧t * ((M⇧⋆)⇧t)" using assms transpose_of_prod
by (metis cnj_transpose_is_dagger dim_col_of_dagger dim_row_of_dagger index_transpose_mat(2) index_transpose_mat(3))
finally show "(M * N)⇧† = N⇧† * (M⇧†)" by auto
qed

text ‹The product of two quantum gates is a quantum gate.›

lemma prod_of_gate_is_gate:
assumes "gate n G1" and "gate n G2"
shows "gate n (G1 * G2)"
proof
show "dim_row (G1 * G2) = 2^n" using assms by (simp add: gate_def)
next
show "square_mat (G1 * G2)"
using assms gate.dim_row gate.square_mat by simp
next
show "unitary (G1 * G2)"
proof-
have "((G1 * G2)⇧†) * (G1 * G2) = 1⇩m (dim_col (G1 * G2))"
proof-
have f0: "G1 ∈ carrier_mat (2^n) (2^n) ∧ G2 ∈ carrier_mat (2^n) (2^n)
∧ G1⇧† ∈ carrier_mat (2^n) (2^n) ∧ G2⇧† ∈ carrier_mat (2^n) (2^n)
∧ G1 * G2 ∈ carrier_mat (2^n) (2^n)"
using assms gate.dim_row gate.square_mat by auto
have "((G1 * G2)⇧†) * (G1 * G2) = ((G2⇧†) * (G1⇧†)) * (G1 * G2)"
using assms dagger_of_prod gate.dim_row gate.square_mat by simp
also have "... = (G2⇧†) * ((G1⇧†) * (G1 * G2))"
using assms f0 by auto
also have "... = (G2⇧†) * (((G1⇧†) * G1) * G2)"
using assms f0 f0 by auto
also have "... = (G2⇧†) * ((1⇩m (dim_col G1)) * G2)"
using gate.unitary[of n G1] assms unitary_def[of G1] by simp
also have "... = (G2⇧†) * ((1⇩m (dim_col G2)) * G2)"
using assms f0 by (metis carrier_matD(2))
also have "... = (G2⇧†) * G2"
using f0 by (metis carrier_matD(2) left_mult_one_mat)
finally show "((G1 * G2)⇧†) * (G1 * G2) = 1⇩m (dim_col (G1 * G2))"
using assms gate.unitary unitary_def by simp
qed
moreover have "(G1 * G2) * ((G1 * G2)⇧†) = 1⇩m (dim_row (G1 * G2))"
using assms calculation
by (smt carrier_matI dim_col_of_dagger dim_row_of_dagger gate.dim_row gate.square_mat index_mult_mat(2) index_mult_mat(3)
mat_mult_left_right_inverse square_mat.elims(2))
ultimately show ?thesis using unitary_def by simp
qed
qed

lemma left_inv_of_unitary_transpose [simp]:
assumes "unitary U"
shows "(U⇧t)⇧† * (U⇧t) =  1⇩m(dim_row U)"
proof -
have "dim_col U = dim_row ((U⇧t)⇧⋆)" by simp
then have "(U * ((U⇧t)⇧⋆))⇧⋆ = (U⇧⋆) * (U⇧t)"
using cpx_mat_cnj_prod cpx_mat_cnj_cnj by presburger
also have "… = (U⇧t)⇧† * (U⇧t)" by simp
finally show ?thesis
using assms by (metis transpose_cnj_is_dagger cpx_mat_cnj_id unitary_def)
qed

lemma right_inv_of_unitary_transpose [simp]:
assumes "unitary U"
shows "U⇧t * ((U⇧t)⇧†) = 1⇩m(dim_col U)"
proof -
have "dim_col ((U⇧t)⇧⋆) = dim_row U" by simp
then have "U⇧t * ((U⇧t)⇧†) = (((U⇧t)⇧⋆ * U)⇧⋆)"
using cpx_mat_cnj_cnj cpx_mat_cnj_prod dagger_of_transpose_is_cnj by presburger
also have "… = (U⇧† * U)⇧⋆" by simp
finally show ?thesis
using assms by (metis cpx_mat_cnj_id unitary_def)
qed

lemma transpose_of_unitary_is_unitary [simp]:
assumes "unitary U"
shows "unitary (U⇧t)"
using unitary_def assms left_inv_of_unitary_transpose right_inv_of_unitary_transpose by simp

subsection "The Inner Product"

text ‹We introduce a coercion between complex vectors and (column) complex matrices.›

definition ket_vec :: "complex vec ⇒ complex mat" ("|_⟩") where
"|v⟩ ≡ mat (dim_vec v) 1 (λ(i,j). v \$ i)"

lemma ket_vec_index [simp]:
assumes "i < dim_vec v"
shows "|v⟩ \$\$ (i,0) = v \$ i"
using assms ket_vec_def by simp

lemma ket_vec_col [simp]:
"col |v⟩ 0 = v"
by (auto simp: col_def ket_vec_def)

lemma smult_ket_vec [simp]:
"|x ⋅⇩v v⟩ = x ⋅⇩m |v⟩"
by (auto simp: ket_vec_def)

lemma smult_vec_length_bis [simp]:
assumes "x ≥ 0"
shows "∥col (complex_of_real(x) ⋅⇩m |v⟩) 0∥ = x * ∥v∥"
using assms smult_ket_vec smult_vec_length ket_vec_col by metis

declare [[coercion ket_vec]]

definition row_vec :: "complex vec ⇒ complex mat" where
"row_vec v ≡ mat 1 (dim_vec v) (λ(i,j). v \$ j)"

definition bra_vec :: "complex vec ⇒ complex mat" where
"bra_vec v ≡ (row_vec v)⇧⋆"

lemma row_bra_vec [simp]:
"row (bra_vec v) 0 = vec (dim_vec v) (λi. cnj(v \$ i))"
by (auto simp: row_def bra_vec_def cpx_mat_cnj_def row_vec_def)

text ‹We introduce a definition called @{term "bra"} to see a vector as a column matrix.›

definition bra :: "complex mat ⇒ complex mat" ("⟨_|") where
"⟨v| ≡ mat 1 (dim_row v) (λ(i,j). cnj(v \$\$ (j,i)))"

text ‹The relation between @{term "bra"}, @{term "bra_vec"} and @{term "ket_vec"} is given as follows.›

lemma bra_bra_vec [simp]:
"bra (ket_vec v) = bra_vec v"
by (auto simp: bra_def ket_vec_def bra_vec_def cpx_mat_cnj_def row_vec_def)

lemma row_bra [simp]:
fixes v::"complex vec"
shows "row ⟨v| 0 = vec (dim_vec v) (λi. cnj (v \$ i))" by simp

text ‹We introduce the inner product of two complex vectors in @{text "ℂ⇧n"}.›

definition inner_prod :: "complex vec ⇒ complex vec ⇒ complex" ("⟨_|_⟩") where
"inner_prod u v ≡ ∑ i ∈ {0..< dim_vec v}. cnj(u \$ i) * (v \$ i)"

lemma inner_prod_with_row_bra_vec [simp]:
assumes "dim_vec u = dim_vec v"
shows "⟨u|v⟩ = row (bra_vec u) 0 ∙ v"
using assms inner_prod_def scalar_prod_def row_bra_vec index_vec
by (smt lessThan_atLeast0 lessThan_iff sum.cong)

lemma inner_prod_with_row_bra_vec_col_ket_vec [simp]:
assumes "dim_vec u = dim_vec v"
shows "⟨u|v⟩ = (row ⟨u| 0) ∙ (col |v⟩ 0)"
using assms by (simp add: inner_prod_def scalar_prod_def)

lemma inner_prod_with_times_mat [simp]:
assumes "dim_vec u = dim_vec v"
shows "⟨u|v⟩ = (⟨u| * |v⟩) \$\$ (0,0)"
using assms inner_prod_with_row_bra_vec_col_ket_vec
by (simp add: inner_prod_def times_mat_def ket_vec_def bra_def)

lemma orthogonal_unit_vec [simp]:
assumes "i < n" and "j < n" and "i ≠ j"
shows "⟨unit_vec n i|unit_vec n j⟩ = 0"
proof-
have "⟨unit_vec n i|unit_vec n j⟩ = unit_vec n i ∙ unit_vec n j"
using assms unit_vec_def inner_prod_def scalar_prod_def
by (smt complex_cnj_zero index_unit_vec(3) index_vec inner_prod_with_row_bra_vec row_bra_vec
scalar_prod_right_unit)
thus ?thesis
using assms scalar_prod_def unit_vec_def by simp
qed

text ‹We prove that our inner product is linear in its second argument.›

lemma vec_index_is_linear [simp]:
assumes "dim_vec u = dim_vec v" and "j < dim_vec u"
shows "(k ⋅⇩v u + l ⋅⇩v v) \$ j = k * (u \$ j) + l * (v \$ j)"
using assms vec_index_def smult_vec_def plus_vec_def by simp

lemma inner_prod_is_linear [simp]:
fixes u::"complex vec" and v::"nat ⇒ complex vec" and l::"nat ⇒ complex"
assumes "∀i∈{0, 1}. dim_vec u = dim_vec (v i)"
shows "⟨u|l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1⟩ = (∑i≤1. l i * ⟨u|v i⟩)"
proof -
have f1:"dim_vec (l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1) = dim_vec u"
using assms by simp
then have "⟨u|l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1⟩ = (∑i∈{0 ..< dim_vec u}. cnj (u \$ i) * ((l 0 ⋅⇩v v 0 + l 1 ⋅⇩v v 1) \$ i))"
also have "… = (∑i∈{0 ..< dim_vec u}. cnj (u \$ i) * (l 0 * v 0 \$ i + l 1 * v 1 \$ i))"
using assms by simp
also have "… = l 0 * (∑i∈{0 ..< dim_vec u}. cnj(u \$ i) * (v 0 \$ i)) + l 1 * (∑i∈{0 ..< dim_vec u}. cnj(u \$ i) * (v 1 \$ i))"
by (auto simp: algebra_simps)
also have "… = l 0 * ⟨u|v 0⟩ + l 1 * ⟨u|v 1⟩"
using assms inner_prod_def by auto
finally show ?thesis by simp
qed

lemma inner_prod_cnj:
assumes "dim_vec u = dim_vec v"
shows "⟨v|u⟩ = cnj (⟨u|v⟩)"
by (simp add: assms inner_prod_def algebra_simps)

lemma inner_prod_with_itself_Im [simp]:
"Im (⟨u|u⟩) = 0"
using inner_prod_cnj by (metis Reals_cnj_iff complex_is_Real_iff)

lemma inner_prod_with_itself_real [simp]:
"⟨u|u⟩ ∈ ℝ"
using inner_prod_with_itself_Im by (simp add: complex_is_Real_iff)

lemma inner_prod_with_itself_eq0 [simp]:
assumes "u = 0⇩v (dim_vec u)"
shows "⟨u|u⟩ = 0"
using assms inner_prod_def zero_vec_def
by (smt atLeastLessThan_iff complex_cnj_zero index_zero_vec(1) mult_zero_left sum.neutral)

lemma inner_prod_with_itself_Re:
"Re (⟨u|u⟩) ≥ 0"
proof -
have "Re (⟨u|u⟩) = (∑i<dim_vec u. Re (cnj(u \$ i) * (u \$ i)))"
moreover have "… = (∑i<dim_vec u. (Re (u \$ i))⇧2 + (Im (u \$ i))⇧2)"
using complex_mult_cnj
by (metis (no_types, lifting) Re_complex_of_real semiring_normalization_rules(7))
ultimately show "Re (⟨u|u⟩) ≥ 0" by (simp add: sum_nonneg)
qed

lemma inner_prod_with_itself_nonneg_reals:
fixes u::"complex vec"
shows "⟨u|u⟩ ∈ nonneg_Reals"
using inner_prod_with_itself_real inner_prod_with_itself_Re complex_nonneg_Reals_iff
inner_prod_with_itself_Im by auto

lemma inner_prod_with_itself_Re_non0:
assumes "u ≠ 0⇩v (dim_vec u)"
shows "Re (⟨u|u⟩) > 0"
proof -
obtain i where a1:"i < dim_vec u" and "u \$ i ≠ 0"
using assms zero_vec_def by (metis dim_vec eq_vecI index_zero_vec(1))
then have f1:"Re (cnj (u \$ i) * (u \$ i)) > 0"
by (metis Re_complex_of_real complex_mult_cnj complex_neq_0 mult.commute)
moreover have f2:"Re (⟨u|u⟩) = (∑i<dim_vec u. Re (cnj(u \$ i) * (u \$ i)))"
using inner_prod_def by (simp add: lessThan_atLeast0)
moreover have f3:"∀i<dim_vec u. Re (cnj(u \$ i) * (u \$ i)) ≥ 0"
using complex_mult_cnj by simp
ultimately show ?thesis
using a1 inner_prod_def lessThan_iff
by (metis (no_types, lifting) finite_lessThan sum_pos2)
qed

lemma inner_prod_with_itself_nonneg_reals_non0:
assumes "u ≠ 0⇩v (dim_vec u)"
shows "⟨u|u⟩ ≠ 0"
using assms inner_prod_with_itself_Re_non0 by fastforce

lemma cpx_vec_length_inner_prod [simp]:
"∥v∥⇧2 = ⟨v|v⟩"
proof -
have "∥v∥⇧2 = (∑i<dim_vec v. (cmod (v \$ i))⇧2)"
using cpx_vec_length_def complex_of_real_def
by (metis (no_types, lifting) real_sqrt_power real_sqrt_unique sum_nonneg zero_le_power2)
also have "… = (∑i<dim_vec v. cnj (v \$ i) * (v \$ i))"
using complex_norm_square mult.commute by (smt of_real_sum sum.cong)
finally show ?thesis
using inner_prod_def by (simp add: lessThan_atLeast0)
qed

lemma inner_prod_csqrt [simp]:
"csqrt ⟨v|v⟩ = ∥v∥"
using inner_prod_with_itself_Re inner_prod_with_itself_Im csqrt_of_real_nonneg cpx_vec_length_def
by (metis (no_types, lifting) Re_complex_of_real cpx_vec_length_inner_prod real_sqrt_ge_0_iff
real_sqrt_unique sum_nonneg zero_le_power2)

subsection "Unitary Matrices and Length-Preservation"

subsubsection "Unitary Matrices are Length-Preserving"

text ‹The bra-vector @{text "⟨A * v|"} is given by @{text "⟨v| * A⇧†"}›

lemma dagger_of_ket_is_bra:
fixes v:: "complex vec"
shows "( |v⟩ )⇧† = ⟨v|"
by (simp add: bra_def dagger_def ket_vec_def)

lemma bra_mat_on_vec:
fixes v::"complex vec" and A::"complex mat"
assumes "dim_col A = dim_vec v"
shows "⟨A * v| = ⟨v| * (A⇧†)"
proof
show "dim_row ⟨A * v| = dim_row (⟨v| * (A⇧†))"
next
show "dim_col ⟨A * v| = dim_col (⟨v| * (A⇧†))"
next
fix i j::nat
assume a1:"i < dim_row (⟨v| * (A⇧†))" and a2:"j < dim_col (⟨v| * (A⇧†))"
then have "cnj((A * v) \$\$ (j,0)) = cnj (row A j ∙ v)"
using bra_def times_mat_def ket_vec_col ket_vec_def by simp
also have f7:"…= (∑i∈{0 ..< dim_vec v}. cnj(v \$ i) * cnj(A \$\$ (j,i)))"
using row_def scalar_prod_def cnj_sum complex_cnj_mult mult.commute
by (smt assms index_vec lessThan_atLeast0 lessThan_iff sum.cong)
moreover have f8:"(row ⟨v| 0) ∙ (col (A⇧†) j) =
vec (dim_vec v) (λi. cnj (v \$ i)) ∙ vec (dim_col A) (λi. cnj (A \$\$ (j,i)))"
using a2 by simp
ultimately have "cnj((A * v) \$\$ (j,0)) = (row ⟨v| 0) ∙ (col (A⇧†) j)"
using assms scalar_prod_def
by (smt dim_vec index_vec lessThan_atLeast0 lessThan_iff sum.cong)
then have "⟨A * v| \$\$ (0,j) = (⟨v| * (A⇧†)) \$\$ (0,j)"
using bra_def times_mat_def a2 by simp
thus "⟨A * |v⟩| \$\$ (i, j) = (⟨v| * (A⇧†)) \$\$ (i, j)"
using a1 by (simp add: times_mat_def bra_def)
qed

lemma mat_on_ket:
fixes v:: "complex vec" and A:: "complex mat"
assumes "dim_col A = dim_vec v"
shows "A * |v⟩ = |col (A * v) 0⟩"
using assms ket_vec_def by auto

lemma dagger_of_mat_on_ket:
fixes v:: "complex vec" and A :: "complex mat"
assumes "dim_col A = dim_vec v"
shows "(A * |v⟩ )⇧† = ⟨v| * (A⇧†)"
using assms by (metis bra_mat_on_vec dagger_of_ket_is_bra mat_on_ket)

definition col_fst :: "'a mat ⇒ 'a vec" where
"col_fst A = vec (dim_row A) (λ i. A \$\$ (i,0))"

lemma col_fst_is_col [simp]:
"col_fst M = col M 0"

text ‹
We need to declare @{term "col_fst"} as a coercion from matrices to vectors in order to see a column
matrix as a vector.
›

declare
[[coercion_delete ket_vec]]
[[coercion col_fst]]

lemma unit_vec_to_col:
assumes "dim_col A = n" and "i < n"
shows "col A i = A * |unit_vec n i⟩"
proof
show "dim_vec (col A i) = dim_vec (A * |unit_vec n i⟩)"
using col_def times_mat_def by simp
next
fix j::nat
assume "j < dim_vec (col_fst (A * |unit_vec n i⟩))"
then show "col A i \$ j = (A * |unit_vec n i⟩) \$ j"
using assms times_mat_def ket_vec_def
by (smt col_fst_is_col dim_col dim_col_mat(1) index_col index_mult_mat(1) index_mult_mat(2)
index_row(1) ket_vec_col less_numeral_extra(1) scalar_prod_right_unit)
qed

lemma mult_ket_vec_is_ket_vec_of_mult:
fixes A::"complex mat" and v::"complex vec"
assumes "dim_col A = dim_vec v"
shows "|A * |v⟩ ⟩ = A * |v⟩"
using assms ket_vec_def
by (metis One_nat_def col_fst_is_col dim_col dim_col_mat(1) index_mult_mat(3) ket_vec_col less_Suc0
mat_col_eqI)

lemma unitary_is_sq_length_preserving [simp]:
assumes "unitary U" and "dim_vec v = dim_col U"
shows "∥U * |v⟩∥⇧2 = ∥v∥⇧2"
proof -
have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * (U⇧†) * |U * |v⟩⟩) \$\$ (0,0)"
using assms(2) bra_mat_on_vec
by (metis inner_prod_with_times_mat mult_ket_vec_is_ket_vec_of_mult)
then have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * (U⇧†) * (U * |v⟩)) \$\$ (0,0)"
using assms(2) mult_ket_vec_is_ket_vec_of_mult by simp
moreover have f1:"dim_col ⟨|v⟩| = dim_vec v"
using ket_vec_def bra_def by simp
moreover have "dim_row (U⇧†) = dim_vec v"
using assms(2) by simp
ultimately have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * ((U⇧†) * U) * |v⟩) \$\$ (0,0)"
using assoc_mult_mat
by(smt carrier_mat_triv dim_row_mat(1) dagger_def ket_vec_def mat_carrier times_mat_def)
then have "⟨U * |v⟩|U * |v⟩ ⟩ = (⟨|v⟩| * |v⟩) \$\$ (0,0)"
using assms f1 unitary_def by simp
thus ?thesis
using cpx_vec_length_inner_prod by(metis Re_complex_of_real inner_prod_with_times_mat)
qed

lemma col_ket_vec [simp]:
assumes "dim_col M = 1"
shows "|col M 0⟩ = M"
using eq_matI assms ket_vec_def by auto

lemma state_col_ket_vec:
assumes "state 1 v"
shows "state 1 |col v 0⟩"
using assms by (simp add: state_def)

lemma col_ket_vec_index [simp]:
assumes "i < dim_row v"
shows "|col v 0⟩ \$\$ (i,0) = v \$\$ (i,0)"
using assms ket_vec_def by (simp add: col_def)

lemma col_index_of_mat_col [simp]:
assumes "dim_col v = 1" and "i < dim_row v"
shows "col v 0 \$ i = v \$\$ (i,0)"
using assms by simp

lemma unitary_is_sq_length_preserving_bis [simp]:
assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
shows "∥col (U * v) 0∥⇧2 = ∥col v 0∥⇧2"
proof -
have "dim_vec (col v 0) = dim_col U"
using assms(2) by simp
then have "∥col_fst (U * |col v 0⟩)∥⇧2 = ∥col v 0∥⇧2"
using unitary_is_sq_length_preserving[of "U" "col v 0"] assms(1) by simp
thus ?thesis
using assms(3) by simp
qed

text ‹
A unitary matrix is length-preserving, i.e. it acts on a vector to produce another vector of the
same length.
›

lemma unitary_is_length_preserving_bis [simp]:
fixes U::"complex mat" and v::"complex mat"
assumes "unitary U" and "dim_row v = dim_col U" and "dim_col v = 1"
shows "∥col (U * v) 0∥ = ∥col v 0∥"
using assms unitary_is_sq_length_preserving_bis
by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)

lemma unitary_is_length_preserving [simp]:
fixes U:: "complex mat" and v:: "complex vec"
assumes "unitary U" and "dim_vec v = dim_col U"
shows "∥U * |v⟩∥ = ∥v∥"
using assms unitary_is_sq_length_preserving
by (metis cpx_vec_length_inner_prod inner_prod_csqrt of_real_hom.injectivity)

subsubsection "Length-Preserving Matrices are Unitary"

lemma inverts_mat_sym:
fixes A B:: "complex mat"
assumes "inverts_mat A B" and "dim_row B = dim_col A" and "square_mat B"
shows "inverts_mat B A"
proof-
define n where d0:"n = dim_row B"
have "A * B = 1⇩m (dim_row A)" using assms(1) inverts_mat_def by auto
moreover have "dim_col B = dim_col (A * B)" using times_mat_def by simp
ultimately have "dim_col B = dim_row A" by simp
then have c0:"A ∈ carrier_mat n n" using assms(2,3) d0 by auto
have c1:"B ∈ carrier_mat n n" using assms(3) d0 by auto
have f0:"A * B = 1⇩m n" using inverts_mat_def c0 c1 assms(1) by auto
have f1:"det B ≠ 0"
proof
assume "det B = 0"
then have "∃v. v ∈ carrier_vec n ∧ v ≠ 0⇩v n ∧ B *⇩v v = 0⇩v n"
using det_0_iff_vec_prod_zero assms(3) c1 by blast
then obtain v where d1:"v ∈ carrier_vec n ∧ v ≠ 0⇩v n ∧ B *⇩v v = 0⇩v n" by auto
then have d2:"dim_vec v = n" by simp
have "B * |v⟩ = |0⇩v n⟩"
proof
show "dim_row (B * |v⟩) = dim_row |0⇩v n⟩" using ket_vec_def d0 by simp
next
show "dim_col (B * |v⟩) = dim_col |0⇩v n⟩" using ket_vec_def d0 by simp
next
fix i j assume "i < dim_row |0⇩v n⟩" and "j < dim_col |0⇩v n⟩"
then have f2:"i < n ∧ j = 0" using ket_vec_def by simp
moreover have "vec (dim_row B) ((\$) v) = v" using d0 d1 by auto
moreover have "(B *⇩v v) \$ i = (∑ia = 0..<dim_row B. row B i \$ ia * v \$ ia)"
using d0 d2 f2 by (auto simp add: scalar_prod_def)
ultimately show "(B * |v⟩) \$\$ (i, j) = |0⇩v n⟩ \$\$ (i, j)"
using ket_vec_def d0 d1 times_mat_def mult_mat_vec_def by (auto simp add: scalar_prod_def)
qed
moreover have "|v⟩ ∈ carrier_mat n 1" using d2 ket_vec_def by simp
ultimately have "(A * B) * |v⟩ = A * |0⇩v n⟩" using c0 c1 by simp
then have f3:"|v⟩ = A * |0⇩v n⟩" using d2 f0 ket_vec_def by auto
have "v = 0⇩v n"
proof
show "dim_vec v = dim_vec (0⇩v n)" using d2 by simp
next
fix i assume f4:"i < dim_vec (0⇩v n)"
then have "|v⟩ \$\$ (i,0) = v \$ i" using d2 ket_vec_def by simp
moreover have "(A * |0⇩v n⟩) \$\$ (i, 0) = 0"
using ket_vec_def times_mat_def scalar_prod_def f4 c0 by auto
ultimately show "v \$ i = 0⇩v n \$ i" using f3 f4 by simp
qed
then show False using d1 by simp
qed
have f5:"adj_mat B ∈ carrier_mat n n ∧ B * adj_mat B = det B ⋅⇩m 1⇩m n" using c1 adj_mat by auto
then have c2:"((1/det B) ⋅⇩m adj_mat B) ∈ carrier_mat n n" by simp
have f6:"B * ((1/det B) ⋅⇩m adj_mat B) = 1⇩m n" using c1 f1 f5 mult_smult_distrib[of "B"] by auto
then have "A = (A * B) * ((1/det B) ⋅⇩m adj_mat B)" using c0 c1 c2 by simp
then have "A = (1/det B) ⋅⇩m adj_mat B" using f0 c2 by auto
then show ?thesis using c0 c1 f6 inverts_mat_def by auto
qed

lemma sum_of_unit_vec_length:
fixes i j n:: nat and c:: complex
assumes "i < n" and "j < n" and "i ≠ j"
shows "∥unit_vec n i + c ⋅⇩v unit_vec n j∥⇧2 = 1 + cnj(c) * c"
proof-
define v where d0:"v = unit_vec n i + c ⋅⇩v unit_vec n j"
have "∀k<n. v \$ k = (if k = i then 1 else (if k = j then c else 0))"
using d0 assms(1,2,3) by auto
then have "∀k<n. cnj (v \$ k) * v \$ k = (if k = i then 1 else 0) + (if k = j then cnj(c) * c else 0)"
using assms(3) by auto
moreover have "∥v∥⇧2 = (∑k = 0..<n. cnj (v \$ k) * v \$ k)"
using d0 assms cpx_vec_length_inner_prod inner_prod_def by simp
ultimately show ?thesis
using d0 assms by (auto simp add: sum.distrib)
qed

lemma sum_of_unit_vec_to_col:
assumes "dim_col A = n" and "i < n" and "j < n"
shows "col A i + c ⋅⇩v col A j = A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩"
proof
show "dim_vec (col A i + c ⋅⇩v col A j) = dim_vec (col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩))"
using assms(1) by auto
next
fix k assume "k < dim_vec (col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩))"
then have f0:"k < dim_row A" using assms(1) by auto
have "(col A i + c ⋅⇩v col A j) \$ k = A \$\$ (k, i) + c * A \$\$ (k, j)"
using f0 assms(1-3) by auto
moreover have "(∑x<n. A \$\$ (k, x) * ((if x = i then 1 else 0) + c * (if x = j then 1 else 0))) =
(∑x<n. A \$\$ (k, x) * (if x = i then 1 else 0)) +
(∑x<n. A \$\$ (k, x) * c * (if x = j then 1 else 0))"
by (auto simp add: sum.distrib algebra_simps)
moreover have "∀x<n. A \$\$ (k, x) * (if x = i then 1 else 0) = (if x = i then A \$\$ (k, x) else 0)"
by simp
moreover have "∀x<n. A \$\$ (k, x) * c * (if x = j then 1 else 0) = (if x = j then A \$\$ (k, x) * c else 0)"
by simp
ultimately show "(col A i + c ⋅⇩v col A j) \$ k = col_fst (A * |unit_vec n i + c ⋅⇩v unit_vec n j⟩) \$ k"
using f0 assms(1-3) times_mat_def scalar_prod_def ket_vec_def by auto
qed

lemma inner_prod_is_sesquilinear:
fixes u1 u2 v1 v2:: "complex vec" and c1 c2 c3 c4:: complex and n:: nat
assumes "dim_vec u1 = n" and "dim_vec u2 = n" and "dim_vec v1 = n" and "dim_vec v2 = n"
shows "⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|c3 ⋅⇩v v1 + c4 ⋅⇩v v2⟩ = cnj (c1) * c3 * ⟨u1|v1⟩ + cnj (c2) * c3 * ⟨u2|v1⟩ +
cnj (c1) * c4 * ⟨u1|v2⟩ + cnj (c2) * c4 * ⟨u2|v2⟩"
proof-
have "⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|c3 ⋅⇩v v1 + c4 ⋅⇩v v2⟩ = c3 * ⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|v1⟩ + c4 * ⟨c1 ⋅⇩v u1 + c2 ⋅⇩v u2|v2⟩"
using inner_prod_is_linear[of "c1 ⋅⇩v u1 + c2 ⋅⇩v u2" "λi. if i = 0 then v1 else v2"
"λi. if i = 0 then c3 else c4"] assms
by simp
also have "... = c3 * cnj(⟨v1|c1 ⋅⇩v u1 + c2 ⋅⇩v u2⟩) + c4 * cnj(⟨v2|c1 ⋅⇩v u1 + c2 ⋅⇩v u2⟩)"
using assms inner_prod_cnj[of "v1" "c1 ⋅⇩v u1 + c2 ⋅⇩v u2"] inner_prod_cnj[of "v2" "c1 ⋅⇩v u1 + c2 ⋅⇩v u2"]
by simp
also have "... = c3 * cnj(c1 * ⟨v1|u1⟩ + c2 * ⟨v1|u2⟩) + c4 * cnj(c1 * ⟨v2|u1⟩ + c2 * ⟨v2|u2⟩)"
using inner_prod_is_linear[of "v1" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"]
inner_prod_is_linear[of "v2" "λi. if i = 0 then u1 else u2" "λi. if i = 0 then c1 else c2"] assms
by simp
also have "... = c3 * (cnj(c1) * ⟨u1|v1⟩ + cnj(c2) * ⟨u2|v1⟩) +
c4 * (cnj(c1) * ⟨u1|v2⟩ + cnj(c2) * ⟨u2|v2⟩)"
using inner_prod_cnj[of "v1" "u1"] inner_prod_cnj[of "v1" "u2"]
inner_prod_cnj[of "v2" "u1"] inner_prod_cnj[of "v2" "u2"] assms
by simp
finally show ?thesis
qed

text ‹
A length-preserving matrix is unitary. So, unitary matrices are exactly the length-preserving
matrices.
›

lemma length_preserving_is_unitary:
fixes U:: "complex mat"
assumes "square_mat U" and "∀v::complex vec. dim_vec v = dim_col U ⟶ ∥U * |v⟩∥ = ∥v∥"
shows "unitary U"
proof-
define n where "n = dim_col U"
then have c0:"U ∈ carrier_mat n n" using assms(1) by auto
then have c1:"U⇧† ∈ carrier_mat n n" using assms(1) dagger_def by auto
have f0:"(U⇧†) * U = 1⇩m (dim_col U)"
proof
show "dim_row (U⇧† * U) = dim_row (1⇩m (dim_col U))" using c0 by simp
next
show "dim_col (U⇧† * U) = dim_col (1⇩m (dim_col U))" using c0 by simp
next
fix i j assume "i < dim_row (1⇩m (dim_col U))" and "j < dim_col (1⇩m (dim_col U))"
then have a0:"i < n ∧ j < n" using c0 by simp
have f1:"⋀l. l<n ⟶ (∑k<n. cnj (U \$\$ (k, l)) * U \$\$ (k, l)) = 1"
proof
fix l assume a1:"l<n"
define v::"complex vec" where d1:"v = unit_vec n l"
have "∥col U l∥⇧2 = (∑k<n. cnj (U \$\$ (k, l)) * U \$\$ (k, l))"
using c0 a1 cpx_vec_length_inner_prod inner_prod_def lessThan_atLeast0 by simp
moreover have "∥col U l∥⇧2 = ∥v∥⇧2" using c0 d1 a1 assms(2) unit_vec_to_col by simp
moreover have "∥v∥⇧2 = 1" using d1 a1 cpx_vec_length_inner_prod by simp
ultimately show "(∑k<n. cnj (U \$\$ (k, l)) * U \$\$ (k, l)) = 1" by simp
qed
moreover have "i ≠ j ⟶ (∑k<n. cnj (U \$\$ (k, i)) * U \$\$ (k, j)) = 0"
proof
assume a2:"i ≠ j"
define v1::"complex vec" where d1:"v1 = unit_vec n i + 1 ⋅⇩v unit_vec n j"
define v2::"complex vec" where d2:"v2 = unit_vec n i + 𝗂 ⋅⇩v unit_vec n j"
have "∥v1∥⇧2 = 1 + cnj 1 * 1" using d1 a0 a2 sum_of_unit_vec_length by blast
then have "∥v1∥⇧2 = 2"
by (metis complex_cnj_one cpx_vec_length_inner_prod mult.left_neutral of_real_eq_iff
then have "∥U * |v1⟩∥⇧2 = 2" using c0 d1 assms(2) unit_vec_to_col by simp
moreover have "col U i + 1 ⋅⇩v col U j = U * |v1⟩"
using c0 d1 a0 sum_of_unit_vec_to_col by blast
moreover have "col U i + 1 ⋅⇩v col U j = col U i + col U j" by simp
ultimately have "⟨col U i + col U j|col U i + col U j⟩ = 2"
using cpx_vec_length_inner_prod by (metis of_real_numeral)
moreover have "⟨col U i + col U j|col U i + col U j⟩ =
⟨col U i|col U i⟩ + ⟨col U j|col U i⟩ + ⟨col U i|col U j⟩ + ⟨col U j|col U j⟩"
using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "1" "1" "1"]
by simp
ultimately have f2:"⟨col U j|col U i⟩ + ⟨col U i|col U j⟩ = 0"
using c0 a0 f1 inner_prod_def lessThan_atLeast0 by simp

have "∥v2∥⇧2 = 1 + cnj 𝗂 * 𝗂" using a0 a2 d2 sum_of_unit_vec_length by simp
then have "∥v2∥⇧2 = 2"
by (metis Re_complex_of_real complex_norm_square mult.commute norm_ii numeral_Bit0
numeral_One numeral_eq_one_iff of_real_numeral one_power2)
moreover have "∥U * |v2⟩∥⇧2 = ∥v2∥⇧2" using c0 d2 assms(2) unit_vec_to_col by simp
moreover have "⟨col U i + 𝗂 ⋅⇩v col U j|col U i + 𝗂 ⋅⇩v col U j⟩ = ∥U * |v2⟩∥⇧2"
using c0 a0 d2 sum_of_unit_vec_to_col cpx_vec_length_inner_prod by auto
moreover have "⟨col U i + 𝗂 ⋅⇩v col U j|col U i + 𝗂 ⋅⇩v col U j⟩ =
⟨col U i|col U i⟩ + (-𝗂) * ⟨col U j|col U i⟩ + 𝗂 * ⟨col U i|col U j⟩ + ⟨col U j|col U j⟩"
using inner_prod_is_sesquilinear[of "col U i" "dim_row U" "col U j" "col U i" "col U j" "1" "𝗂" "1" "𝗂"]
by simp
ultimately have "⟨col U j|col U i⟩ - ⟨col U i|col U j⟩ = 0"
using c0 a0 f1 inner_prod_def lessThan_atLeast0 by auto
then show "(∑k<n. cnj (U \$\$ (k, i)) * U \$\$ (k, j)) = 0"
using c0 a0 f2 lessThan_atLeast0 inner_prod_def by auto
qed
ultimately show "(U⇧† * U) \$\$ (i, j) = 1⇩m (dim_col U) \$\$ (i, j)"
using c0 assms(1) a0 one_mat_def dagger_def by auto
qed
then have "(U⇧†) * U = 1⇩m n" using c0 by simp
then have "inverts_mat (U⇧†) U" using c1 inverts_mat_def by auto
then have "inverts_mat U (U⇧†)" using c0 c1 inverts_mat_sym by simp
then have "U * (U⇧†) = 1⇩m (dim_row U)" using c0 inverts_mat_def by auto
then show ?thesis using f0 unitary_def by simp
qed

lemma inner_prod_with_unitary_mat [simp]:
assumes "unitary U" and "dim_vec u = dim_col U" and "dim_vec v = dim_col U"
shows "⟨U * |u⟩|U * |v⟩⟩ = ⟨u|v⟩"
proof -
have f1:"⟨U * |u⟩|U * |v⟩⟩ = (⟨|u⟩| * (U⇧†) * U * |v⟩) \$\$ (0,0)"
using assms(2-3) bra_mat_on_vec mult_ket_vec_is_ket_vec_of_mult
by (smt assoc_mult_mat carrier_mat_triv col_fst_def dim_vec dim_col_of_dagger index_mult_mat(2)
index_mult_mat(3) inner_prod_with_times_mat ket_vec_def mat_carrier)
moreover have f2:"⟨|u⟩| ∈ carrier_mat 1 (dim_vec v)"
using bra_def ket_vec_def assms(2-3) by simp
moreover have f3:"U⇧† ∈ carrier_mat (dim_col U) (dim_row U)"
using dagger_def by simp
ultimately have "⟨U * |u⟩|U * |v⟩⟩ = (⟨|u⟩| * (U⇧† * U) * |v⟩) \$\$ (0,0)"
using assms(3) assoc_mult_mat by (metis carrier_mat_triv)
also have "… = (⟨|u⟩| * |v⟩) \$\$ (0,0)"
using assms(1) unitary_def
by (simp add: assms(2) bra_def ket_vec_def)
finally show ?thesis
using assms(2-3) inner_prod_with_times_mat by presburger
qed

text ‹As a consequence we prove that columns and rows of a unitary matrix are orthonormal vectors.›

lemma unitary_unit_col [simp]:
assumes "unitary U" and "dim_col U = n" and "i < n"
shows "∥col U i∥ = 1"
using assms unit_vec_to_col unitary_is_length_preserving by simp

lemma unitary_unit_row [simp]:
assumes "unitary U" and "dim_row U = n" and "i < n"
shows "∥row U i∥ = 1"
proof -
have "row U i = col (U⇧t) i"
using  assms(2-3) by simp
thus ?thesis
using assms transpose_of_unitary_is_unitary unitary_unit_col
by (metis index_transpose_mat(3))
qed

lemma orthogonal_col_of_unitary [simp]:
assumes "unitary U" and "dim_col U = n" and "i < n" and "j < n" and "i ≠ j"
shows "⟨col U i|col U j⟩ = 0"
proof -
have "⟨col U i|col U j⟩ = ⟨U * |unit_vec n i⟩| U * |unit_vec n j⟩⟩"
using assms(2-4) unit_vec_to_col by simp
also have "… = ⟨unit_vec n i |unit_vec n j⟩"
using assms(1-2) inner_prod_with_unitary_mat index_unit_vec(3) by simp
finally show ?thesis
using assms(3-5) by simp
qed

lemma orthogonal_row_of_unitary [simp]:
fixes U::"complex mat"
assumes "unitary U" and "dim_row U = n" and "i < n" and "j < n" and "i ≠ j"
shows "⟨row U i|row U j⟩ = 0"
using assms orthogonal_col_of_unitary transpose_of_unitary_is_unitary col_transpose
by (metis index_transpose_mat(3))

text‹
As a consequence, we prove that a quantum gate acting on a state of a system of n qubits give
another state of that same system.
›

lemma gate_on_state_is_state [intro, simp]:
assumes a1:"gate n A" and a2:"state n v"
shows "state n (A * v)"
proof
show "dim_row (A * v) = 2^n"
using gate_def state_def a1 by simp
next
show "dim_col (A * v) = 1"
using state_def a2 by simp
next
have "square_mat A"
using a1 gate_def by simp
then have "dim_col A = 2^n"
using a1 gate.dim_row by simp
then have "dim_col A = dim_row v"
using a2 state.dim_row by simp
then have "∥col (A * v) 0∥ = ∥col v 0∥"
using unitary_is_length_preserving_bis assms gate_def state_def by simp
thus"∥col (A * v) 0∥ = 1"
using a2 state.is_normal by simp
qed

subsection ‹A Few Well-known Quantum Gates›

text ‹
Any unitary operation on n qubits can be implemented exactly by composing single qubits and
CNOT-gates (controlled-NOT gates). However, no straightforward method is known to implement these
gates in a fashion which is resistant to errors. But, the Hadamard gate, the phase gate, the
CNOT-gate and the @{text "π/8"} gate are also universal for quantum computations, i.e. any quantum circuit on
n qubits can be approximated to an arbitrary accuracy by using only these gates, and these gates can
be implemented in a fault-tolerant way.
›

text ‹We introduce a coercion from real matrices to complex matrices.›

definition real_to_cpx_mat:: "real mat ⇒ complex mat" where
"real_to_cpx_mat A ≡ mat (dim_row A) (dim_col A) (λ(i,j). A \$\$ (i,j))"

text ‹Our first quantum gate: the identity matrix! Arguably, not a very interesting one though!›

definition Id :: "nat ⇒ complex mat" where
"Id n ≡ 1⇩m (2^n)"

lemma id_is_gate [simp]:
"gate n (Id n)"
proof
show "dim_row (Id n) = 2^n"
using Id_def by simp
next
show "square_mat (Id n)"
using Id_def by simp
next
show "unitary (Id n)"
qed

text ‹More interesting: the Pauli matrices.›

definition X ::"complex mat" where
"X ≡ mat 2 2 (λ(i,j). if i=j then 0 else 1)"

text‹
Be aware that @{text "gate n A"} means that the matrix A has dimension @{text "2^n * 2^n"}.
For instance, with this convention a 2 X 2 matrix A which is unitary satisfies @{text "gate 1 A"}
but not @{text "gate 2 A"} as one might have been expected.
›

lemma dagger_of_X [simp]:
"X⇧† = X"
using dagger_def by (simp add: X_def cong_mat)

lemma X_inv [simp]:
"X * X = 1⇩m 2"
apply(rule cong_mat)
by(auto simp: scalar_prod_def)

lemma X_is_gate [simp]:
"gate 1 X"

definition Y ::"complex mat" where
"Y ≡ mat 2 2 (λ(i,j). if i=j then 0 else (if i=0 then -𝗂 else 𝗂))"

lemma dagger_of_Y [simp]:
"Y⇧† = Y"
using dagger_def by (simp add: Y_def cong_mat)

lemma Y_inv [simp]:
"Y * Y = 1⇩m 2"
apply(rule cong_mat)
by(auto simp: scalar_prod_def)

lemma Y_is_gate [simp]:
"gate 1 Y"

definition Z ::"complex mat" where
"Z ≡ mat 2 2 (λ(i,j). if i≠j then 0 else (if i=0 then 1 else -1))"

lemma dagger_of_Z [simp]:
"Z⇧† = Z"
using dagger_def by (simp add: Z_def cong_mat)

lemma Z_inv [simp]:
"Z * Z = 1⇩m 2"
apply(rule cong_mat)
by(auto simp: scalar_prod_def)

lemma Z_is_gate [simp]:
"gate 1 Z"

definition H ::"complex mat" where
"H ≡ 1/sqrt(2) ⋅⇩m (mat 2 2 (λ(i,j). if i≠j then 1 else (if i=0 then 1 else -1)))"

lemma H_without_scalar_prod:
"H = mat 2 2 (λ(i,j). if i≠j then 1/sqrt(2) else (if i=0 then 1/sqrt(2) else -(1/sqrt(2))))"
using cong_mat by (auto simp: H_def)

lemma dagger_of_H [simp]:
"H⇧† = H"
using dagger_def by (auto simp: H_def cong_mat)

lemma H_inv [simp]:
"H * H = 1⇩m 2"
apply(rule cong_mat)
by(auto simp: scalar_prod_def complex_eqI)

lemma H_is_gate [simp]:
"gate 1 H"

lemma H_values:
fixes i j:: nat
assumes "i < dim_row H" and "j < dim_col H" and "i ≠ 1 ∨ j ≠ 1"
shows "H \$\$ (i,j) = 1/sqrt 2"
proof-
have "i < 2"
using assms(1) by (simp add: H_without_scalar_prod less_2_cases)
moreover have "j < 2"
using assms(2) by (simp add: H_without_scalar_prod less_2_cases)
ultimately show ?thesis
using assms(3) H_without_scalar_prod by(smt One_nat_def index_mat(1) less_2_cases old.prod.case)
qed

lemma H_values_right_bottom:
fixes i j:: nat
assumes "i = 1 ∧ j = 1"
shows "H \$\$ (i,j) = - 1/sqrt 2"
using assms by (simp add: H_without_scalar_prod)

text ‹The controlled-NOT gate›

definition CNOT ::"complex mat" where
"CNOT ≡ mat 4 4
(λ(i,j). if i=0 ∧ j=0 then 1 else
(if i=1 ∧ j=1 then 1 else
(if i=2 ∧ j=3 then 1 else
(if i=3 ∧ j=2 then 1 else 0))))"

lemma dagger_of_CNOT [simp]:
"CNOT⇧† = CNOT"
using dagger_def by (simp add: CNOT_def cong_mat)

lemma CNOT_inv [simp]:
"CNOT * CNOT = 1⇩m 4"
apply(rule cong_mat)
by(auto simp: scalar_prod_def)

lemma CNOT_is_gate [simp]:
"gate 2 CNOT"

text ‹The phase gate, also known as the S-gate›

definition S ::"complex mat" where
"S ≡ mat 2 2 (λ(i,j). if i=0 ∧ j=0 then 1 else (if i=1 ∧ j=1 then 𝗂 else 0))"

text ‹The @{text "π/8"} gate, also known as the T-gate›

definition T ::"complex mat" where
"T ≡ mat 2 2 (λ(i,j). if i=0 ∧ j=0 then 1 else (if i=1 ∧ j=1 then exp(𝗂*(pi/4)) else 0))"

text ‹A few relations between the Hadamard gate and the Pauli matrices›

lemma HXH_is_Z [simp]:
"H * X * H = Z"
apply(simp add: X_def Z_def H_def times_mat_def)
apply(rule cong_mat)

lemma HYH_is_minusY [simp]:
"H * Y * H = - Y"
apply(rule eq_matI)

lemma HZH_is_X [simp]:
shows "H * Z * H = X"
apply(simp add: X_def Z_def H_def times_mat_def)
apply(rule cong_mat)

subsection ‹The Bell States›

text ‹
We introduce below the so-called Bell states, also known as EPR pairs (EPR stands for Einstein,
Podolsky and Rosen).
›

definition bell00 ::"complex mat" ("|β⇩0⇩0⟩") where
"bell00 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=0 ∨ i=3 then 1 else 0)⟩"

definition bell01 ::"complex mat" ("|β⇩0⇩1⟩") where
"bell01 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=1 ∨ i=2 then 1 else 0)⟩"

definition bell10 ::"complex mat" ("|β⇩1⇩0⟩") where
"bell10 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=0 then 1 else if i=3 then -1 else 0)⟩"

definition bell11 ::"complex mat" ("|β⇩1⇩1⟩") where
"bell11 ≡ 1/sqrt(2) ⋅⇩m |vec 4 (λi. if i=1 then 1 else if i=2 then -1 else 0)⟩"

lemma
shows bell00_is_state [simp]:"state 2 |β⇩0⇩0⟩" and bell01_is_state [simp]:"state 2 |β⇩0⇩1⟩" and
bell10_is_state [simp]:"state 2 |β⇩1⇩0⟩" and bell11_is_state [simp]:"state 2 |β⇩1⇩1⟩"
by (auto simp: state_def bell00_def bell01_def bell10_def bell11_def ket_vec_def)
(auto simp: cpx_vec_length_def Set_Interval.lessThan_atLeast0 cmod_def power2_eq_square)

lemma bell00_index [simp]:
shows "|β⇩0⇩0⟩ \$\$ (0,0) = 1/sqrt 2" and "|β⇩0⇩0⟩ \$\$ (1,0) = 0" and "|β⇩0⇩0⟩ \$\$ (2,0) = 0" and
"|β⇩0⇩0⟩ \$\$ (3,0) = 1/sqrt 2"
by (auto simp: bell00_def ket_vec_def)

lemma bell01_index [simp]:
shows "|β⇩0⇩1⟩ \$\$ (0,0) = 0" and "|β⇩0⇩1⟩ \$\$ (1,0) = 1/sqrt 2" and "|β⇩0⇩1⟩ \$\$ (2,0) = 1/sqrt 2" and
"|β⇩0⇩1⟩ \$\$ (3,0) = 0"
by (auto simp: bell01_def ket_vec_def)

lemma bell10_index [simp]:
shows "|β⇩1⇩0⟩ \$\$ (0,0) = 1/sqrt 2" and "|β⇩1⇩0⟩ \$\$ (1,0) = 0" and "|β⇩1⇩0⟩ \$\$ (2,0) = 0" and
"|β⇩1⇩0⟩ \$\$ (3,0) = - 1/sqrt 2"
by (auto simp: bell10_def ket_vec_def)

lemma bell_11_index [simp]:
shows "|β⇩1⇩1⟩ \$\$ (0,0) = 0" and "|β⇩1⇩1⟩ \$\$ (1,0) = 1/sqrt 2" and "|β⇩1⇩1⟩ \$\$ (2,0) = - 1/sqrt 2" and
"|β⇩1⇩1⟩ \$\$ (3,0) = 0"
by (auto simp: bell11_def ket_vec_def)

subsection ‹The Bitwise Inner Product›

definition bitwise_inner_prod:: "nat ⇒ nat ⇒ nat ⇒ nat" where
"bitwise_inner_prod n i j = (∑k∈{0..<n}. (bin_rep n i) ! k * (bin_rep n j) ! k)"

abbreviation bip:: "nat ⇒ nat ⇒ nat ⇒ nat" ("_ ⋅⇘_⇙  _") where
"```