Theory BenOr_Kozen_Reif.More_Matrix
theory More_Matrix
  imports "Jordan_Normal_Form.Matrix"
    "Jordan_Normal_Form.DL_Rank"
    "Jordan_Normal_Form.VS_Connect"
    "Jordan_Normal_Form.Gauss_Jordan_Elimination"
begin
section "Kronecker Product"
definition kronecker_product :: "'a :: ring mat ⇒ 'a mat ⇒ 'a mat" where
  "kronecker_product A B =
  (let ra = dim_row A; ca = dim_col A;
       rb = dim_row B; cb = dim_col B
  in
    mat (ra*rb) (ca*cb)
    (λ(i,j).
      A $$ (i div rb, j div cb) *
      B $$ (i mod rb, j mod cb)
  ))"
lemma arith:
  assumes "d < a"
  assumes "c < b"
  shows "b*d+c < a*(b::nat)"
proof -
  have "b*d+c < b*(d+1)"
    by (simp add: assms(2))
  thus ?thesis
    by (metis One_nat_def Suc_leI add.right_neutral add_Suc_right assms(1) less_le_trans mult.commute mult_le_cancel2)
qed
lemma dim_kronecker[simp]:
  "dim_row (kronecker_product A B) = dim_row A * dim_row B"
  "dim_col (kronecker_product A B) = dim_col A * dim_col B"
  unfolding kronecker_product_def Let_def by auto
lemma kronecker_inverse_index:
  assumes "r < dim_row A" "s < dim_col A"
  assumes "v < dim_row B" "w < dim_col B"
  shows "kronecker_product A B $$ (dim_row B*r+v, dim_col B*s+w) = A $$ (r,s) * B $$ (v,w)"
proof -
  from arith[OF assms(1) assms(3)]
  have "dim_row B*r+v < dim_row A * dim_row B" .
  moreover from arith[OF assms(2) assms(4)]
  have "dim_col B * s + w < dim_col A * dim_col B" .
  ultimately show ?thesis
    unfolding kronecker_product_def Let_def
    using assms by auto
qed
lemma kronecker_distr_left:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product A (B+C) = kronecker_product A B + kronecker_product A C"
  unfolding kronecker_product_def Let_def
  using assms apply (auto simp add: mat_eq_iff) 
  by (metis (no_types, lifting) distrib_left index_add_mat(1) mod_less_divisor mult_eq_0_iff neq0_conv not_less_zero)
lemma kronecker_distr_right:
  assumes "dim_row B = dim_row C" "dim_col B = dim_col C"
  shows "kronecker_product (B+C) A = kronecker_product B A + kronecker_product C A"
  unfolding kronecker_product_def Let_def
  using assms by (auto simp add: mat_eq_iff less_mult_imp_div_less distrib_right)
lemma index_mat_mod[simp]: "nr > 0 & nc > 0 ⟹ mat nr nc f $$ (i mod nr,j mod nc) = f (i mod nr,j mod nc)"
  by auto
lemma kronecker_assoc:
  shows "kronecker_product A (kronecker_product B C) = kronecker_product (kronecker_product A B) C"
  unfolding kronecker_product_def Let_def
  apply (case_tac "dim_row B * dim_row C > 0 & dim_col B * dim_col C > 0")
   apply (auto simp add: mat_eq_iff less_mult_imp_div_less)
  by (smt (verit, best) div_less_iff_less_mult div_mult2_eq kronecker_inverse_index linordered_semiring_strict_class.mult_pos_pos mod_less_divisor mod_mult2_eq mult.assoc mult.commute mult_div_mod_eq)
lemma sum_sum_mod_div:
  "(∑ia = 0::nat..<x. ∑ja = 0..<y. f ia ja) =
   (∑ia = 0..<x*y. f (ia div y) (ia mod y))"
proof -
  have 1: "inj_on (λia. (ia div y, ia mod y)) {0..<x * y}"
    by (smt (verit, best) Pair_inject div_mod_decomp inj_onI)
  have 21: "{0..<x} × {0..<y} ⊆ (λia. (ia div y, ia mod y)) ` {0..<x * y}"
  proof clarsimp
    fix a b
    assume *:"a < x" "b < y"
    have "a * y +  b ∈ {0..<x*y}"
      by (metis arith * atLeastLessThan_iff le0 mult.commute)
    thus "(a, b) ∈ (λia. (ia div y, ia mod y)) ` {0..<x * y}"
      using * by (auto simp add: image_iff)
        (metis ‹a * y + b ∈ {0..<x * y}› add.commute add.right_neutral div_less div_mult_self1 less_zeroE mod_eq_self_iff_div_eq_0 mod_mult_self1)
  qed
  have 22:"(λia. (ia div y, ia mod y)) ` {0..<x * y} ⊆ {0..<x} × {0..<y}"
    using less_mult_imp_div_less apply auto
    by (metis mod_less_divisor mult.commute neq0_conv not_less_zero)
  have 2: "{0..<x} × {0..<y} = (λia. (ia div y, ia mod y)) ` {0..<x * y}"
    using 21 22 by auto
  have *: "(∑ia = 0::nat..<x. ∑ja = 0..<y. f ia ja) =
        (∑(x, y)∈{0..<x} × {0..<y}. f x y)"
    by (auto simp add: sum.cartesian_product)
  show ?thesis unfolding *
    apply (intro sum.reindex_cong[of "λia. (ia div y, ia mod y)"])
    using 1 2 by auto
qed
lemma kronecker_of_mult:
  assumes "dim_col (A :: 'a :: comm_ring mat) = dim_row C"
  assumes "dim_col B = dim_row D"
  shows "kronecker_product A B * kronecker_product C D = kronecker_product (A * C) (B * D)"
  unfolding kronecker_product_def Let_def mat_eq_iff
proof clarsimp
  fix i j
  assume ij: "i < dim_row A * dim_row B" "j < dim_col C * dim_col D"
  have 1: "(A * C) $$ (i div dim_row B, j div dim_col D) =
    row A (i div dim_row B) ∙ col C (j div dim_col D)"
    using ij less_mult_imp_div_less by (auto intro!: index_mult_mat)
  have 2: "(B * D) $$ (i mod dim_row B, j mod dim_col D) =
    row B (i mod dim_row B) ∙ col D (j mod dim_col D)"
    using ij apply (auto intro!: index_mult_mat)
    using gr_implies_not0 apply fastforce
    using gr_implies_not0 by fastforce
  have 3: "⋀x. x < dim_row C * dim_row D ⟹
         A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))"
  proof -
    fix x
    assume *:"x < dim_row C * dim_row D"
    have 1: "row A (i div dim_row B) $ (x div dim_row D) = A $$ (i div dim_row B, x div dim_row D)"
      by (simp add: * assms(1) less_mult_imp_div_less row_def)
    have 2: "row B (i mod dim_row B) $ (x mod dim_row D) = B $$ (i mod dim_row B, x mod dim_row D)"
      by (metis "*" assms(2) ij(1) index_row(1) mod_less_divisor nat_0_less_mult_iff neq0_conv not_less_zero)
    have 3: "col C (j div dim_col D) $ (x div dim_row D) = C $$ (x div dim_row D, j div dim_col D)"
      by (simp add: "*" ij(2) less_mult_imp_div_less)
    have 4: "col D (j mod dim_col D) $ (x mod dim_row D) = D $$ (x mod dim_row D, j mod dim_col D)"
      by (metis "*" bot_nat_0.not_eq_extremum ij(2) index_col mod_less_divisor mult_zero_right not_less_zero)
    show "A $$ (i div dim_row B, x div dim_row D) *
         B $$ (i mod dim_row B, x mod dim_row D) *
         (C $$ (x div dim_row D, j div dim_col D) *
          D $$ (x mod dim_row D, j mod dim_col D)) =
         row A (i div dim_row B) $ (x div dim_row D) *
         col C (j div dim_col D) $ (x div dim_row D) *
         (row B (i mod dim_row B) $ (x mod dim_row D) *
          col D (j mod dim_col D) $ (x mod dim_row D))" unfolding 1 2 3 4
      by (simp add: mult.assoc mult.left_commute)
  qed
  have *: "(A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D) =
    (∑ia = 0..<dim_row C * dim_row D.
               A $$ (i div dim_row B, ia div dim_row D) *
               B $$ (i mod dim_row B, ia mod dim_row D) *
               (C $$ (ia div dim_row D, j div dim_col D) *
                D $$ (ia mod dim_row D, j mod dim_col D)))"
    unfolding 1 2 scalar_prod_def sum_product sum_sum_mod_div
    apply (auto simp add: sum_product sum_sum_mod_div intro!: sum.cong)
    using 3 by presburger
  show "vec (dim_col A * dim_col B)
          (λj. A $$ (i div dim_row B, j div dim_col B) *
               B $$ (i mod dim_row B, j mod dim_col B)) ∙
       vec (dim_row C * dim_row D)
          (λi. C $$ (i div dim_row D, j div dim_col D) *
               D $$ (i mod dim_row D, j mod dim_col D)) =
        (A * C) $$ (i div dim_row B, j div dim_col D) *
        (B * D) $$ (i mod dim_row B, j mod dim_col D)"
    unfolding * scalar_prod_def
    by (auto simp add: assms sum_product sum_sum_mod_div intro!: sum.cong)
qed
lemma inverts_mat_length:
  assumes "square_mat A" "inverts_mat A B" "inverts_mat B A"
  shows "dim_row B = dim_row A" "dim_col B = dim_col A"
   apply (metis assms(1) assms(3) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)
  by (metis assms(1) assms(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def square_mat.simps)
lemma less_mult_imp_mod_less:
  "m mod i < i" if "m < n * i" for m n i :: nat
  using gr_implies_not_zero that by fastforce
lemma kronecker_one:
  shows "kronecker_product ((1⇩m x)::'a :: ring_1 mat) (1⇩m y) = 1⇩m (x*y)"
  unfolding kronecker_product_def Let_def
  apply  (auto simp add:mat_eq_iff less_mult_imp_div_less less_mult_imp_mod_less)
  by (metis div_mult_mod_eq)
lemma kronecker_invertible:
  assumes "invertible_mat (A :: 'a :: comm_ring_1 mat)" "invertible_mat B"
  shows "invertible_mat (kronecker_product A B)"
proof -
  obtain Ai where Ai: "inverts_mat A Ai" "inverts_mat Ai A" using assms invertible_mat_def by blast
  obtain Bi where Bi: "inverts_mat B Bi" "inverts_mat Bi B" using assms invertible_mat_def by blast
  have "square_mat (kronecker_product A B)"
    by (metis (no_types, lifting) assms(1) assms(2) dim_col_mat(1) dim_row_mat(1) invertible_mat_def kronecker_product_def square_mat.simps)
  moreover have "inverts_mat (kronecker_product A B) (kronecker_product Ai Bi)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  moreover have "inverts_mat (kronecker_product Ai Bi) (kronecker_product A B)"
    using Ai Bi unfolding inverts_mat_def
    by (metis (no_types, lifting) dim_kronecker(1) index_mult_mat(3) index_one_mat(3) kronecker_of_mult kronecker_one)
  ultimately show ?thesis unfolding invertible_mat_def by blast
qed
section "More DL Rank"
instantiation mat :: (conjugate) conjugate
begin
definition conjugate_mat :: "'a :: conjugate mat ⇒ 'a mat"
  where "conjugate m = mat (dim_row m) (dim_col m) (λ(i,j). conjugate (m $$ (i,j)))"
lemma dim_row_conjugate[simp]: "dim_row (conjugate m) = dim_row m"
  unfolding conjugate_mat_def by auto
lemma dim_col_conjugate[simp]: "dim_col (conjugate m) = dim_col m"
  unfolding conjugate_mat_def by auto
lemma carrier_vec_conjugate[simp]: "m ∈ carrier_mat nr nc ⟹ conjugate m ∈ carrier_mat nr nc"
  by (auto)
lemma mat_index_conjugate[simp]:
  shows "i < dim_row m ⟹ j < dim_col m ⟹ conjugate m  $$ (i,j) = conjugate (m $$ (i,j))"
  unfolding conjugate_mat_def by auto
lemma row_conjugate[simp]: "i < dim_row m ⟹ row (conjugate m) i = conjugate (row m i)"
  by (auto)
lemma col_conjugate[simp]: "i < dim_col m ⟹ col (conjugate m) i = conjugate (col m i)"
  by (auto)
lemma rows_conjugate: "rows (conjugate m) = map conjugate (rows m)"
  by (simp add: list_eq_iff_nth_eq)
lemma cols_conjugate: "cols (conjugate m) = map conjugate (cols m)"
  by (simp add: list_eq_iff_nth_eq)
instance
proof
  fix a b :: "'a mat"
  show "conjugate (conjugate a) = a"
    unfolding mat_eq_iff by auto
  let ?a = "conjugate a"
  let ?b = "conjugate b"
  show "conjugate a = conjugate b ⟷ a = b"
    by (metis dim_col_conjugate dim_row_conjugate mat_index_conjugate conjugate_cancel_iff mat_eq_iff)
qed
end
abbreviation conjugate_transpose :: "'a::conjugate mat  ⇒ 'a mat"
  where "conjugate_transpose A ≡ conjugate (A⇧T)"
notation conjugate_transpose (‹(_⇧H)› [1000])
lemma transpose_conjugate:
  shows "(conjugate A)⇧T = A⇧H"
  unfolding conjugate_mat_def
  by auto
lemma vec_module_col_helper:
  fixes A:: "('a :: field) mat"
  shows "(0⇩v (dim_row A) ∈ LinearCombinations.module.span class_ring ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈ (set (cols A)))"
proof -
  have "∀v. (0::'a) ⋅⇩v v + v = v"
    by auto
  then show "0⇩v (dim_row A) ∈ LinearCombinations.module.span class_ring ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈ (set (cols A))"
    by (metis cols_dim module_vec_def right_zero_vec smult_carrier_vec vec_space.prod_in_span zero_carrier_vec)
qed
lemma vec_module_col_helper2:
  fixes A:: "('a :: field) mat"
  shows "⋀a x. x ∈ LinearCombinations.module.span class_ring
                ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                   zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈
                (set (cols A)) ⟹
           (⋀a b v. (a + b) ⋅⇩v v = a ⋅⇩v v + b ⋅⇩v v) ⟹
           a ⋅⇩v x
           ∈ LinearCombinations.module.span class_ring
               ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined,
                  zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈
               (set (cols A))"
proof -
  fix a :: 'a and x :: "'a vec"
  assume "x ∈ LinearCombinations.module.span class_ring ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈ (set (cols A))"
  then show "a ⋅⇩v x ∈ LinearCombinations.module.span class_ring ⦇carrier = carrier_vec (dim_row A), mult = undefined, one = undefined, zero = 0⇩v (dim_row A), add = (+), smult = (⋅⇩v)⦈ (set (cols A))"
    by (metis (full_types) cols_dim idom_vec.smult_in_span module_vec_def)
qed
lemma vec_module_col: "module (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) 
    (dim_row A)
      ⦇carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A))⦈)"
proof -
  interpret abelian_group "module_vec TYPE('a) (dim_row A)
      ⦇carrier :=
         LinearCombinations.module.span
          class_ring (module_vec TYPE('a) (dim_row A)) (set (cols A))⦈"
    apply (unfold_locales)
          apply (auto simp add:module_vec_def)
          apply (metis cols_dim module_vec_def partial_object.select_convs(1) ring.simps(2) vec_vs vectorspace.span_add1)
         apply (metis assoc_add_vec cols_dim module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    using vec_module_col_helper[of A] apply (auto)    
       apply (metis cols_dim left_zero_vec module_vec_def partial_object.select_convs(1) vec_vs vectorspace.span_closed)
      apply (metis cols_dim module_vec_def partial_object.select_convs(1) right_zero_vec vec_vs vectorspace.span_closed)
     apply (metis cols_dim comm_add_vec module_vec_def vec_space.cV vec_vs vectorspace.span_closed)
    unfolding Units_def apply auto
    by (metis (no_types, opaque_lifting) cols_dim comm_add_vec module_vec_def partial_object.select_convs(1) uminus_l_inv_vec vec_space.vec_neg vec_vs vectorspace.span_closed vectorspace.span_neg)
  show ?thesis
    apply (unfold_locales)
    unfolding class_ring_simps apply auto
    unfolding module_vec_simps using add_smult_distrib_vec apply auto
     apply (auto simp add:module_vec_def)
    using vec_module_col_helper2
     apply blast
    by (smt (verit) cols_dim module_vec_def smult_add_distrib_vec vec_space.cV vec_vs vectorspace.span_closed)
qed
lemma vec_vs_col: "vectorspace (class_ring :: 'a :: field ring)
  (module_vec TYPE('a) (dim_row A)
      ⦇carrier :=
         LinearCombinations.module.span
          class_ring
          (module_vec TYPE('a)
            (dim_row A))
          (set (cols A))⦈)"
  unfolding vectorspace_def
  using vec_module_col class_field 
  by (auto simp: class_field_def)
lemma cols_mat_mul_map:
  shows "cols (A * B) = map ((*⇩v) A) (cols B)"
  unfolding list_eq_iff_nth_eq
  by auto
lemma cols_mat_mul:
  shows "set (cols (A * B)) = (*⇩v) A ` set (cols B)"
  by (simp add: cols_mat_mul_map)
lemma set_obtain_sublist:
  assumes "S ⊆ set ls"
  obtains ss where "distinct ss" "S = set ss"
  using assms finite_distinct_list infinite_super by blast
lemma mul_mat_of_cols:
  assumes "A ∈ carrier_mat nr n"
  assumes "⋀j. j < length cs ⟹ cs ! j ∈ carrier_vec n"
  shows "A * (mat_of_cols n cs) = mat_of_cols nr (map ((*⇩v) A) cs)"
  unfolding mat_eq_iff
  using assms apply auto
  apply (subst mat_of_cols_index)
  by auto
lemma helper:
  fixes x y z ::"'a :: {conjugatable_ring, comm_ring}"
  shows "x * (y * z) = y * x * z"
  by (simp add: mult.assoc mult.left_commute)
lemma cscalar_prod_conjugate_transpose:
  fixes x y ::"'a :: {conjugatable_ring, comm_ring} vec"
  assumes "A ∈ carrier_mat nr nc"
  assumes "x ∈ carrier_vec nr"
  assumes "y ∈ carrier_vec nc"
  shows "x ∙c (A *⇩v y) = (A⇧H *⇩v x) ∙c y"
  unfolding mult_mat_vec_def scalar_prod_def
  using assms apply (auto simp add: sum_distrib_left sum_distrib_right sum_conjugate conjugate_dist_mul)
  apply (subst sum.swap)
  by (meson helper mult.assoc mult.left_commute sum.cong)
lemma mat_mul_conjugate_transpose_vec_eq_0:                        
  fixes v ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} vec"
  assumes "A ∈ carrier_mat nr nc"
  assumes "v ∈ carrier_vec nr"
  assumes "A *⇩v (A⇧H *⇩v v) = 0⇩v nr"
  shows "A⇧H *⇩v v = 0⇩v nc"
proof -
  have "(A⇧H *⇩v v) ∙c (A⇧H *⇩v v) = (A *⇩v (A⇧H *⇩v v)) ∙c v"
    by (metis (mono_tags, lifting) Matrix.carrier_vec_conjugate assms(1) assms(2) assms(3) carrier_matD(2) conjugate_zero_vec cscalar_prod_conjugate_transpose dim_row_conjugate index_transpose_mat(2) mult_mat_vec_def scalar_prod_left_zero scalar_prod_right_zero vec_carrier)
  also have "... = 0"
    by (simp add: assms(2) assms(3))
      
  ultimately have "(A⇧H *⇩v v) ∙c (A⇧H *⇩v v) = 0" by auto
  thus ?thesis
    apply (subst conjugate_square_eq_0_vec[symmetric])
    using assms(1) carrier_dim_vec apply fastforce
    by auto
qed
lemma row_mat_of_cols:
  assumes "i < nr"
  shows "row (mat_of_cols nr ls) i = vec (length ls) (λj. (ls ! j) $i)"
  by (simp add: assms mat_of_cols_index vec_eq_iff)
lemma mat_of_cols_cons_mat_vec:
  fixes v ::"'a::comm_ring vec"
  assumes "v ∈ carrier_vec (length ls)"
  assumes "dim_vec a = nr"
  shows
    "mat_of_cols nr (a # ls) *⇩v (vCons m v) =
   m ⋅⇩v a + mat_of_cols nr ls *⇩v v"
  unfolding mult_mat_vec_def vec_eq_iff
  using assms by
    (auto simp add: row_mat_of_cols vec_Suc o_def mult.commute)
lemma smult_vec_zero:
  fixes v ::"'a::ring vec"
  shows "0 ⋅⇩v v = 0⇩v (dim_vec v)"
  unfolding smult_vec_def vec_eq_iff
  by (auto)
lemma helper2:
  fixes A ::"'a::comm_ring mat"
  fixes v ::"'a vec"
  assumes "v ∈ carrier_vec (length ss)"
  assumes "⋀x. x ∈ set ls ⟹ dim_vec x = nr"
  shows
    "mat_of_cols nr ss *⇩v v =
   mat_of_cols nr (ls @ ss) *⇩v (0⇩v (length ls) @⇩v v)"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case apply (auto simp add:zero_vec_Suc)
    apply (subst mat_of_cols_cons_mat_vec)
    by (auto simp add:assms smult_vec_zero)
qed
lemma mat_of_cols_mult_mat_vec_permute_list:
  fixes v ::"'a::comm_ring list"
  assumes "f permutes {..<length ss}"
  assumes "length ss = length v"
  shows
    "mat_of_cols nr (permute_list f ss) *⇩v vec_of_list (permute_list f v) =
     mat_of_cols nr ss *⇩v vec_of_list v"
  unfolding mat_of_cols_def mult_mat_vec_def vec_eq_iff scalar_prod_def
proof clarsimp
  fix i
  assume "i < nr"
  from sum.permute[OF assms(1)]
  have "(∑ia<length ss. ss ! f ia $ i * v ! f ia) =
  sum ((λia. ss ! f ia $ i * v ! f ia) ∘ f) {..<length ss}" .
  also have "... = (∑ia = 0..<length v. ss ! f ia $ i * v ! f ia)"
    using assms(2) calculation lessThan_atLeast0 by auto
  ultimately have *: "(∑ia = 0..<length v.
             ss ! f ia $ i * v ! f ia) =
         (∑ia = 0..<length v.
             ss ! ia $ i * v ! ia)"
    by (metis (mono_tags, lifting) ‹⋀g. sum g {..<length ss} = sum (g ∘ f) {..<length ss}› assms(2) comp_apply lessThan_atLeast0 sum.cong)
  show "(∑ia = 0..<length v.
         vec (length ss) (λj. permute_list f ss ! j $ i) $ ia *
         vec_of_list (permute_list f v) $ ia) =
         (∑ia = 0..<length v. vec (length ss) (λj. ss ! j $ i) $ ia * vec_of_list v $ ia)"
    using assms * by (auto simp add: permute_list_nth vec_of_list_index)
qed
lemma subindex_permutation:
  assumes "distinct ss" "set ss ⊆ {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "permute_list f ls = map ((!) ls) (filter (λi. i ∉ set ss) [0..<length ls]) @ map ((!) ls) ss"
proof -
  have "set [0..<length ls] = set (filter (λi. i ∉ set ss) [0..<length ls] @ ss)"
    using assms unfolding multiset_eq_iff by auto
  then have "mset [0..<length ls] = mset (filter (λi. i ∉ set ss) [0..<length ls] @ ss)"
    apply (subst set_eq_iff_mset_eq_distinct[symmetric])
    using assms by auto  
  then have "mset ls = mset (map ((!) ls)
           (filter (λi. i ∉ set ss)
             [0..<length ls]) @ map ((!) ls) ss)"
    by (metis map_append map_nth mset_map)
  thus ?thesis
    by (metis mset_eq_permutation that)
qed
lemma subindex_permutation2:
  assumes "distinct ss" "set ss ⊆ {..<length ls}"
  obtains f where "f permutes {..<length ls}"
    "ls = permute_list f (map ((!) ls) (filter (λi. i ∉ set ss) [0..<length ls]) @ map ((!) ls) ss)"
  using subindex_permutation
  by (metis assms(1) assms(2) length_permute_list mset_eq_permutation mset_permute_list)
lemma distinct_list_subset_nths:
  assumes "distinct ss" "set ss ⊆ set ls"
  obtains ids where "distinct ids" "set ids ⊆ {..<length ls}" "ss = map ((!) ls) ids"
proof -
  let ?ids = "map (λi. @j. j < length ls ∧ ls!j = i ) ss"
  have 1: "distinct ?ids" unfolding distinct_map
    using assms apply (auto simp add: inj_on_def)
    by (smt (verit) in_set_conv_nth someI subset_eq)
  have 2: "set ?ids ⊆ {..<length ls}"
    using assms apply (auto)
    by (metis (mono_tags, lifting) in_mono in_set_conv_nth tfl_some)
  have 3: "ss = map ((!) ls) ?ids"
    using assms apply (auto simp add: list_eq_iff_nth_eq)
    by (smt (verit, best) in_set_conv_nth someI subsetD)
  show "(⋀ids. distinct ids ⟹
            set ids ⊆ {..<length ls} ⟹
            ss = map ((!) ls) ids ⟹ thesis) ⟹
    thesis" using 1 2 3 by blast
qed
lemma helper3: 
  fixes A ::"'a::comm_ring mat"
  assumes A: "A ∈ carrier_mat nr nc"
  assumes ss:"distinct ss" "set ss ⊆ set (cols A)"
  assumes "v ∈ carrier_vec (length ss)"
  obtains c where "mat_of_cols nr ss *⇩v v = A *⇩v c" "dim_vec c = nc"
proof -
  from distinct_list_subset_nths[OF ss]
  obtain ids where ids: "distinct ids" "set ids ⊆ {..<length (cols A)}"
    and ss: "ss = map ((!) (cols A)) ids" by blast
  let ?ls = " map ((!) (cols A)) (filter (λi. i ∉ set ids) [0..<length (cols A)])"
  from subindex_permutation2[OF ids] obtain f where
    f: "f permutes {..<length (cols A)}"
    "cols A = permute_list f (?ls @ ss)" using ss by blast
  have *: "⋀x. x ∈ set ?ls ⟹ dim_vec x = nr"
    using A by auto
  let ?cs1 = "(list_of_vec (0⇩v (length ?ls) @⇩v v))"
  from helper2[OF assms(4) ]
  have "mat_of_cols nr ss *⇩v v = mat_of_cols nr (?ls @ ss) *⇩v vec_of_list (?cs1)"
    using *
    by (metis vec_list)
  also have "... = mat_of_cols nr (permute_list f (?ls @ ss)) *⇩v vec_of_list (permute_list f ?cs1)"
    apply (auto intro!: mat_of_cols_mult_mat_vec_permute_list[symmetric])
     apply (metis cols_length f(1) f(2) length_append length_map length_permute_list)
    using assms(4) by auto
  also have "... =  A *⇩v vec_of_list (permute_list f ?cs1)" using f(2) assms by auto
  ultimately show
    "(⋀c. mat_of_cols nr ss *⇩v v = A *⇩v c ⟹ dim_vec c = nc ⟹ thesis) ⟹ thesis"
    by (metis A assms(4) carrier_matD(2) carrier_vecD cols_length dim_vec_of_list f(2) index_append_vec(2) index_zero_vec(2) length_append length_list_of_vec length_permute_list)
qed
lemma mat_mul_conjugate_transpose_sub_vec_eq_0:                        
  fixes A ::"'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors,comm_ring} mat"
  assumes "A ∈ carrier_mat nr nc"
  assumes "distinct ss" "set ss ⊆ set (cols (A⇧H))"
  assumes "v ∈ carrier_vec (length ss)"
  assumes "A *⇩v (mat_of_cols nc ss *⇩v v) = 0⇩v nr"
  shows "(mat_of_cols nc ss *⇩v v) = 0⇩v nc"
proof -
  have "A⇧H ∈ carrier_mat nc nr" using assms(1) by auto
  from  helper3[OF this assms(2-4)]
  obtain c where c: "mat_of_cols nc ss *⇩v v = A⇧H *⇩v c" "dim_vec c = nr" by blast
  have 1: "c ∈ carrier_vec nr"
    using c carrier_vec_dim_vec by blast
  have 2: "A *⇩v (A⇧H *⇩v c) = 0⇩v nr" using c assms(5) by auto
  from mat_mul_conjugate_transpose_vec_eq_0[OF assms(1) 1 2]
  have "A⇧H *⇩v c = 0⇩v nc" .
  thus ?thesis unfolding c(1)[symmetric] .
qed
lemma Units_invertible:
  fixes A:: "'a::semiring_1 mat"
  assumes "A ∈ Units (ring_mat TYPE('a) n b)"
  shows "invertible_mat A"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  using inverts_mat_def by blast
lemma invertible_Units:
  fixes A:: "'a::semiring_1 mat"
  assumes "invertible_mat A"
  shows "A ∈ Units (ring_mat TYPE('a) (dim_row A) b)"
  using assms unfolding Units_def invertible_mat_def
  apply (auto simp add: ring_mat_def)
  by (metis assms carrier_mat_triv invertible_mat_def inverts_mat_def inverts_mat_length(1) inverts_mat_length(2))
lemma invertible_det:
  fixes A:: "'a::field mat"
  assumes "A ∈ carrier_mat n n"
  shows "invertible_mat A ⟷ det A ≠ 0"
  apply auto
  using invertible_Units unit_imp_det_non_zero apply fastforce
  using assms by (auto intro!: Units_invertible det_non_zero_imp_unit)
context vec_space begin
lemma find_indices_distinct:
  assumes "distinct ss"
  assumes "i < length ss"
  shows "find_indices (ss ! i) ss = [i]"
proof -
  have "set (find_indices (ss ! i) ss) = {i}"
    using assms apply auto by (simp add: assms(1) assms(2) nth_eq_iff_index_eq)
  thus ?thesis
    by (metis distinct.simps(2) distinct_find_indices empty_iff empty_set insert_iff list.exhaust list.simps(15)) 
qed
lemma lin_indpt_lin_comb_list:
  assumes "distinct ss"
  assumes "lin_indpt (set ss)"
  assumes "set ss ⊆ carrier_vec n"
  assumes "lincomb_list f ss = 0⇩v n"
  assumes "i < length ss"
  shows "f i = 0"
proof -
  from lincomb_list_as_lincomb[OF assms(3)]
  have "lincomb_list f ss = lincomb (mk_coeff ss f) (set ss)" .
  also have "... = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)"
    unfolding mk_coeff_def
    apply (subst R.sumlist_map_as_finsum)
    by (auto simp add: distinct_find_indices)
  ultimately have "lincomb_list f ss = lincomb  (λv. sum f (set (find_indices v ss))) (set ss)" by auto
  then have *:"lincomb (λv. sum f (set (find_indices v ss))) (set ss) = 0⇩v n"
    using assms(4) by auto
  have "finite (set ss)" by simp
  from not_lindepD[OF assms(2) this _ _ *]
  have "(λv. sum f (set (find_indices v ss))) ∈ set ss → {0}"
    by auto
  from funcset_mem[OF this]
  have "sum f (set (find_indices (nth ss i) ss)) ∈ {0}"
    using assms(5) by auto
  thus ?thesis unfolding find_indices_distinct[OF assms(1) assms(5)]
    by auto
qed
lemma span_mat_mul_subset:
  assumes "A ∈ carrier_mat n d"
  assumes "B ∈ carrier_mat d nc"
  shows "span (set (cols (A * B))) ⊆ span (set (cols A))"
proof -
  have *: "⋀v. ∃ca. lincomb_list v (cols (A * B)) =
              lincomb_list ca  (cols A)"
  proof -
    fix v
    have "lincomb_list v (cols (A * B)) = (A * B) *⇩v vec nc v"
      apply (subst lincomb_list_as_mat_mult)
       apply (metis assms(1) carrier_dim_vec carrier_matD(1) cols_dim index_mult_mat(2) subset_code(1))
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length index_mult_mat(2) index_mult_mat(3) mat_of_cols_cols)
    also have "... = A *⇩v (B *⇩v vec nc v)"
      using assms(1) assms(2) by auto
    also have "... = lincomb_list (λi. (B *⇩v vec nc v) $ i) (cols A)"
      apply (subst lincomb_list_as_mat_mult)
      using assms(1) carrier_dim_vec cols_dim apply blast
      by (metis assms(1) assms(2) carrier_matD(1) carrier_matD(2) cols_length dim_mult_mat_vec dim_vec eq_vecI index_vec mat_of_cols_cols)
    ultimately have "lincomb_list v (cols (A * B)) =
              lincomb_list (λi. (B *⇩v vec nc v) $ i) (cols A)" by auto
    thus "∃ca. lincomb_list v (cols (A * B)) = lincomb_list ca (cols A)" by auto
  qed
  show ?thesis
    apply (subst span_list_as_span[symmetric])
     apply (metis assms(1) carrier_matD(1) cols_dim index_mult_mat(2))
    apply (subst span_list_as_span[symmetric])
    using assms(1) cols_dim apply blast
    by (auto simp add:span_list_def *)
qed
lemma rank_mat_mul_right:
  assumes "A ∈ carrier_mat n d"
  assumes "B ∈ carrier_mat d nc"
  shows "rank (A * B) ≤ rank A"
proof -
  have "subspace class_ring (local.span (set (cols (A*B))))
        (vs (local.span (set (cols A))))"
    unfolding subspace_def
    by (metis assms(1) assms(2) carrier_matD(1) cols_dim index_mult_mat(2) nested_submodules span_is_submodule vec_space.span_mat_mul_subset vec_vs_col)
  from vectorspace.subspace_dim[OF _ this]
  have "vectorspace.dim class_ring
   (vs (local.span (set (cols A)))
    ⦇carrier := local.span (set (cols (A * B)))⦈) ≤
  vectorspace.dim class_ring
      (vs (local.span (set (cols A))))"
    apply auto
    by (metis (no_types) assms(1) carrier_matD(1) fin_dim_span_cols index_mult_mat(2) mat_of_cols_carrier(1) mat_of_cols_cols vec_vs_col)
  thus ?thesis unfolding rank_def
    by auto
qed
lemma sumlist_drop:
  assumes "⋀v. v ∈ set ls ⟹ dim_vec v = n"
  shows "sumlist ls = sumlist (filter (λv. v ≠ 0⇩v n) ls)"
  using assms
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a ls)
  then show ?case using dim_sumlist by auto
qed
lemma lincomb_list_alt:
  shows "lincomb_list c s =
    sumlist (map2 (λi j. i ⋅⇩v s ! j) (map (λi. c i) [0..<length s]) [0..<length s])"
  unfolding lincomb_list_def
  by (smt (verit, ccfv_SIG) length_map map2_map_map map_nth nth_equalityI nth_map)
lemma lincomb_list_alt2:
  assumes "⋀v. v ∈ set s ⟹ dim_vec v = n"
  assumes "⋀i. i ∈ set ls ⟹ i < length s"
  shows "
    sumlist (map2 (λi j. i ⋅⇩v s ! j) (map (λi. c i) ls) ls) =
    sumlist (map2 (λi j. i ⋅⇩v s ! j) (map (λi. c i) (filter (λi. c i ≠ 0) ls)) (filter (λi. c i ≠ 0) ls))"
  using assms(2)
proof (induction ls)
  case Nil
  then show ?case by auto
next
  case (Cons a s)
  then show ?case
    apply auto
    apply (subst smult_l_null)
     apply (simp add: assms(1) carrier_vecI)
    apply (subst left_zero_vec)
     apply (subst sumlist_carrier)
      apply auto
    by (metis (no_types, lifting) assms(1) carrier_dim_vec mem_Collect_eq nth_mem set_filter set_zip_rightD)
qed 
lemma two_set:
  assumes "distinct ls"
  assumes "set ls = set [a,b]"
  assumes "a ≠ b"
  shows "ls = [a,b] ∨ ls = [b,a]"
  apply (cases ls)
  using assms(2) apply auto[1]
proof -
  fix x xs
  assume ls:"ls = x # xs"
  obtain y ys where xs:"xs = y # ys"
    by (metis (no_types) ‹ls = x # xs› assms(2) assms(3) list.set_cases list.set_intros(1) list.set_intros(2) set_ConsD)
  have 1:"x = a ∨ x = b"
    using ‹ls = x # xs› assms(2) by auto
  have 2:"y = a ∨ y = b"
    using ‹ls = x # xs› ‹xs = y # ys› assms(2) by auto
  have 3:"ys = []"
    by (metis (no_types) ‹ls = x # xs› ‹xs = y # ys› assms(1) assms(2) distinct.simps(2) distinct_length_2_or_more in_set_member member_rec(2) neq_Nil_conv set_ConsD)
  show "ls = [a, b] ∨ ls = [b, a]" using ls xs 1 2 3 assms
    by auto
qed
lemma filter_disj_inds:
  assumes "i < length ls" "j < length ls" "i ≠ j"
  shows "filter (λia. ia ≠ j ⟶ ia = i) [0..<length ls] = [i, j] ∨
  filter (λia. ia ≠ j ⟶ ia = i) [0..<length ls] = [j,i]"
proof -
  have 1: "distinct (filter (λia. ia = i ∨ ia = j) [0..<length ls])"
    using distinct_filter distinct_upt by blast
  have 2:"set (filter (λia. ia = i ∨ ia = j) [0..<length ls]) = {i, j}"
    using assms by auto
  show ?thesis using two_set[OF 1]
    using assms(3) empty_set filter_cong list.simps(15)
    by (smt(verit, ccfv_SIG) "2" assms(3) empty_set filter_cong list.simps(15))
qed
lemma lincomb_list_indpt_distinct:
  assumes "⋀v. v ∈ set ls ⟹ dim_vec v = n"
  assumes
    "⋀c. lincomb_list c ls = 0⇩v n ⟹ (∀i. i < (length ls) ⟶ c i = 0)"
  shows "distinct ls"
  unfolding distinct_conv_nth
proof clarsimp
  fix i j
  assume ij: "i < length ls" "j < length ls" "i ≠ j" 
  assume lsij: "ls ! i = ls ! j"
  have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls =
     (ls ! i) - (ls ! j)"
    unfolding lincomb_list_alt
    apply (subst lincomb_list_alt2[OF assms(1)])
      apply auto
    using  filter_disj_inds[OF ij]
    apply auto
    using ij(3) apply force
    using assms(1) ij(2) apply auto[1]
    using ij(3) apply blast
    using assms(1) ij(2) by auto
  also have "...  = 0⇩v n" unfolding lsij
    apply (rule minus_cancel_vec)
    using ‹j < length ls› assms(1)
    using carrier_vec_dim_vec nth_mem by blast
  ultimately have "lincomb_list (λk. if k = i then 1 else if k = j then -1 else 0) ls = 0⇩v n" by auto
  from assms(2)[OF this]
  show False
    using ‹i < length ls› by auto
qed
end
locale conjugatable_vec_space = vec_space f_ty n for
  f_ty::"'a::conjugatable_ordered_field itself"
  and n
begin                                                           
lemma transpose_rank_mul_conjugate_transpose:
  fixes A :: "'a mat"
  assumes "A ∈ carrier_mat n nc"
  shows "vec_space.rank nc A⇧H ≤ rank (A * A⇧H)"
proof -
  have 1: "A⇧H ∈ carrier_mat nc n" using assms by auto
  have 2: "A * A⇧H ∈ carrier_mat n n" using assms by auto
      
  let ?P = "(λT. T ⊆ set (cols A⇧H) ∧ module.lin_indpt class_ring (module_vec TYPE('a) nc) T)"
  have *:"⋀A. ?P A ⟹ finite A ∧ card A ≤ n"
  proof clarsimp
    fix S
    assume S: "S ⊆ set (cols A⇧H)"
    have "card S ≤ card (set (cols A⇧H))" using S
      using card_mono by blast
    also have "... ≤ length (cols A⇧H)" using card_length by blast
    also have "... ≤ n" using assms by auto
    ultimately show "finite S ∧ card S ≤ n"
      by (meson List.finite_set S dual_order.trans finite_subset)
  qed
  have **:"?P {}"
    apply (subst module.lin_dep_def)
    by (auto simp add: vec_module)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting)) 
      
  from vec_space.rank_card_indpt[OF 1 S]
  have rankeq: "vec_space.rank nc A⇧H = card S" .
  have s_hyp: "S ⊆ set (cols A⇧H)"
    using S unfolding maximal_def by simp
  have modhyp: "module.lin_indpt class_ring (module_vec TYPE('a) nc) S" 
    using S unfolding maximal_def by simp
  obtain ss where ss: "set ss = S" "distinct ss"
    by (metis (mono_tags) S maximal_def set_obtain_sublist)
  have ss2: "set (map ((*⇩v) A) ss) = (*⇩v) A ` S"
    by (simp add: ss(1))
  have rw_hyp: "cols (mat_of_cols n (map ((*⇩v) A) ss)) = cols  (A * mat_of_cols nc ss)" 
    unfolding cols_def apply (auto)
    using mat_vec_as_mat_mat_mult[of A n nc]
    by (metis (no_types, lifting) "1" assms carrier_matD(1) cols_dim mul_mat_of_cols nth_mem s_hyp ss(1) subset_code(1))
  then have rw: "mat_of_cols n (map ((*⇩v) A) ss) = A * mat_of_cols nc ss"
    by (metis assms carrier_matD(1) index_mult_mat(2) mat_of_cols_carrier(2) mat_of_cols_cols) 
  have indpt: "⋀c. lincomb_list c (map ((*⇩v) A) ss) = 0⇩v n ⟹
      ∀i. (i < (length ss) ⟶ c i = 0)"
  proof clarsimp
    fix c i
    assume *: "lincomb_list c (map ((*⇩v) A) ss) = 0⇩v n"
    assume i: "i < length ss"
    have "∀w∈set (map ((*⇩v) A) ss). dim_vec w = n"
      using assms by auto
    from lincomb_list_as_mat_mult[OF this]
    have "A * mat_of_cols nc ss *⇩v  vec (length ss) c = 0⇩v n"
      using * rw by auto
    then have hq: "A *⇩v (mat_of_cols nc ss *⇩v vec (length ss) c) =  0⇩v n"
      by (metis assms assoc_mult_mat_vec mat_of_cols_carrier(1) vec_carrier)
    then have eq1: "(mat_of_cols nc ss *⇩v vec (length ss) c) =  0⇩v nc"
      apply (intro mat_mul_conjugate_transpose_sub_vec_eq_0)
      using assms ss s_hyp by auto
    have dim_hyp2: "∀w∈set ss. dim_vec w = nc"
      using ss(1) s_hyp
      by (metis "1" carrier_matD(1) carrier_vecD cols_dim subsetD) 
    from vec_module.lincomb_list_as_mat_mult[OF this, symmetric]
    have "mat_of_cols nc ss *⇩v vec (length ss) c = module.lincomb_list (module_vec TYPE('a) nc) c ss" .
    then have "module.lincomb_list (module_vec TYPE('a) nc) c ss = 0⇩v nc" using eq1 by auto
    from vec_space.lin_indpt_lin_comb_list[OF ss(2) _ _ this i]
    show "c i = 0" using modhyp ss s_hyp
      using "1" cols_dim by blast
  qed
  have distinct: "distinct (map ((*⇩v) A) ss)"
    by (metis (no_types, lifting) assms carrier_matD(1) dim_mult_mat_vec imageE indpt length_map lincomb_list_indpt_distinct ss2)
  then have 3: "card S = card ((*⇩v) A ` S)"
    by (metis ss distinct_card image_set length_map)
  then have 4: "(*⇩v) A ` S ⊆ set (cols (A * A⇧H))"
    using cols_mat_mul ‹S ⊆ set (cols A⇧H)› by blast
  have 5: "lin_indpt ((*⇩v) A ` S)"
  proof clarsimp
    assume ld:"lin_dep ((*⇩v) A ` S)"
    have *: "finite ((*⇩v) A ` S)"
      by (metis List.finite_set ss2)
    have **: "(*⇩v) A ` S ⊆ carrier_vec n"
      using "2" "4" cols_dim by blast
    from finite_lin_dep[OF * ld **]
    obtain a v where
      a: "lincomb a ((*⇩v) A ` S) = 0⇩v n" and
      v: "v ∈ (*⇩v) A ` S" "a v ≠ 0" by blast
    obtain i where i:"v = map ((*⇩v) A) ss ! i" "i < length ss"
      using v unfolding ss2[symmetric]
      using find_first_le nth_find_first by force
    from ss2[symmetric]
    have "set (map ((*⇩v) A) ss)⊆ carrier_vec n" using ** ss2 by auto
    from lincomb_as_lincomb_list_distinct[OF this distinct] have
      "lincomb_list
     (λi. a (map ((*⇩v) A) ss ! i))  (map ((*⇩v) A) ss) = 0⇩v n"
      using a ss2 by auto
    from indpt[OF this]
    show False using v i by simp
  qed
  from rank_ge_card_indpt[OF 2 4 5]
  have "card ((*⇩v) A ` S) ≤ rank (A * A⇧H)" .
  thus ?thesis using rankeq 3 by linarith
qed
lemma conjugate_transpose_rank_le:
  assumes "A ∈ carrier_mat n nc"
  shows "vec_space.rank nc (A⇧H) ≤ rank A"
  by (metis assms carrier_matD(2) carrier_mat_triv dim_row_conjugate dual_order.trans index_transpose_mat(2) rank_mat_mul_right transpose_rank_mul_conjugate_transpose)
lemma conjugate_finsum:
  assumes f: "f : U → carrier_vec n"
  shows "conjugate (finsum V f U) = finsum V (conjugate ∘ f) U"
  using f
proof (induct U rule: infinite_finite_induct)
  case (infinite A)
  then show ?case by auto
next
  case empty
  then show ?case by auto
next
  case (insert u U)
  hence f: "f : U → carrier_vec n" "f u : carrier_vec n"  by auto
  then have cf: "conjugate ∘ f : U → carrier_vec n"
    "(conjugate ∘ f) u : carrier_vec n"
     apply (simp add: Pi_iff)
    by (simp add: f(2))
  then show ?case
    unfolding finsum_insert[OF insert(1) insert(2) f]
    unfolding finsum_insert[OF insert(1) insert(2) cf ]
    apply (subst conjugate_add_vec[of _ n])
    using f(2) apply blast
    using M.finsum_closed f(1) apply blast
    by (simp add: comp_def f(1) insert.hyps(3))
qed
lemma rank_conjugate_le:
  assumes A:"A ∈ carrier_mat n d"
  shows "rank (conjugate (A)) ≤ rank A"
proof -
  
  let ?P = "(λT. T ⊆ set (cols (conjugate A)) ∧ lin_indpt T)"
  have *:"⋀A. ?P A ⟹ finite A ∧ card A ≤ d"
    by (metis List.finite_set assms card_length card_mono carrier_matD(2) cols_length dim_col_conjugate dual_order.trans rev_finite_subset)
  have **:"?P {}"
    by (simp add: finite_lin_indpt2)
  from maximal_exists[OF *]
  obtain S where S: "maximal S ?P" using **
    by (metis (no_types, lifting))
  have s_hyp: "S ⊆ set (cols (conjugate A))" "lin_indpt S"
    using S unfolding maximal_def
     apply blast
    by (metis (no_types, lifting) S maximal_def)
  from rank_card_indpt[OF _ S, of d]
  have rankeq: "rank (conjugate A) = card S" using assms by auto 
  have 1:"conjugate ` S ⊆ set (cols A)"
    using S apply auto
    by (metis (no_types, lifting) cols_conjugate conjugate_id image_eqI in_mono list.set_map s_hyp(1))
  have 2: "lin_indpt (conjugate ` S)"
    apply (rule ccontr)
    apply (auto simp add: lin_dep_def)
  proof -
    fix T c v
    assume T: "T ⊆ conjugate ` S" "finite T" and
      lc:"lincomb c T = 0⇩v n" and "v ∈ T"  "c v ≠ 0"
    let ?T = "conjugate ` T"
    let ?c = "conjugate ∘ c ∘ conjugate"
    have 1: "finite ?T"  using T by auto
    have 2: "?T ⊆ S"  using T by auto
    have 3: "?c ∈ ?T → UNIV" by auto
    have "lincomb ?c ?T = (⨁⇘V⇙x∈T. conjugate (c x) ⋅⇩v conjugate x)"
      unfolding lincomb_def
      apply (subst finsum_reindex)
        apply auto
       apply (metis "2" carrier_vec_conjugate assms carrier_matD(1) cols_dim image_eqI s_hyp(1) subsetD)
      by (meson conjugate_cancel_iff inj_onI)
    also have "... = (⨁⇘V⇙x∈T. conjugate (c x ⋅⇩v x)) "
      by (simp add: conjugate_smult_vec)
    also have "... = conjugate (⨁⇘V⇙x∈T. (c x ⋅⇩v x))"
      apply(subst conjugate_finsum[of "λx.(c x ⋅⇩v x)" T])
       apply (auto simp add:o_def)
      by (smt (verit, ccfv_SIG) Matrix.carrier_vec_conjugate Pi_I' T(1) assms carrier_matD(1) cols_dim dim_row_conjugate imageE s_hyp(1) smult_carrier_vec subset_eq) 
    also have "... = conjugate (lincomb c T)"
      using lincomb_def by presburger
    ultimately have "lincomb ?c ?T = conjugate (lincomb c T)" by auto
    then have 4:"lincomb ?c ?T = 0⇩v n" using lc by auto
    from not_lindepD[OF s_hyp(2) 1 2 3 4]
    have "conjugate ∘ c ∘ conjugate ∈ conjugate ` T → {0}" .
    then have "c v = 0"
      by (simp add: Pi_iff ‹v ∈ T›)
    thus False using ‹c v ≠ 0› by auto
  qed
  from rank_ge_card_indpt[OF A 1 2]
  have 3:"card (conjugate ` S) ≤ rank A" .
  have 4: "card (conjugate ` S) = card S"
    apply (auto intro!: card_image)
    by (meson conjugate_cancel_iff inj_onI)
  show ?thesis using rankeq 3 4 by auto
qed
lemma rank_conjugate:
  assumes "A ∈ carrier_mat n d"
  shows "rank (conjugate A) = rank A"
  using  rank_conjugate_le
  by (metis carrier_vec_conjugate assms conjugate_id dual_order.antisym)
end 
lemma conjugate_transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (A⇧H)"
  using  conjugatable_vec_space.conjugate_transpose_rank_le
  by (metis (no_types, lifting) Matrix.transpose_transpose carrier_matI conjugate_id dim_col_conjugate dual_order.antisym index_transpose_mat(2) transpose_conjugate)
lemma transpose_rank:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  shows "vec_space.rank (dim_row A) A = vec_space.rank (dim_col A) (A⇧T)"
  by (metis carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_transpose_mat(2))
lemma rank_mat_mul_left:
  fixes A::"'a::{conjugatable_ordered_field} mat"
  assumes "A ∈ carrier_mat n d"
  assumes "B ∈ carrier_mat d nc"
  shows "vec_space.rank n (A * B) ≤ vec_space.rank d B"
  by (metis (no_types, lifting) Matrix.transpose_transpose assms(1) assms(2) carrier_matD(1) carrier_matD(2) carrier_mat_triv conjugatable_vec_space.rank_conjugate conjugate_transpose_rank index_mult_mat(3) index_transpose_mat(3) transpose_mult vec_space.rank_mat_mul_right)
section "Results on Invertibility"
definition take_cols :: "'a mat ⇒ nat list ⇒ 'a mat"
  where "take_cols A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (filter ((>) (dim_col A)) inds))"
definition take_cols_var :: "'a mat ⇒ nat list ⇒ 'a mat"
  where "take_cols_var A inds = mat_of_cols (dim_row A) (map ((!) (cols A)) (inds))"
definition take_rows :: "'a mat ⇒ nat list ⇒ 'a mat"
  where "take_rows A inds = mat_of_rows (dim_col A) (map ((!) (rows A)) (filter ((>) (dim_row A)) inds))"
lemma cong1:
  "x = y ⟹  mat_of_cols n x = mat_of_cols n y"
  by auto
lemma nth_filter:
  assumes "j < length (filter P ls)"
  shows "P  ((filter P ls) ! j)"
  by (simp add: assms list_ball_nth)
lemma take_cols_mat_mul:
  assumes "A ∈ carrier_mat nr n"
  assumes "B ∈ carrier_mat n nc"
  shows "A * take_cols B inds = take_cols (A * B) inds"
proof -
  have "⋀j. j < length (map ((!) (cols B)) (filter ((>) nc) inds)) ⟹
      (map ((!) (cols B)) (filter ((>) nc) inds)) ! j ∈ carrier_vec n"
    using assms apply auto
    apply (subst cols_nth)
    using nth_filter by auto
  from mul_mat_of_cols[OF assms(1) this]
  have "A *  take_cols B inds = mat_of_cols nr (map (λx. A *⇩v cols B ! x) (filter ((>) (dim_col B)) inds))"
    unfolding take_cols_def using assms by (auto simp add: o_def)
  also have "... = take_cols (A * B) inds"
    unfolding take_cols_def using assms by (auto intro!: cong1)
  ultimately show ?thesis by auto
qed
lemma take_cols_carrier_mat:
  assumes "A ∈ carrier_mat nr nc"
  obtains n where "take_cols A inds ∈ carrier_mat nr n"
  unfolding take_cols_def
  using assms
  by fastforce
lemma take_cols_carrier_mat_strict:
  assumes "A ∈ carrier_mat nr nc"
  assumes "⋀i. i ∈ set inds ⟹ i < nc"
  shows "take_cols A inds ∈ carrier_mat nr (length inds)"
  unfolding take_cols_def
  using assms by auto
lemma gauss_jordan_take_cols:  
  assumes "gauss_jordan A (take_cols A inds) = (C,D)"
  shows "D = take_cols C inds"
proof -
  obtain nr nc where A: "A  ∈ carrier_mat nr nc" by auto
  from take_cols_carrier_mat[OF this]
  obtain n where B: "take_cols A inds ∈ carrier_mat nr n" by auto
  from gauss_jordan_transform[OF A B assms, of undefined]
  obtain P where PP:"P∈Units (ring_mat TYPE('a) nr undefined)" and
    CD: "C = P * A" "D = P * take_cols A inds" by blast
  have P: "P ∈ carrier_mat nr nr"
    by (metis (no_types, lifting) Units_def PP mem_Collect_eq partial_object.select_convs(1) ring_mat_def)
  from take_cols_mat_mul[OF P A]
  have "P * take_cols A inds = take_cols (P * A) inds" .
  thus ?thesis using CD by blast  
qed
lemma dim_col_take_cols:
  assumes "⋀j. j ∈ set inds ⟹ j < dim_col A"
  shows "dim_col (take_cols A inds) = length inds"
  unfolding take_cols_def
  using assms by auto
lemma dim_col_take_rows[simp]:
  shows "dim_col (take_rows A inds) = dim_col A"
  unfolding take_rows_def by auto
lemma cols_take_cols_subset:
  shows "set (cols (take_cols A inds)) ⊆ set (cols A)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply auto
  using in_set_conv_nth by fastforce
lemma dim_row_take_cols[simp]:
  shows "dim_row (take_cols A ls) = dim_row A"
  by (simp add: take_cols_def)
lemma dim_row_append_rows[simp]:
  shows "dim_row (A @⇩r B) = dim_row A + dim_row B"
  by (simp add: append_rows_def)
lemma rows_inj:
  assumes "dim_col A = dim_col B"
  assumes "rows A = rows B"
  shows "A = B"
  unfolding mat_eq_iff
  apply auto
    apply (metis assms(2) length_rows)
  using assms(1) apply blast
  by (metis assms(1) assms(2) mat_of_rows_rows)
lemma append_rows_index:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  assumes "j < dim_col A"
  shows "(A @⇩r B) $$ (i,j) = (if i < dim_row A then A $$ (i,j) else B $$ (i-dim_row A,j))"
  unfolding append_rows_def
  apply (subst index_mat_four_block)
  using assms by auto
lemma row_append_rows:
  assumes "dim_col A = dim_col B"
  assumes "i < dim_row A + dim_row B"
  shows "row (A @⇩r B) i = (if i < dim_row A then row A i else row B (i-dim_row A))"
  unfolding vec_eq_iff
  using assms by (auto simp add: append_rows_def)
lemma append_rows_mat_mul:
  assumes "dim_col A = dim_col B"
  shows "(A @⇩r B) * C = A * C @⇩r B * C"
  unfolding mat_eq_iff
  apply auto
   apply (simp add: append_rows_def)
  apply (subst index_mult_mat)
    apply auto
   apply (simp add: append_rows_def)
  apply (subst  append_rows_index)
     apply auto
    apply (simp add: append_rows_def)
   apply (metis add.right_neutral append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) row_append_rows trans_less_add1)
  by (metis add_cancel_right_right add_diff_inverse_nat append_rows_def assms index_mat_four_block(3) index_mult_mat(1) index_mult_mat(3) index_zero_mat(3) nat_add_left_cancel_less row_append_rows)
lemma cardlt:
  shows "card  {i. i < (n::nat)} ≤ n"
  by simp
lemma row_echelon_form_zero_rows:
  assumes row_ech: "row_echelon_form A"
  assumes dim_asm: "dim_col A ≥ dim_row A"
  shows "take_rows A [0..<length (pivot_positions A)] @⇩r  0⇩m (dim_row A - length (pivot_positions A))  (dim_col A) = A"
proof -
  have ex_pivot_fun: "∃ f. pivot_fun A f (dim_col A)" using row_ech unfolding row_echelon_form_def by auto
  have len_help: "length (pivot_positions A) = card {i. i < dim_row A ∧ row A i ≠ 0⇩v (dim_col A)}"
    using ex_pivot_fun pivot_positions[where A = "A",where nr = "dim_row A", where nc = "dim_col A"]
    by auto
  then have len_help2: "length (pivot_positions A) ≤ dim_row A"
    by (metis (no_types, lifting) card_mono cardlt finite_Collect_less_nat le_trans mem_Collect_eq subsetI)
  have fileq: "filter (λy. y < dim_row A) [0..< length (pivot_positions A)] = [0..<length (pivot_positions A)]"
    apply (rule filter_True)
    using len_help2 by auto
  have "∀n. card {i. i < n  ∧ row A i ≠ 0⇩v (dim_col A)} ≤ n"
  proof clarsimp 
    fix n
    have h: "∀x. x ∈ {i. i < n ∧ row A i ≠ 0⇩v (dim_col A)} ⟶ x∈{..<n}"
      by simp
    then have h1: "{i. i < n  ∧ row A i ≠ 0⇩v (dim_col A)} ⊆ {..<n}"
      by blast
    then have h2: "(card {i. i < n  ∧ row A i ≠ 0⇩v (dim_col A)}::nat) ≤ (card {..<n}::nat)"
      using card_mono by blast 
    then show "(card {i. i < n ∧ row A i ≠ 0⇩v (dim_col A)}::nat) ≤ (n::nat)" using h2 card_lessThan[of n]
      by auto
  qed
  then have pivot_len: "length (pivot_positions A) ≤ dim_row A "  using len_help
    by simp
  have alt_char: "mat_of_rows (dim_col A)
         (map ((!) (rows A)) (filter (λy. y < dim_col A) [0..<length (pivot_positions A)])) = 
      mat_of_rows (dim_col A) (map ((!) (rows A))  [0..<length (pivot_positions A)])"
    using pivot_len dim_asm
    by auto
  have h1: "⋀i j. i < dim_row A ⟹
           j < dim_col A ⟹
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) ⟹
           take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
  proof - 
    fix i 
    fix j
    assume "i < dim_row A"
    assume j_lt: "j < dim_col A"
    assume i_lt: "i < dim_row (take_rows A [0..<length (pivot_positions A)])" 
    have lt: "length (pivot_positions A) ≤ dim_row A" using pivot_len by auto
    have h1: "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = (row (take_rows A [0..<length (pivot_positions A)]) i)$j"
      by (simp add: i_lt j_lt)
    then have h2: "(row (take_rows A [0..<length (pivot_positions A)]) i)$j = (row A i)$j"
      using lt alt_char i_lt unfolding take_rows_def by auto
    show "take_rows A [0..<length (pivot_positions A)] $$ (i, j) = A $$ (i, j)"
      using h1 h2
      by (simp add: ‹i < dim_row A› j_lt) 
  qed
  let ?nc = "dim_col A"
  let ?nr = "dim_row A"
  have h2: "⋀i j. i < dim_row A ⟹
           j < dim_col A ⟹
           ¬ i < dim_row (take_rows A [0..<length (pivot_positions A)]) ⟹
           0⇩m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)"
  proof - 
    fix i
    fix j
    assume lt_i: "i < dim_row A"
    assume lt_j: "j < dim_col A"
    assume not_lt: "¬ i < dim_row (take_rows A [0..<length (pivot_positions A)])"
    let ?ip = "i+1"
    have h0: "∃f. pivot_fun A f (dim_col A) ∧ f i = ?nc"
    proof -  
      have half1: "∃f. pivot_fun A f (dim_col A)" using assms unfolding row_echelon_form_def
        by blast
      have half2: "∀f. pivot_fun A f (dim_col A) ⟶ f i = ?nc " 
      proof clarsimp
        fix f
        assume is_piv: "pivot_fun A f (dim_col A)"
        have len_pp: "length (pivot_positions A) = card {i. i < ?nr ∧ row A i ≠ 0⇩v ?nc}" using is_piv pivot_positions[of A ?nr ?nc f]
          by auto
        have  "∀i. (i < ?nr ∧ row A i ≠ 0⇩v ?nc) ⟷  (i < ?nr ∧ f i ≠ ?nc)"
          using is_piv pivot_fun_zero_row_iff[of A f ?nc ?nr]
          by blast
        then have len_pp_var: "length (pivot_positions A) = card {i. i < ?nr ∧ f i ≠ ?nc}" 
          using len_pp  by auto 
        have allj_hyp: "∀j < ?nr. f j = ?nc ⟶ ((Suc j) < ?nr ⟶ f (Suc j) = ?nc)" 
          using is_piv unfolding pivot_fun_def 
          using lt_i
          by (metis le_antisym le_less) 
        have if_then_bad: "f i ≠ ?nc ⟶ (∀j. j ≤ i ⟶ f j ≠ ?nc)"
        proof clarsimp 
          fix j
          assume not_i: "f i ≠ ?nc"
          assume j_leq: "j ≤ i"
          assume bad_asm: "f j = ?nc"
          have "⋀k. k ≥ j ⟹  k < ?nr ⟹ f k = ?nc"
          proof -
            fix k :: nat
            assume a1: "j ≤ k"
            assume a2: "k < dim_row A"
            have f3: "⋀n. ¬ n < dim_row A ∨ f n ≠ f j ∨ ¬ Suc n < dim_row A ∨ f (Suc n) = f j"
              using allj_hyp bad_asm by presburger
            obtain nn :: "nat ⇒ nat ⇒ (nat ⇒ bool) ⇒ nat" where
              f4: "⋀n na p nb nc. (¬ n ≤ na ∨ Suc n ≤ Suc na) ∧ (¬ p nb ∨ ¬ nc ≤ nb ∨ ¬ p (nn nc nb p) ∨ p nc) ∧ (¬ p nb ∨ ¬ nc ≤ nb ∨ p nc ∨ p (Suc (nn nc nb p)))"
              using inc_induct by (metis Suc_le_mono)
            then have f5: "⋀p. ¬ p k ∨ p j ∨ p (Suc (nn j k p))"
              using a1 by presburger
            have f6: "⋀p. ¬ p k ∨ ¬ p (nn j k p) ∨ p j"
              using f4 a1 by meson
            { assume "nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A) < dim_row A ∧ f (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) ≠ dim_col A"
              moreover
              { assume "(nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A) < dim_row A ∧ f (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) ≠ dim_col A) ∧ (¬ j < dim_row A ∨ f j = dim_col A)"
                then have "¬ k < dim_row A ∨ f k = dim_col A"
                  using f6
                  by (metis (mono_tags, lifting)) }
              ultimately have "(¬ j < dim_row A ∨ f j = dim_col A) ∧ (¬ Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) < dim_row A ∨ f (Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A))) = dim_col A) ∨ ¬ k < dim_row A ∨ f k = dim_col A"
                using bad_asm
                by blast }
            moreover
            { assume "(¬ j < dim_row A ∨ f j = dim_col A) ∧ (¬ Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) < dim_row A ∨ f (Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A))) = dim_col A)"
              then have "¬ k < dim_row A ∨ f k = dim_col A"
                using f5
              proof -
                have "¬ (Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) < dim_row A ∧ f (Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A))) ≠ dim_col A) ∧ ¬ (j < dim_row A ∧ f j ≠ dim_col A)"
                  using ‹(¬ j < dim_row A ∨ f j = dim_col A) ∧ (¬ Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A)) < dim_row A ∨ f (Suc (nn j k (λn. n < dim_row A ∧ f n ≠ dim_col A))) = dim_col A)› by linarith
                then have "¬ (k < dim_row A ∧ f k ≠ dim_col A)"
                  by (metis (mono_tags, lifting) a2 bad_asm f5 le_less)
                then show ?thesis
                  by meson
              qed }
            ultimately show "f k = dim_col A"
              using f3 a2 by (metis (lifting) Suc_lessD bad_asm)
          qed
          then show "False" using lt_i not_i
            using j_leq by blast 
        qed
        have "f i ≠ ?nc ⟶ ({0..<?ip} ⊆ {y. y < ?nr ∧ f y ≠ dim_col A})"
        proof -
          have h1: "f i ≠ dim_col A ⟶ (∀j≤i. j < ?nr ∧ f j ≠ dim_col A)"
            using if_then_bad lt_i by auto
          then show ?thesis by auto
        qed
        then have gteq: "f i ≠ ?nc ⟶ (card {i. i < ?nr ∧ f i ≠ dim_col A} ≥ (i+1))"
          using card_lessThan[of ?ip] card_mono[where B = "{i. i < ?nr ∧ f i ≠ dim_col A} ", where A = "{0..<?ip}"]
          by auto
        then have clear: "dim_row (take_rows A [0..<length (pivot_positions A)]) = length (pivot_positions A)"
          unfolding take_rows_def using dim_asm fileq by (auto)
        have "i + 1 > length (pivot_positions A)" using not_lt clear by auto
        then show "f i = ?nc" using gteq len_pp_var by auto
      qed
      show ?thesis using half1 half2
        by blast 
    qed
    then have h1a: "row A i =  0⇩v (dim_col A)" 
      using pivot_fun_zero_row_iff[of A _ ?nc ?nr]
      using lt_i by blast
    then have h1: "A $$ (i, j) = 0"
      using index_row(1) lt_i lt_j by fastforce 
    have h2a: "i - dim_row (take_rows A [0..<length (pivot_positions A)]) < dim_row A - length (pivot_positions A)"
      using pivot_len lt_i not_lt
      by (simp add: take_rows_def)
    then have h2: "0⇩m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) = 0 " 
      unfolding zero_mat_def using pivot_len lt_i lt_j
      using index_mat(1) by blast 
    then show "0⇩m (dim_row A - length (pivot_positions A)) (dim_col A) $$
           (i - dim_row (take_rows A [0..<length (pivot_positions A)]), j) =
           A $$ (i, j)" using h1 h2
      by simp 
  qed
  have h3: "(dim_row (take_rows A [0..<length (pivot_positions A)])::nat) + ((dim_row A::nat) - (length (pivot_positions A)::nat)) =
    dim_row A"
  proof - 
    have h0: "dim_row (take_rows A [0..<length (pivot_positions A)]) = (length (pivot_positions A)::nat)" 
      by (simp add: take_rows_def fileq)
    then show ?thesis using add_diff_inverse_nat  pivot_len
      by linarith
  qed
  have h4: " ⋀i j. i < dim_row A ⟹
           j < dim_col A ⟹
           i < dim_row (take_rows A [0..<length (pivot_positions A)]) +
               (dim_row A - length (pivot_positions A))"
    using pivot_len
    by (simp add: h3) 
  then show ?thesis apply (subst mat_eq_iff)
    using h1 h2 h3 h4 by (auto simp add: append_rows_def)
qed
lemma length_pivot_positions_dim_row:
  assumes "row_echelon_form A"
  shows "length (pivot_positions A) ≤ dim_row A"
proof -
  have 1: "A ∈ carrier_mat (dim_row A) (dim_col A)" by auto
  obtain f where 2: "pivot_fun A f (dim_col A)"
    using assms row_echelon_form_def by blast
  from pivot_positions(4)[OF 1 2] have
    "length (pivot_positions A) = card {i. i < dim_row A ∧ row A i ≠ 0⇩v (dim_col A)}" .
  also have "... ≤ card {i. i < dim_row A}"
    apply (rule card_mono)
    by auto
  ultimately show ?thesis by auto
qed
lemma rref_pivot_positions:
  assumes "row_echelon_form R"
  assumes R: "R ∈ carrier_mat nr nc"
  shows "⋀i j. (i,j) ∈ set (pivot_positions R) ⟹ i < nr ∧ j < nc"
proof -
  obtain f where f: "pivot_fun R f nc"
    using assms(1) assms(2) row_echelon_form_def by blast
  have *: "⋀i. i < nr ⟹ f i ≤ nc" using f
    using R pivot_funD(1) by blast
  from pivot_positions[OF R f]
  have "set (pivot_positions R) = {(i, f i) |i. i < nr ∧ f i ≠ nc}" by auto
  then have **: "set (pivot_positions R) = {(i, f i) |i. i < nr ∧ f i < nc}"
    using *
    by fastforce
  fix i j
  assume "(i, j) ∈ set (pivot_positions R)"
  thus "i < nr ∧ j < nc" using **
    by simp
qed
lemma pivot_fun_monoton: 
  assumes pf: "pivot_fun A f (dim_col A)"
  assumes dr: "dim_row A = nr"
  shows "⋀ i. i < nr ⟹ (⋀ k. ((k < nr ∧ i < k) ⟶ f i ≤ f k))"
proof -
  fix i
  assume "i < nr"
  show "(⋀ k. ((k < nr ∧ i < k) ⟶ f i ≤ f k))"
  proof -
    fix k
    show "((k < nr ∧ i < k) ⟶ f i ≤ f k)"
    proof (induct k)
      case 0
      then show ?case
        by blast 
    next
      case (Suc k)
      then show ?case 
        by (smt (verit, ccfv_SIG) dr le_less_trans less_Suc_eq less_imp_le_nat pf pivot_funD(1) pivot_funD(3))
    qed
  qed
qed
lemma pivot_positions_contains:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A ≥ dim_row A"
  assumes "pivot_fun A f (dim_col A)"
  shows "∀i < (length (pivot_positions A)). (i, f i) ∈ set (pivot_positions A)"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A"          
  have i_nr: "∀i < (length ?pp). i < ?nr" using rref_pivot_positions assms
    using length_pivot_positions_dim_row less_le_trans by blast 
  have i_nc: "∀i < (length ?pp). f i < ?nc"
  proof clarsimp 
    fix i
    assume i_lt: "i < length ?pp"
    have fis_nc: "f i = ?nc ⟹ (∀ k > i. k < ?nr ⟶ f k = ?nc)"
    proof -
      assume is_nc: "f i = ?nc"
      show "(∀ k > i. k < ?nr ⟶ f k = ?nc)" 
      proof clarsimp
        fix k
        assume k_gt: "k > i"
        assume k_lt: "k < ?nr"
        have fk_lt: "f k ≤ ?nc" using pivot_funD(1)[of A ?nr f ?nc k] k_lt apply (auto)
          using ‹pivot_fun A f (dim_col A)› by blast 
        show "f k = ?nc"
          using fk_lt is_nc k_gt k_lt assms pivot_fun_monoton[of A f ?nr i k]
          using ‹pivot_fun A f (dim_col A)› by auto 
      qed
    qed
    have ncimp: "f i = ?nc ⟹ (∀ k ≥i. k ∉ { i. i < ?nr ∧ row A i ≠ 0⇩v ?nc})"
    proof -
      assume nchyp: "f i = ?nc"
      show "(∀ k ≥i. k ∉ { i. i < ?nr ∧ row A i ≠ 0⇩v ?nc})"
      proof clarsimp 
        fix k
        assume i_lt: "i ≤ k" 
        assume k_lt: "k < dim_row A"
        show "row A k = 0⇩v (dim_col A) "
          using i_lt k_lt fis_nc
          using pivot_fun_zero_row_iff[of A f ?nc ?nr]
          using ‹pivot_fun A f (dim_col A)› le_neq_implies_less nchyp by blast 
      qed
    qed
    then have "f i = ?nc ⟹ card { i. i < ?nr ∧ row A i ≠ 0⇩v ?nc} ≤ i"
    proof - 
      assume nchyp: "f i = ?nc"
      have h: "{ i. i < ?nr ∧ row A i ≠ 0⇩v ?nc} ⊆ {0..<i}"
        using atLeast0LessThan le_less_linear nchyp ncimp by blast
      then show " card { i. i < ?nr ∧ row A i ≠ 0⇩v ?nc} ≤ i"
        using card_lessThan
        using subset_eq_atLeast0_lessThan_card by blast 
    qed
    then show "f i < ?nc" using i_lt pivot_positions(4)[of A ?nr ?nc f]
      apply (auto)
      by (metis ‹pivot_fun A f (dim_col A)› i_nr le_neq_implies_less not_less pivot_funD(1)) 
  qed
  then show ?thesis
    using pivot_positions(1)
    by (smt (verit, ccfv_SIG) ‹pivot_fun A f (dim_col A)› carrier_matI i_nr less_not_refl mem_Collect_eq)
qed
lemma pivot_positions_form_helper_1:
  shows "(a, b) ∈ set (pivot_positions_main_gen z A nr nc i j) ⟹ i ≤ a"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j]
    by (metis Pair_inject Suc_leD emptyE list.set(1) nle_le set_ConsD)
qed
lemma pivot_positions_form_helper_2:
  shows "sorted_wrt (<) (map fst (pivot_positions_main_gen z A nr nc i j))"
proof  (induct i j rule: pivot_positions_main_gen.induct[of nr nc A z])
  case (1 i j)
  then show ?case using  pivot_positions_main_gen.simps[of z A nr nc i j] 
    by (auto simp: pivot_positions_form_helper_1 Suc_le_lessD) 
qed
lemma sorted_pivot_positions:
  shows "sorted_wrt (<) (map fst (pivot_positions A))"
  using pivot_positions_form_helper_2
  by (simp add: pivot_positions_form_helper_2 pivot_positions_gen_def) 
lemma pivot_positions_form:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A ≥ dim_row A"
  shows "∀ i < (length (pivot_positions A)). fst ((pivot_positions A) ! i) = i"
proof clarsimp 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  let ?pp = "pivot_positions A :: (nat × nat) list"
  fix i
  assume i_lt: "i < length (pivot_positions A)"
  have "∃f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
    by blast
  then obtain f where pf:"pivot_fun A f ?nc"
    by blast                  
  have all_f_in: "∀i < (length ?pp). (i, f i) ∈ set ?pp"
    using pivot_positions_contains pf
      assms 
    by blast   
  have sorted_hyp: "⋀ (p::nat) (q::nat). p < (length ?pp) ⟹ q < (length ?pp) ⟹ p < q ⟹ (fst (?pp ! p) < fst (?pp ! q))"
  proof -
    fix p::nat
    fix q::nat
    assume p_lt: "p < q"
    assume p_welldef: "p < (length ?pp)"
    assume q_welldef: "q < (length ?pp)"
    show "fst (?pp ! p) < fst (?pp ! q)"
      using sorted_pivot_positions p_lt p_welldef q_welldef sorted_wrt_nth_less by fastforce
  qed
  have h: "i < (length ?pp) ⟶ fst (pivot_positions A ! i) = i"
  proof (induct i)
    case 0
    have "∃j. fst (pivot_positions A ! j) = 0"
      by (metis all_f_in fst_conv i_lt in_set_conv_nth length_greater_0_conv list.size(3) not_less0)
    then obtain j where jth:" fst (pivot_positions A ! j) = 0"
      by blast      
    have "j ≠ 0 ⟶ (fst (pivot_positions A ! 0) > 0 ⟶ j ≤ 0)"
      by (smt (verit, ccfv_SIG) all_f_in fst_conv i_lt in_set_conv_nth less_nat_zero_code not_gr_zero sorted_hyp)
    then show ?case
      using jth neq0_conv by blast
  next
    case (Suc i)
    have ind_h: "i < length (pivot_positions A) ⟶ fst (pivot_positions A ! i) = i"
      using Suc.hyps by blast 
    have thesis_h: "(Suc i) < length (pivot_positions A) ⟹ fst (pivot_positions A ! (Suc i)) = (Suc i)"
    proof - 
      assume suc_i_lt: "(Suc i) < length (pivot_positions A)"
      have fst_i_is: "fst (pivot_positions A ! i) = i" using suc_i_lt ind_h
        using Suc_lessD by blast 
      have "(∃j < (length ?pp). fst (pivot_positions A ! j) = (Suc i))"
        by (metis suc_i_lt all_f_in fst_conv  in_set_conv_nth)
      then obtain j where jth: "j < (length ?pp) ∧ fst (pivot_positions A ! j) = (Suc i)"
        by blast
      have "j > i"
        using sorted_hyp apply (auto)
        by (metis Suc_lessD ‹fst (pivot_positions A ! i) = i› jth less_not_refl linorder_neqE_nat n_not_Suc_n suc_i_lt)
      have "j > (Suc i) ⟹ False"
      proof -
        assume j_gt: "j > (Suc i)"
        then have h1: "fst (pivot_positions A ! (Suc i)) > i"
          using fst_i_is sorted_pivot_positions
          using sorted_hyp suc_i_lt by force
        have "fst (pivot_positions A ! j) > fst (pivot_positions A ! (Suc i))"
          using jth j_gt sorted_hyp apply (auto)
          by fastforce 
        then have h2: "fst (pivot_positions A ! (Suc i)) < (Suc i)" 
          using jth
          by simp   
        show "False" using h1 h2
          using not_less_eq by blast 
      qed
      show "fst (pivot_positions A ! (Suc i)) = (Suc i)"
        using Suc_lessI ‹Suc i < j ⟹ False› ‹i < j› jth by blast
    qed
    then show ?case
      by blast 
  qed
  then show "fst (pivot_positions A ! i) = i"
    using i_lt by auto
qed
lemma take_cols_pivot_eq:
  assumes row_ech: "row_echelon_form A"
  assumes dim_h: "dim_col A ≥ dim_row A"
  shows "take_cols A (map snd (pivot_positions A)) =
    1⇩m (length (pivot_positions A)) @⇩r
    0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))"
proof - 
  let ?nr = "dim_row A"
  let ?nc = "dim_col A"
  have h1: " dim_col
     (1⇩m (length (pivot_positions A)) @⇩r
      0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) = (length (pivot_positions A))"
    by (simp add: append_rows_def)
  have len_pivot: "length (pivot_positions A) = card {i. i < ?nr ∧ row A i ≠ 0⇩v ?nc}"
    using row_ech pivot_positions(4) row_echelon_form_def by blast
  have pp_leq_nc: "∀f. pivot_fun A f ?nc ⟶ (∀i < ?nr. f i ≤ ?nc)" unfolding pivot_fun_def
    by meson 
  have pivot_set: "∃f. pivot_fun A f ?nc ∧ set (pivot_positions A) = {(i, f i) | i. i < ?nr ∧ f i ≠ ?nc}"
    using row_ech row_echelon_form_def pivot_positions(1)
    by (smt (verit) Collect_cong carrier_matI)
  then have pivot_set_alt: "∃f. pivot_fun A f ?nc ∧ set (pivot_positions A) = {(i, f i) | i. i < ?nr ∧ row A i ≠ 0⇩v ?nc}"
    using pivot_positions pivot_fun_zero_row_iff Collect_cong carrier_mat_triv
    by (smt (verit, best))
  have "∃f. pivot_fun A f ?nc ∧ set (pivot_positions A) = {(i, f i) | i. f i ≤ ?nc ∧ i < ?nr ∧ f i ≠ ?nc}"
    using pivot_set pp_leq_nc by auto
  then have pivot_set_var: "∃f. pivot_fun A f ?nc ∧ set (pivot_positions A) = {(i, f i) | i. i < ?nr ∧ f i < ?nc}"
    by auto
  have "length (map snd (pivot_positions A)) = card (set (map snd (pivot_positions A)))" 
    using row_ech row_echelon_form_def pivot_positions(3) distinct_card[where xs = "map snd (pivot_positions A)"]
    by (metis carrier_mat_triv)
  then have "length (map snd (pivot_positions A)) = card (set (pivot_positions A))"
    by (metis card_distinct distinct_card distinct_map length_map) 
  then have "length (map snd (pivot_positions A)) = card {i. i < ?nr ∧ row A i ≠ 0⇩v ?nc}"
    using pivot_set_alt
    by (simp add: len_pivot) 
  then have length_asm: "length (map snd (pivot_positions A)) = length (pivot_positions A)"
    using len_pivot by linarith
  then have "∀a. List.member (map snd (pivot_positions A)) a ⟶ a < dim_col A"
  proof clarsimp 
    fix a
    assume a_in: "List.member (map snd (pivot_positions A)) a"
    have "∃v ∈ set (pivot_positions A). a = snd v" 
      using a_in in_set_member[where xs = "(pivot_positions A)"] apply (auto)
      by (metis in_set_impl_in_set_zip2 in_set_member length_map snd_conv zip_map_fst_snd) 
    then show "a < dim_col A"
      using pivot_set_var in_set_member by auto
  qed
  then have h2b: "(filter (λy. y < dim_col A) (map snd (pivot_positions A))) =  (map snd (pivot_positions A))"
    by (meson filter_True in_set_member)
  then have h2a: "length (map ((!) (cols A)) (filter (λy. y < dim_col A) (map snd (pivot_positions A)))) = length (pivot_positions A)"
    using length_asm
    by (simp add: h2b) 
  then have h2: "length (pivot_positions A) ≤ dim_row A ⟹
    dim_col (take_cols A (map snd (pivot_positions A))) = (length (pivot_positions A))" 
    unfolding take_cols_def using mat_of_cols_carrier by auto
  have h_len: "length (pivot_positions A) ≤ dim_row A ⟹
    dim_col (take_cols A (map snd (pivot_positions A))) =
    dim_col
     (1⇩m (length (pivot_positions A)) @⇩r
      0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))" 
    using h1 h2
    by (simp add: h1 assms length_pivot_positions_dim_row)
  have h2: "⋀i j. length (pivot_positions A) ≤ dim_row A ⟹
           i < dim_row A ⟹
           j < dim_col
                (1⇩m (length (pivot_positions A)) @⇩r
                 0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) ⟹
           take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1⇩m (length (pivot_positions A)) @⇩r
            0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
  proof -
    fix i 
    fix j 
    let ?pp = "(pivot_positions A)"
    assume len_lt: "length (pivot_positions A) ≤ dim_row A" 
    assume i_lt: " i < dim_row A" 
    assume j_lt: "j < dim_col
                (1⇩m (length (pivot_positions A)) @⇩r
                 0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A)))"
    let ?w = "((map snd (pivot_positions A)) ! j)"
    have breaking_it_down: "mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  
     =  ((cols A) ! ?w) $ i"
      apply (auto)
      by (metis comp_apply h1 i_lt j_lt length_map mat_of_cols_index nth_map) 
    have h1a: "i < (length ?pp) ⟹ (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1⇩m (length (pivot_positions A))) $$ (i, j))"
    proof - 
      
      assume "i < (length ?pp)"
      have "∃f. pivot_fun A f ?nc" using row_ech unfolding row_echelon_form_def
        by blast
      then obtain f where "pivot_fun A f ?nc"
        by blast
      have j_nc: "j < (length ?pp)" using j_lt
        by (simp add: h1) 
      then have j_lt_nr: "j < ?nr" using dim_h
        using len_lt by linarith 
      then have is_this_true: "(pivot_positions A) ! j = (j, f j)" 
        using pivot_positions_form pivot_positions(1)[of A ?nr ?nc f]
      proof -
        have "pivot_positions A ! j ∈ set (pivot_positions A)"
          using j_nc nth_mem by blast
        then have "∃n. pivot_positions A ! j = (n, f n) ∧ n < dim_row A ∧ f n ≠ dim_col A"
          using ‹⟦A ∈ carrier_mat (dim_row A) (dim_col A); pivot_fun A f (dim_col A)⟧ ⟹ set (pivot_positions A) = {(i, f i) |i. i < dim_row A ∧ f i ≠ dim_col A}› ‹pivot_fun A f (dim_col A)› by blast
        then show ?thesis
          by (metis (no_types) ‹⋀A. ⟦row_echelon_form A; dim_row A ≤ dim_col A⟧ ⟹ ∀i<length (pivot_positions A). fst (pivot_positions A ! i) = i› dim_h fst_conv j_nc row_ech)
      qed
      then have w_is: "?w = f j"
        by (metis h1 j_lt nth_map snd_conv)
      have h0: "i = j ⟶ ((cols A) ! ?w) $ i = 1" using w_is pivot_funD(4)[of A ?nr f ?nc i]
        by (metis ‹∀a. List.member (map snd (pivot_positions A)) a ⟶ a < dim_col A› ‹i < length (pivot_positions A)› ‹pivot_fun A f (dim_col A)› cols_length i_lt in_set_member length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h1:  "i ≠ j ⟶ ((cols A) ! ?w) $ i = 0" using w_is pivot_funD(5)
        by (metis ‹∀a. List.member (map snd (pivot_positions A)) a ⟶ a < dim_col A› ‹pivot_fun A f (dim_col A)› cols_length h1 i_lt in_set_member j_lt len_lt length_asm less_le_trans mat_of_cols_cols mat_of_cols_index nth_mem)
      show "(mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) 
        = (1⇩m (length (pivot_positions A))) $$ (i, j))" using h0 h1 breaking_it_down
        by (metis ‹i < length (pivot_positions A)› h2 h_len index_one_mat(1) j_lt len_lt) 
    qed
    have h1b: "i ≥ (length ?pp) ⟹ (mat_of_cols (dim_row A) (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j)  = 0)"
    proof - 
      assume i_gt: "i ≥ (length ?pp)"
      have h0a: "((cols A) ! ((map snd (pivot_positions A)) ! j)) $ i = (row A i) $ ?w"
        by (metis ‹∀a. List.member (map snd (pivot_positions A)) a ⟶ a < dim_col A› cols_length h1 i_lt in_set_member index_row(1) j_lt length_asm mat_of_cols_cols mat_of_cols_index nth_mem)
      have h0b: 
        "take_rows A [0..<length (pivot_positions A)] @⇩r 0⇩m (dim_row A - length (pivot_positions A)) (dim_col A) = A"
        using assms row_echelon_form_zero_rows[of A]
        by blast 
      then have h0c: "(row A i) = 0⇩v (dim_col A)"  using i_gt
        by (smt (verit, best) add_diff_cancel_left' add_diff_cancel_right' add_less_cancel_left dim_col_take_rows 
            dim_row_append_rows i_lt index_zero_mat(2) index_zero_mat(3) le_Suc_ex len_lt nat_less_le nle_le row_append_rows row_zero)
      then show ?thesis using h0a breaking_it_down apply (auto)
        by (metis ‹∀a. List.member (map snd (pivot_positions A)) a ⟶ a < dim_col A› h1 in_set_member index_zero_vec(1) j_lt length_asm nth_mem) 
    qed
    have h1: " mat_of_cols (dim_row A)
     (map ((!) (cols A)) (map snd (pivot_positions A))) $$ (i, j) =
           (1⇩m (length (pivot_positions A)) @⇩r
            0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j) " using h1a h1b
      by (smt (verit) add_diff_inverse_nat append_rows_index diff_less_mono h1 i_lt index_one_mat(2) index_one_mat(3) index_zero_mat(1) index_zero_mat(2) index_zero_mat(3) j_lt leD len_lt not_le_imp_less)
    then show "take_cols A (map snd (pivot_positions A)) $$ (i, j) =
           (1⇩m (length (pivot_positions A)) @⇩r
            0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))) $$
           (i, j)" 
      unfolding take_cols_def
      by (simp add: h2b)
  qed
  show ?thesis
    unfolding mat_eq_iff
    using length_pivot_positions_dim_row[OF assms(1)] h_len h2 by auto
qed
lemma rref_right_mul:
  assumes "row_echelon_form A"
  assumes "dim_col A ≥ dim_row A"
  shows
    "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] = A"
proof -
  from take_cols_pivot_eq[OF assms] have
    1: "take_cols A (map snd (pivot_positions A)) =
    1⇩m (length (pivot_positions A)) @⇩r
    0⇩m (dim_row A - length (pivot_positions A)) (length (pivot_positions A))" .
  have 2: "take_cols A (map snd (pivot_positions A)) * take_rows A [0..<length (pivot_positions A)] =
    take_rows A [0..<length (pivot_positions A)]  @⇩r 0⇩m (dim_row A - length (pivot_positions A)) (dim_col A)"
    unfolding 1
    apply (simp add: append_rows_mat_mul)
    by (metis (no_types, lifting) "1" add_right_imp_eq assms dim_col_take_rows dim_row_append_rows dim_row_take_cols index_one_mat(2) index_zero_mat(2) left_mult_one_mat' left_mult_zero_mat' row_echelon_form_zero_rows)
  from row_echelon_form_zero_rows[OF assms] have "... = A" .
  thus ?thesis
    by (simp add: "2")
qed
context conjugatable_vec_space begin
lemma lin_indpt_id:
  shows "lin_indpt (set (cols (1⇩m n)::'a vec list))"
proof -
  have *: "set (cols (1⇩m n)) = set (rows (1⇩m n))"
    by (metis cols_transpose transpose_one)
  have "det (1⇩m n) ≠ 0" using det_one by auto
  from det_not_0_imp_lin_indpt_rows[OF _ this]
  have "lin_indpt (set (rows (1⇩m n)))"
    using one_carrier_mat by blast
  thus ?thesis
    by (simp add: *) 
qed
lemma lin_indpt_take_cols_id:
  shows "lin_indpt (set (cols (take_cols (1⇩m n) inds)))"
proof - 
  have subset_h: "set (cols (take_cols (1⇩m n) inds)) ⊆ set (cols (1⇩m n)::'a vec list)"
    using cols_take_cols_subset by blast
  then show ?thesis using lin_indpt_id subset_li_is_li by auto
qed
lemma cols_id_unit_vecs:
  shows "cols (1⇩m d) = unit_vecs d"
  unfolding unit_vecs_def list_eq_iff_nth_eq
  by auto
lemma distinct_cols_id:
  shows "distinct (cols (1⇩m d)::'a vec list)"
  by (simp add: conjugatable_vec_space.cols_id_unit_vecs vec_space.unit_vecs_distinct)
lemma distinct_map_nth:
  assumes "distinct ls"
  assumes "distinct inds"
  assumes "⋀j. j ∈ set inds ⟹ j < length ls"
  shows "distinct (map ((!) ls) inds)"
  by (simp add: assms(1) assms(2) assms(3) distinct_map inj_on_nth)
lemma distinct_take_cols_id:
  assumes "distinct inds"
  shows "distinct (cols (take_cols (1⇩m n) inds) :: 'a vec list)"
  unfolding take_cols_def
  apply (subst cols_mat_of_cols)
   apply (auto intro!:  distinct_map_nth simp add: distinct_cols_id)
  using assms distinct_filter by blast
lemma rank_take_cols:
  assumes "distinct inds"
  shows "rank (take_cols (1⇩m n) inds) = length (filter ((>) n) inds)"
  apply (subst lin_indpt_full_rank[of _ "length (filter ((>) n) inds)"])
     apply (auto simp add: lin_indpt_take_cols_id)
   apply (metis (full_types) index_one_mat(2) index_one_mat(3) length_map mat_of_cols_carrier(1) take_cols_def)
  by (simp add: assms distinct_take_cols_id)
lemma rank_mul_left_invertible_mat:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A ∈ carrier_mat n n"
  assumes "B ∈ carrier_mat n nc"
  shows "rank (A * B) = rank B"
proof -
  obtain C where C: "inverts_mat A C" "inverts_mat C A"
    using assms invertible_mat_def by blast 
  from C have ceq: "C * A = 1⇩m n"
    by (metis assms(2) carrier_matD(2) index_mult_mat(3) index_one_mat(3) inverts_mat_def)
  then have *:"B = C*A*B"
    using assms(3) by auto
  from rank_mat_mul_left[OF assms(2-3)]
  have **: "rank (A*B) ≤ rank B" .
  have 1: "C ∈ carrier_mat n n" using C ceq
    by (metis assms(2) carrier_matD(1) carrier_matI index_mult_mat(3) index_one_mat(3) inverts_mat_def) 
  have 2: "A * B ∈ carrier_mat n nc" using assms by auto  
  have "rank B = rank (C* A * B)" using * by auto
  also have "... ≤ rank (A*B)" using rank_mat_mul_left[OF 1 2]
    using "1" assms(2) assms(3) by auto
  ultimately show ?thesis using ** by auto
qed
lemma invertible_take_cols_rank:
  fixes A::"'a mat"
  assumes "invertible_mat A"
  assumes "A ∈ carrier_mat n n"
  assumes "distinct inds"
  shows "rank (take_cols A inds) = length (filter ((>) n) inds)"
proof -
  have " A = A * 1⇩m n" using assms(2) by auto
  then have "take_cols A inds = A * take_cols (1⇩m n) inds"
    by (metis assms(2) one_carrier_mat take_cols_mat_mul)
  then have "rank (take_cols A inds) = rank (take_cols (1⇩m n) inds)"
    by (metis assms(1) assms(2) conjugatable_vec_space.rank_mul_left_invertible_mat one_carrier_mat take_cols_carrier_mat)
  thus ?thesis
    by (simp add: assms(3) conjugatable_vec_space.rank_take_cols)
qed
lemma rank_take_cols_leq:
  assumes R:"R ∈ carrier_mat n nc"
  shows "rank (take_cols R ls) ≤ rank R"
proof -
  from take_cols_mat_mul[OF R]
  have "take_cols R ls =  R * take_cols (1⇩m nc) ls"
    by (metis assms one_carrier_mat right_mult_one_mat)
  thus ?thesis
    by (metis assms one_carrier_mat take_cols_carrier_mat vec_space.rank_mat_mul_right)
qed
lemma rank_take_cols_geq:
  assumes R:"R ∈ carrier_mat n nc"
  assumes t:"take_cols R ls ∈ carrier_mat n r"
  assumes B:"B ∈ carrier_mat r nc"
  assumes "R = (take_cols R ls) * B"
  shows "rank (take_cols R ls) ≥ rank R"
  by (metis B assms(4) t vec_space.rank_mat_mul_right)
lemma rref_drop_pivots:
  assumes row_ech: "row_echelon_form R"
  assumes dims: "R ∈ carrier_mat n nc"
  assumes order: "nc ≥ n"
  shows "rank (take_cols R (map snd (pivot_positions R))) = rank R"
proof -
  let ?B = "take_rows R [0..<length (pivot_positions R)]"
  have equa: "R = take_cols R (map snd (pivot_positions R)) * ?B" using assms rref_right_mul
    by (metis carrier_matD(1) carrier_matD(2))
  have ex_r: "∃r. take_cols R (map snd (pivot_positions R)) ∈ carrier_mat n r ∧ ?B ∈ carrier_mat r nc"
  proof - 
    have h1:
      "take_cols R (map snd (pivot_positions R)) ∈ carrier_mat n (length (pivot_positions R))"
      using assms
      by (metis in_set_impl_in_set_zip2 length_map rref_pivot_positions take_cols_carrier_mat_strict zip_map_fst_snd)
    have "∃ f. pivot_fun R f nc" using row_ech unfolding row_echelon_form_def using dims
      by blast
    then have "length (pivot_positions R) = card {i. i < n ∧ row R i ≠ 0⇩v nc}"
      using pivot_positions[of R n nc]
      using dims by auto 
    then have "nc ≥ length (pivot_positions R)" using order
      using carrier_matD(1) dims dual_order.trans length_pivot_positions_dim_row row_ech by blast
    then have "dim_col R ≥ length (pivot_positions R)" using dims by auto
    then have h2: "?B ∈ carrier_mat (length (pivot_positions R)) nc" unfolding take_rows_def
      using dims 
      by (smt (verit) atLeastLessThan_iff carrier_matD(2) filter_True le_eq_less_or_eq length_map 
          length_pivot_positions_dim_row less_trans map_nth mat_of_cols_carrier(1) row_ech set_upt transpose_carrier_mat transpose_mat_of_rows) 
    show ?thesis using h1 h2
      by blast
  qed
    
  have "rank R  ≤ rank (take_cols R (map snd (pivot_positions R)))"
    using dims ex_r rank_take_cols_geq[where R = "R", where B = "?B", where ls = "(map snd (pivot_positions R))", where nc = "nc"]
    using equa by blast
  thus ?thesis
    using assms(2) conjugatable_vec_space.rank_take_cols_leq le_antisym by blast
qed
lemma gjs_and_take_cols_var:
  fixes A::"'a mat"
  assumes A:"A ∈ carrier_mat n nc"
  assumes order: "nc ≥ n"
  shows "(take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = 
  (take_cols_var A (map snd (pivot_positions (gauss_jordan_single A))))"
proof -
  let ?gjs = "(gauss_jordan_single A)"
  have "∀x. List.member (map snd (pivot_positions (gauss_jordan_single A))) x ⟶ x ≤ dim_col A"  
    using rref_pivot_positions gauss_jordan_single(3) carrier_matD(2) gauss_jordan_single(2) in_set_impl_in_set_zip2 in_set_member length_map less_irrefl less_trans not_le_imp_less zip_map_fst_snd
    by (metis (no_types, lifting) carrier_mat_triv)
  then have "(filter (λy. y < dim_col A) (map snd (pivot_positions (gauss_jordan_single A)))) = 
    (map snd (pivot_positions (gauss_jordan_single A)))"
    by (metis (no_types, lifting) A carrier_matD(2) filter_True gauss_jordan_single(2) gauss_jordan_single(3) in_set_impl_in_set_zip2 length_map rref_pivot_positions zip_map_fst_snd)
  then show ?thesis unfolding take_cols_def take_cols_var_def
    by simp
qed
lemma gauss_jordan_single_rank:
  fixes A::"'a mat"
  assumes A:"A ∈ carrier_mat n nc"
  assumes order: "nc ≥ n"
  shows "rank (take_cols A (map snd (pivot_positions (gauss_jordan_single A)))) = rank A"
proof -
  let ?R = "gauss_jordan_single A"
  obtain P where P:"P∈Units (ring_mat TYPE('a) n undefined)" and
    i: "?R = P * A" using gauss_jordan_transform[OF A]
    by (metis A carrier_matD(1) fst_eqD gauss_jordan_single_def surj_pair zero_carrier_mat)
  have pcarrier: "P ∈ carrier_mat n n" using P unfolding Units_def
    by (auto simp add: ring_mat_def)
  have "invertible_mat P" using P unfolding invertible_mat_def Units_def inverts_mat_def
    apply auto
     apply (simp add: ring_mat_simps(5))
    by (metis index_mult_mat(2) index_one_mat(2) ring_mat_simps(1) ring_mat_simps(3))
  then
  obtain Pi where Pi: "invertible_mat Pi" "Pi * P = 1⇩m n"
  proof -
    assume a1: "⋀Pi. ⟦invertible_mat Pi; Pi * P = 1⇩m n⟧ ⟹ thesis"
    have "dim_row P = n"
      by (metis (no_types) A assms(1) carrier_matD(1) gauss_jordan_single(2) i index_mult_mat(2))
    then show ?thesis
      using a1 by (metis (no_types) ‹invertible_mat P› index_mult_mat(3) index_one_mat(3) invertible_mat_def inverts_mat_def square_mat.simps)
  qed
  then have pi_carrier:"Pi ∈ carrier_mat n n"
    by (metis carrier_mat_triv index_mult_mat(2) index_one_mat(2) invertible_mat_def square_mat.simps)
  have R1:"row_echelon_form ?R"
    using assms(2) gauss_jordan_single(3) by blast
  have R2: "?R ∈ carrier_mat n nc"
    using A assms(2) gauss_jordan_single(2) by auto
  have Rcm: "take_cols ?R (map snd (pivot_positions ?R))
    ∈ carrier_mat n (length (map snd (pivot_positions ?R)))"
    apply (rule take_cols_carrier_mat_strict[OF R2])
    using rref_pivot_positions[OF R1 R2] by auto
  have "Pi * ?R = A" using i Pi
    by (smt (verit, best) A assoc_mult_mat left_mult_one_mat pcarrier pi_carrier)
  then have "rank (take_cols A (map snd (pivot_positions ?R))) = rank (take_cols (Pi * ?R) (map snd (pivot_positions ?R)))"
    by auto
  also have "... = rank ( Pi * take_cols ?R (map snd (pivot_positions ?R)))"
    by (metis A gauss_jordan_single(2) pi_carrier take_cols_mat_mul)
  also have "... = rank (take_cols ?R (map snd (pivot_positions ?R)))"
    by (intro rank_mul_left_invertible_mat[OF Pi(1) pi_carrier Rcm])
  also have "... = rank ?R"
    using assms(2) conjugatable_vec_space.rref_drop_pivots gauss_jordan_single(3)
    using R1 R2 by blast
  ultimately show ?thesis                                                            
    using A ‹P ∈ carrier_mat n n› ‹invertible_mat P› conjugatable_vec_space.rank_mul_left_invertible_mat i
    by auto
qed
lemma lin_indpt_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec set"
  assumes "A ∈ carrier_mat n n"
  assumes inv: "invertible_mat A"
  assumes "B ⊆ set (cols A)"
  shows "lin_indpt B"
proof -
  have "det A ≠ 0"
    using assms(1) inv invertible_det by blast
  then have "lin_indpt (set (rows A⇧T))"
    using assms(1) idom_vec.lin_dep_cols_imp_det_0 by auto
  thus ?thesis using subset_li_is_li assms(3)
    by auto
qed
lemma rank_invertible_subset_cols:
  fixes A:: "'a mat"
  fixes B:: "'a vec list"
  assumes inv: "invertible_mat A"
  assumes A_square: "A ∈ carrier_mat n n"
  assumes set_sub: "set (B) ⊆ set (cols A)"
  assumes dist_B: "distinct B"
  shows "rank (mat_of_cols n B) = length B"
proof - 
  let ?B_mat = "(mat_of_cols n B)"
  have h1: "lin_indpt (set(B))" 
    using assms lin_indpt_subset_cols[of A] by auto
  have "set B ⊆ carrier_vec n"
    using set_sub A_square cols_dim[of A] by auto
  then have cols_B: "cols (mat_of_cols n B) = B" using cols_mat_of_cols by auto
  then have "maximal (set B) (λT. T ⊆ set (B) ∧ lin_indpt T)" using h1
    by (simp add: maximal_def subset_antisym)
  then have h2: "maximal (set B) (λT. T ⊆ set (cols (mat_of_cols n B)) ∧ lin_indpt T)"
    using cols_B by auto
  have h3: "rank (mat_of_cols n B) = card (set B)"
    using h1 h2 rank_card_indpt[of ?B_mat]
    using mat_of_cols_carrier(1) by blast 
  then show ?thesis using assms distinct_card by auto
qed
end
end