# Theory Jordan_Normal_Form.Matrix

```(*
Author:      René Thiemann
*)
(* with contributions from Alexander Bentkamp, Universität des Saarlandes *)

section‹Vectors and Matrices›

text ‹We define vectors as pairs of dimension and a characteristic function from natural numbers
to elements.
Similarly, matrices are defined as triples of two dimensions and one
characteristic function from pairs of natural numbers to elements.
Via a subtype we ensure that the characteristic function always behaves the same
on indices outside the intended one. Hence, every matrix has a unique representation.

In this part we define basic operations like matrix-addition, -multiplication, scalar-product,
etc. We connect these operations to HOL-Algebra with its explicit carrier sets.›

theory Matrix
imports
Polynomial_Interpolation.Ring_Hom
Missing_Ring
Conjugate
"HOL-Algebra.Module"
begin

subsection‹Vectors›

text ‹Here we specify which value should be returned in case
an index is out of bounds. The current solution has the advantage
that in the implementation later on, no index comparison has to be performed.›

definition undef_vec :: "nat ⇒ 'a" where
"undef_vec i ≡ [] ! i"

definition mk_vec :: "nat ⇒ (nat ⇒ 'a) ⇒ (nat ⇒ 'a)" where
"mk_vec n f ≡ λ i. if i < n then f i else undef_vec (i - n)"

typedef 'a vec = "{(n, mk_vec n f) | n f :: nat ⇒ 'a. True}"
by auto

setup_lifting type_definition_vec

lift_definition dim_vec :: "'a vec ⇒ nat" is fst .
lift_definition vec_index :: "'a vec ⇒ (nat ⇒ 'a)" (infixl "\$" 100) is snd .
lift_definition vec :: "nat ⇒ (nat ⇒ 'a) ⇒ 'a vec"
is "λ n f. (n, mk_vec n f)" by auto

lift_definition vec_of_list :: "'a list ⇒ 'a vec" is
"λ v. (length v, mk_vec (length v) (nth v))" by auto

lift_definition list_of_vec :: "'a vec ⇒ 'a list" is
"λ (n,v). map v [0 ..< n]" .

definition carrier_vec :: "nat ⇒ 'a vec set" where
"carrier_vec n = { v . dim_vec v = n}"

lemma carrier_vec_dim_vec[simp]: "v ∈ carrier_vec (dim_vec v)" unfolding carrier_vec_def by auto

lemma dim_vec[simp]: "dim_vec (vec n f) = n" by transfer simp
lemma vec_carrier[simp]: "vec n f ∈ carrier_vec n" unfolding carrier_vec_def by auto
lemma index_vec[simp]: "i < n ⟹ vec n f \$ i = f i" by transfer (simp add: mk_vec_def)
lemma eq_vecI[intro]: "(⋀ i. i < dim_vec w ⟹ v \$ i = w \$ i) ⟹ dim_vec v = dim_vec w
⟹ v = w"
by (transfer, auto simp: mk_vec_def)

lemma carrier_dim_vec: "v ∈ carrier_vec n ⟷ dim_vec v = n"
unfolding carrier_vec_def by auto

lemma carrier_vecD[simp]: "v ∈ carrier_vec n ⟹ dim_vec v = n" using carrier_dim_vec by auto

lemma carrier_vecI: "dim_vec v = n ⟹ v ∈ carrier_vec n" using carrier_dim_vec by auto

instantiation vec :: (plus) plus
begin
definition plus_vec :: "'a vec ⇒ 'a vec ⇒ 'a :: plus vec" where
"v⇩1 + v⇩2 ≡ vec (dim_vec v⇩2) (λ i. v⇩1 \$ i + v⇩2 \$ i)"
instance ..
end

instantiation vec :: (minus) minus
begin
definition minus_vec :: "'a vec ⇒ 'a vec ⇒ 'a :: minus vec" where
"v⇩1 - v⇩2 ≡ vec (dim_vec v⇩2) (λ i. v⇩1 \$ i - v⇩2 \$ i)"
instance ..
end

definition
zero_vec :: "nat ⇒ 'a :: zero vec" ("0⇩v")
where "0⇩v n ≡ vec n (λ i. 0)"

lemma zero_carrier_vec[simp]: "0⇩v n ∈ carrier_vec n"
unfolding zero_vec_def carrier_vec_def by auto

lemma index_zero_vec[simp]: "i < n ⟹ 0⇩v n \$ i = 0" "dim_vec (0⇩v n) = n"
unfolding zero_vec_def by auto

lemma vec_of_dim_0[simp]: "dim_vec v = 0 ⟷ v = 0⇩v 0" by auto

definition
unit_vec :: "nat ⇒ nat ⇒ ('a :: zero_neq_one) vec"
where "unit_vec n i = vec n (λ j. if j = i then 1 else 0)"

lemma index_unit_vec[simp]:
"i < n ⟹ j < n ⟹ unit_vec n i \$ j = (if j = i then 1 else 0)"
"i < n ⟹ unit_vec n i \$ i = 1"
"dim_vec (unit_vec n i) = n"
unfolding unit_vec_def by auto

lemma unit_vec_eq[simp]:
assumes i: "i < n"
shows "(unit_vec n i = unit_vec n j) = (i = j)"
proof -
have "i ≠ j ⟹ unit_vec n i \$ i ≠ unit_vec n j \$ i"
unfolding unit_vec_def using i by simp
then show ?thesis by metis
qed

lemma unit_vec_nonzero[simp]:
assumes i_n: "i < n" shows "unit_vec n i ≠ zero_vec n" (is "?l ≠ ?r")
proof -
have "?l \$ i = 1" "?r \$ i = 0" using i_n by auto
thus "?l ≠ ?r" by auto
qed

lemma unit_vec_carrier[simp]: "unit_vec n i ∈ carrier_vec n"
unfolding unit_vec_def carrier_vec_def by auto

definition unit_vecs:: "nat ⇒ 'a :: zero_neq_one vec list"
where "unit_vecs n = map (unit_vec n) [0..<n]"

text "List of first i units"

fun unit_vecs_first:: "nat ⇒ nat ⇒ 'a::zero_neq_one vec list"
where "unit_vecs_first n 0 = []"
|   "unit_vecs_first n (Suc i) = unit_vecs_first n i @ [unit_vec n i]"

lemma unit_vecs_first: "unit_vecs n = unit_vecs_first n n"
unfolding unit_vecs_def set_map set_upt
proof -
{fix m
have "m ≤ n ⟹ map (unit_vec n) [0..<m] = unit_vecs_first n m"
proof (induct m)
case (Suc m) then have mn:"m≤n" by auto
show ?case unfolding upt_Suc using Suc(1)[OF mn] by auto
qed auto
}
thus "map (unit_vec n) [0..<n] = unit_vecs_first n n" by auto
qed

text "list of last i units"

fun unit_vecs_last:: "nat ⇒ nat ⇒ 'a :: zero_neq_one vec list"
where "unit_vecs_last n 0 = []"
|   "unit_vecs_last n (Suc i) = unit_vec n (n - Suc i) # unit_vecs_last n i"

lemma unit_vecs_last_carrier: "set (unit_vecs_last n i) ⊆ carrier_vec n"
by (induct i;auto)

lemma unit_vecs_last[code]: "unit_vecs n = unit_vecs_last n n"
proof -
{ fix m assume "m = n"
have "m ≤ n ⟹ map (unit_vec n) [n-m..<n] = unit_vecs_last n m"
proof (induction m)
case (Suc m)
then have nm:"n - Suc m < n" by auto
have ins: "[n - Suc m ..< n] = (n - Suc m) # [n - m ..< n]"
unfolding upt_conv_Cons[OF nm]
by (auto simp: Suc.prems Suc_diff_Suc Suc_le_lessD)
show ?case
unfolding ins
unfolding unit_vecs_last.simps
unfolding list.map
using Suc
unfolding Suc by auto
qed simp
}
thus "unit_vecs n = unit_vecs_last n n"
unfolding unit_vecs_def by auto
qed

lemma unit_vecs_carrier: "set (unit_vecs n) ⊆ carrier_vec n"
proof
fix u :: "'a vec"  assume u: "u ∈ set (unit_vecs n)"
then obtain i where "u = unit_vec n i" unfolding unit_vecs_def by auto
then show "u ∈ carrier_vec n"
using unit_vec_carrier by auto
qed

lemma unit_vecs_last_distinct:
"j ≤ n ⟹ i < n - j ⟹ unit_vec n i ∉ set (unit_vecs_last n j)"
by (induction j arbitrary:i, auto)

lemma unit_vecs_first_distinct:
"i ≤ j ⟹ j < n ⟹ unit_vec n j ∉ set (unit_vecs_first n i)"
by (induction i arbitrary:j, auto)

definition map_vec where "map_vec f v ≡ vec (dim_vec v) (λi. f (v \$ i))"

instantiation vec :: (uminus) uminus
begin
definition uminus_vec :: "'a :: uminus vec ⇒ 'a vec" where
"- v ≡ vec (dim_vec v) (λ i. - (v \$ i))"
instance ..
end

definition smult_vec :: "'a :: times ⇒ 'a vec ⇒ 'a vec" (infixl "⋅⇩v" 70)
where "a ⋅⇩v v ≡ vec (dim_vec v) (λ i. a * v \$ i)"

definition scalar_prod :: "'a vec ⇒ 'a vec ⇒ 'a :: semiring_0" (infix "∙" 70)
where "v ∙ w ≡ ∑ i ∈ {0 ..< dim_vec w}. v \$ i * w \$ i"

definition monoid_vec :: "'a itself ⇒ nat ⇒ ('a :: monoid_add vec) monoid" where
"monoid_vec ty n ≡ ⦇
carrier = carrier_vec n,
mult = (+),
one = 0⇩v n⦈"

definition module_vec ::
"'a :: semiring_1 itself ⇒ nat ⇒ ('a,'a vec) module" where
"module_vec ty n ≡ ⦇
carrier = carrier_vec n,
mult = undefined,
one = undefined,
zero = 0⇩v n,
smult = (⋅⇩v)⦈"

lemma monoid_vec_simps:
"mult (monoid_vec ty n) = (+)"
"carrier (monoid_vec ty n) = carrier_vec n"
"one (monoid_vec ty n) = 0⇩v n"
unfolding monoid_vec_def by auto

lemma module_vec_simps:
"add (module_vec ty n) = (+)"
"zero (module_vec ty n) = 0⇩v n"
"carrier (module_vec ty n) = carrier_vec n"
"smult (module_vec ty n) = (⋅⇩v)"
unfolding module_vec_def by auto

definition finsum_vec :: "'a :: monoid_add itself ⇒ nat ⇒ ('c ⇒ 'a vec) ⇒ 'c set ⇒ 'a vec" where
"finsum_vec ty n = finprod (monoid_vec ty n)"

"i < dim_vec v⇩2 ⟹ (v⇩1 + v⇩2) \$ i = v⇩1 \$ i + v⇩2 \$ i" "dim_vec (v⇩1 + v⇩2) = dim_vec v⇩2"
unfolding plus_vec_def by auto

lemma index_minus_vec[simp]:
"i < dim_vec v⇩2 ⟹ (v⇩1 - v⇩2) \$ i = v⇩1 \$ i - v⇩2 \$ i" "dim_vec (v⇩1 - v⇩2) = dim_vec v⇩2"
unfolding minus_vec_def by auto

lemma index_map_vec[simp]:
"i < dim_vec v ⟹ map_vec f v \$ i = f (v \$ i)"
"dim_vec (map_vec f v) = dim_vec v"
unfolding map_vec_def by auto

lemma map_carrier_vec[simp]: "map_vec h v ∈ carrier_vec n = (v ∈ carrier_vec n)"
unfolding map_vec_def carrier_vec_def by auto

lemma index_uminus_vec[simp]:
"i < dim_vec v ⟹ (- v) \$ i = - (v \$ i)"
"dim_vec (- v) = dim_vec v"
unfolding uminus_vec_def by auto

lemma index_smult_vec[simp]:
"i < dim_vec v ⟹ (a ⋅⇩v v) \$ i = a * v \$ i" "dim_vec (a ⋅⇩v v) = dim_vec v"
unfolding smult_vec_def by auto

"v⇩1 ∈ carrier_vec n ⟹ v⇩2 ∈ carrier_vec n ⟹ v⇩1 + v⇩2 ∈ carrier_vec n"
unfolding carrier_vec_def by auto

lemma minus_carrier_vec[simp]:
"v⇩1 ∈ carrier_vec n ⟹ v⇩2 ∈ carrier_vec n ⟹ v⇩1 - v⇩2 ∈ carrier_vec n"
unfolding carrier_vec_def by auto

"(v⇩1 :: 'a :: ab_semigroup_add vec) ∈ carrier_vec n ⟹ v⇩2 ∈ carrier_vec n ⟹ v⇩1 + v⇩2 = v⇩2 + v⇩1"
by (intro eq_vecI, auto simp: ac_simps)

"(v⇩1 :: 'a :: semigroup_add vec) ∈ carrier_vec n ⟹ v⇩2 ∈ carrier_vec n ⟹ v⇩3 ∈ carrier_vec n
⟹ (v⇩1 + v⇩2) + v⇩3 = v⇩1 + (v⇩2 + v⇩3)"
by (intro eq_vecI, auto simp: ac_simps)

lemma zero_minus_vec[simp]: "(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ 0⇩v n - v = - v"
by (intro eq_vecI, auto)

lemma minus_zero_vec[simp]: "(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ v - 0⇩v n = v"
by (intro eq_vecI, auto)

lemma minus_cancel_vec[simp]: "(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ v - v = 0⇩v n"
by (intro eq_vecI, auto)

lemma minus_add_uminus_vec: "(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹
w ∈ carrier_vec n ⟹ v - w = v + (- w)"
by (intro eq_vecI, auto)

lemma comm_monoid_vec: "comm_monoid (monoid_vec TYPE ('a :: comm_monoid_add) n)"
by (unfold_locales, auto simp: monoid_vec_def ac_simps)

lemma left_zero_vec[simp]: "(v :: 'a :: monoid_add vec) ∈ carrier_vec n  ⟹ 0⇩v n + v = v" by auto

lemma right_zero_vec[simp]: "(v :: 'a :: monoid_add vec) ∈ carrier_vec n  ⟹ v + 0⇩v n = v" by auto

lemma uminus_carrier_vec[simp]:
"(- v ∈ carrier_vec n) = (v ∈ carrier_vec n)"
unfolding carrier_vec_def by auto

lemma uminus_r_inv_vec[simp]:
"(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ (v + - v) = 0⇩v n"
by (intro eq_vecI, auto)

lemma uminus_l_inv_vec[simp]:
"(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ (- v + v) = 0⇩v n"
by (intro eq_vecI, auto)

"(v :: 'a :: group_add vec) ∈ carrier_vec n ⟹ ∃ w ∈ carrier_vec n. w + v = 0⇩v n ∧ v + w = 0⇩v n"
by (intro bexI[of _ "- v"], auto)

lemma comm_group_vec: "comm_group (monoid_vec TYPE ('a :: ab_group_add) n)"
by (unfold_locales, insert add_inv_exists_vec, auto simp: monoid_vec_def ac_simps Units_def)

lemmas finsum_vec_insert =
comm_monoid.finprod_insert[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_closed =
comm_monoid.finprod_closed[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_empty =
comm_monoid.finprod_empty[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemma smult_carrier_vec[simp]: "(a ⋅⇩v v ∈ carrier_vec n) = (v ∈ carrier_vec n)"
unfolding carrier_vec_def by auto

lemma scalar_prod_left_zero[simp]: "v ∈ carrier_vec n ⟹ 0⇩v n ∙ v = 0"
unfolding scalar_prod_def
by (rule sum.neutral, auto)

lemma scalar_prod_right_zero[simp]: "v ∈ carrier_vec n ⟹ v ∙ 0⇩v n = 0"
unfolding scalar_prod_def
by (rule sum.neutral, auto)

lemma scalar_prod_left_unit[simp]: assumes v: "(v :: 'a :: semiring_1 vec) ∈ carrier_vec n" and i: "i < n"
shows "unit_vec n i ∙ v = v \$ i"
proof -
let ?f = "λ k. unit_vec n i \$ k * v \$ k"
have id: "(∑k∈{0..<n}. ?f k) = unit_vec n i \$ i * v \$ i + (∑k∈{0..<n} - {i}. ?f k)"
by (rule sum.remove, insert i, auto)
also have "(∑ k∈{0..<n} - {i}. ?f k) = 0"
by (rule sum.neutral, insert i, auto)
finally
show ?thesis unfolding scalar_prod_def using i v by simp
qed

lemma scalar_prod_right_unit[simp]: assumes i: "i < n"
shows "(v :: 'a :: semiring_1 vec) ∙ unit_vec n i = v \$ i"
proof -
let ?f = "λ k. v \$ k * unit_vec n i \$ k"
have id: "(∑k∈{0..<n}. ?f k) = v \$ i * unit_vec n i \$ i + (∑k∈{0..<n} - {i}. ?f k)"
by (rule sum.remove, insert i, auto)
also have "(∑k∈{0..<n} - {i}. ?f k) = 0"
by (rule sum.neutral, insert i, auto)
finally
show ?thesis unfolding scalar_prod_def using i by simp
qed

lemma add_scalar_prod_distrib: assumes v: "v⇩1 ∈ carrier_vec n" "v⇩2 ∈ carrier_vec n" "v⇩3 ∈ carrier_vec n"
shows "(v⇩1 + v⇩2) ∙ v⇩3 = v⇩1 ∙ v⇩3 + v⇩2 ∙ v⇩3"
proof -
have "(∑i∈{0..<dim_vec v⇩3}. (v⇩1 + v⇩2) \$ i * v⇩3 \$ i) = (∑i∈{0..<dim_vec v⇩3}. v⇩1 \$ i * v⇩3 \$ i + v⇩2 \$ i * v⇩3 \$ i)"
by (rule sum.cong, insert v, auto simp: algebra_simps)
thus ?thesis unfolding scalar_prod_def using v by (auto simp: sum.distrib)
qed

lemma scalar_prod_add_distrib: assumes v: "v⇩1 ∈ carrier_vec n" "v⇩2 ∈ carrier_vec n" "v⇩3 ∈ carrier_vec n"
shows "v⇩1 ∙ (v⇩2 + v⇩3) = v⇩1 ∙ v⇩2 + v⇩1 ∙ v⇩3"
proof -
have "(∑i∈{0..<dim_vec v⇩3}. v⇩1 \$ i * (v⇩2 + v⇩3) \$ i) = (∑i∈{0..<dim_vec v⇩3}. v⇩1 \$ i * v⇩2 \$ i + v⇩1 \$ i * v⇩3 \$ i)"
by (rule sum.cong, insert v, auto simp: algebra_simps)
thus ?thesis unfolding scalar_prod_def using v by (auto intro: sum.distrib)
qed

lemma smult_scalar_prod_distrib[simp]: assumes v: "v⇩1 ∈ carrier_vec n" "v⇩2 ∈ carrier_vec n"
shows "(a ⋅⇩v v⇩1) ∙ v⇩2 = a * (v⇩1 ∙ v⇩2)"
unfolding scalar_prod_def sum_distrib_left
by (rule sum.cong, insert v, auto simp: ac_simps)

lemma scalar_prod_smult_distrib[simp]: assumes v: "v⇩1 ∈ carrier_vec n" "v⇩2 ∈ carrier_vec n"
shows "v⇩1 ∙ (a ⋅⇩v v⇩2) = (a :: 'a :: comm_ring) * (v⇩1 ∙ v⇩2)"
unfolding scalar_prod_def sum_distrib_left
by (rule sum.cong, insert v, auto simp: ac_simps)

lemma comm_scalar_prod: assumes "(v⇩1 :: 'a :: comm_semiring_0 vec) ∈ carrier_vec n" "v⇩2 ∈ carrier_vec n"
shows "v⇩1 ∙ v⇩2 = v⇩2 ∙ v⇩1"
unfolding scalar_prod_def
by (rule sum.cong, insert assms, auto simp: ac_simps)

"((a::'a::ring) + b) ⋅⇩v v = a ⋅⇩v v + b ⋅⇩v v"
unfolding smult_vec_def plus_vec_def
by (rule eq_vecI, auto simp: distrib_right)

assumes "v ∈ carrier_vec n" "w ∈ carrier_vec n"
shows "(a::'a::ring) ⋅⇩v (v + w) = a ⋅⇩v v + a ⋅⇩v w"
apply (rule eq_vecI)
unfolding smult_vec_def plus_vec_def
using assms distrib_left by auto

lemma smult_smult_assoc:
"a ⋅⇩v (b ⋅⇩v v) = (a * b::'a::ring) ⋅⇩v v"
apply (rule sym, rule eq_vecI)
unfolding smult_vec_def plus_vec_def using mult.assoc by auto

lemma one_smult_vec [simp]:
"(1::'a::ring_1) ⋅⇩v v = v" unfolding smult_vec_def
by (rule eq_vecI,auto)

lemma uminus_zero_vec[simp]: "- (0⇩v n) = (0⇩v n :: 'a :: group_add vec)"
by (intro eq_vecI, auto)

lemma index_finsum_vec: assumes "finite F" and i: "i < n"
and vs: "vs ∈ F → carrier_vec n"
shows "finsum_vec TYPE('a :: comm_monoid_add) n vs F \$ i = sum (λ f. vs f \$ i) F"
using ‹finite F› vs
proof (induct F)
case (insert f F)
hence IH: "finsum_vec TYPE('a) n vs F \$ i = (∑f∈F. vs f \$ i)"
and vs: "vs ∈ F → carrier_vec n" "vs f ∈ carrier_vec n" by auto
show ?case unfolding finsum_vec_insert[OF insert(1-2) vs]
unfolding sum.insert[OF insert(1-2)]
unfolding IH[symmetric]
by (rule index_add_vec, insert i, insert finsum_vec_closed[OF vs(1)], auto)
qed (insert i, auto simp: finsum_vec_empty)

text ‹Definition of pointwise ordering on vectors for non-strict part, and
strict version is defined in a way such that the @{class order} constraints are satisfied.›

instantiation vec :: (ord) ord
begin

definition less_eq_vec :: "'a vec ⇒ 'a vec ⇒ bool" where
"less_eq_vec v w = (dim_vec v = dim_vec w ∧ (∀ i < dim_vec w. v \$ i ≤ w \$ i))"

definition less_vec :: "'a vec ⇒ 'a vec ⇒ bool" where
"less_vec v w = (v ≤ w ∧ ¬ (w ≤ v))"
instance ..
end

instantiation vec :: (preorder) preorder
begin
instance
by (standard, auto simp: less_vec_def less_eq_vec_def order_trans)
end

instantiation vec :: (order) order
begin
instance
by (standard, intro eq_vecI, auto simp: less_eq_vec_def order.antisym)
end

subsection‹Matrices›

text ‹Similarly as for vectors, we specify which value should be returned in case
an index is out of bounds. It is defined in a way that only few
index comparisons have to be performed in the implementation.›

definition undef_mat :: "nat ⇒ nat ⇒ (nat × nat ⇒ 'a) ⇒ nat × nat ⇒ 'a" where
"undef_mat nr nc f ≡ λ (i,j). [[f (i,j). j <- [0 ..< nc]] . i <- [0 ..< nr]] ! i ! j"

lemma undef_cong_mat: assumes "⋀ i j. i < nr ⟹ j < nc ⟹ f (i,j) = f' (i,j)"
shows "undef_mat nr nc f x = undef_mat nr nc f' x"
proof (cases x)
case (Pair i j)
have nth_map_ge: "⋀ i xs. ¬ i < length xs ⟹ xs ! i = [] ! (i - length xs)"
by (metis append_Nil2 nth_append)
note [simp] = Pair undef_mat_def nth_map_ge[of i] nth_map_ge[of j]
show ?thesis
by (cases "i < nr", simp, cases "j < nc", insert assms, auto)
qed

definition mk_mat :: "nat ⇒ nat ⇒ (nat × nat ⇒ 'a) ⇒ (nat × nat ⇒ 'a)" where
"mk_mat nr nc f ≡ λ (i,j). if i < nr ∧ j < nc then f (i,j) else undef_mat nr nc f (i,j)"

lemma cong_mk_mat: assumes "⋀ i j. i < nr ⟹ j < nc ⟹ f (i,j) = f' (i,j)"
shows "mk_mat nr nc f = mk_mat nr nc f'"
using undef_cong_mat[of nr nc f f', OF assms]
using assms unfolding mk_mat_def
by auto

typedef 'a mat = "{(nr, nc, mk_mat nr nc f) | nr nc f :: nat × nat ⇒ 'a. True}"
by auto

setup_lifting type_definition_mat

lift_definition dim_row :: "'a mat ⇒ nat" is fst .
lift_definition dim_col :: "'a mat ⇒ nat" is "fst o snd" .
lift_definition index_mat :: "'a mat ⇒ (nat × nat ⇒ 'a)" (infixl "\$\$" 100) is "snd o snd" .
lift_definition mat :: "nat ⇒ nat ⇒ (nat × nat ⇒ 'a) ⇒ 'a mat"
is "λ nr nc f. (nr, nc, mk_mat nr nc f)" by auto
lift_definition mat_of_row_fun :: "nat ⇒ nat ⇒ (nat ⇒ 'a vec) ⇒ 'a mat" ("mat⇩r")
is "λ nr nc f. (nr, nc, mk_mat nr nc (λ (i,j). f i \$ j))" by auto

definition mat_to_list :: "'a mat ⇒ 'a list list" where
"mat_to_list A = [ [A \$\$ (i,j) . j <- [0 ..< dim_col A]] . i <- [0 ..< dim_row A]]"

fun square_mat :: "'a mat ⇒ bool" where "square_mat A = (dim_col A = dim_row A)"

definition upper_triangular :: "'a::zero mat ⇒ bool"
where "upper_triangular A ≡
∀i < dim_row A. ∀ j < i. A \$\$ (i,j) = 0"

lemma upper_triangularD[elim] :
"upper_triangular A ⟹ j < i ⟹ i < dim_row A ⟹ A \$\$ (i,j) = 0"
unfolding upper_triangular_def by auto

lemma upper_triangularI[intro] :
"(⋀i j. j < i ⟹ i < dim_row A ⟹ A \$\$ (i,j) = 0) ⟹ upper_triangular A"
unfolding upper_triangular_def by auto

lemma dim_row_mat[simp]: "dim_row (mat nr nc f) = nr" "dim_row (mat⇩r nr nc g) = nr"
by (transfer, simp)+

lemma dim_col_mat[simp]: "dim_col (mat nr nc f) = nc" "dim_col (mat⇩r nr nc g) = nc"
by (transfer, simp)+

definition carrier_mat :: "nat ⇒ nat ⇒ 'a mat set"
where "carrier_mat nr nc = { m . dim_row m = nr ∧ dim_col m = nc}"

lemma carrier_mat_triv[simp]: "m ∈ carrier_mat (dim_row m) (dim_col m)"
unfolding carrier_mat_def by auto

lemma mat_carrier[simp]: "mat nr nc f ∈ carrier_mat nr nc"
unfolding carrier_mat_def by auto

definition elements_mat :: "'a mat ⇒ 'a set"
where "elements_mat A = set [A \$\$ (i,j). i <- [0 ..< dim_row A], j <- [0 ..< dim_col A]]"

lemma elements_matD [dest]:
"a ∈ elements_mat A ⟹ ∃i j. i < dim_row A ∧ j < dim_col A ∧ a = A \$\$ (i,j)"
unfolding elements_mat_def by force

lemma elements_matI [intro]:
"A ∈ carrier_mat nr nc ⟹ i < nr ⟹ j < nc ⟹ a = A \$\$ (i,j) ⟹ a ∈ elements_mat A"
unfolding elements_mat_def carrier_mat_def by force

lemma index_mat[simp]:  "i < nr ⟹ j < nc ⟹ mat nr nc f \$\$ (i,j) = f (i,j)"
"i < nr ⟹ j < nc ⟹ mat⇩r nr nc g \$\$ (i,j) = g i \$ j"

lemma eq_matI[intro]: "(⋀ i j . i < dim_row B ⟹ j < dim_col B ⟹ A \$\$ (i,j) = B \$\$ (i,j))
⟹ dim_row A = dim_row B
⟹ dim_col A = dim_col B
⟹ A = B"
by (transfer, auto intro!: cong_mk_mat, auto simp: mk_mat_def)

lemma carrier_matI[intro]:
assumes "dim_row A = nr" "dim_col A = nc" shows  "A ∈ carrier_mat nr nc"
using assms unfolding carrier_mat_def by auto

lemma carrier_matD[dest,simp]: assumes "A ∈ carrier_mat nr nc"
shows "dim_row A = nr" "dim_col A = nc" using assms
unfolding carrier_mat_def by auto

lemma cong_mat: assumes "nr = nr'" "nc = nc'" "⋀ i j. i < nr ⟹ j < nc ⟹
f (i,j) = f' (i,j)" shows "mat nr nc f = mat nr' nc' f'"
by (rule eq_matI, insert assms, auto)

definition row :: "'a mat ⇒ nat ⇒ 'a vec" where
"row A i = vec (dim_col A) (λ j. A \$\$ (i,j))"

definition rows :: "'a mat ⇒ 'a vec list" where
"rows A = map (row A) [0..<dim_row A]"

lemma row_carrier[simp]: "row A i ∈ carrier_vec (dim_col A)" unfolding row_def by auto

lemma rows_carrier[simp]: "set (rows A) ⊆ carrier_vec (dim_col A)" unfolding rows_def by auto

lemma length_rows[simp]: "length (rows A) = dim_row A" unfolding rows_def by auto

lemma nth_rows[simp]: "i < dim_row A ⟹ rows A ! i = row A i"
unfolding rows_def by auto

lemma row_mat_of_row_fun[simp]: "i < nr ⟹ dim_vec (f i) = nc ⟹ row (mat⇩r nr nc f) i = f i"
by (rule eq_vecI, auto simp: row_def)

lemma set_rows_carrier:
assumes "A ∈ carrier_mat m n" and "v ∈ set (rows A)" shows "v ∈ carrier_vec n"
using assms by (auto simp: rows_def row_def)

definition mat_of_rows :: "nat ⇒ 'a vec list ⇒ 'a mat"
where "mat_of_rows n rs = mat (length rs) n (λ(i,j). rs ! i \$ j)"

definition mat_of_rows_list :: "nat ⇒ 'a list list ⇒ 'a mat" where
"mat_of_rows_list nc rs = mat (length rs) nc (λ (i,j). rs ! i ! j)"

lemma mat_of_rows_carrier[simp]:
"mat_of_rows n vs ∈ carrier_mat (length vs) n"
"dim_row (mat_of_rows n vs) = length vs"
"dim_col (mat_of_rows n vs) = n"
unfolding mat_of_rows_def by auto

lemma mat_of_rows_row[simp]:
assumes i:"i < length vs" and n: "vs ! i ∈ carrier_vec n"
shows "row (mat_of_rows n vs) i = vs ! i"
unfolding mat_of_rows_def row_def using n i by auto

lemma rows_mat_of_rows[simp]:
assumes "set vs ⊆ carrier_vec n" shows "rows (mat_of_rows n vs) = vs"
unfolding rows_def apply (rule nth_equalityI)
using assms unfolding subset_code(1) by auto

lemma mat_of_rows_rows[simp]:
"mat_of_rows (dim_col A) (rows A) = A"
unfolding mat_of_rows_def by (rule, auto simp: row_def)

definition col :: "'a mat ⇒ nat ⇒ 'a vec" where
"col A j = vec (dim_row A) (λ i. A \$\$ (i,j))"

definition cols :: "'a mat ⇒ 'a vec list" where
"cols A = map (col A) [0..<dim_col A]"

definition mat_of_cols :: "nat ⇒ 'a vec list ⇒ 'a mat"
where "mat_of_cols n cs = mat n (length cs) (λ(i,j). cs ! j \$ i)"

definition mat_of_cols_list :: "nat ⇒ 'a list list ⇒ 'a mat" where
"mat_of_cols_list nr cs = mat nr (length cs) (λ (i,j). cs ! j ! i)"

lemma col_dim[simp]: "col A i ∈ carrier_vec (dim_row A)" unfolding col_def by auto

lemma dim_col[simp]: "dim_vec (col A i) = dim_row A" by auto

lemma cols_dim[simp]: "set (cols A) ⊆ carrier_vec (dim_row A)" unfolding cols_def by auto

lemma cols_length[simp]: "length (cols A) = dim_col A" unfolding cols_def by auto

lemma cols_nth[simp]: "i < dim_col A ⟹ cols A ! i = col A i"
unfolding cols_def by auto

lemma mat_of_cols_carrier[simp]:
"mat_of_cols n vs ∈ carrier_mat n (length vs)"
"dim_row (mat_of_cols n vs) = n"
"dim_col (mat_of_cols n vs) = length vs"
unfolding mat_of_cols_def by auto

lemma col_mat_of_cols[simp]:
assumes j:"j < length vs" and n: "vs ! j ∈ carrier_vec n"
shows "col (mat_of_cols n vs) j = vs ! j"
unfolding mat_of_cols_def col_def using j n by auto

lemma cols_mat_of_cols[simp]:
assumes "set vs ⊆ carrier_vec n" shows "cols (mat_of_cols n vs) = vs"
unfolding cols_def apply(rule nth_equalityI)
using assms unfolding subset_code(1) by auto

lemma mat_of_cols_cols[simp]:
"mat_of_cols (dim_row A) (cols A) = A"
unfolding mat_of_cols_def by (rule, auto simp: col_def)

instantiation mat :: (ord) ord
begin

definition less_eq_mat :: "'a mat ⇒ 'a mat ⇒ bool" where
"less_eq_mat A B = (dim_row A = dim_row B ∧ dim_col A = dim_col B ∧
(∀ i < dim_row B. ∀ j < dim_col B. A \$\$ (i,j) ≤ B \$\$ (i,j)))"

definition less_mat :: "'a mat ⇒ 'a mat ⇒ bool" where
"less_mat A B = (A ≤ B ∧ ¬ (B ≤ A))"
instance ..
end

instantiation mat :: (preorder) preorder
begin
instance
proof (standard, auto simp: less_mat_def less_eq_mat_def, goal_cases)
case (1 A B C i j)
thus ?case using order_trans[of "A \$\$ (i,j)" "B \$\$ (i,j)" "C \$\$ (i,j)"] by auto
qed
end

instantiation mat :: (order) order
begin
instance
by (standard, intro eq_matI, auto simp: less_eq_mat_def order.antisym)
end

instantiation mat :: (plus) plus
begin
definition plus_mat :: "('a :: plus) mat ⇒ 'a mat ⇒ 'a mat" where
"A + B ≡ mat (dim_row B) (dim_col B) (λ ij. A \$\$ ij + B \$\$ ij)"
instance ..
end

definition map_mat :: "('a ⇒ 'b) ⇒ 'a mat ⇒ 'b mat" where
"map_mat f A ≡ mat (dim_row A) (dim_col A) (λ ij. f (A \$\$ ij))"

definition smult_mat :: "'a :: times ⇒ 'a mat ⇒ 'a mat" (infixl "⋅⇩m" 70)
where "a ⋅⇩m A ≡ map_mat (λ b. a * b) A"

definition zero_mat :: "nat ⇒ nat ⇒ 'a :: zero mat" ("0⇩m") where
"0⇩m nr nc ≡ mat nr nc (λ ij. 0)"

lemma elements_0_mat [simp]: "elements_mat (0⇩m nr nc) ⊆ {0}"
unfolding elements_mat_def zero_mat_def by auto

definition transpose_mat :: "'a mat ⇒ 'a mat" where
"transpose_mat A ≡ mat (dim_col A) (dim_row A) (λ (i,j). A \$\$ (j,i))"

definition one_mat :: "nat ⇒ 'a :: {zero,one} mat" ("1⇩m") where
"1⇩m n ≡ mat n n (λ (i,j). if i = j then 1 else 0)"

instantiation mat :: (uminus) uminus
begin
definition uminus_mat :: "'a :: uminus mat ⇒ 'a mat" where
"- A ≡ mat (dim_row A) (dim_col A) (λ ij. - (A \$\$ ij))"
instance ..
end

instantiation mat :: (minus) minus
begin
definition minus_mat :: "('a :: minus) mat ⇒ 'a mat ⇒ 'a mat" where
"A - B ≡ mat (dim_row B) (dim_col B) (λ ij. A \$\$ ij - B \$\$ ij)"
instance ..
end

instantiation mat :: (semiring_0) times
begin
definition times_mat :: "'a :: semiring_0 mat ⇒ 'a mat ⇒ 'a mat"
where "A * B ≡ mat (dim_row A) (dim_col B) (λ (i,j). row A i ∙ col B j)"
instance ..
end

definition mult_mat_vec :: "'a :: semiring_0 mat ⇒ 'a vec ⇒ 'a vec" (infixl "*⇩v" 70)
where "A *⇩v v ≡ vec (dim_row A) (λ i. row A i ∙ v)"

definition inverts_mat :: "'a :: semiring_1 mat ⇒ 'a mat ⇒ bool" where
"inverts_mat A B ≡ A * B = 1⇩m (dim_row A)"

definition invertible_mat :: "'a :: semiring_1 mat ⇒ bool"
where "invertible_mat A ≡ square_mat A ∧ (∃B. inverts_mat A B ∧ inverts_mat B A)"

definition monoid_mat :: "'a :: monoid_add itself ⇒ nat ⇒ nat ⇒ 'a mat monoid" where
"monoid_mat ty nr nc ≡ ⦇
carrier = carrier_mat nr nc,
mult = (+),
one = 0⇩m nr nc⦈"

definition ring_mat :: "'a :: semiring_1 itself ⇒ nat ⇒ 'b ⇒ ('a mat,'b) ring_scheme" where
"ring_mat ty n b ≡ ⦇
carrier = carrier_mat n n,
mult = (*),
one = 1⇩m n,
zero = 0⇩m n n,
… = b⦈"

definition module_mat :: "'a :: semiring_1 itself ⇒ nat ⇒ nat ⇒ ('a,'a mat)module" where
"module_mat ty nr nc ≡ ⦇
carrier = carrier_mat nr nc,
mult = (*),
one = 1⇩m nr,
zero = 0⇩m nr nc,
smult = (⋅⇩m)⦈"

lemma ring_mat_simps:
"mult (ring_mat ty n b) = (*)"
"add (ring_mat ty n b) = (+)"
"one (ring_mat ty n b) = 1⇩m n"
"zero (ring_mat ty n b) = 0⇩m n n"
"carrier (ring_mat ty n b) = carrier_mat n n"
unfolding ring_mat_def by auto

lemma module_mat_simps:
"mult (module_mat ty nr nc) = (*)"
"add (module_mat ty nr nc) = (+)"
"one (module_mat ty nr nc) = 1⇩m nr"
"zero (module_mat ty nr nc) = 0⇩m nr nc"
"carrier (module_mat ty nr nc) = carrier_mat nr nc"
"smult (module_mat ty nr nc) = (⋅⇩m)"
unfolding module_mat_def by auto

lemma index_zero_mat[simp]: "i < nr ⟹ j < nc ⟹ 0⇩m nr nc \$\$ (i,j) = 0"
"dim_row (0⇩m nr nc) = nr" "dim_col (0⇩m nr nc) = nc"
unfolding zero_mat_def by auto

lemma index_one_mat[simp]: "i < n ⟹ j < n ⟹ 1⇩m n \$\$ (i,j) = (if i = j then 1 else 0)"
"dim_row (1⇩m n) = n" "dim_col (1⇩m n) = n"
unfolding one_mat_def by auto

"i < dim_row B ⟹ j < dim_col B ⟹ (A + B) \$\$ (i,j) = A \$\$ (i,j) + B \$\$ (i,j)"
"dim_row (A + B) = dim_row B" "dim_col (A + B) = dim_col B"
unfolding plus_mat_def by auto

lemma index_minus_mat[simp]:
"i < dim_row B ⟹ j < dim_col B ⟹ (A - B) \$\$ (i,j) = A \$\$ (i,j) - B \$\$ (i,j)"
"dim_row (A - B) = dim_row B" "dim_col (A - B) = dim_col B"
unfolding minus_mat_def by auto

lemma index_map_mat[simp]:
"i < dim_row A ⟹ j < dim_col A ⟹ map_mat f A \$\$ (i,j) = f (A \$\$ (i,j))"
"dim_row (map_mat f A) = dim_row A" "dim_col (map_mat f A) = dim_col A"
unfolding map_mat_def by auto

lemma index_smult_mat[simp]:
"i < dim_row A ⟹ j < dim_col A ⟹ (a ⋅⇩m A) \$\$ (i,j) = a * A \$\$ (i,j)"
"dim_row (a ⋅⇩m A) = dim_row A" "dim_col (a ⋅⇩m A) = dim_col A"
unfolding smult_mat_def by auto

lemma index_uminus_mat[simp]:
"i < dim_row A ⟹ j < dim_col A ⟹ (- A) \$\$ (i,j) = - (A \$\$ (i,j))"
"dim_row (- A) = dim_row A" "dim_col (- A) = dim_col A"
unfolding uminus_mat_def by auto

lemma index_transpose_mat[simp]:
"i < dim_col A ⟹ j < dim_row A ⟹ transpose_mat A \$\$ (i,j) = A \$\$ (j,i)"
"dim_row (transpose_mat A) = dim_col A" "dim_col (transpose_mat A) = dim_row A"
unfolding transpose_mat_def by auto

lemma index_mult_mat[simp]:
"i < dim_row A ⟹ j < dim_col B ⟹ (A * B) \$\$ (i,j) = row A i ∙ col B j"
"dim_row (A * B) = dim_row A" "dim_col (A * B) = dim_col B"
by (auto simp: times_mat_def)

lemma dim_mult_mat_vec[simp]: "dim_vec (A *⇩v v) = dim_row A"
by (auto simp: mult_mat_vec_def)

lemma index_mult_mat_vec[simp]: "i < dim_row A ⟹ (A *⇩v v) \$ i = row A i ∙ v"
by (auto simp: mult_mat_vec_def)

lemma index_row[simp]:
"i < dim_row A ⟹ j < dim_col A ⟹ row A i \$ j = A \$\$ (i,j)"
"dim_vec (row A i) = dim_col A"
by (auto simp: row_def)

lemma index_col[simp]: "i < dim_row A ⟹ j < dim_col A ⟹ col A j \$ i = A \$\$ (i,j)"
by (auto simp: col_def)

lemma upper_triangular_one[simp]: "upper_triangular (1⇩m n)"
by (rule, auto)

lemma upper_triangular_zero[simp]: "upper_triangular (0⇩m n n)"
by (rule, auto)

lemma mat_row_carrierI[intro,simp]: "mat⇩r nr nc r ∈ carrier_mat nr nc"
by (unfold carrier_mat_def carrier_vec_def, auto)

lemma eq_rowI: assumes rows: "⋀ i. i < dim_row B ⟹ row A i = row B i"
and dims: "dim_row A = dim_row B" "dim_col A = dim_col B"
shows "A = B"
proof (rule eq_matI[OF _ dims])
fix i j
assume i: "i < dim_row B" and j: "j < dim_col B"
from rows[OF i] have id: "row A i \$ j = row B i \$ j" by simp
show "A \$\$ (i, j) = B \$\$ (i, j)"
using index_row(1)[OF i j, folded id] index_row(1)[of i A j] i j dims
by auto
qed

lemma elements_mat_map[simp]: "elements_mat (map_mat f A) = f ` elements_mat A"
by fastforce

lemma row_mat[simp]: "i < nr ⟹ row (mat nr nc f) i = vec nc (λ j. f (i,j))"
by auto

lemma col_mat[simp]: "j < nc ⟹ col (mat nr nc f) j = vec nr (λ i. f (i,j))"
by auto

lemma zero_carrier_mat[simp]: "0⇩m nr nc ∈ carrier_mat nr nc"
unfolding carrier_mat_def by auto

lemma smult_carrier_mat[simp]:
"A ∈ carrier_mat nr nc ⟹ k ⋅⇩m A ∈ carrier_mat nr nc"
unfolding carrier_mat_def by auto

"B ∈ carrier_mat nr nc ⟹ A + B ∈ carrier_mat nr nc"
unfolding carrier_mat_def by force

lemma one_carrier_mat[simp]: "1⇩m n ∈ carrier_mat n n"
unfolding carrier_mat_def by auto

lemma uminus_carrier_mat:
"A ∈ carrier_mat nr nc ⟹ (- A ∈ carrier_mat nr nc)"
unfolding carrier_mat_def by auto

lemma uminus_carrier_iff_mat[simp]:
"(- A ∈ carrier_mat nr nc) = (A ∈ carrier_mat nr nc)"
unfolding carrier_mat_def by auto

lemma minus_carrier_mat:
"B ∈ carrier_mat nr nc ⟹ (A - B ∈ carrier_mat nr nc)"
unfolding carrier_mat_def by auto

lemma transpose_carrier_mat[simp]: "(transpose_mat A ∈ carrier_mat nc nr) = (A ∈ carrier_mat nr nc)"
unfolding carrier_mat_def by auto

lemma row_carrier_vec[simp]: "i < nr ⟹ A ∈ carrier_mat nr nc ⟹ row A i ∈ carrier_vec nc"
unfolding carrier_vec_def by auto

lemma col_carrier_vec[simp]: "j < nc ⟹ A ∈ carrier_mat nr nc ⟹ col A j ∈ carrier_vec nr"
unfolding carrier_vec_def by auto

lemma mult_carrier_mat[simp]:
"A ∈ carrier_mat nr n ⟹ B ∈ carrier_mat n nc ⟹ A * B ∈ carrier_mat nr nc"
unfolding carrier_mat_def by auto

lemma mult_mat_vec_carrier[simp]:
"A ∈ carrier_mat nr n ⟹ v ∈ carrier_vec n ⟹ A *⇩v v ∈ carrier_vec nr"
unfolding carrier_mat_def carrier_vec_def by auto

"(A :: 'a :: comm_monoid_add mat) ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹ A + B = B + A"
by (intro eq_matI, auto simp: ac_simps)

lemma minus_r_inv_mat[simp]:
"(A :: 'a :: group_add mat) ∈ carrier_mat nr nc ⟹ (A - A) = 0⇩m nr nc"
by (intro eq_matI, auto)

lemma uminus_l_inv_mat[simp]:
"(A :: 'a :: group_add mat) ∈ carrier_mat nr nc ⟹ (- A + A) = 0⇩m nr nc"
by (intro eq_matI, auto)

"(A :: 'a :: group_add mat) ∈ carrier_mat nr nc ⟹ ∃ B ∈ carrier_mat nr nc. B + A = 0⇩m nr nc ∧ A + B = 0⇩m nr nc"
by (intro bexI[of _ "- A"], auto)

"(A :: 'a :: monoid_add mat) ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹ C ∈ carrier_mat nr nc
⟹ (A + B) + C = A + (B + C)"
by (intro eq_matI, auto simp: ac_simps)

assumes "A ∈ carrier_mat nr nc"
and "B ∈ carrier_mat nr nc"
shows "- (A + B) = - B + - A"
by (intro eq_matI, insert assms, auto simp: minus_add)

lemma transpose_transpose[simp]:
"transpose_mat (transpose_mat A) = A"
by (intro eq_matI, auto)

lemma transpose_one[simp]: "transpose_mat (1⇩m n) = (1⇩m n)"
by auto

lemma row_transpose[simp]:
"j < dim_col A ⟹ row (transpose_mat A) j = col A j"
unfolding row_def col_def
by (intro eq_vecI, auto)

lemma col_transpose[simp]:
"i < dim_row A ⟹ col (transpose_mat A) i = row A i"
unfolding row_def col_def
by (intro eq_vecI, auto)

lemma row_zero[simp]:
"i < nr ⟹ row (0⇩m nr nc) i = 0⇩v nc"
by (intro eq_vecI, auto)

lemma col_zero[simp]:
"j < nc ⟹ col (0⇩m nr nc) j = 0⇩v nr"
by (intro eq_vecI, auto)

lemma row_one[simp]:
"i < n ⟹ row (1⇩m n) i = unit_vec n i"
by (intro eq_vecI, auto)

lemma col_one[simp]:
"j < n ⟹ col (1⇩m n) j = unit_vec n j"
by (intro eq_vecI, auto)

lemma transpose_add: "A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc
⟹ transpose_mat (A + B) = transpose_mat A + transpose_mat B"
by (intro eq_matI, auto)

lemma transpose_minus: "A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc
⟹ transpose_mat (A - B) = transpose_mat A - transpose_mat B"
by (intro eq_matI, auto)

lemma transpose_uminus: "transpose_mat (- A) = - (transpose_mat A)"
by (intro eq_matI, auto)

"A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹ i < nr
⟹ row (A + B) i = row A i + row B i"
"i < dim_row A ⟹ dim_row B = dim_row A ⟹ dim_col B = dim_col A ⟹ row (A + B) i = row A i + row B i"
by (rule eq_vecI, auto)

"A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹ j < nc
⟹ col (A + B) j = col A j + col B j"
by (rule eq_vecI, auto)

lemma row_mult[simp]: assumes m: "A ∈ carrier_mat nr n" "B ∈ carrier_mat n nc"
and i: "i < nr"
shows "row (A * B) i = vec nc (λ j. row A i ∙ col B j)"
by (rule eq_vecI, insert m i, auto)

lemma col_mult[simp]: assumes m: "A ∈ carrier_mat nr n" "B ∈ carrier_mat n nc"
and j: "j < nc"
shows "col (A * B) j = vec nr (λ i. row A i ∙ col B j)"
by (rule eq_vecI, insert m j, auto)

lemma transpose_mult:
"(A :: 'a :: comm_semiring_0 mat) ∈ carrier_mat nr n ⟹ B ∈ carrier_mat n nc
⟹ transpose_mat (A * B) = transpose_mat B * transpose_mat A"
by (intro eq_matI, auto simp: comm_scalar_prod[of _ n])

"(A :: 'a :: monoid_add mat) ∈ carrier_mat nr nc  ⟹ 0⇩m nr nc + A = A"
by (intro eq_matI, auto)

lemma add_uminus_minus_mat: "A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹
A + (- B) = A - (B :: 'a :: group_add mat)"
by (intro eq_matI, auto)

lemma right_add_zero_mat[simp]: "A ∈ carrier_mat nr nc ⟹
A + 0⇩m nr nc = (A :: 'a :: monoid_add mat)"
by (intro eq_matI, auto)

lemma left_mult_zero_mat:
"A ∈ carrier_mat n nc ⟹ 0⇩m nr n * A = 0⇩m nr nc"
by (intro eq_matI, auto)

lemma left_mult_zero_mat'[simp]: "dim_row A = n ⟹ 0⇩m nr n * A = 0⇩m nr (dim_col A)"
by (rule left_mult_zero_mat, unfold carrier_mat_def, simp)

lemma right_mult_zero_mat:
"A ∈ carrier_mat nr n ⟹ A * 0⇩m n nc = 0⇩m nr nc"
by (intro eq_matI, auto)

lemma right_mult_zero_mat'[simp]: "dim_col A = n ⟹ A * 0⇩m n nc = 0⇩m (dim_row A) nc"
by (rule right_mult_zero_mat, unfold carrier_mat_def, simp)

lemma left_mult_one_mat:
"(A :: 'a :: semiring_1 mat) ∈ carrier_mat nr nc ⟹ 1⇩m nr * A = A"
by (intro eq_matI, auto)

lemma left_mult_one_mat'[simp]: "dim_row (A :: 'a :: semiring_1 mat) = n ⟹ 1⇩m n * A = A"
by (rule left_mult_one_mat, unfold carrier_mat_def, simp)

lemma right_mult_one_mat:
"(A :: 'a :: semiring_1 mat) ∈ carrier_mat nr nc ⟹ A * 1⇩m nc = A"
by (intro eq_matI, auto)

lemma right_mult_one_mat'[simp]: "dim_col (A :: 'a :: semiring_1 mat) = n ⟹ A * 1⇩m n = A"
by (rule right_mult_one_mat, unfold carrier_mat_def, simp)

lemma one_mult_mat_vec[simp]:
"(v :: 'a :: semiring_1 vec) ∈ carrier_vec n ⟹ 1⇩m n *⇩v v = v"
by (intro eq_vecI, auto)

shows "A ∈ carrier_mat nr nc ⟹ B ∈ carrier_mat nr nc ⟹
A - B = A + (- B)"
by (intro eq_matI, auto)

lemma add_mult_distrib_mat[algebra_simps]: assumes m: "A ∈ carrier_mat nr n"
"B ∈ carrier_mat nr n" "C ∈ carrier_mat n nc"
shows "(A + B) * C = A * C + B * C"
using m by (intro eq_matI, auto simp: add_scalar_prod_distrib[of _ n])

lemma mult_add_distrib_mat[algebra_simps]: assumes m: "A ∈ carrier_mat nr n"
"B ∈ carrier_mat n nc" "C ∈ carrier_mat n nc"
shows "A * (B + C) = A * B + A * C"
using m by (intro eq_matI, auto simp: scalar_prod_add_distrib[of _ n])

lemma add_mult_distrib_mat_vec[algebra_simps]: assumes m: "A ∈ carrier_mat nr nc"
"B ∈ carrier_mat nr nc" "v ∈ carrier_vec nc"
shows "(A + B) *⇩v v = A *⇩v v + B *⇩v v"
using m by (intro eq_vecI, auto intro!: add_scalar_prod_distrib)

lemma mult_add_distrib_mat_vec[algebra_simps]: assumes m: "A ∈ carrier_mat nr nc"
"v⇩1 ∈ carrier_vec nc" "v⇩2 ∈ carrier_vec nc"
shows "A *⇩v (v⇩1 + v⇩2) = A *⇩v v⇩1 + A *⇩v v⇩2"
using m by (intro eq_vecI, auto simp: scalar_prod_add_distrib[of _ nc])

lemma mult_mat_vec:
assumes m: "(A::'a::field mat) ∈ carrier_mat nr nc" and v: "v ∈ carrier_vec nc"
shows "A *⇩v (k ⋅⇩v v) = k ⋅⇩v (A *⇩v v)" (is "?l = ?r")
proof
have nr: "dim_vec ?l = nr" using m v by auto
also have "... = dim_vec ?r" using m v by auto
finally show "dim_vec ?l = dim_vec ?r".

show "⋀i. i < dim_vec ?r ⟹ ?l \$ i = ?r \$ i"
proof -
fix i assume "i < dim_vec ?r"
hence i: "i < dim_row A" using nr m by auto
hence i2: "i < dim_vec (A *⇩v v)" using m by auto
show "?l \$ i = ?r \$ i"
apply (subst (1) mult_mat_vec_def)
apply (subst (2) smult_vec_def)
unfolding index_vec[OF i] index_vec[OF i2]
unfolding mult_mat_vec_def smult_vec_def
unfolding scalar_prod_def index_vec[OF i]
qed
qed

lemma assoc_scalar_prod: assumes *: "v⇩1 ∈ carrier_vec nr" "A ∈ carrier_mat nr nc" "v⇩2 ∈ carrier_vec nc"
shows "vec nc (λj. v⇩1 ∙ col A j) ∙ v⇩2 = v⇩1 ∙ vec nr (λi. row A i ∙ v⇩2)"
proof -
have "vec nc (λj. v⇩1 ∙ col A j) ∙ v⇩2 = (∑i∈{0..<nc}. vec nc (λj. ∑k∈{0..<nr}. v⇩1 \$ k * col A j \$ k) \$ i * v⇩2 \$ i)"
unfolding scalar_prod_def using * by auto
also have "… = (∑i∈{0..<nc}. (∑k∈{0..<nr}. v⇩1 \$ k * col A i \$ k) * v⇩2 \$ i)"
by (rule sum.cong, auto)
also have "… = (∑i∈{0..<nc}. (∑k∈{0..<nr}. v⇩1 \$ k * col A i \$ k * v⇩2 \$ i))"
unfolding sum_distrib_right ..
also have "… = (∑k∈{0..<nr}. (∑i∈{0..<nc}. v⇩1 \$ k * col A i \$ k * v⇩2 \$ i))"
by (rule sum.swap)
also have "… = (∑k∈{0..<nr}. (∑i∈{0..<nc}. v⇩1 \$ k * (col A i \$ k * v⇩2 \$ i)))"
also have "… = (∑k∈{0..<nr}. v⇩1 \$ k * (∑i∈{0..<nc}. col A i \$ k * v⇩2 \$ i))"
unfolding sum_distrib_left ..
also have "… = (∑k∈{0..<nr}. v⇩1 \$ k * vec nr (λk. ∑i∈{0..<nc}. row A k \$ i * v⇩2 \$ i) \$ k)"
using * by auto
also have "… = v⇩1 ∙ vec nr (λi. row A i ∙ v⇩2)" unfolding scalar_prod_def using * by simp
finally show ?thesis .
qed

lemma transpose_vec_mult_scalar:
fixes A :: "'a :: comm_semiring_0 mat"
assumes A: "A ∈ carrier_mat nr nc"
and x: "x ∈ carrier_vec nc"
and y: "y ∈ carrier_vec nr"
shows "(transpose_mat A *⇩v y) ∙ x = y ∙ (A *⇩v x)"
proof -
have "(transpose_mat A *⇩v y) = vec nc (λi. col A i ∙ y)"
unfolding mult_mat_vec_def using A by auto
also have "… = vec nc (λi. y ∙ col A i)"
by (intro eq_vecI, simp, rule comm_scalar_prod[OF _ y], insert A, auto)
also have "… ∙ x = y ∙ vec nr (λi. row A i ∙ x)"
by (rule assoc_scalar_prod[OF y A x])
also have "vec nr (λi. row A i ∙ x) = A *⇩v x"
unfolding mult_mat_vec_def using A by auto
finally show ?thesis .
qed

lemma assoc_mult_mat[simp]:
"A ∈ carrier_mat n⇩1 n⇩2 ⟹ B ∈ carrier_mat n⇩2 n⇩3 ⟹ C ∈ carrier_mat n⇩3 n⇩4
⟹ (A * B) * C = A * (B * C)"
by (intro eq_matI, auto simp: assoc_scalar_prod)

lemma assoc_mult_mat_vec[simp]:
"A ∈ carrier_mat n⇩1 n⇩2 ⟹ B ∈ carrier_mat n⇩2 n⇩3 ⟹ v ∈ carrier_vec n⇩3
⟹ (A * B) *⇩v v = A *⇩v (B *⇩v v)"
by (intro eq_vecI, auto simp add: mult_mat_vec_def assoc_scalar_prod)

lemma comm_monoid_mat: "comm_monoid (monoid_mat TYPE('a :: comm_monoid_add) nr nc)"
by (unfold_locales, auto simp: monoid_mat_def ac_simps)

lemma comm_group_mat: "comm_group (monoid_mat TYPE('a :: ab_group_add) nr nc)"
by (unfold_locales, insert add_inv_exists_mat, auto simp: monoid_mat_def ac_simps Units_def)

lemma semiring_mat: "semiring (ring_mat TYPE('a :: semiring_1) n b)"
by (unfold_locales, auto simp: ring_mat_def algebra_simps)

lemma ring_mat: "ring (ring_mat TYPE('a :: comm_ring_1) n b)"
by (unfold_locales, insert add_inv_exists_mat, auto simp: ring_mat_def algebra_simps Units_def)

lemma abelian_group_mat: "abelian_group (module_mat TYPE('a :: comm_ring_1) nr nc)"
by (unfold_locales, insert add_inv_exists_mat, auto simp: module_mat_def Units_def)

lemma row_smult[simp]: assumes i: "i < dim_row A"
shows "row (k ⋅⇩m A) i = k ⋅⇩v (row A i)"
by (rule eq_vecI, insert i, auto)

lemma col_smult[simp]: assumes i: "i < dim_col A"
shows "col (k ⋅⇩m A) i = k ⋅⇩v (col A i)"
by (rule eq_vecI, insert i, auto)

lemma row_uminus[simp]: assumes i: "i < dim_row A"
shows "row (- A) i = - (row A i)"
by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_left[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
shows "- v ∙ w = - (v ∙ w)"
unfolding scalar_prod_def dim[symmetric]
by (subst sum_negf[symmetric], rule sum.cong, auto)

lemma col_uminus[simp]: assumes i: "i < dim_col A"
shows "col (- A) i = - (col A i)"
by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_right[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
shows "v ∙ - w = - (v ∙ w)"
unfolding scalar_prod_def dim
by (subst sum_negf[symmetric], rule sum.cong, auto)

context fixes A B :: "'a :: ring mat"
assumes dim: "dim_col A = dim_row B"
begin
lemma uminus_mult_left_mat[simp]: "(- A * B) = - (A * B)"
by (intro eq_matI, insert dim, auto)

lemma uminus_mult_right_mat[simp]: "(A * - B) = - (A * B)"
by (intro eq_matI, insert dim, auto)
end

lemma minus_mult_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
assumes m: "A ∈ carrier_mat nr n" "B ∈ carrier_mat nr n" "C ∈ carrier_mat n nc"
shows "(A - B) * C = A * C - B * C"
by (subst uminus_mult_left_mat, insert m, auto)

lemma minus_mult_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat) ∈ carrier_mat nr nc"
and B: "B ∈ carrier_mat nr nc"
and v: "v ∈ carrier_vec nc"
shows "(A - B) *⇩v v = A *⇩v v - B *⇩v v"
by (subst add_mult_distrib_mat_vec[OF A _ v], insert A B v, auto)

lemma mult_minus_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat) ∈ carrier_mat nr nc"
and v: "v ∈ carrier_vec nc"
and w: "w ∈ carrier_vec nc"
shows "A *⇩v (v - w) = A *⇩v v - A *⇩v w"
by (subst mult_add_distrib_mat_vec[OF A], insert A v w, auto)

lemma mult_minus_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
assumes m: "A ∈ carrier_mat nr n" "B ∈ carrier_mat n nc" "C ∈ carrier_mat n nc"
shows "A * (B - C) = A * B - A * C"
by (subst uminus_mult_right_mat, insert m, auto)

lemma uminus_mult_mat_vec[simp]: assumes v: "dim_vec v = dim_col (A :: 'a :: ring mat)"
shows "- A *⇩v v = - (A *⇩v v)"
using v by (intro eq_vecI, auto)

lemma uminus_zero_vec_eq: assumes v: "(v :: 'a :: group_add vec) ∈ carrier_vec n"
shows "(- v = 0⇩v n) = (v = 0⇩v n)"
proof
assume z: "- v = 0⇩v n"
{
fix i
assume i: "i < n"
have "v \$ i = - (- (v \$ i))" by simp
also have "- (v \$ i) = 0" using arg_cong[OF z, of "λ v. v \$ i"] i v by auto
also have "- 0 = (0 :: 'a)" by simp
finally have "v \$ i = 0" .
}
thus "v = 0⇩v n" using v
by (intro eq_vecI, auto)
qed auto

lemma map_carrier_mat[simp]:
"(map_mat f A ∈ carrier_mat nr nc) = (A ∈ carrier_mat nr nc)"
unfolding carrier_mat_def by auto

lemma col_map_mat[simp]:
assumes "j < dim_col A" shows "col (map_mat f A) j = map_vec f (col A j)"
unfolding map_mat_def map_vec_def using assms by auto

lemma scalar_vec_one[simp]: "1 ⋅⇩v (v :: 'a :: semiring_1 vec) = v"
by (rule eq_vecI, auto)

lemma scalar_prod_smult_right[simp]:
"dim_vec w = dim_vec v ⟹ w ∙ (k ⋅⇩v v) = (k :: 'a :: comm_semiring_0) * (w ∙ v)"
unfolding scalar_prod_def sum_distrib_left
by (auto intro: sum.cong simp: ac_simps)

lemma scalar_prod_smult_left[simp]:
"dim_vec w = dim_vec v ⟹ (k ⋅⇩v w) ∙ v = (k :: 'a :: comm_semiring_0) * (w ∙ v)"
unfolding scalar_prod_def sum_distrib_left
by (auto intro: sum.cong simp: ac_simps)

lemma mult_smult_distrib: assumes A: "A ∈ carrier_mat nr n" and B: "B ∈ carrier_mat n nc"
shows "A * (k ⋅⇩m B) = (k :: 'a :: comm_semiring_0) ⋅⇩m (A * B)"
by (rule eq_matI, insert A B, auto)

lemma add_smult_distrib_left_mat: assumes "A ∈ carrier_mat nr nc" "B ∈ carrier_mat nr nc"
shows "k ⋅⇩m (A + B) = (k :: 'a :: semiring) ⋅⇩m A + k ⋅⇩m B"
by (rule eq_matI, insert assms, auto simp: field_simps)

lemma add_smult_distrib_right_mat: assumes "A ∈ carrier_mat nr nc"
shows "(k + l) ⋅⇩m A = (k :: 'a :: semiring) ⋅⇩m A + l ⋅⇩m A"
by (rule eq_matI, insert assms, auto simp: field_simps)

lemma mult_smult_assoc_mat: assumes A: "A ∈ carrier_mat nr n" and B: "B ∈ carrier_mat n nc"
shows "(k ⋅⇩m A) * B = (k :: 'a :: comm_semiring_0) ⋅⇩m (A * B)"
by (rule eq_matI, insert A B, auto)

definition similar_mat_wit :: "'a :: semiring_1 mat ⇒ 'a mat ⇒ 'a mat ⇒ 'a mat ⇒ bool" where
"similar_mat_wit A B P Q = (let n = dim_row A in {A,B,P,Q} ⊆ carrier_mat n```