Theory Linear_Algebra_Complements

(*
Author: 
  Mnacho Echenim, Université Grenoble Alpes
*)

theory Linear_Algebra_Complements imports 
  "Isabelle_Marries_Dirac.Tensor" 
  "Isabelle_Marries_Dirac.More_Tensor"
  "QHLProver.Gates" 
  "HOL-Types_To_Sets.Group_On_With" 
  "HOL-Probability.Probability" 


begin
hide_const(open) S
section ‹Preliminaries›


subsection ‹Misc›

lemma mult_real_cpx:
  fixes a::complex
  fixes b::complex
  assumes "a Reals"
  shows "a* (Re b) = Re (a * b)" using assms
  by (metis Reals_cases complex.exhaust complex.sel(1) complex_of_real_mult_Complex of_real_mult)

lemma fct_bound:
  fixes f::"complex real"
  assumes "f(-1) + f 1 = 1"
and "0  f 1"
and "0  f (-1)"
shows "-1  f 1  - f(-1)  f 1  - f(-1)  1"
proof
  have "f 1  - f(-1) = 1 - f(-1) - f(-1)"  using assms by simp
  also have "... -1" using assms by simp
  finally show "-1  f 1  - f(-1)" .
next
  have "f(-1) - f 1  = 1 - f 1  - f 1 " using assms by simp
  also have "...  -1" using assms by simp
  finally have "-1  f(-1) - f 1" .
  thus "f 1 - f (-1)  1" by simp
qed

lemma fct_bound':
  fixes f::"complex real"
  assumes "f(-1) + f 1 = 1"
and "0  f 1"
and "0  f (-1)"
shows "¦f 1  - f(-1)¦  1" using assms fct_bound by auto

lemma pos_sum_1_le:
  assumes "finite I"
and " i  I. (0::real)  f i"
and "(i I. f i) = 1"
and "j I"
shows "f j  1"
proof (rule ccontr)
  assume "¬ f j  1"
  hence "1 < f j" by simp
  hence "1 < (i I. f i)" using assms by (metis ¬ f j  1 sum_nonneg_leq_bound) 
  thus False using assms by simp
qed

lemma last_subset:
  assumes "A  {a,b}"
  and "a b"
and "A  {a, b}"
and "A {}"
and "A  {a}"
shows "A = {b}" using assms by blast

lemma disjoint_Un:
  assumes "disjoint_family_on A (insert x F)"
  and "x F"
shows "(A x)  ( a F. A a) = {}" 
proof -
  have "(A x)  ( a F. A a) = (iF. (A x)  A i)" using Int_UN_distrib by simp
  also have "... = (iF. {})" using assms disjoint_family_onD by fastforce
  also have "... = {}" using SUP_bot_conv(2) by simp
  finally show ?thesis .
qed

lemma sum_but_one:
  assumes "j < (n::nat). j i  f j = (0::'a::ring)"
  and "i < n"
  shows "( j  {0 ..< n}. f j * g j) = f i * g i"
proof -
  have "sum (λx. f x * g x) (({0 ..< n} - {i})  {i}) = sum (λx. f x * g x) ({0 ..< n} - {i}) + 
    sum (λx. f x * g x) {i}" by (rule sum.union_disjoint, auto)
  also have "... = sum (λx. f x * g x) {i}" using assms by auto
  also have "... = f i * g i" by simp
  finally have "sum (λx. f x * g x) (({0 ..< n} - {i})  {i}) = f i * g i" .
  moreover have "{0 ..< n} = ({0 ..< n} - {i})  {i}" using assms by auto
  ultimately show ?thesis by simp
qed

lemma sum_2_elems:
  assumes "I = {a,b}"
    and "a b"
  shows "(aI. f a) = f a + f b" 
proof -
  have "(aI. f a) = (a(insert a {b}). f a)" using assms by simp
  also have "... = f a + (a{b}. f a)" 
  proof (rule sum.insert)
    show "finite {b}" by simp
    show "a {b}" using assms by simp
  qed
  also have "... = f a + f b" by simp
  finally show ?thesis .
qed

lemma sum_4_elems:
  shows "(i<(4::nat). f i) = f 0 + f 1 + f 2 + f 3" 
proof -
  have "(i<(4::nat). f i) = (i<(3::nat). f i)  + f 3"
    by (metis Suc_numeral semiring_norm(2) semiring_norm(8) sum.lessThan_Suc)
  moreover have "(i<(3::nat). f i) = (i<(2::nat). f i) + f 2"
    by (metis Suc_1 add_2_eq_Suc' nat_1_add_1 numeral_code(3) numerals(1) 
        one_plus_numeral_commute sum.lessThan_Suc)
  moreover have "(i<(2::nat). f i) = (i<(1::nat). f i) + f 1"
    by (metis Suc_1 sum.lessThan_Suc)
  ultimately show ?thesis by simp
qed

lemma disj_family_sum:
  shows "finite I  disjoint_family_on A I  (i. i  I  finite (A i))   
  ( i  (n  I. A n). f i) = ( n I. ( i  A n. f i))" 
proof (induct rule:finite_induct)
  case empty
  then show ?case by simp
next
  case (insert x F)
  hence "disjoint_family_on A F"
    by (meson disjoint_family_on_mono subset_insertI) 
  have "(n  (insert x F). A n) = A x  (n  F. A n)" using insert by simp
  hence "( i  (n  (insert x F). A n). f i) = ( i  (A x  (n  F. A n)). f i)" by simp
  also have "... = ( i   A x. f i) + ( i  (n  F. A n). f i)" 
    by (rule sum.union_disjoint, (simp add: insert disjoint_Un)+)
  also have "... = ( i   A x. f i) + (nF. sum f (A n))" using  disjoint_family_on A F 
    by (simp add: insert)
  also have "... = (n(insert x F). sum f (A n))" using insert by simp
  finally show ?case .
qed

lemma integrable_real_mult_right:
  fixes c::real
  assumes "integrable M f"
  shows "integrable M (λw. c * f w)" 
proof (cases "c = 0")
  case True
  thus ?thesis by simp
next
  case False
  thus ?thesis using integrable_mult_right[of c] assms by simp
qed


subsection ‹Unifying notions between Isabelle Marries Dirac and QHLProver›

lemma mult_conj_cmod_square:
  fixes z::complex
  shows "z * conjugate z = (cmod z)2"
proof -
  have "z * conjugate z = (Re z)2 + (Im z)2" using  complex_mult_cnj by auto
  also have "... = (cmod z)2" unfolding cmod_def by simp
  finally show ?thesis .
qed

lemma vec_norm_sq_cpx_vec_length_sq:
  shows "(vec_norm v)2 = (cpx_vec_length v)2"
proof -
  have "(vec_norm v)2 = inner_prod v v" unfolding vec_norm_def using power2_csqrt by blast
  also have "... = (i<dim_vec v. (cmod (Matrix.vec_index v i))2)" unfolding Matrix.scalar_prod_def
  proof -
    have "i. i < dim_vec v  Matrix.vec_index v  i * conjugate (Matrix.vec_index v i) = 
      (cmod (Matrix.vec_index v i))2" using mult_conj_cmod_square by simp
    thus "(i = 0..<dim_vec (conjugate v). Matrix.vec_index v i * 
      Matrix.vec_index (conjugate v) i) =  (i<dim_vec v. (cmod (Matrix.vec_index v i))2)" 
      by (simp add: lessThan_atLeast0)
  qed
  finally show "(vec_norm v)2 = (cpx_vec_length v)2" unfolding cpx_vec_length_def
    by (simp add: sum_nonneg)
qed

lemma vec_norm_eq_cpx_vec_length:
  shows "vec_norm v = cpx_vec_length v" using vec_norm_sq_cpx_vec_length_sq
by (metis cpx_vec_length_inner_prod inner_prod_csqrt power2_csqrt vec_norm_def) 

lemma cpx_vec_length_square:
  shows "v2 = (i = 0..<dim_vec v. (cmod (Matrix.vec_index v i))2)" unfolding cpx_vec_length_def
  by (simp add: lessThan_atLeast0 sum_nonneg)

lemma state_qbit_norm_sq:
  assumes "v state_qbit n"
  shows "(cpx_vec_length v)2 = 1"
proof -
  have "cpx_vec_length v = 1" using assms unfolding state_qbit_def by simp
  thus ?thesis by simp
qed

lemma dagger_adjoint:
shows "dagger M = Complex_Matrix.adjoint M" unfolding dagger_def Complex_Matrix.adjoint_def
  by (simp add: cong_mat)


subsection ‹Types to sets lemmata transfers›

context ab_group_add_on_with begin

context includes lifting_syntax assumes ltd: "(Rep::'s  'a) (Abs::'a  's). 
  type_definition Rep Abs S" begin
interpretation local_typedef_ab_group_add_on_with pls z mns um S "TYPE('s)" by unfold_locales fact

lemmas lt_sum_union_disjoint = sum.union_disjoint
  [var_simplified explicit_ab_group_add,
    unoverload_type 'c,
    OF type.comm_monoid_add_axioms,
    untransferred]

lemmas lt_disj_family_sum = disj_family_sum
  [var_simplified explicit_ab_group_add,
    unoverload_type 'd,
OF type.comm_monoid_add_axioms,
    untransferred]

lemmas lt_sum_reindex_cong = sum.reindex_cong
  [var_simplified explicit_ab_group_add,
    unoverload_type 'd,
OF type.comm_monoid_add_axioms,
    untransferred]
end

lemmas sum_with_union_disjoint =
  lt_sum_union_disjoint
    [cancel_type_definition,
    OF carrier_ne,
    simplified pred_fun_def, simplified]

lemmas disj_family_sum_with =
  lt_disj_family_sum
    [cancel_type_definition,
    OF carrier_ne,
    simplified pred_fun_def, simplified]

lemmas sum_with_reindex_cong = 
  lt_sum_reindex_cong
    [cancel_type_definition,
    OF carrier_ne,
    simplified pred_fun_def, simplified]

end

lemma (in comm_monoid_add_on_with) sum_with_cong':
  shows "finite I  (i. i I  A i = B i)  (i. i I  A i  S)  
    (i. i I  B i  S)  sum_with pls z A I = sum_with pls z B I"
proof (induct rule: finite_induct)
  case empty
  then show ?case by simp
next
  case (insert x F)
  have "sum_with pls z A (insert x F) = pls (A x) (sum_with pls z A F)" using insert 
      sum_with_insert[of A] by (simp add:  image_subset_iff) 
  also have "... = pls (B x)  (sum_with pls z B F)" using insert by simp
  also have "... = sum_with pls z B (insert x F)" using insert sum_with_insert[of B]
    by (simp add:  image_subset_iff) 
  finally show ?case .
qed


section ‹Linear algebra complements›

subsection ‹Additional properties of matrices›

lemma smult_one:
  shows "(1::'a::monoid_mult) m A = A" by (simp add:eq_matI)

lemma times_diag_index:
  fixes A::"'a::comm_ring Matrix.mat"
  assumes "A  carrier_mat n n"
and "B carrier_mat n n"
and "diagonal_mat B"
and "j < n"
and "i < n"
shows "Matrix.vec_index (Matrix.row (A*B) j) i = diag_mat B ! i *A $$ (j, i)"
proof -
  have "Matrix.vec_index (Matrix.row (A*B) j) i = (A*B) $$ (j,i)" 
    using Matrix.row_def[of "A*B" ] assms by simp
  also have "... = Matrix.scalar_prod (Matrix.row A j) (Matrix.col B i)" using assms 
      times_mat_def[of A] by simp
  also have "... = Matrix.scalar_prod (Matrix.col B i) (Matrix.row A j)" 
    using comm_scalar_prod[of "Matrix.row A j" n] assms by auto
  also have "... = (Matrix.vec_index (Matrix.col B i) i) * (Matrix.vec_index  (Matrix.row A j) i)" 
    unfolding Matrix.scalar_prod_def 
  proof (rule sum_but_one)
    show "i < dim_vec (Matrix.row A j)" using assms by simp
    show "ia<dim_vec (Matrix.row A j). ia  i  Matrix.vec_index (Matrix.col B i) ia = 0" using assms
      by (metis carrier_matD(1) carrier_matD(2) diagonal_mat_def index_col index_row(2))
  qed
  also have "... = B $$(i,i) * (Matrix.vec_index  (Matrix.row A j) i)" using assms by auto
  also have "... = diag_mat B ! i * (Matrix.vec_index  (Matrix.row A j) i)" unfolding diag_mat_def 
    using assms by simp
  also have "... = diag_mat B ! i * A $$ (j, i)" using assms by simp
  finally show ?thesis .
qed

lemma inner_prod_adjoint_comp:
  assumes "(U::'a::conjugatable_field Matrix.mat)  carrier_mat n n"
and "(V::'a::conjugatable_field Matrix.mat)  carrier_mat n n"
and "i < n"
and "j < n"
shows "Complex_Matrix.inner_prod  (Matrix.col V i) (Matrix.col U j) = 
  ((Complex_Matrix.adjoint V) * U) $$ (i, j)"
proof -
  have "Complex_Matrix.inner_prod (Matrix.col V i) (Matrix.col U j) = 
    Matrix.scalar_prod (Matrix.col U j) (Matrix.row (Complex_Matrix.adjoint V) i)"
    using adjoint_row[of i V] assms  by simp  
  also have "... = Matrix.scalar_prod (Matrix.row (Complex_Matrix.adjoint V) i) (Matrix.col U j)"
    by (metis adjoint_row assms(1) assms(2) assms(3) carrier_matD(1) carrier_matD(2) Matrix.col_dim 
        conjugate_vec_sprod_comm)
  also have "... = ((Complex_Matrix.adjoint V) * U) $$ (i, j)" using assms 
    by (simp add:times_mat_def)
  finally show ?thesis .
qed

lemma mat_unit_vec_col:
  assumes "(A::'a::conjugatable_field Matrix.mat)  carrier_mat n n"
and "i < n"
shows "A *v (unit_vec n i) = Matrix.col A i"
proof
  show "dim_vec (A *v unit_vec n i) = dim_vec (Matrix.col A i)" using assms by simp
  show "j. j < dim_vec (Matrix.col A i)  Matrix.vec_index (A *v unit_vec n i)  j = 
    Matrix.vec_index (Matrix.col A i)  j"
  proof -
    fix j
    assume "j < dim_vec (Matrix.col A i)"
    hence "Matrix.vec_index (A *v unit_vec n i)  j = 
      Matrix.scalar_prod (Matrix.row A j) (unit_vec n i)" unfolding mult_mat_vec_def by simp
    also have "... = Matrix.scalar_prod  (unit_vec n i) (Matrix.row A j)" using comm_scalar_prod
        assms by auto
    also have "... = (Matrix.vec_index (unit_vec n i) i) * (Matrix.vec_index (Matrix.row A j) i)"
      unfolding Matrix.scalar_prod_def 
    proof (rule sum_but_one)
      show "i < dim_vec (Matrix.row A j)" using assms by auto
      show "ia<dim_vec (Matrix.row A j). ia  i  Matrix.vec_index (unit_vec n i) ia = 0" 
        using assms unfolding unit_vec_def by auto
    qed
    also have "... = (Matrix.vec_index (Matrix.row A j) i)" using assms by simp
    also have "... = A $$ (j, i)" using assms j < dim_vec (Matrix.col A i) by simp
    also have "... = Matrix.vec_index (Matrix.col A i)  j" using assms j < dim_vec (Matrix.col A i) by simp
    finally show "Matrix.vec_index (A *v unit_vec n i)  j = 
      Matrix.vec_index (Matrix.col A i)  j" .
  qed
qed

lemma mat_prod_unit_vec_cong:
  assumes "(A::'a::conjugatable_field Matrix.mat)  carrier_mat n n"
and "B carrier_mat n n"
and "i. i < n  A *v (unit_vec n i) = B *v (unit_vec n i)"
shows "A = B"
proof
  show "dim_row A = dim_row B" using assms by simp
  show "dim_col A = dim_col B" using assms by simp
  show "i j. i < dim_row B  j < dim_col B  A $$ (i, j) = B $$ (i, j)"
  proof -
    fix i j
    assume ij: "i < dim_row B" "j < dim_col B"
    hence "A $$ (i,j) = Matrix.vec_index (Matrix.col A j) i" using assms by simp
    also have "... = Matrix.vec_index (A *v (unit_vec n j)) i" using mat_unit_vec_col[of A] ij assms 
      by simp
    also have "... = Matrix.vec_index (B *v (unit_vec n j)) i" using assms ij by simp
    also have "... = Matrix.vec_index (Matrix.col B j) i" using mat_unit_vec_col ij assms by simp
    also have "... = B $$ (i,j)" using assms ij by simp
    finally show "A $$ (i, j) = B $$ (i, j)" .
  qed
qed

lemma smult_smult_times:
  fixes a::"'a::semigroup_mult"
  shows "am (k m A) = (a * k)m A"
proof
  show r:"dim_row (a m (k m A)) = dim_row (a * k m A)" by simp
  show c:"dim_col (a m (k m A)) = dim_col (a * k m A)" by simp
  show "i j. i < dim_row (a * k m A) 
           j < dim_col (a * k m A)  (a m (k m A)) $$ (i, j) = (a * k m A) $$ (i, j)"
  proof -
    fix i j
    assume "i < dim_row (a * k m A)" and "j < dim_col (a * k m A)" note ij = this
    hence "(a m (k m A)) $$ (i, j) = a * (k m A) $$ (i, j)"  by simp
    also have "... = a * (k * A $$ (i,j))" using ij by simp
    also have "... = (a * k) * A $$ (i,j)"
      by (simp add: semigroup_mult_class.mult.assoc)
    also have "... = (a * k m A) $$ (i, j)" using r c ij by simp
    finally show "(a m (k m A)) $$ (i, j) = (a * k m A) $$ (i, j)" .
  qed
qed

lemma mat_minus_minus:
  fixes A :: "'a :: ab_group_add Matrix.mat"
  assumes "A  carrier_mat n m"
  and "B carrier_mat n m"
  and "C carrier_mat n m"
shows "A - (B - C) = A - B + C"
proof
  show "dim_row (A - (B - C)) = dim_row (A - B + C)" using assms by simp
  show "dim_col (A - (B - C)) = dim_col (A - B + C)" using assms by simp
  show "i j. i < dim_row (A - B + C)  j < dim_col (A - B + C)  
    (A - (B - C)) $$ (i, j) = (A - B + C) $$ (i, j)" 
  proof -
    fix i j
    assume "i < dim_row (A - B + C)" and "j < dim_col (A - B + C)" note ij = this
    have "(A - (B - C)) $$ (i, j) = (A $$ (i,j) - B $$ (i,j) + C $$ (i,j))" using ij assms by simp
    also have "... = (A - B + C) $$ (i, j)" using assms ij by simp
    finally show "(A - (B - C)) $$ (i, j) = (A - B + C) $$ (i, j)" .
  qed
qed


subsection ‹Complements on complex matrices›

lemma hermitian_square:
  assumes "hermitian M"
  shows "M  carrier_mat (dim_row M) (dim_row M)"
proof -
  have "dim_col M = dim_row M" using assms unfolding hermitian_def adjoint_def
    by (metis adjoint_dim_col)
  thus ?thesis by auto
qed

lemma hermitian_add:
  assumes "A carrier_mat n n"
  and "B carrier_mat n n"
and "hermitian A"
and "hermitian B"
shows "hermitian (A + B)" unfolding hermitian_def
  by (metis adjoint_add assms hermitian_def)

lemma hermitian_minus:
  assumes "A carrier_mat n n"
  and "B carrier_mat n n"
and "hermitian A"
and "hermitian B"
shows "hermitian (A - B)" unfolding hermitian_def
  by (metis adjoint_minus assms hermitian_def)

lemma hermitian_smult:
  fixes a::real
  fixes A::"complex Matrix.mat"
  assumes "A  carrier_mat n n"
and "hermitian A"
shows "hermitian (a m  A)"    
proof -
  have dim: "Complex_Matrix.adjoint A  carrier_mat n n" using assms by (simp add: adjoint_dim) 
  {
    fix i j
    assume "i < n" and "j < n"
    hence "Complex_Matrix.adjoint (a m A) $$ (i,j) = a * (Complex_Matrix.adjoint A $$ (i,j))" 
      using adjoint_scale[of a A] assms by simp
    also have "... = a * (A $$ (i,j))" using assms unfolding hermitian_def by simp
    also have "... = (a m A) $$ (i,j)" using i < n j < n assms by simp
    finally have "Complex_Matrix.adjoint (a m A) $$ (i,j) = (a m A) $$ (i,j)" .
  }
  thus ?thesis using dim assms unfolding hermitian_def by auto
qed

lemma unitary_eigenvalues_norm_square:
  fixes U::"complex Matrix.mat"
  assumes "unitary U"
  and "U  carrier_mat n n"
  and "eigenvalue U k"
shows "conjugate k * k = 1"
proof -
  have "v. eigenvector U v k" using assms unfolding eigenvalue_def by simp
  from this obtain v where "eigenvector U v k" by auto
  define vn where "vn = vec_normalize v"
  have "eigenvector U vn k" using normalize_keep_eigenvector eigenvector U v k
    using assms(2) eigenvector_def vn_def by blast
  have "vn  carrier_vec n"
    using eigenvector U v k assms(2) eigenvector_def normalized_vec_dim vn_def by blast
  have "Complex_Matrix.inner_prod vn vn = 1" using vn = vec_normalize v eigenvector U v k 
        eigenvector_def normalized_vec_norm by blast
  hence "conjugate k * k = conjugate k * k * Complex_Matrix.inner_prod vn vn" by simp
  also have "... = conjugate k * Complex_Matrix.inner_prod vn (k v vn)"
  proof -
    have "k * Complex_Matrix.inner_prod vn vn = Complex_Matrix.inner_prod vn (k v vn)" 
      using inner_prod_smult_left[of vn n vn k] vn  carrier_vec n by simp
    thus ?thesis by simp
  qed
  also have "... = Complex_Matrix.inner_prod (k v vn) (k v vn)"
    using inner_prod_smult_right[of vn n _ k] by (simp add: vn  carrier_vec n)
  also have "... = Complex_Matrix.inner_prod (U *v vn) (U *v vn)" 
    using eigenvector U vn k unfolding eigenvector_def by simp
  also have "... =  
    Complex_Matrix.inner_prod (Complex_Matrix.adjoint U *v (U *v vn)) vn" 
    using adjoint_def_alter[of "U *v vn" n vn n U] assms
    by (metis eigenvector U vn k carrier_matD(1) carrier_vec_dim_vec dim_mult_mat_vec 
        eigenvector_def)
  also have "... = Complex_Matrix.inner_prod vn vn"
  proof -
    have "Complex_Matrix.adjoint U *v (U *v vn) = (Complex_Matrix.adjoint U * U) *v vn"
      using assms
      by (metis eigenvector U vn k adjoint_dim assoc_mult_mat_vec carrier_matD(1) eigenvector_def)
    also have "... = vn" using assms unfolding unitary_def inverts_mat_def
      by (metis eigenvector U vn k assms(1) eigenvector_def one_mult_mat_vec unitary_simps(1))
    finally show ?thesis by simp
  qed
  also have "... = 1" using vn = vec_normalize v eigenvector U v k eigenvector_def 
      normalized_vec_norm by blast
  finally show ?thesis .
qed

lemma outer_prod_smult_left:
  fixes v::"complex Matrix.vec"
  shows "outer_prod (a v v) w = a m outer_prod v w" 
proof -
  define paw where "paw = outer_prod (a v v) w"
  define apw where "apw = a m outer_prod v w"
  have "paw = apw"
  proof
    have "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim
      by (metis carrier_matD(1) carrier_vec_dim_vec index_smult_vec(2))
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim
      by (metis carrier_matD(1) carrier_vec_dim_vec index_smult_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim
      using carrier_vec_dim_vec by blast
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis apw_def carrier_matD(2) carrier_vec_dim_vec smult_carrier_mat)
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = a * (Matrix.vec_index v i) * cnj (Matrix.vec_index w j)" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_prod_smult_right:
  fixes v::"complex Matrix.vec"
  shows "outer_prod v (a v w) = (conjugate a) m outer_prod v w" 
proof -
  define paw where "paw = outer_prod v (a v w)"
  define apw where "apw = (conjugate a) m outer_prod v w"
  have "paw = apw"
  proof
    have "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim
      by (metis carrier_matD(1) carrier_vec_dim_vec)
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim
      by (metis carrier_matD(1) carrier_vec_dim_vec index_smult_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim
      using carrier_vec_dim_vec by (metis carrier_matD(2) index_smult_vec(2)) 
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis apw_def carrier_matD(2) carrier_vec_dim_vec smult_carrier_mat)
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = (conjugate a) * (Matrix.vec_index v i) * cnj (Matrix.vec_index w j)" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_prod_add_left:
  fixes v::"complex Matrix.vec"
  assumes "dim_vec v = dim_vec x"
  shows "outer_prod (v + x) w = outer_prod v w + (outer_prod x w)" 
proof -
  define paw where "paw = outer_prod (v+x) w"
  define apw where "apw = outer_prod v w + (outer_prod x w)"
  have "paw = apw"
  proof
    have rv: "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_add_vec(2) paw_def)
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_add_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have cw: "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim assms
      using carrier_vec_dim_vec by (metis carrier_matD(2)) 
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis apw_def carrier_matD(2) carrier_vec_dim_vec add_carrier_mat)
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = (Matrix.vec_index v i + Matrix.vec_index x i) * 
        cnj (Matrix.vec_index w j)" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = Matrix.vec_index v i * cnj (Matrix.vec_index w j) + 
        Matrix.vec_index x i * cnj (Matrix.vec_index w j)"
        by (simp add: ring_class.ring_distribs(2))
      also have "... = (outer_prod v w) $$ (i,j) + (outer_prod x w) $$ (i,j)" 
        using rv cw dr dc ij assms unfolding outer_prod_def by auto
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_prod_add_right:
  fixes v::"complex Matrix.vec"
  assumes "dim_vec w = dim_vec x"
  shows "outer_prod v (w + x) = outer_prod v w + (outer_prod v x)" 
proof -
  define paw where "paw = outer_prod v (w+x)"
  define apw where "apw = outer_prod v w + (outer_prod v x)"
  have "paw = apw"
  proof
    have rv: "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_add_vec(2) paw_def)
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_add_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have cw: "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim assms
      using carrier_vec_dim_vec
      by (metis carrier_matD(2) index_add_vec(2) paw_def) 
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis assms carrier_matD(2) carrier_vec_dim_vec index_add_mat(3))
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = Matrix.vec_index v i * 
        (cnj (Matrix.vec_index w j + (Matrix.vec_index x j)))" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = Matrix.vec_index v i * cnj (Matrix.vec_index w j) + 
        Matrix.vec_index v i * cnj (Matrix.vec_index x j)"
        by (simp add: ring_class.ring_distribs(1))
      also have "... = (outer_prod v w) $$ (i,j) + (outer_prod v x) $$ (i,j)" 
        using rv cw dr dc ij assms unfolding outer_prod_def by auto
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_prod_minus_left:
  fixes v::"complex Matrix.vec"
  assumes "dim_vec v = dim_vec x"
  shows "outer_prod (v - x) w = outer_prod v w - (outer_prod x w)" 
proof -
  define paw where "paw = outer_prod (v-x) w"
  define apw where "apw = outer_prod v w - (outer_prod x w)"
  have "paw = apw"
  proof
    have rv: "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_minus_vec(2) paw_def)
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_minus_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have cw: "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim assms
      using carrier_vec_dim_vec by (metis carrier_matD(2)) 
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis apw_def carrier_matD(2) carrier_vec_dim_vec minus_carrier_mat)
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = (Matrix.vec_index v i - Matrix.vec_index x i) * 
        cnj (Matrix.vec_index w j)" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = Matrix.vec_index v i * cnj (Matrix.vec_index w j) - 
        Matrix.vec_index x i * cnj (Matrix.vec_index w j)"
        by (simp add: ring_class.ring_distribs)
      also have "... = (outer_prod v w) $$ (i,j) - (outer_prod x w) $$ (i,j)" 
        using rv cw dr dc ij assms unfolding outer_prod_def by auto
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_prod_minus_right:
  fixes v::"complex Matrix.vec"
  assumes "dim_vec w = dim_vec x"
  shows "outer_prod v (w - x) = outer_prod v w - (outer_prod v x)" 
proof -
  define paw where "paw = outer_prod v (w-x)"
  define apw where "apw = outer_prod v w - (outer_prod v x)"
  have "paw = apw"
  proof
    have rv: "dim_row paw = dim_vec v" unfolding paw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec paw_def)
    also have "... = dim_row apw" unfolding apw_def using outer_prod_dim assms
      by (metis carrier_matD(1) carrier_vec_dim_vec index_minus_mat(2))
    finally show dr: "dim_row paw = dim_row apw" .
    have cw: "dim_col paw = dim_vec w" unfolding paw_def using outer_prod_dim assms
      using carrier_vec_dim_vec
      by (metis carrier_matD(2) index_minus_vec(2) paw_def) 
    also have "... = dim_col apw" unfolding apw_def using outer_prod_dim
      by (metis assms carrier_matD(2) carrier_vec_dim_vec index_minus_mat(3))
    finally show dc: "dim_col paw = dim_col apw" .
    show "i j. i < dim_row apw  j < dim_col apw  paw $$ (i, j) = apw $$ (i, j)"
    proof -
      fix i j
      assume  "i < dim_row apw" and "j < dim_col apw" note ij = this
      hence "paw $$ (i,j) = Matrix.vec_index v i * 
        (cnj (Matrix.vec_index w j - (Matrix.vec_index x j)))" 
        using dr dc unfolding  paw_def outer_prod_def by simp
      also have "... = Matrix.vec_index v i * cnj (Matrix.vec_index w j) - 
        Matrix.vec_index v i * cnj (Matrix.vec_index x j)"
        by (simp add: ring_class.ring_distribs)
      also have "... = (outer_prod v w) $$ (i,j) - (outer_prod v x) $$ (i,j)" 
        using rv cw dr dc ij assms unfolding outer_prod_def by auto
      also have "... = apw $$ (i,j)" using dr dc ij unfolding apw_def outer_prod_def by simp
      finally show "paw $$ (i, j) = apw $$ (i, j)" .
    qed
  qed
  thus ?thesis unfolding paw_def apw_def by simp
qed

lemma outer_minus_minus:
  fixes a::"complex Matrix.vec" 
  assumes "dim_vec a = dim_vec b"
  and "dim_vec u = dim_vec v"
  shows "outer_prod (a - b) (u - v) = outer_prod a u - outer_prod a v -
      outer_prod b u +  outer_prod b v"
proof -
  have "outer_prod (a - b) (u - v) = outer_prod a (u - v)
    - outer_prod b (u - v)" using  outer_prod_minus_left assms by simp 
  also have "... = outer_prod a u - outer_prod a v -
    outer_prod b (u - v)" using assms outer_prod_minus_right by simp  
  also have "... = outer_prod a u - outer_prod a v -
    (outer_prod b u - outer_prod b v)" using assms outer_prod_minus_right by simp  
  also have "...  = outer_prod a u - outer_prod a v -
    outer_prod b u +  outer_prod b v"  
  proof (rule mat_minus_minus)
    show "outer_prod b u  carrier_mat (dim_vec b) (dim_vec u)" by simp
    show "outer_prod b v  carrier_mat (dim_vec b) (dim_vec u)" using assms by simp
    show "outer_prod a u - outer_prod a v  carrier_mat (dim_vec b) (dim_vec u)" using assms
      by (metis carrier_vecI minus_carrier_mat outer_prod_dim)
  qed
  finally show ?thesis .
qed

lemma  outer_trace_inner:
  assumes "A  carrier_mat n n"
  and "dim_vec u = n"
and "dim_vec v = n"
  shows "Complex_Matrix.trace (outer_prod u v * A) = Complex_Matrix.inner_prod v (A *v u)"
proof -
 have "Complex_Matrix.trace (outer_prod u v * A) = Complex_Matrix.trace (A * outer_prod u v)"
  proof (rule trace_comm)
    show "A  carrier_mat n n" using assms  by simp
    show "outer_prod u v  carrier_mat n n" using  assms
      by (metis carrier_vec_dim_vec outer_prod_dim)
  qed
  also have "... = Complex_Matrix.inner_prod v (A *v u)" using trace_outer_prod_right[of A n u v]
    assms   carrier_vec_dim_vec  by metis
  finally show ?thesis .
qed

lemma zero_hermitian:
  shows "hermitian (0m n n)" unfolding hermitian_def
    by (metis adjoint_minus hermitian_def hermitian_one minus_r_inv_mat one_carrier_mat)

lemma  trace_1: 
  shows "Complex_Matrix.trace ((1m n)::complex Matrix.mat) =(n::complex)" using one_mat_def
  by (simp add: Complex_Matrix.trace_def Matrix.mat_def)

lemma  trace_add: 
  assumes "square_mat A"
  and "square_mat B"
  and "dim_row A = dim_row B"
  shows "Complex_Matrix.trace (A + B) = Complex_Matrix.trace A + Complex_Matrix.trace B" 
  using  assms by (simp add: Complex_Matrix.trace_def sum.distrib)

lemma bra_vec_carrier:
  shows "bra_vec v  carrier_mat 1 (dim_vec v)"
proof -
  have "dim_row (ket_vec v) = dim_vec v" unfolding ket_vec_def by simp
  thus ?thesis using bra_bra_vec[of v] bra_def[of "ket_vec v"] by simp
qed

lemma mat_mult_ket_carrier:
  assumes "A carrier_mat n m"
shows "A * |v  carrier_mat n 1" using assms
      by (metis bra_bra_vec bra_vec_carrier carrier_matD(1) carrier_matI dagger_of_ket_is_bra 
          dim_row_of_dagger index_mult_mat(2) index_mult_mat(3)) 

lemma mat_mult_ket:
  assumes "A  carrier_mat n m"
and "dim_vec v = m"
shows "A * |v = |A *v v"
proof -
  have rn: "dim_row (A * |v) = n" unfolding times_mat_def using assms by simp
  have co: "dim_col |A *v v = 1" using assms unfolding ket_vec_def by simp
  have cov: "dim_col |v = 1" using assms unfolding ket_vec_def by simp
  have er: "dim_row (A * |v) = dim_row |A *v v" using assms
    by (metis bra_bra_vec bra_vec_carrier carrier_matD(2) dagger_of_ket_is_bra dim_col_of_dagger 
        dim_mult_mat_vec index_mult_mat(2)) 
  have ec: "dim_col (A * |v) = dim_col |A *v v" using assms
    by (metis carrier_matD(2) index_mult_mat(3) mat_mult_ket_carrier)
  {
    fix i::nat 
    fix j::nat
    assume "i < n"
    and "j < 1"
    hence "j = 0" by simp
    have "(A * |v) $$ (i,0) = Matrix.scalar_prod (Matrix.row A i) (Matrix.col |v 0)" 
      using times_mat_def[of A] i < n rn cov by simp
    also have "... = Matrix.scalar_prod (Matrix.row A i) v"  using ket_vec_col  by simp
    also have "... =  |A *v v $$ (i,j)" unfolding mult_mat_vec_def
      using i < n j = 0 assms(1) by auto
  } note idx = this
  have "A * |v = Matrix.mat n 1 (λ(i, j). Matrix.scalar_prod (Matrix.row A i) (Matrix.col |v j))" 
    using assms unfolding times_mat_def ket_vec_def by simp
  also have "... = |A *v v" using er ec idx rn co by auto
  finally show ?thesis .
qed

lemma unitary_density:
  assumes "density_operator R"
  and "unitary U"
  and "R carrier_mat n n"
  and "U carrier_mat n n"
shows "density_operator (U * R * (Complex_Matrix.adjoint U))" unfolding density_operator_def
proof (intro conjI)
  show "Complex_Matrix.positive (U * R * Complex_Matrix.adjoint U)" 
  proof (rule positive_close_under_left_right_mult_adjoint)
    show "U  carrier_mat n n" using assms by simp
    show "R  carrier_mat n n" using assms by simp
    show "Complex_Matrix.positive R" using assms unfolding density_operator_def by simp
  qed
  have "Complex_Matrix.trace (U * R * Complex_Matrix.adjoint U) = 
    Complex_Matrix.trace (Complex_Matrix.adjoint U * U * R)" 
    using trace_comm[of "U * R" n "Complex_Matrix.adjoint U"] assms
    by (metis adjoint_dim  mat_assoc_test(10))
  also have "... = Complex_Matrix.trace R" using assms by simp 
  also have "... = 1" using assms unfolding density_operator_def by simp
  finally show "Complex_Matrix.trace (U * R * Complex_Matrix.adjoint U) = 1" .
qed


subsection ‹Tensor product complements›

lemma tensor_vec_dim[simp]:
  shows "dim_vec (tensor_vec u v) = dim_vec u * (dim_vec v)" 
proof -
  have "length (mult.vec_vec_Tensor (*) (list_of_vec u) (list_of_vec v)) = 
    length (list_of_vec u) * length (list_of_vec v)" 
    using  mult.vec_vec_Tensor_length[of "1::real" "(*)" "list_of_vec u" "list_of_vec v"]
    by (simp add: Matrix_Tensor.mult_def)  
  thus ?thesis unfolding tensor_vec_def by simp
qed

lemma index_tensor_vec[simp]:
  assumes "0 < dim_vec v" 
  and "i < dim_vec u * dim_vec v"
shows "vec_index (tensor_vec u v) i = 
  vec_index u (i div (dim_vec v)) * vec_index v (i mod dim_vec v)" 
proof -
  have m: "Matrix_Tensor.mult (1::complex) (*)" by (simp add: Matrix_Tensor.mult_def) 
  have "length (list_of_vec v) = dim_vec v" using assms by simp
  hence "vec_index (tensor_vec u v) i = (*) (vec_index u (i div dim_vec v)) (vec_index v (i mod dim_vec v))"
    unfolding tensor_vec_def using mult.vec_vec_Tensor_elements assms m
    by (metis (mono_tags, lifting) length_greater_0_conv length_list_of_vec list_of_vec_index 
        mult.vec_vec_Tensor_elements vec_of_list_index)
  thus ?thesis by simp
qed

lemma  outer_prod_tensor_comm:
  fixes a::"complex Matrix.vec"
  fixes u::"complex Matrix.vec"
  assumes "0 < dim_vec a"
  and "0 < dim_vec b"
shows "outer_prod (tensor_vec u v) (tensor_vec a b) = tensor_mat (outer_prod u a) (outer_prod v b)"
proof -
  define ot where "ot = outer_prod (tensor_vec u v) (tensor_vec a b)"
  define to where "to = tensor_mat (outer_prod u a) (outer_prod v b)"
  define dv where "dv = dim_vec v"
  define db where "db = dim_vec b"
  have "ot = to"
  proof
    have ro: "dim_row ot = dim_vec u * dim_vec v" unfolding ot_def outer_prod_def by simp
    have "dim_row to = dim_row (outer_prod u a) * dim_row (outer_prod v b)" 
      unfolding to_def by simp 
    also have "... = dim_vec u * dim_vec v" using outer_prod_dim
      by (metis carrier_matD(1) carrier_vec_dim_vec) 
    finally have rt: "dim_row to = dim_vec u * dim_vec v" .
    show "dim_row ot = dim_row to" using ro rt by simp
    have co: "dim_col ot = dim_vec a * dim_vec b" unfolding ot_def outer_prod_def by simp
    have "dim_col to = dim_col (outer_prod u a) * dim_col (outer_prod v b)" unfolding to_def by simp
    also have "... = dim_vec a * dim_vec b" using outer_prod_dim
      by (metis carrier_matD(2) carrier_vec_dim_vec)
    finally have ct: "dim_col to = dim_vec a * dim_vec b" .
    show "dim_col ot = dim_col to" using co ct by simp
    show "i j. i < dim_row to  j < dim_col to  ot $$ (i, j) = to $$ (i, j)"
    proof -
      fix i j
      assume "i < dim_row to" and "j < dim_col to" note ij = this
      have "ot $$ (i,j) = Matrix.vec_index (tensor_vec u v) i * 
        (conjugate (Matrix.vec_index (tensor_vec a b) j))"
        unfolding ot_def outer_prod_def using ij rt ct by simp
      also have "... = vec_index u (i div dv) * vec_index v (i mod dv)  * 
        (conjugate (Matrix.vec_index (tensor_vec a b) j))" using ij rt assms 
        unfolding dv_def
        by (metis index_tensor_vec less_nat_zero_code nat_0_less_mult_iff neq0_conv)
      also have "... = vec_index u (i div dv) * vec_index v (i mod dv)  *
        (conjugate (vec_index a (j div db) * vec_index b (j mod db)))" using ij ct assms 
        unfolding db_def by simp
      also have "... = vec_index u (i div dv) * vec_index v (i mod dv)  *
        (conjugate (vec_index a (j div db))) * (conjugate (vec_index b (j mod db)))" by simp
      also have "... = vec_index u (i div dv) * (conjugate (vec_index a (j div db))) * 
        vec_index v (i mod dv) * (conjugate (vec_index b (j mod db)))" by simp
      also have "... = (outer_prod u a) $$ (i div dv, j div db) * 
        vec_index v (i mod dv) * (conjugate (vec_index b (j mod db)))" 
      proof -
        have "i div dv < dim_vec u" using ij rt unfolding dv_def
          by (simp add: less_mult_imp_div_less)
        moreover have "j div db < dim_vec a" using ij ct assms unfolding db_def
          by (simp add: less_mult_imp_div_less)
        ultimately have "vec_index u (i div dv) * (conjugate (vec_index a (j div db))) = 
          (outer_prod u a) $$ (i div dv, j div db)" unfolding outer_prod_def by simp
        thus ?thesis by simp
      qed
      also have "... = (outer_prod u a) $$ (i div dv, j div db) * 
        (outer_prod v b) $$ (i mod dv, j mod db)" 
      proof -
        have "i mod dv < dim_vec v" using ij rt unfolding dv_def
          using assms mod_less_divisor
          by (metis less_nat_zero_code mult.commute neq0_conv times_nat.simps(1))
        moreover have "j mod db < dim_vec b" using ij ct assms unfolding db_def
          by (simp add: less_mult_imp_div_less)
        ultimately have "vec_index v (i mod dv) * (conjugate (vec_index b (j mod db))) = 
          (outer_prod v b) $$ (i mod dv, j mod db)" unfolding outer_prod_def by simp
        thus ?thesis by simp
      qed
      also have "... = tensor_mat (outer_prod u a) (outer_prod v b) $$ (i, j)" 
      proof (rule index_tensor_mat[symmetric])
        show "dim_row (outer_prod u a) = dim_vec u" unfolding outer_prod_def by simp
        show "dim_row (outer_prod v b) = dv" unfolding outer_prod_def dv_def by simp
        show "dim_col (outer_prod v b) = db" unfolding db_def outer_prod_def by simp
        show "i < dim_vec u * dv" unfolding dv_def using ij rt by simp
        show "dim_col (outer_prod u a) = dim_vec a" unfolding outer_prod_def by simp
        show "j < dim_vec a * db" unfolding db_def using ij ct by simp
        show "0 < dim_vec a" using assms by simp
        show "0 < db" unfolding db_def using assms by simp
      qed
      finally show "ot $$ (i, j) = to $$ (i, j)" unfolding to_def .
    qed
  qed
  thus ?thesis unfolding ot_def to_def by simp
qed

lemma tensor_mat_adjoint:
  assumes "m1  carrier_mat r1 c1"
    and "m2  carrier_mat r2 c2"
    and "0 < c1"
    and "0 < c2"
and "0 < r1"
and "0 < r2"
  shows "Complex_Matrix.adjoint (tensor_mat m1 m2) = 
  tensor_mat (Complex_Matrix.adjoint m1) (Complex_Matrix.adjoint m2)"
  apply (rule eq_matI, auto)
proof -
  fix i j
  assume "i < dim_col m1 * dim_col m2" and "j < dim_row m1 * dim_row m2" note ij = this
  have c1: "dim_col m1 = c1" using assms by simp
  have r1: "dim_row m1 = r1" using assms by simp
  have c2: "dim_col m2 = c2" using assms by simp
  have r2: "dim_row m2 = r2" using assms by simp
  have "Complex_Matrix.adjoint (m1  m2) $$ (i, j) = conjugate ((m1  m2) $$ (j, i))" 
    using  ij by (simp add: adjoint_eval) 
  also have "... = conjugate (m1 $$ (j div r2, i div c2) * m2 $$ (j mod r2, i mod c2))" 
  proof -
    have "(m1  m2) $$ (j, i) = m1 $$ (j div r2, i div c2) * m2 $$ (j mod r2, i mod c2)"
    proof (rule index_tensor_mat[of m1 r1 c1 m2 r2 c2 j i], (auto simp add: assms ij c1 c2 r1 r2))
      show "j < r1 * r2" using ij r1 r2 by simp
      show "i < c1 * c2" using ij c1 c2 by simp
    qed
    thus ?thesis by simp
  qed
  also have "... = conjugate (m1 $$ (j div r2, i div c2)) * conjugate ( m2 $$ (j mod r2, i mod c2))" 
    by simp
  also have "... = (Complex_Matrix.adjoint m1) $$ (i div c2, j div r2) * 
    conjugate ( m2 $$ (j mod r2, i mod c2))"
    by (metis adjoint_eval c2 ij less_mult_imp_div_less r2)
  also have "... = (Complex_Matrix.adjoint m1) $$ (i div c2, j div r2) *
    (Complex_Matrix.adjoint m2) $$ (i mod c2, j mod r2)"
    using 0 < c2 0 < r2 by (simp add: adjoint_eval c2 r2)
  also have "... = (tensor_mat (Complex_Matrix.adjoint m1) (Complex_Matrix.adjoint m2)) $$ (i,j)"
  proof (rule index_tensor_mat[symmetric], (simp add: ij c1 c2 r1 r2) +)
    show "i < c1 * c2" using ij c1 c2 by simp
    show "j < r1 * r2" using ij r1 r2 by simp
    show "0 < r1" using assms by simp
    show "0 < r2" using assms by simp
  qed
  finally show "Complex_Matrix.adjoint (m1  m2) $$ (i, j) =
           (Complex_Matrix.adjoint m1  Complex_Matrix.adjoint m2) $$ (i, j)" .
qed

lemma index_tensor_mat':
  assumes "0 < dim_col A"
  and "0 < dim_col B"
  and "i < dim_row A * dim_row B"
  and "j < dim_col A * dim_col B"
  shows "(A  B) $$ (i, j) = 
    A $$ (i div (dim_row B), j div (dim_col B)) * B $$ (i mod (dim_row B), j mod (dim_col B))"
  by (rule index_tensor_mat, (simp add: assms)+)

lemma tensor_mat_carrier:
  shows "tensor_mat U V  carrier_mat (dim_row U * dim_row V) (dim_col U * dim_col V)" by auto

lemma tensor_mat_id:
  assumes "0 < d1"
  and "0 < d2"
shows "tensor_mat (1m d1) (1m d2) = 1m (d1 * d2)"
proof (rule eq_matI, auto)
  show "tensor_mat (1m d1) (1m d2) $$ (i, i) = 1" if "i < (d1 * d2)" for i
    using that index_tensor_mat'[of "1m d1" "1m d2"]   
    by (simp add: assms less_mult_imp_div_less)  
next
  show "tensor_mat (1m d1) (1m d2) $$ (i, j) = 0" if "i < d1 * d2" "j < d1 * d2" "i  j" for i j
    using that index_tensor_mat[of "1m d1" d1 d1 "1m d2" d2 d2 i j]
    by (metis assms(1) assms(2) index_one_mat(1) index_one_mat(2) index_one_mat(3) 
        less_mult_imp_div_less mod_less_divisor mult_div_mod_eq mult_not_zero)
qed

lemma tensor_mat_hermitian:
  assumes "A  carrier_mat n n"
  and "B  carrier_mat n' n'"
  and "0 < n"
  and "0 < n'"
  and "hermitian A"
  and "hermitian B"
  shows "hermitian (tensor_mat A B)" using assms by (metis hermitian_def tensor_mat_adjoint)

lemma  tensor_mat_unitary:
  assumes "Complex_Matrix.unitary U"
  and "Complex_Matrix.unitary V"
and "0 < dim_row U"
and "0 < dim_row V"
shows "Complex_Matrix.unitary (tensor_mat U V)" 
proof -
  define UI where "UI = tensor_mat U V"
  have "Complex_Matrix.adjoint UI = 
    tensor_mat (Complex_Matrix.adjoint U) (Complex_Matrix.adjoint V)" unfolding UI_def
  proof (rule tensor_mat_adjoint)
    show "U  carrier_mat (dim_row U) (dim_row U)" using assms unfolding Complex_Matrix.unitary_def 
      by simp
    show "V  carrier_mat (dim_row V) (dim_row V)" using assms unfolding Complex_Matrix.unitary_def 
      by simp
    show "0 < dim_row V" using assms by simp
    show "0 < dim_row U" using assms by simp
    show "0 < dim_row V" using assms by simp
    show "0 < dim_row U" using assms by simp
  qed
  hence "UI * (Complex_Matrix.adjoint UI) = 
    tensor_mat (U * Complex_Matrix.adjoint U) (V * Complex_Matrix.adjoint V)"
    using mult_distr_tensor[of U "Complex_Matrix.adjoint U" "V" "Complex_Matrix.adjoint V"]
    unfolding UI_def
    by (metis (no_types, lifting) Complex_Matrix.unitary_def adjoint_dim_col adjoint_dim_row 
        assms carrier_matD(2) )    
  also have "... = tensor_mat (1m (dim_row U)) (1m (dim_row V))" using assms unitary_simps(2)
    by (metis Complex_Matrix.unitary_def)
  also have "... = (1m (dim_row U * dim_row V))" using tensor_mat_id assms by simp
  finally have "UI * (Complex_Matrix.adjoint UI) = (1m (dim_row U * dim_row V))" .
  hence "inverts_mat UI (Complex_Matrix.adjoint UI)" unfolding inverts_mat_def UI_def by simp
  thus ?thesis using assms unfolding Complex_Matrix.unitary_def UI_def
    by (metis carrier_matD(2) carrier_matI dim_col_tensor_mat dim_row_tensor_mat)
qed


subsection ‹Fixed carrier matrices locale›

text ‹We define a locale for matrices with a fixed number of rows and columns, and define a
finite sum operation on this locale. The \verb+Type_To_Sets+ transfer tools then permits to obtain
lemmata on the locale for free. ›

locale fixed_carrier_mat =
  fixes fc_mats::"'a::field Matrix.mat set" 
  fixes dimR dimC
  assumes fc_mats_carrier: "fc_mats = carrier_mat dimR dimC"
begin

sublocale semigroup_add_on_with fc_mats "(+)"
proof (unfold_locales)
  show "a b. a  fc_mats  b  fc_mats  a + b  fc_mats" using fc_mats_carrier by simp
  show "a b c. a  fc_mats  b  fc_mats  c  fc_mats  a + b + c = a + (b + c)" 
    using fc_mats_carrier by simp
qed

sublocale ab_semigroup_add_on_with fc_mats "(+)"
proof (unfold_locales)
  show "a b. a  fc_mats  b  fc_mats  a + b = b + a" using fc_mats_carrier 
    by (simp add: comm_add_mat) 
qed

sublocale comm_monoid_add_on_with fc_mats "(+)" "0m dimR dimC"
proof (unfold_locales)
  show "0m dimR dimC  fc_mats" using fc_mats_carrier by simp
  show "a. a  fc_mats  0m dimR dimC + a = a" using fc_mats_carrier by simp
qed

sublocale ab_group_add_on_with fc_mats "(+)" "0m dimR dimC" "(-)" "uminus"
proof (unfold_locales)
  show "a. a  fc_mats  - a + a = 0m dimR dimC" using fc_mats_carrier by simp
  show "a b. a  fc_mats  b  fc_mats  a - b = a + - b" using fc_mats_carrier
    by (simp add: add_uminus_minus_mat)
  show "a. a  fc_mats  - a  fc_mats" using fc_mats_carrier by simp
qed
end

lemma (in fixed_carrier_mat) smult_mem:
  assumes "A  fc_mats"
  shows "a m A  fc_mats" using fc_mats_carrier assms by auto

definition (in fixed_carrier_mat) sum_mat where
"sum_mat A I = sum_with (+) (0m dimR dimC) A I"

lemma (in fixed_carrier_mat) sum_mat_empty[simp]:
  shows "sum_mat A {} = 0m dimR dimC" unfolding sum_mat_def by simp

lemma (in fixed_carrier_mat) sum_mat_carrier:
  shows "(i. i  I  (A i) fc_mats)  sum_mat A I  carrier_mat dimR dimC" 
  unfolding sum_mat_def using sum_with_mem[of A I] fc_mats_carrier by auto

lemma (in fixed_carrier_mat) sum_mat_insert:
  assumes "A x  fc_mats" "A ` I  fc_mats"
    and A: "finite I" and x: "x  I"
  shows "sum_mat A (insert x I) = A x + sum_mat A I" unfolding sum_mat_def
  using assms sum_with_insert[of A x I] by simp


subsection ‹A locale for square matrices›

locale cpx_sq_mat = fixed_carrier_mat "fc_mats::complex Matrix.mat set" for fc_mats +
  assumes dim_eq: "dimR = dimC"
  and npos: "0 < dimR"

lemma (in cpx_sq_mat) one_mem:
  shows "1m dimR  fc_mats" using fc_mats_carrier dim_eq by simp

lemma (in cpx_sq_mat) square_mats:
  assumes "A  fc_mats"
  shows "square_mat A" using fc_mats_carrier dim_eq assms by simp

lemma (in cpx_sq_mat) cpx_sq_mat_mult:
  assumes "A  fc_mats"
  and "B  fc_mats"
shows "A * B  fc_mats"
proof -
  have "dim_row (A * B) = dimR" using assms fc_mats_carrier by simp
  moreover have "dim_col (A * B) = dimR" using assms fc_mats_carrier dim_eq by simp
  ultimately show ?thesis using fc_mats_carrier carrier_mat_def dim_eq by auto
qed

lemma (in cpx_sq_mat) sum_mat_distrib_left:
  shows "finite I  R fc_mats  (i. i  I  (A i) fc_mats)  
    sum_mat (λi. R * (A i)) I = R * (sum_mat A I)"
proof (induct rule: finite_induct)
  case empty
  hence a: "sum_mat (λi. R * (A i)) {} = 0m dimR dimC" unfolding sum_mat_def by simp 
  have "sum_mat A {} = 0m dimR dimC" unfolding sum_mat_def by simp
  hence "R * (sum_mat A {}) = 0m dimR dimC" using  fc_mats_carrier
      right_mult_zero_mat[of R dimR dimC dimC] empty dim_eq by simp
  thus ?case using a by simp
next
  case (insert x F)
  hence "sum_mat (λi. R * A i) (insert x F) = R * (A x) + sum_mat (λi. R * A i) F"  
    using sum_mat_insert[of "λi. R * A i" x F] by (simp add: image_subsetI fc_mats_carrier dim_eq) 
  also have "... = R * (A x) + R * (sum_mat A F)" using insert by simp
  also have "... = R * (A x + (sum_mat A F))"
    by (metis dim_eq fc_mats_carrier insert.prems(1) insert.prems(2) insertCI mult_add_distrib_mat 
        sum_mat_carrier)
  also have "... = R * sum_mat A (insert x F)" 
  proof -
    have "A x + (sum_mat A F) = sum_mat A (insert x F)" 
      by (rule sum_mat_insert[symmetric], (auto simp add: insert))
    thus ?thesis by simp
  qed
  finally show ?case .
qed

lemma (in cpx_sq_mat) sum_mat_distrib_right:
  shows "finite I  R fc_mats  (i. i  I  (A i) fc_mats)  
    sum_mat (λi. (A i) * R) I = (sum_mat A I) * R"
proof (induct rule: finite_induct)
  case empty
  hence a: "sum_mat (λi. (A i) * R) {} = 0m dimR dimC" unfolding sum_mat_def by simp 
  have "sum_mat A {} = 0m dimR dimC" unfolding sum_mat_def by simp
  hence "(sum_mat A {}) * R = 0m dimR dimC" using  fc_mats_carrier right_mult_zero_mat[of R ] 
      dim_eq empty by simp
  thus ?case using a by simp
next
  case (insert x F)
  have a: "(λi. A i * R) ` F  fc_mats" using insert cpx_sq_mat_mult
    by (simp add: image_subsetI) 
  have "A x * R  fc_mats" using insert 
      by (metis insertI1 local.fc_mats_carrier mult_carrier_mat dim_eq)
  hence "sum_mat (λi. A i * R) (insert x F) = (A x) * R + sum_mat (λi. A i * R) F"  using insert a
    using sum_mat_insert[of "λi. A i * R" x F]  by (simp add: image_subsetI local.fc_mats_carrier) 
  also have "... = (A x) * R + (sum_mat A F) * R" using insert by simp
  also have "... = (A x + (sum_mat A F)) * R" 
  proof (rule add_mult_distrib_mat[symmetric])
    show "A x  carrier_mat dimR dimC" using insert fc_mats_carrier by simp
    show "sum_mat A F  carrier_mat dimR dimC" using insert fc_mats_carrier sum_mat_carrier by blast
    show "R  carrier_mat dimC dimC" using insert fc_mats_carrier dim_eq by simp
  qed    
  also have "... = sum_mat A (insert x F) * R" 
  proof -
    have "A x + (sum_mat A F) = sum_mat A (insert x F)" 
      by (rule sum_mat_insert[symmetric], (auto simp add: insert))
    thus ?thesis by simp
  qed
  finally show ?case .
qed

lemma (in cpx_sq_mat)  trace_sum_mat:
  fixes A::"'b  complex Matrix.mat"
  shows "finite I  (i. i  I  (A i) fc_mats) 
  Complex_Matrix.trace (sum_mat A I) = ( i I. Complex_Matrix.trace (A i))" unfolding sum_mat_def
proof (induct rule: finite_induct)
  case empty
  then show ?case using trace_zero dim_eq by simp
next
  case (insert x F)
  have "Complex_Matrix.trace (sum_with (+) (0m dimR dimC) A (insert x F)) = 
    Complex_Matrix.trace (A x + sum_with (+) (0m dimR dimC) A F)" 
    using sum_with_insert[of A x F] insert by (simp add: image_subset_iff dim_eq) 
  also have "... = Complex_Matrix.trace (A x) + 
    Complex_Matrix.trace (sum_with (+) (0m dimR dimC) A F)" using trace_add square_mats insert 
    by (metis carrier_matD(1) fc_mats_carrier image_subset_iff insert_iff sum_with_mem) 
  also have "... = Complex_Matrix.trace (A x) + ( i F. Complex_Matrix.trace (A i))" 
      using insert by simp
    also have "... = ( i (insert x F). Complex_Matrix.trace (A i))" 
      using sum_with_insert[of A x F] insert by (simp add: image_subset_iff) 
  finally show ?case .
qed

lemma (in cpx_sq_mat) cpx_sq_mat_smult:
  assumes "A  fc_mats"
  shows "x  m A  fc_mats"
  using assms fc_mats_carrier by auto

lemma (in cpx_sq_mat) mult_add_distrib_right:
  assumes "A fc_mats" "B fc_mats" "C fc_mats"
  shows "A * (B + C) = A * B + A * C"
  using assms fc_mats_carrier mult_add_distrib_mat dim_eq by simp

lemma (in cpx_sq_mat) mult_add_distrib_left:
  assumes "A fc_mats" "B fc_mats" "C fc_mats"
  shows "(B + C) * A = B * A + C * A"
  using assms fc_mats_carrier add_mult_distrib_mat dim_eq by simp

lemma (in cpx_sq_mat)  mult_sum_mat_distrib_left:
  shows "finite I  (i. i  I  (A i) fc_mats)  B  fc_mats 
  (sum_mat (λi. B * (A i)) I) = B * (sum_mat A I)" 
proof (induct rule: finite_induct)
  case empty
  hence "sum_mat A {} = 0m dimR dimC" using sum_mat_empty by simp
  hence "B * (sum_mat A {}) = 0m dimR dimC" using empty by (simp add: fc_mats_carrier dim_eq)
  moreover have "sum_mat (λi. B * (A i)) {} = 0m dimR dimC" using sum_mat_empty[of "λi. B * (A i)"] 
    by simp
  ultimately show ?case by simp
next
  case (insert x F)
  have "sum_mat (λi. B * (A i)) (insert x F) = B * (A x) + sum_mat (λi. B * (A i)) F"
    using sum_with_insert[of "λi. B * (A i)" x F] insert
    by (simp add: image_subset_iff local.sum_mat_def cpx_sq_mat_mult)
  also have "... = B * (A x) + B * (sum_mat A F)" using insert by simp
  also have "... = B * (A x + (sum_mat A F))" 
  proof (rule mult_add_distrib_right[symmetric])
    show "B fc_mats" using insert by simp
    show "A x  fc_mats" using insert by simp
    show "sum_mat A F  fc_mats" using insert by (simp add: fc_mats_carrier sum_mat_carrier) 
  qed
  also have "... = B * (sum_mat A (insert x F))" using sum_with_insert[of A x F] insert 
    by (simp add: image_subset_iff sum_mat_def)
  finally show ?case .
qed

lemma (in cpx_sq_mat)  mult_sum_mat_distrib_right:
  shows "finite I  (i. i  I  (A i) fc_mats)  B  fc_mats 
  (sum_mat (λi. (A i) * B) I) = (sum_mat A I) * B" 
proof (induct rule: finite_induct)
  case empty
  hence "sum_mat A {} = 0m dimR dimC" using sum_mat_empty by simp
  hence "(sum_mat A {}) * B = 0m dimR dimC" using empty by (simp add: fc_mats_carrier dim_eq)
  moreover have "sum_mat (λi. (A i) * B) {} = 0m dimR dimC" by simp
  ultimately show ?case by simp
next
  case (insert x F)
  have "sum_mat (λi. (A i) * B) (insert x F) = (A x) * B + sum_mat (λi. (A i) * B) F"
    using sum_with_insert[of "λi. (A i) * B" x F] insert
    by (simp add: image_subset