Theory More_Matrix

theory More_Matrix
  imports "Jordan_Normal_Form.Matrix"
    "Jordan_Normal_Form.DL_Rank"
    "Jordan_Normal_Form.VS_Connect"
    "Jordan_Normal_Form.Gauss_Jordan_Elimination"
begin

section "Kronecker Product"

definition kronecker_product :: "'a :: ring mat  'a mat  'a mat" where
  "kronecker_product A B =
  (let ra = dim_row A; ca = dim_col A;
       rb = dim_row B; cb = dim_col B
  in
    mat (ra*rb) (ca*cb)
    (λ(i,j).
      A $$ (i div rb, j div cb) *
      B $$ (i mod rb, j mod cb)
  ))"

lemma arith:
  assumes "d < a"
  assumes "c < b"
  shows "b*d+c < a*(b::nat)"
proof -
  have "b*d+c < b*(d+1)"
    by (simp add: assms(2))
  thus ?thesis
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right assms(1) less_le_trans mult.commute mult_le_cancel2)
qed

lemma dim_kronecker[simp]:
  "dim_row (kronecker_product A B) = dim_row A * dim_row B"
  "dim_col (kronecker_product A B) = dim_col A * dim_col B"
  unfolding kronecker_product_def Let_def by auto

lemma kronecker_inverse_index:
  assumes "r < dim_row A" "s < dim_col A"
  assumes "v < dim_row B" "w < dim_col B"
  shows "kronecker_product A B $$ (dim_row B*r+v, dim_col B*s+w) = A $$ (r,s) * B $$ (v,w)"
proof -
  from arith[OF assms(1) assms(3)]
  have "dim_row B*r+v < dim_row A * dim_row B" .
  moreover from arith[OF assms(2) assms(4)]
  have "dim_col B * s + w < dim_col A * dim_col B" .
  ultimately show ?thesis
    unfolding kronecker_product_def Let_def
    using assms by auto
qed

lemma kronecker_distr_left:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product A (B+C) = kronecker_product A B + kronecker_product A C"
  unfolding kronecker_product_def Let_def
  using assms apply (auto simp add: mat_eq_iff) 
  by (metis (no_types, lifting) distrib_left index_add_mat(1) mod_less_divisor mult_eq_0_iff neq0_conv not_less_zero)

lemma kronecker_distr_right:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product (B+C) A = kronecker_product B A + kronecker_product C A"
  unfolding kronecker_product_def Let_def
  using assms by (auto simp add: mat_eq_iff less_mult_imp_div_less distrib_right)

lemma index_mat_mod[simp]: "nr > 0 & nc > 0  mat nr nc f $$ (i mod nr,j mod nc) = f (i mod nr,j mod nc)"
  by auto

lemma kronecker_assoc:
  shows "kronecker_product A (kronecker_product B C) = kronecker_product (kronecker_product A B) C"
  unfolding kronecker_product_def Let_def
  apply (case_tac "dim_row B * dim_row C > 0 & dim_col B * dim_col C > 0")
   apply (auto simp add: mat_eq_iff less_mult_imp_div_less)
  by (smt (verit, best) div_less_iff_less_mult div_mult2_eq kronecker_inverse_index linordered_semiring_strict_class.mult_pos_pos mod_less_divisor mod_mult2_eq mult.assoc mult.commute mult_div_mod_eq)

lemma sum_sum_mod_div:
  "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
   (ia = 0..<x*y. f (ia div y) (ia mod y))"
proof -
  have 1: "inj_on (λia. (ia div y, ia mod y)) {0..<x * y}"
    by (smt (verit, best) Pair_inject div_mod_decomp inj_onI)
  have 21: "{0..<x} × {0..<y}  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
  proof clarsimp
    fix a b
    assume *:"a < x" "b < y"
    have "a * y +  b  {0..<x*y}"
      by (metis arith * atLeastLessThan_iff le0 mult.commute)
    thus "(a, b)  (λia. (ia div y, ia mod y)) ` {0..<x * y}"
      using * by (auto simp add: image_iff)
        (metis a * y + b  {0..<x * y} add.commute add.right_neutral div_less div_mult_self1 less_zeroE mod_eq_self_iff_div_eq_0 mod_mult_self1)
  qed
  have 22:"(λia. (ia div y, ia mod y)) ` {0..<x * y}  {0..<x} × {0..<y}"
    using less_mult_imp_div_less apply auto
    by (metis mod_less_divisor mult.commute neq0_conv not_less_zero)
  have 2: "{0..<x} × {0..<y} = (λia. (ia div y, ia mod y)) ` {0..<x * y}"
    using 21 22 by auto
  have *: "(ia = 0::nat..<x. ja = 0..<y. f ia ja) =
        ((x, y){0..<x} × {0..<y}. f x y)"
    by (auto simp add: sum.cartesian_product)
  show ?thesis unfolding *
    apply (intro sum.reindex_cong[of "λia. (ia div y, ia mod y)"])
    using 1 2 by auto
qed

(* Kronecker product distributes over matrix multiplication *)
lemma kronecker_of_mult:
  assumes "dim_col (A :: 'a :: comm_ring mat) = dim_row C"
  assumes "dim_col B = dim_row D"
  shows "kronecker_product A B * kronecker_product C D = kronecker_product (A * C) (B * D)"
  unfolding kronecker_product_def Let_def mat_eq_iff
proof clarsimp
  fix i j
  assume ij: "i < dim_row A * dim_row B" "j < dim_col C * dim_col D"
  have 1: "(A * C) $$ (i div dim_row B, j div dim_col D) =
    row A (i div dim_row B)  col C (j div dim_col D)"
    using ij less_mult_imp_div_less by (auto intro!: index_mult_mat)
  have 2: "(B * D) $$ (i mod dim_row B, j mod dim_col D) =
    row B (i mod dim_row B)  col D (j mod dim_col D)"
    using ij apply (auto intro!: index_mult_mat)
    using gr_implies_not0 apply fastforce
    using gr_implies_not0 by fastforce
  have 3: "x. x < dim_row C * dim_row D 
         A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))"
  proof -
    fix x
    assume *:"x < dim_row C * dim_row D"
    have 1: "row A (i div dim_row B) $ (x div dim_row D) = A $$ (i div dim_row B, x div dim_row D)"
      by (simp add: * assms(1) less_mult_imp_div_less row_def)
    have 2: "row B (i mod dim_row B) $ (x mod dim_row D) = B $$ (i mod dim_row B, x mod dim_row D)"
      by (metis "*" assms(2) ij(1) index_row(1) mod_less_divisor nat_0_less_mult_iff neq0_conv not_less_zero)
    have 3: "col C (j div dim_col D) $ (x div dim_row D) = C $$ (x div dim_row D, j div dim_col D)"
      by (simp add: "*" ij(2) less_mult_imp_div_less)
    have 4: "col D (j mod dim_col D) $ (x mod dim_row D) = D $$ (x mod dim_row D, j mod dim_col D)"
      by (metis "*" bot_nat_0.not_eq_extremum ij(2) index_col mod_less_divisor mult_zero_right not_less_zero)
    show "A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))" unfolding 1 2 3 4
      by (simp add: mult.assoc mult.left_commute)
  qed
  have *: "(A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D) =
    (ia = 0..<dim_row C * dim_row D.
               A $$ (i div dim_row B, ia div dim_row D) *
               B $$ (i mod dim_row B, ia mod dim_row D) *
               (C $$ (ia div dim_row D, j div dim_col D) *
                D $$ (ia mod dim_row D, j mod dim_col D)))"
    unfolding 1 2 scalar_prod_def sum_product sum_sum_mod_div
    apply (auto simp add: sum_product sum_sum_mod_div intro!: sum.cong)
    using 3 by presburger
  show "vec (dim_col A * dim_col B)
          (λj. A $$ (i div dim_row B, j div dim_col B) *
               B $$ (i mod dim_row B, j mod dim_col B)) 
       vec (dim_row C * dim_row D)
          (λi. C $$ (i div dim_row D, j div dim_col D) *
               D $$ (i mod dim_row D, j mod dim_col D)) =
        (A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D)"
    unfolding * scalar_prod_def
    by (auto simp add: assms sum_product sum_sum_mod_div intro!: sum.cong)
qed

lemma inverts_mat_length:
  assumes "square_mat A" "inverts_mat A B" "inverts_mat B A"
  shows "dim_row B = dim_row A" "dim_col B = dim_col A"
   apply (metis assms(1) assms(3) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)
  by (metis assms(1) assms(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)

lemma less_mult_imp_mod_less:
  "m mod i < i" if "m < n * i" for m n i :: nat
  using gr_implies_not_zero that by fastforce

lemma kronecker_one:
  shows "kronecker_product ((1m x)::'a :: ring_1 mat) (1m y) = 1m (x*y)"
  unfolding kronecker_product_def Let_def
  apply  (auto simp add:mat_eq_iff less_mult_imp_div_less less_mult_imp_mod_less)
  by (metis div_mult_mod_eq)

lemma kronecker_invertible:
  assumes "invertible_mat (A :: 'a :: comm_ring_1 mat)" "invertible_mat B"
  shows "invertible_mat (kronecker_product A B)"
proof -
  obtain Ai where Ai: "inverts_mat A Ai" "inverts_mat Ai A" using assms invertible_mat_def by blast
  obtain Bi where Bi: "inverts_mat B Bi" "inverts_mat Bi B" using assms invertible_mat_def by blast
  have "square_mat (kronecker_product A B)"
    by (metis (no_types, lifting) assms(1) assms(2) dim_col_mat(1) dim_row_mat(1) invertible_mat_def kronecker_product_def square_mat.simps)
  moreover have "inverts_mat (kronecker_product A B) (kronecker_product Ai Bi)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  moreover have "inverts_mat (kronecker_product Ai Bi) (kronecker_product A B)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  ultimately show ?thesis unfolding invertible_mat_def by blast
qed

section "More DL Rank"

(* conjugate matrices *)
instantiation mat :: (conjugate) conjugate
begin

definition conjugate_mat :: "'a :: conjugate mat  'a mat"
  where "conjugate m = mat (dim_row m) (dim_col m) (λ(i,j). conjugate (m $$ (i,j)))"

lemma dim_row_conjugate[simp]: "dim_row (conjugate m) = dim_row m"
  unfolding conjugate_mat_def by auto

lemma dim_col_conjugate[simp]: "dim_col (conjugate m) = dim_col m"
  unfolding conjugate_mat_def by auto

lemma carrier_vec_conjugate[simp]: "m  carrier_mat nr nc  conjugate m  carrier_mat nr nc"
  by (auto)

lemma mat_index_conjugate[simp]:
  shows "i < dim_row m  j < dim_col m  conjugate m  $$ (i,j) = conjugate (m $$ (i,j))"
  unfolding conjugate_mat_def by auto

lemma row_conjugate[simp]: "i < dim_row m  row (conjugate m) i = conjugate (row m i)"
  by (auto)

lemma col_conjugate[simp]: "i < dim_col m  col (conjugate m) i = conjugate (col m i)"
  by (auto)

lemma rows_conjugate: "rows (conjugate m) = map conjugate (rows m)"
  by (simp add: list_eq_iff_nth_eq)

lemma cols_conjugate: "cols (conjugate m) = map conjugate (cols m)"
  by (simp add: list_eq_iff_nth_eq)

instance
proof
  fix a b :: "'a mat"
  show "conjugate (conjugate a) = a"
    unfolding mat_eq_iff by auto
  let ?a = "conjugate a"
  let ?b = "conjugate b"
  show "conjugate a = conjugate b  a = b"
    by (metis dim_col_conjugate dim_row_conjugate mat_index_conjugate conjugate_cancel_iff mat_eq_iff)
qed

end

abbreviation conjugate_transpose :: "'a::conjugate mat   'a mat"
  where "conjugate_transpose A  conjugate (AT)"

notation conjugate_transpose ("(_H)" [1000])

lemma transpose_conjugate:
  shows "(conjugate A)T = AH"
  unfolding conjugate_mat_def
  by auto

lemma vec_module_col_helper:
  fixes A:: "('a :: field) mat"
  shows "(0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A)))"
proof -
  have "v. (0::'a) v v + v = v"
    by auto
  then show "0v (dim_row A)  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis cols_dim module_vec_def right_zero_vec smult_carrier_vec vec_space.prod_in_span zero_carrier_vec)
qed

lemma vec_module_col_helper2:
  fixes A:: "('a :: field) mat"
  shows "a x. x  LinearCombinations.module.span class_ring
                carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                   zero = 0v (dim_row A), add = (+), smult = (⋅v)
                (set (cols A)) 
           (a b v. (a + b) v v = a v v + b v v) 
           a v x
            LinearCombinations.module.span class_ring
               carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                  zero = 0v (dim_row A), add = (+), smult = (⋅v)
               (set (cols A))"
proof -
  fix a :: 'a and x :: "'a vec"
  assume "x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
  then show "a v x  LinearCombinations.module.span class_ring carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0v (dim_row A), add = (+), smult = (⋅v) (set (cols A))"
    by (metis (full_types) cols_dim idom_vec.smult_in_span module_vec_def)
qed

lemma vec_module_col: "module (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) 
    (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A)))"
proof -
  interpret abelian_group "module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A))"
    apply (unfold_locales)
          apply (auto simp add:module_vec_def)
          apply (metis cols_dim module_vec_def partial_object.select_convs(1) ring.simps(2) vec_vs vectorspace.span_add1)
         apply (metis assoc_add_vec cols_dim module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    using vec_module_col_helper[of A] apply (auto)    
       apply (metis cols_dim left_zero_vec module_vec_def partial_object.select_convs(1) vec_vs vectorspace.span_closed)
      apply (metis cols_dim module_vec_def partial_object.select_convs(1) right_zero_vec vec_vs vectorspace.span_closed)
     apply (metis cols_dim comm_add_vec module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    unfolding Units_def apply auto
    by (metis (no_types, opaque_lifting) cols_dim comm_add_vec module_vec_def partial_object.select_convs(1) uminus_l_inv_vec vec_space.vec_neg vec_vs vectorspace.span_closed vectorspace.span_neg)
  show ?thesis
    apply (unfold_locales)
    unfolding class_ring_simps apply auto
    unfolding module_vec_simps using add_smult_distrib_vec apply auto
     apply (auto simp add:module_vec_def)
    using vec_module_col_helper2
     apply blast
    by (smt (verit) cols_dim module_vec_def smult_add_distrib_vec vec_space.cV vec_vs vectorspace.span_closed)
qed

(* The columns of a matrix form a vectorspace *)
lemma vec_vs_col: "vectorspace (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) (dim_row A)
      carrier :=
         LinearCombinations.module.span
          class_ring
          (module_vec TYPE('a)
            (dim_row A))
          (set (cols A)))"
  unfolding vectorspace_def
  using vec_module_col class_field 
  by (auto simp: class_field_def)

lemma cols_mat_mul_map:
  shows "cols (A * B) = map ((*v) A) (cols B)"
  unfolding list_eq_iff_nth_eq
  by auto

lemma cols_mat_mul:
  shows "set (cols (A * B)) = (*v) A ` set (cols B)"
  by (simp add: cols_mat_mul_map)

lemma set_obtain_sublist:
  assumes "S  set ls"
  obtains ss where "distinct ss" "S = set ss"
  using assms finite_distinct_list infinite_super by blast

lemma mul_mat_of_cols:
  assumes "A  carrier_mat nr n"
  assumes "j. j < length cs  cs ! j  carrier_vec n"
  shows "A * (mat_of_cols n cs) = mat_of_cols nr (map ((*v) A) cs)"
  unfolding mat_eq_iff
  using assms apply auto
  apply (subst mat_of_cols_index)
  by auto

lemma helper:
  fixes x y z ::"'a :: {conjugatable_ring, comm_ring}"
  shows "x * (y * z) = y * x * z"
  by (simp add: mult.assoc mult.left_commute)

lemma cscalar_prod_conjugate_transpose:
  fixes x y ::"'a :: {conjugatable_ring, comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "x  carrier_vec nr"
  assumes "y  carrier_vec nc"
  shows "x ∙c (A *v y) = (AH *v x) ∙c y"
  unfolding mult_mat_vec_def scalar_prod_def
  using assms apply (auto simp add: sum_distrib_left sum_distrib_right sum_conjugate conjugate_dist_mul)
  apply (subst sum.swap)
  by (meson helper mult.assoc mult.left_commute sum.cong)

lemma mat_mul_conjugate_transpose_vec_eq_0:                        
  fixes v ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} vec"
  assumes "A  carrier_mat nr nc"
  assumes "v  carrier_vec nr"
  assumes "A *v (AH *v v) = 0v nr"
  shows "AH *v v = 0v nc"
proof -
  have "(AH *v v) ∙c (AH *v v) = (A *v (AH *v v)) ∙c v"
    by (metis (mono_tags, lifting) Matrix.carrier_vec_conjugate assms(1) assms(2) assms(3) carrier_matD(2) conjugate_zero_vec cscalar_prod_conjugate_transpose dim_row_conjugate index_transpose_mat(2) mult_mat_vec_def scalar_prod_left_zero scalar_prod_right_zero vec_carrier)
  also have "... = 0"
    by (simp add: assms(2) assms(3))
      (* this step requires real entries *)
  ultimately have "(AH *v v) ∙c (AH *v v) = 0" by auto
  thus ?thesis
    apply (subst conjugate_square_eq_0_vec[symmetric])
    using assms(1) carrier_dim_vec apply fastforce
    by auto
qed

lemma row_mat_of_cols:
  assumes "i < nr"
  shows "row (mat_of_cols nr ls) i = vec (length ls) (λj. (ls ! j) $i)"
  by (simp add: assms mat_of_cols_index vec_eq_iff)

lemma mat_of_cols_cons_mat_vec:
  fixes v ::"'a::comm_ring vec"
  assumes "v  carrier_vec (length ls)"
  assumes "dim_vec a = nr"
  shows
    "mat_of_cols nr (a # ls) *v (vCons m v) =
   m v a + mat_of_cols nr ls *v v"
  unfolding mult_mat_vec_def vec_eq_iff
  using assms by
    (auto simp add: row_mat_of_cols vec_Suc o_def mult.commute)

lemma smult_vec_zero:
  fixes v ::"'a::ring vec"
  shows "0 v v = 0v (dim_vec v)"
  unfolding smult_vec_def vec_eq_iff
  by (auto)

lemma helper2:
  fixes A ::"'a::comm_ring mat"
  fixes v ::"'a vec"
  assumes "v  carrier_vec (length ss)"
  assumes "x. x  set ls  dim_vec x = nr"
  shows
    "mat_of_cols nr ss *v v =
   mat_of_cols nr (ls @ ss) *v (0v (length ls) @v v)"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case apply (auto simp add:zero_vec_Suc)
    apply (subst mat_of_cols_cons_mat_vec)
    by (auto simp add:assms smult_vec_zero)
qed

lemma mat_of_cols_mult_mat_vec_permute_list:
  fixes v ::"'a::comm_ring list"
  assumes "f permutes {..<length ss}"
  assumes "length ss = length v"
  shows
    "mat_of_cols nr (permute_list f ss) *v vec_of_list (permute_list f v) =
     mat_of_cols nr ss *v vec_of_list v"
  unfolding mat_of_cols_def mult_mat_vec_def vec_eq_iff scalar_prod_def
proof clarsimp
  fix i
  assume "i < nr"
  from sum.permute[OF assms(1)]
  have "(ia<length ss. ss ! f ia $ i * v ! f ia) =
  sum ((λia. ss ! f ia $ i * v ! f ia)  f) {..<length ss}" .
  also have "... = (ia = 0..<length v. ss ! f ia $ i * v ! f ia)"
    using assms(2) calculation lessThan_atLeast0 by auto
  ultimately have *: "(ia = 0..<length v.
             ss ! f ia $ i * v ! f ia) =
         (ia = 0..<length v.
             ss ! ia $ i * v ! ia)"
    by (metis (mono_tags, lifting) g. sum g {..<length ss} = sum (g  f) {..<length ss} assms(2) comp_apply lessThan_atLeast0 sum.cong)
  show "(ia = 0..<length v.
         vec (length ss) (λj. permute_list f ss ! j $ i) $ ia *
         vec_of_list (permute_list f v) $ ia) =
         (ia = 0..<length v. vec (length ss) (λj. ss ! j $ i) $ ia * vec_of_list v $ ia)"
    using assms * by (auto simp add: permute_list_nth vec_of_list_index)
qed

(* permute everything in a subset of the indices to the back *)
lemma subindex_permutation:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "permute_list f ls = map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss"
proof -
  have "set [0..<length ls] = set (filter (λi. i  set ss) [0..<length ls] @ ss)"
    using assms unfolding multiset_eq_iff by auto
  then have "mset [0..<length ls] = mset (filter (λi. i  set ss) [0..<length ls] @ ss)"
    apply (subst set_eq_iff_mset_eq_distinct[symmetric])
    using assms by auto  
  then have "mset ls = mset (map ((!) ls)
           (filter (λi. i  set ss)
             [0..<length ls]) @ map ((!) ls) ss)"
    by (metis map_append map_nth mset_map)
  thus ?thesis
    by (metis mset_eq_permutation that)
qed

lemma subindex_permutation2:
  assumes "distinct ss" "set ss  {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "ls = permute_list f (map ((!) ls) (filter (λi. i  set ss) [0..<length ls]) @ map ((!) ls) ss)"
  using subindex_permutation
  by (metis assms(1) assms(2) length_permute_list mset_eq_permutation mset_permute_list)

lemma distinct_list_subset_nths:
  assumes "distinct ss" "set ss  set ls"
  obtains ids where "distinct ids" "set ids  {..<length ls}" "ss = map ((!) ls) ids"
proof -
  let ?ids = "map (λi. @j. j < length ls  ls!j = i ) ss"
  have 1: "distinct ?ids" unfolding distinct_map
    using assms apply (auto simp add: inj_on_def)
    by (smt (verit) in_set_conv_nth someI subset_eq)
  have 2: "set ?ids  {..<length ls}"
    using assms apply (auto)
    by (metis (mono_tags, lifting) in_mono in_set_conv_nth tfl_some)
  have 3: "ss = map ((!) ls) ?ids"
    using assms apply (auto simp add: list_eq_iff_nth_eq)
    by (smt (verit, best) in_set_conv_nth someI subsetD)
  show "(ids. distinct ids 
            set ids  {..<length ls} 
            ss = map ((!) ls) ids  thesis) 
    thesis" using 1 2 3 by blast
qed

lemma helper3: 
  fixes A ::"'a::comm_ring mat"
  assumes A: "A  carrier_mat nr nc"
  assumes ss:"distinct ss" "set ss  set (cols A)"
  assumes "v  carrier_vec (length ss)"
  obtains c where "mat_of_cols nr ss *v v = A *v c" "dim_vec c = nc"
proof -
  from distinct_list_subset_nths[OF ss]
  obtain ids where ids: "distinct ids" "set ids  {..<length (cols A)}"
    and ss: "ss = map ((!) (cols A)) ids" by blast
  let ?ls = " map ((!) (cols A)) (filter (λi. i  set ids) [0..<length (cols A)])"
  from subindex_permutation2[OF ids] obtain f where
    f: "f permutes {..<length (cols A)}"
    "cols A = permute_list f (?ls @ ss)" using ss by blast
  have *: "x. x  set ?ls  dim_vec x = nr"
    using A by auto
  let ?cs1 = "(list_of_vec (0v (length ?ls) @v v))"
  from helper2[OF assms(4) ]
  have "mat_of_cols nr ss *v v = mat_of_cols nr (?ls @ ss) *v vec_of_list (?cs1)"
    using *
    by (metis vec_list)
  also have "... = mat_of_cols nr (permute_list f (?ls @ ss)) *v vec_of_list (permute_list f ?cs1)"
    apply (auto intro!: mat_of_cols_mult_mat_vec_permute_list[symmetric])
     apply (metis cols_length f(1) f(2) length_append length_map length_permute_list)
    using assms(4) by auto
  also have "... =  A *v vec_of_list (permute_list f ?cs1)" using f(2) assms by auto
  ultimately show
    "(c. mat_of_cols nr ss *v v = A *v c  dim_vec c = nc  thesis)  thesis"
    by (metis A assms(4) carrier_matD(2) carrier_vecD cols_length dim_vec_of_list f(2) index_append_vec(2) index_zero_vec(2) length_append length_list_of_vec length_permute_list)
qed

lemma mat_mul_conjugate_transpose_sub_vec_eq_0:                        
  fixes A ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} mat"
  assumes "A  carrier_mat nr nc"
  assumes "distinct ss" "set ss  set (cols (AH))"
  assumes "v  carrier_vec (length ss)"
  assumes "A *v (mat_of_cols nc ss *v v) = 0v nr"
  shows "(mat_of_cols nc ss *v v) = 0v nc"
proof -
  have "AH  carrier_mat nc nr" using assms(1) by auto
  from  helper3[OF this assms(2-4)]
  obtain c where c: "mat_of_cols nc ss *v v = AH *v c" "dim_vec c = nr" by blast
  have 1: "c  carrier_vec nr"
    using c carrier_vec_dim_vec by blast
  have 2: "A *v (AH *v c) = 0v nr" using c assms(5) by auto
  from mat_mul_conjugate_transpose_vec_eq_0[OF assms(1) 1 2]
  have "AH *v c = 0v nc" .
  thus ?thesis unfolding c(1)[symmetric] .
qed

lemma Units_invertible:
  fixes A:: "'a::semiring_1 mat"
  assumes "A  Units (ring_mat TYPE('a) n b)"
  shows "invertible_mat A"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  using inverts_mat_def by blast

lemma invertible_Units:
  fixes A:: "'a::semiring_1 mat"
  assumes "invertible_mat A"
  shows "A  Units (ring_mat TYPE('a) (dim_row A) b)"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  by (metis assms carrier_mat_triv invertible_mat_def inverts_mat_def inverts_mat_length(1) inverts_mat_length(2))

lemma invertible_det:
  fixes A:: "'a::field mat"
  assumes "A  carrier_mat n n"
  shows "invertible_mat A  det A  0"
  apply auto
  using invertible_Units unit_imp_det_non_zero apply fastforce
  using assms by (auto intro!: Units_invertible det_non_zero_imp_unit)

context vec_space begin

lemma find_indices_distinct:
  assumes "distinct ss"
  assumes "i < length ss"
  shows "find_indices (ss ! i) ss = [i]"
proof -
  have "set (find_indices (ss ! i) ss) = {i}"
    using assms apply auto by (simp add: assms(1) assms(2) nth_eq_iff_index_eq)
  thus ?thesis
    by (metis distinct.simps(2) distinct_find_indices empty_iff empty_set insert_iff list.exhaust list.simps(15)) 
qed

lemma lin_indpt_lin_comb_list:
  assumes "distinct ss"
  assumes "lin_indpt (set ss)"
  assumes "set ss  carrier_vec n"
  assumes "lincomb_list f ss = 0v n"
  assumes "i < length ss"
  shows "f i = 0"
proof -
  from lincomb_list_as_lincomb[OF assms(3)]
  have "lincomb_list f ss = lincomb (mk_coeff ss f) (set ss)" .
  also have "... = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)"
    unfolding mk_coeff_def
    apply (subst R.sumlist_map_as_finsum)
    by (auto simp add: distinct_find_indices)
  ultimately have "lincomb_list f ss = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)" by auto
  then have *:"lincomb (λv. sum f (set (find_indices v ss))) (set ss) = 0v n"
    using assms(4) by auto
  have "finite (set ss)" by simp
  from not_lindepD[OF assms(2) this _ _ *]
  have "(λv. sum f (set (find_indices v ss)))  set ss  {0}"
    by auto
  from funcset_mem[OF this]
  have "sum f (set (find_indices (nth ss i) ss))  {0}"
    using assms(5) by auto
  thus ?thesis unfolding find_indices_distinct[OF assms(1) assms(5)]
    by auto
qed

(* Note: in this locale dim_row A = n, e.g.:
lemma foo:
  assumes "dim_row A = n"
  shows "rank A = vec_space.rank (dim_row A) A"
  by (simp add: assms) *)

lemma span_mat_mul_subset:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "span (set (cols (A * B)))  span (set (cols A))"
proof -
  have *: "v. ca. lincomb_list v (cols (A * B)) =
              lincomb_list ca  (cols A)"
  proof -
    fix v
    have "lincomb_list v (cols (A * B)) = (A * B) *v vec nc v"
      apply (subst lincomb_list_as_mat_mult)
       apply (metis assms(1) carrier_dim_vec carrier_matD(1) cols_dim index_mult_mat(2) subset_code(1))
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length index_mult_mat(2) index_mult_mat(3) mat_of_cols_cols)
    also have "... = A *v (B *v vec nc v)"
      using assms(1) assms(2) by auto
    also have "... = lincomb_list (λi. (B *v vec nc v) $ i) (cols A)"
      apply (subst lincomb_list_as_mat_mult)
      using assms(1) carrier_dim_vec cols_dim apply blast
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length dim_mult_mat_vec dim_vec eq_vecI index_vec mat_of_cols_cols)
    ultimately have "lincomb_list v (cols (A * B)) =
              lincomb_list (λi. (B *v vec nc v) $ i) (cols A)" by auto
    thus "ca. lincomb_list v (cols (A * B)) = lincomb_list ca (cols A)" by auto
  qed
  show ?thesis
    apply (subst span_list_as_span[symmetric])
     apply (metis assms(1) carrier_matD(1) cols_dim index_mult_mat(2))
    apply (subst span_list_as_span[symmetric])
    using assms(1) cols_dim apply blast
    by (auto simp add:span_list_def *)
qed

lemma rank_mat_mul_right:
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "rank (A * B)  rank A"
proof -
  have "subspace class_ring (local.span (set (cols (A*B))))
        (vs (local.span (set (cols A))))"
    unfolding subspace_def
    by (metis assms(1) assms(2) carrier_matD(1) cols_dim index_mult_mat(2) nested_submodules span_is_submodule vec_space.span_mat_mul_subset vec_vs_col)
  from vectorspace.subspace_dim[OF _ this]
  have "vectorspace.dim class_ring
   (vs (local.span (set (cols A)))
    carrier := local.span (set (cols (A * B)))) 
  vectorspace.dim class_ring
      (vs (local.span (set (cols A))))"
    apply auto
    by (metis (no_types) assms(1) carrier_matD(1) fin_dim_span_cols index_mult_mat(2) mat_of_cols_carrier(1) mat_of_cols_cols vec_vs_col)
  thus ?thesis unfolding rank_def
    by auto
qed

lemma sumlist_drop:
  assumes "v. v  set ls  dim_vec v = n"
  shows "sumlist ls = sumlist (filter (λv. v  0v n) ls)"
  using assms
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case using dim_sumlist by auto
qed

lemma lincomb_list_alt:
  shows "lincomb_list c s =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) [0..<length s]) [0..<length s])"
  unfolding lincomb_list_def
  by (smt (verit, ccfv_SIG) length_map map2_map_map map_nth nth_equalityI nth_map)

lemma lincomb_list_alt2:
  assumes "v. v  set s  dim_vec v = n"
  assumes "i. i  set ls  i < length s"
  shows "
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) ls) ls) =
    sumlist (map2 (λi j. i v s ! j) (map (λi. c i) (filter (λi. c i  0) ls)) (filter (λi. c i  0) ls))"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a s)
  then show ?case
    apply auto
    apply (subst smult_l_null)
     apply (simp add: assms(1) carrier_vecI)
    apply (subst left_zero_vec)
     apply (subst sumlist_carrier)
      apply auto
    by (metis (no_types, lifting) assms(1) carrier_dim_vec mem_Collect_eq nth_mem set_filter set_zip_rightD)
qed 

lemma two_set:
  assumes "distinct ls"
  assumes "set ls = set [a,b]"
  assumes "a  b"
  shows "ls = [a,b]  ls = [b,a]"
  apply (cases ls)
  using assms(2) apply auto[1]
proof -
  fix x xs
  assume ls:"ls = x # xs"
  obtain y ys where xs:"xs = y # ys"
    by (metis (no_types) ls = x # xs assms(2) assms(3) list.set_cases list.set_intros(1) list.set_intros(2) set_ConsD)
  have 1:"x = a  x = b"
    using ls = x # xs assms(2) by auto
  have 2:"y = a  y = b"
    using ls = x # xs xs = y # ys assms(2) by auto
  have 3:"ys = []"
    by (metis (no_types) ls = x # xs xs = y # ys assms(1) assms(2) distinct.simps(2) distinct_length_2_or_more in_set_member member_rec(2) neq_Nil_conv set_ConsD)
  show "ls = [a, b]  ls = [b, a]" using ls xs 1 2 3 assms
    by auto
qed

lemma filter_disj_inds:
  assumes "i < length ls" "j < length ls" "i  j"
  shows "filter (λia. ia  j  ia = i) [0..<length ls] = [i, j] 
  filter (λia. ia  j  ia = i) [0..<length ls] = [j,i]"
proof -
  have 1: "distinct (filter (λia. ia = i  ia = j) [0..<length ls])"
    using distinct_filter distinct_upt by blast
  have 2:"set (filter (λia. ia = i  ia = j) [0..<length ls]) = {i, j}"
    using assms by auto
  show ?thesis using two_set[OF 1]
    using assms(3) empty_set filter_cong list.simps(15)
    by (smt(verit, ccfv_SIG) "2" assms(3) empty_set filter_cong list.simps(15))
qed

lemma lincomb_list_indpt_distinct:
  assumes "v. v  set ls  dim_vec v = n"
  assumes
    "c. lincomb_list c ls = 0v n  (i. i < (length ls)  c i = 0)"
  shows "distinct ls"
  unfolding distinct_conv_nth
proof clarsimp
  fix i j
  assume ij: "i < length ls" "j < length ls" "i  j" 
  assume lsij: "ls ! i = ls ! j"
  have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls =
     (ls ! i) - (ls ! j)"
    unfolding lincomb_list_alt
    apply (subst lincomb_list_alt2[OF assms(1)])
      apply auto
    using  filter_disj_inds[OF ij]
    apply auto
    using ij(3) apply force
    using assms(1) ij(2) apply auto[1]
    using ij(3) apply blast
    using assms(1) ij(2) by auto
  also have "...  = 0v n" unfolding lsij
    apply (rule minus_cancel_vec)
    using j < length ls assms(1)
    using carrier_vec_dim_vec nth_mem by blast
  ultimately have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls = 0v n" by auto
  from assms(2)[OF this]
  show False
    using i < length ls by auto
qed

end

locale conjugatable_vec_space = vec_space f_ty n for
  f_ty::"'a::conjugatable_ordered_field itself"
  and n
begin                                                           

lemma transpose_rank_mul_conjugate_transpose:
  fixes A :: "'a mat"
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc AH  rank (A * AH)"
proof -
  have 1: "AH  carrier_mat nc n" using assms by auto
  have 2: "A * AH  carrier_mat n n" using assms by auto
      (* S is a maximal linearly independent set of rows A (or cols AT) *)
  let ?P = "(λT. T  set (cols AH)  module.lin_indpt class_ring (module_vec TYPE('a) nc) T)"
  have *:"A. ?P A  finite A  card A  n"
  proof clarsimp
    fix S
    assume S: "S  set (cols AH)"
    have "card S  card (set (cols AH))" using S
      using card_mono by blast
    also have "...  length (cols AH)" using card_length by blast
    also have "...  n" using assms by auto
    ultimately show "finite S  card S  n"
      by (meson List.finite_set S dual_order.trans finite_subset)
  qed
  have **:"?P {}"
    apply (subst module.lin_dep_def)
    by (auto simp add: vec_module)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting)) 
      (* Some properties of S *)
  from vec_space.rank_card_indpt[OF 1 S]
  have rankeq: "vec_space.rank nc AH = card S" .

  have s_hyp: "S  set (cols AH)"
    using S unfolding maximal_def by simp
  have modhyp: "module.lin_indpt class_ring (module_vec TYPE('a) nc) S" 
    using S unfolding maximal_def by simp

(* switch to a list representation *)
  obtain ss where ss: "set ss = S" "distinct ss"
    by (metis (mono_tags) S maximal_def set_obtain_sublist)
  have ss2: "set (map ((*v) A) ss) = (*v) A ` S"
    by (simp add: ss(1))
  have rw_hyp: "cols (mat_of_cols n (map ((*v) A) ss)) = cols  (A * mat_of_cols nc ss)" 
    unfolding cols_def apply (auto)
    using mat_vec_as_mat_mat_mult[of A n nc]
    by (metis (no_types, lifting) "1" assms carrier_matD(1) cols_dim mul_mat_of_cols nth_mem s_hyp ss(1) subset_code(1))
  then have rw: "mat_of_cols n (map ((*v) A) ss) = A * mat_of_cols nc ss"
    by (metis assms carrier_matD(1) index_mult_mat(2) mat_of_cols_carrier(2) mat_of_cols_cols) 
  have indpt: "c. lincomb_list c (map ((*v) A) ss) = 0v n 
      i. (i < (length ss)  c i = 0)"
  proof clarsimp
    fix c i
    assume *: "lincomb_list c (map ((*v) A) ss) = 0v n"
    assume i: "i < length ss"
    have "wset (map ((*v) A) ss). dim_vec w = n"
      using assms by auto
    from lincomb_list_as_mat_mult[OF this]
    have "A * mat_of_cols nc ss *v  vec (length ss) c = 0v n"
      using * rw by auto
    then have hq: "A *v (mat_of_cols nc ss *v vec (length ss) c) =  0v n"
      by (metis assms assoc_mult_mat_vec mat_of_cols_carrier(1) vec_carrier)

    then have eq1: "(mat_of_cols nc ss *v vec (length ss) c) =  0v nc"
      apply (intro mat_mul_conjugate_transpose_sub_vec_eq_0)
      using assms ss s_hyp by auto

(* Rewrite the inner vector back to a lincomb_list *)
    have dim_hyp2: "wset ss. dim_vec w = nc"
      using ss(1) s_hyp
      by (metis "1" carrier_matD(1) carrier_vecD cols_dim subsetD) 
    from vec_module.lincomb_list_as_mat_mult[OF this, symmetric]
    have "mat_of_cols nc ss *v vec (length ss) c = module.lincomb_list (module_vec TYPE('a) nc) c ss" .
    then have "module.lincomb_list (module_vec TYPE('a) nc) c ss = 0v nc" using eq1 by auto
    from vec_space.lin_indpt_lin_comb_list[OF ss(2) _ _ this i]
    show "c i = 0" using modhyp ss s_hyp
      using "1" cols_dim by blast
  qed
  have distinct: "distinct (map ((*v) A) ss)"
    by (metis (no_types, lifting) assms carrier_matD(1) dim_mult_mat_vec imageE indpt length_map lincomb_list_indpt_distinct ss2)
  then have 3: "card S = card ((*v) A ` S)"
    by (metis ss distinct_card image_set length_map)
  then have 4: "(*v) A ` S  set (cols (A * AH))"
    using cols_mat_mul S  set (cols AH) by blast
  have 5: "lin_indpt ((*v) A ` S)"
  proof clarsimp
    assume ld:"lin_dep ((*v) A ` S)"
    have *: "finite ((*v) A ` S)"
      by (metis List.finite_set ss2)
    have **: "(*v) A ` S  carrier_vec n"
      using "2" "4" cols_dim by blast
    from finite_lin_dep[OF * ld **]
    obtain a v where
      a: "lincomb a ((*v) A ` S) = 0v n" and
      v: "v  (*v) A ` S" "a v  0" by blast
    obtain i where i:"v = map ((*v) A) ss ! i" "i < length ss"
      using v unfolding ss2[symmetric]
      using find_first_le nth_find_first by force
    from ss2[symmetric]
    have "set (map ((*v) A) ss) carrier_vec n" using ** ss2 by auto
    from lincomb_as_lincomb_list_distinct[OF this distinct] have
      "lincomb_list
     (λi. a (map ((*v) A) ss ! i))  (map ((*v) A) ss) = 0v n"
      using a ss2 by auto
    from indpt[OF this]
    show False using v i by simp
  qed
  from rank_ge_card_indpt[OF 2 4 5]
  have "card ((*v) A ` S)  rank (A * AH)" .
  thus ?thesis using rankeq 3 by linarith
qed

lemma conjugate_transpose_rank_le:
  assumes "A  carrier_mat n nc"
  shows "vec_space.rank nc (AH)  rank A"
  by (metis assms carrier_matD(2) carrier_mat_triv dim_row_conjugate dual_order.trans index_transpose_mat(2) rank_mat_mul_right transpose_rank_mul_conjugate_transpose)

lemma conjugate_finsum:
  assumes f: "f : U  carrier_vec n"
  shows "conjugate (finsum V f U) = finsum V (conjugate  f) U"
  using f
proof (induct U rule: infinite_finite_induct)
  case (infinite A)
  then show ?case by auto
next
  case empty
  then show ?case by auto
next
  case (insert u U)
  hence f: "f : U  carrier_vec n" "f u : carrier_vec n"  by auto
  then have cf: "conjugate  f : U  carrier_vec n"
    "(conjugate  f) u : carrier_vec n"
     apply (simp add: Pi_iff)
    by (simp add: f(2))
  then show ?case
    unfolding finsum_insert[OF insert(1) insert(2) f]
    unfolding finsum_insert[OF insert(1) insert(2) cf ]
    apply (subst conjugate_add_vec[of _ n])
    using f(2) apply blast
    using M.finsum_closed f(1) apply blast
    by (simp add: comp_def f(1) insert.hyps(3))
qed

lemma rank_conjugate_le:
  assumes A:"A  carrier_mat n d"
  shows "rank (conjugate (A))  rank A"
proof -
  (* S is a maximal linearly independent set of (conjugate A) *)
  let ?P = "(λT. T  set (cols (conjugate A))  lin_indpt T)"
  have *:"A. ?P A  finite A  card A  d"
    by (metis List.finite_set assms card_length card_mono carrier_matD(2) cols_length dim_col_conjugate dual_order.trans rev_finite_subset)
  have **:"?P {}"
    by (simp add: finite_lin_indpt2)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting))
  have s_hyp: "S  set (cols (conjugate A))" "lin_indpt S"
    using S unfolding maximal_def
     apply blast
    by (metis (no_types, lifting) S maximal_def)
  from rank_card_indpt[OF _ S, of d]
  have rankeq: "rank (conjugate A) = card S" using assms by auto 
  have 1:"conjugate ` S  set (cols A)"
    using S apply auto
    by (metis (no_types, lifting) cols_conjugate conjugate_id image_eqI in_mono list.set_map s_hyp(1))
  have 2: "lin_indpt (conjugate ` S)"
    apply (rule ccontr)
    apply (auto simp add: lin_dep_def)
  proof -
    fix T c v
    assume T: "T  conjugate ` S" "finite T" and
      lc:"lincomb c T = 0v n" and "v  T"  "c v  0"
    let ?T = "conjugate ` T"
    let ?c = "conjugate  c  conjugate"
    have 1: "finite ?T"  using T by auto
    have 2: "?T  S"  using T by auto
    have 3: "?c  ?T  UNIV" by auto
    have "lincomb ?c ?T = (VxT. conjugate (c x) v conjugate x)"
      unfolding lincomb_def
      apply (subst finsum_reindex)
        apply auto
       apply (metis "2" carrier_vec_conjugate assms carrier_matD(1) cols_dim image_eqI s_hyp(1) subsetD)
      by (meson conjugate_cancel_iff inj_onI)
    also have "... = (VxT. conjugate (c x v x)) "
      by (simp add: conjugate_smult_vec)
    also have "... = conjugate (VxT. (c x v x))"
      apply(subst conjugate_finsum[of "λx.(c x v x)" T])
       apply (auto simp add:o_def)
      by (smt (verit, ccfv_SIG) Matrix.carrier_vec_conjugate Pi_I' T(1) assms carrier_matD(1) cols_dim dim_row_conjugate imageE s_hyp(1) smult_carrier_vec subset_eq) 
    also have "... = conjugate (lincomb c T)"
      using lincomb_def by presburger
    ultimately have "lincomb ?c ?T = conjugate (lincomb c T)" by auto
    then have 4:"lincomb ?c ?T = 0v n" using lc by auto
    from not_lindepD[OF s_hyp(2) 1 2 3 4]
    have "conjugate  c  conjugate  conjugate ` T  {0}" .
    then have "c v = 0"
      by (simp add: Pi_iff v  T)
    thus False using c v  0 by auto
  qed
  from rank_ge_card_indpt[OF A 1 2]
  have 3:"card (conjugate ` S)  rank A" .
  have 4: "card (conjugate ` S) = card S"
    apply (auto intro!: card_image)
    by (meson conjugate_cancel_iff inj_onI)
  show ?thesis using rankeq 3 4 by auto
qed

lemma rank_conjugate:
  assumes "A  carrier_mat n d"
  shows "rank (conjugate A) = rank A"
  using  rank_conjugate_le
  by (metis carrier_vec_conjugate assms conjugate_id dual_order.antisym)

end (* exit the context *)

lemma conjugate_transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AH)"
  using  conjugatable_vec_space.conjugate_transpose_rank_le
  by (metis (no_types, lifting) Matrix.transpose_transpose carrier_matI conjugate_id dim_col_conjugate dual_order.antisym index_transpose_mat(2) transpose_conjugate)

lemma transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (AT)"
  by (metis carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_transpose_mat(2))

lemma rank_mat_mul_left:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  assumes "A  carrier_mat n d"
  assumes "B  carrier_mat d nc"
  shows "vec_space.rank n (A * B)  vec_space.rank d B"
  by (metis (no_types, lifting) Matrix.transpose_transpose assms(1) assms(2) carrier_matD(1) carrier_matD(2) carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_mult_mat(3) index_transpose_mat(3) transpose_mult vec_space.rank_mat_mul_right)

section "Results on Invertibility"

(* Extract specific columns of a matrix  *)
definition take_cols :: "'a mat  nat list  'a mat"
  where "take_cols A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (filter ((>) (dim_col A)) inds))"

definition take_cols_var :: "'a mat  nat list  'a mat"
  where "take_cols_var A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (inds))"

definition take_rows :: "'a mat  nat list  'a mat"
  where "take_rows A inds = mat_of_rows (dim_col A) (map ((!) (rows A)) (filter ((>) (dim_row A)) inds))"

lemma cong1:
  "x = y   mat_of_cols n x = mat_of_cols n y"
  by auto

lemma nth_filter:
  assumes "j < length (filter P ls)"
  shows "P  ((filter P ls) ! j)"
  by (simp add: assms list_ball_nth)

lemma take_cols_mat_mul:
  assumes "A  carrier_mat nr n"
  assumes "B  carrier_mat n nc"
  shows "A * take_cols B inds = take_cols (A * B) inds"
proof -
  have "j. j < length (map ((!) (cols B)) (filter ((>) nc) inds)) 
      (map ((!) (cols B)) (filter ((>) nc) inds)) ! j  carrier_vec n"
    using assms apply auto
    apply (subst cols_nth)
    using nth_filter by auto
  from mul_mat_of_cols[OF assms(1) this]
  have "A *  take_cols B inds = mat_of_cols nr (map (λx. A *v cols B ! x) (filter ((>) (dim_col B)) inds))"
    unfolding take_cols_def using assms by (auto simp add: o_def)
  also have "... = take_cols (A * B) inds"
    unfolding take_cols_def using assms by (auto intro!: cong1)
  ultimately show ?thesis by auto
qed

lemma take_cols_carrier_mat:
  assumes "A  carrier_mat nr nc"
  obtains n where "take_cols A inds  carrier_mat nr n"
  unfolding take_cols_def
  using assms
  by fastforce

lemma take_cols_carrier_mat_strict:
  assumes "A  carrier_mat nr nc"
  assumes "i. i  set inds  i < nc"
  shows "take_cols A inds  carrier_mat nr (length inds)"
  unfolding take_cols_def
  using assms by auto

lemma gauss_jordan_take_cols:  
  assumes "gauss_jordan A (take_cols A inds) = (C,D)"
  shows "D = take_cols C inds"
proof -
  obtain nr nc where A: "A   carrier_mat nr nc" by auto
  from take_cols_carrier_mat[OF this]
  obtain n where B: "take_cols A inds  carrier_mat nr n" by auto
  from gauss_jordan_transform[OF A B assms, of undefined]
  obtain P where PP:"PUnits (ring_mat TYPE('a) nr undefined)" and
    CD: "C = P * A" "D = P * take_cols A inds" by blast
  have P: "P  carrier_mat nr nr"
    by (metis (no_types, lifting) Units_def PP mem_Collect_eq partial_object.select_convs(1) ring_mat_def)
  from take_cols_mat_mul[OF P A]
  have "P * take_cols A inds = take_cols (P * A) inds" .
  thus ?thesis using CD by blast  
qed

lemma dim_col_take_cols:
  assumes "j. j  set inds  j < dim_col A"
  shows "dim_col (take_cols A inds) = length inds"
  unfolding take_cols_def
  using assms by auto

lemma dim_col_take_rows[simp]:
  shows "dim_col (take_rows A inds) = dim_col A"
  unfolding take_rows_def by auto

lemma cols_take_cols_subset:
  shows "set (cols (take_cols A inds))  set (cols A)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply auto
  using in_set_conv_nth by fastforce

lemma dim_row_take_cols[simp]:
  shows "dim_row (take_cols A ls) = dim_row A"
  by (simp add: take_cols_def)

lemma dim_row_append_rows[simp]:
  shows "dim_row (A @r B) = dim_row A + dim_row B"
  by (simp add: append_rows_def)

lemma rows_inj:
  assumes "dim_col A = dim_col B"
  assumes "rows A = rows B"
  shows "A = B"
  unfolding mat_eq_iff
  apply auto
    apply (metis assms(2) length_rows)
  using assms(1) apply blast
  by (metis assms(1) assms(2) mat_of_rows_rows)

lemma append_rows_index:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  assumes "j < dim_col A"
  shows "(A @r B) $$ (i,j) = (if i < dim_row A then A $$ (i,j) else B $$ (i-dim_row A,j))"
  unfolding append_rows_def
  apply (subst index_mat_four_block)
  using assms by auto

lemma row_append_rows:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  shows "row (A @r B) i = (if i < dim_row A then row A i else row B (i-dim_row A))"
  unfolding vec_eq_iff
  using assms by (auto simp add: append_rows_def)

lemma append_rows_mat_mul:
  assumes "dim_col A = dim_col B"
  shows "(A @r B) * C = A * C @r B * C"
  unfolding mat_eq_iff
  apply auto
   apply (simp add: append_rows_def)
  apply (subst index_mult_mat)
    apply auto
   apply (simp add: append_rows_def)
  apply (subst  append_rows_index)
     apply auto
    apply (simp add: append_rows_def)
   apply (metis add.right_neutral append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) row_append_rows trans_less_add1)
  by (metis add_cancel_right_right add_diff_inverse_nat append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) nat_add_left_cancel_less row_append_rows)

lemma cardlt:
  shows "card  {i. i < (n::nat)}  n"
  by simp

lemma row_echelon_form_zero_rows:
  assumes row_ech: "row_echelon_form A"
  assumes dim_asm: "dim_col A  dim_row A"
  shows "take_rows A [0..<length (pivot_positions A)] @r  0m (dim_row A - length (pivot_positions A))  (dim_col A) = A"
proof -
  have ex_pivot_fun: " f. pivot_fun A f (dim_col A)" using row_ech unfolding row_echelon_form_def by auto
  have len_help: "length (pivot_positions A) = card {i. i < dim_row A  row A i  0v (dim_col A)}"
    using ex_pivot_fun pivot_positions[where A = "A",where nr = "dim_row A", where nc = "dim_col A"]
    by auto
  then have len_help2: "length (pivot_positions A)  dim_row A"
    by (metis (no_types, lifting) card_mono cardlt finite_Collect_less_nat le_trans mem_Collect_eq subsetI)
  have fileq: "filter (λy. y < dim_row A) [0..< length (pivot_positions A)] = [0..<length (pivot_positions A)]"
    apply (rule filter_True)
    using len_help2 by auto
  have "n. card {i. i < n   row A i  0v (dim_col A)}  n"
  proof clarsimp 
    fix n
    have h: "x. x  {i. i < n  row A i  0v (dim_col A)}  x{..<n}"
      by simp
    then have h1: "{i. i < n   row A i  0v (dim_col A)}  {..<n}"
      by blast
    then have h2: "(card {i. i < n   row A i  0v (dim_col A)}::nat)  (card {..<n}::nat)"
      using card_mono by blast 
    then show "(card {i. i < n  row A i  0v (dim_col A)}::nat)  (n::nat)" using h2 card_lessThan[of n]
      by auto
  qed
  then have pivot_len: "length (pivot_positions A)  dim_row A "  using len_help
    by simp
  have alt_char: "mat_of_rows (dim_col A)
         (map ((!) (rows A)) (filter (λy. y < dim_col A) [0..<length (pivot_positions A)])) = 
      mat_of_rows (dim_col A) (map ((!) (rows A))  [0..<length (pivot_positions A)])"
    using pivot_len dim_asm
    by auto
  have h1: "i j. i < dim_row A 
           j < dim_col A 
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
  proof - 
    fix i 
    fix j
    assume "i < dim_row A"
    assume j_lt: "j < dim_col A"
    assume i_lt: "i < dim_row (take_rows A [0..<length (pivot_positions A)])" 
    have lt: "length (pivot_positions A)  dim_row A" using pivot_len by auto
    have h1: "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = (row (take_rows A [0..<length (pivot_positions A)]) i)$j"
      by (simp add: i_lt j_lt)
    then have h2: "(row (take_rows A [0..<length (pivot_positions A)]) i)$j = (row A i)$j"
      using lt alt_char i_lt unfolding take_rows_def by auto
    show "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
      using h1 h2
      by (simp add: i < dim_row A j_lt) 
  qed
  let ?nc = "dim_col A"
  let ?nr = "dim_row A"
  have h2: "i j. i < dim_row A 
           j < dim_col A 
           ¬ i < dim_row (take_rows A [0..<length (pivot_positions A)]) 
           0m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)"
  proof - 
    fix i
    fix j
    assume lt_i: "i < dim_row A"
    assume lt_j: "j < dim_col A"
    assume not_lt: "¬ i < dim_row (take_rows A [0..<length (pivot_positions A)])"
    let ?ip = "i+1"
    have h0: "f. pivot_fun A f (dim_col A)  f i = ?nc"
    proof -  
      have half1: "f. pivot_fun A f (dim_col A)" using assms unfolding row_echelon_form_def
        by blast
      have half2: "f. pivot_fun A f (dim_col A)  f i = ?nc " 
      proof clarsimp
        fix f
        assume is_piv: "pivot_fun A f (dim_col A)"
        have len_pp: "length (pivot_positions A) = card {i. i < ?nr  row A i  0v ?nc}" using is_piv pivot_positions[of A ?nr ?nc f]
          by auto
        have  "i. (i < ?nr  row A i  0v ?nc)   (i < ?nr  f i  ?nc)"
          using is_piv pivot_fun_zero_row_iff[of A f ?nc ?nr]
          by blast
        then have len_pp_var: "length (pivot_positions A) = card {i. i < ?nr  f i  ?nc}" 
          using len_pp  by auto 
        have allj_hyp: "j < ?nr. f j = ?nc  ((Suc j) < ?nr  f (Suc j) = ?nc)" 
          using is_piv unfolding pivot_fun_def 
          using lt_i
          by (metis le_antisym le_less) 
        have if_then_bad: "f i  ?nc  (j. j  i  f j  ?nc)"
        proof clarsimp 
          fix j
          assume not_i: "f i  ?nc"
          assume j_leq: "j  i"
          assume bad_asm: "f j = ?nc"
          have "k. k  j   k < ?nr  f k = ?nc"
          proof -
            fix k :: nat
            assume a1: "j  k"
            assume a2: "k < dim_row A"
            have f3: "n. ¬ n < dim_row A  f n  f j  ¬ Suc n < dim_row A  f (Suc n) = f j"
              using allj_hyp bad_asm by presburger
            obtain nn :: "nat  nat  (nat  bool)  nat" where
              f4: "n na p nb nc. (¬ n  na  Suc n  Suc na)  (¬ p nb  ¬ nc  nb  ¬ p (nn nc nb p)  p nc)  (¬ p nb  ¬ nc  nb  p nc  p (Suc (nn nc nb p)))"
              using inc_induct by (metis Suc_le_mono)
            then have f5: "p. ¬ p k  p j  p (Suc (nn j k p))"
              using a1 by presburger
            have f6: "p. ¬ p k  ¬ p (nn j k p)  p j"
              using f4 a1 by meson
            { assume "nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A"
              moreover
              { assume "(nn j k (λn. n < dim_row A  f n  dim_col A) < dim_row A  f (nn j k (λn. n < dim_row A  f n  dim_col A))  dim_col A)  (¬ j < dim_row A  f j = dim_col A)"
                then have "¬ k < dim_row A  f k = dim_col A"
                  using f6
                  by (metis (mono_tags, lifting)) }
              ultimately have "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)  ¬ k < dim_row A  f k = dim_col A"
                using bad_asm
                by blast }
            moreover
            { assume "(¬ j < dim_row A  f j = dim_col A)  (¬ Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A))) = dim_col A)"
              then have "¬ k < dim_row A  f k = dim_col A"
                using f5
              proof -
                have "¬ (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)) < dim_row A  f (Suc (nn j k (λn. n < dim_row A  f n  dim_col A)))  dim_col A)  ¬ (j < dim_row A  f j  dim_col A)"
                  using (¬ j < dim_row A  f j = dim_col