Theory LLL_Basis_Reduction.Gram_Schmidt_2

(*
    Authors:    Ralph Bottesch
                Jose Divasón
                Maximilian Haslbeck
                Sebastiaan Joosten
                René Thiemann
                Akihisa Yamada
    License:    BSD
*)

section ‹Gram-Schmidt›

theory Gram_Schmidt_2
  imports 
    Jordan_Normal_Form.Gram_Schmidt
    Jordan_Normal_Form.Show_Matrix
    Jordan_Normal_Form.Matrix_Impl
    Norms
    Int_Rat_Operations
begin
(* TODO: Documentation and add references to computer algebra book *)

no_notation Group.m_inv  ("invı _" [81] 80)

(* TODO: Is a function like this already in the library
   find_index is used to rewrite the sumlists in the lattice_of definition to finsums *)

fun find_index :: "'b list  'b  nat" where
  "find_index [] _ = 0" |
  "find_index (x#xs) y = (if x = y then 0 else find_index xs y + 1)"

lemma find_index_not_in_set: "x  set xs  find_index xs x = length xs"
  by (induction xs) auto

lemma find_index_in_set: "x  set xs  xs ! (find_index xs x) = x"
  by (induction xs) auto

lemma find_index_inj: "inj_on (find_index xs) (set xs)"
  by (induction xs) (auto simp add: inj_on_def)

lemma find_index_leq_length: "find_index xs x < length xs  x  set xs"
  by (induction xs) (auto)


(* TODO: move *)

lemma rev_unsimp: "rev xs @ (r # rs) = rev (r#xs) @ rs" by(induct xs,auto)


(* TODO: unify *)

lemma corthogonal_is_orthogonal[simp]: 
  "corthogonal (xs :: 'a :: trivial_conjugatable_ordered_field vec list) = orthogonal xs"
  unfolding corthogonal_def orthogonal_def by simp


(* TODO: move *)

context vec_module begin

definition lattice_of :: "'a vec list  'a vec set" where
  "lattice_of fs = range (λ c. sumlist (map (λ i. of_int (c i) v fs ! i) [0 ..< length fs]))"

lemma lattice_of_finsum:
  assumes "set fs  carrier_vec n"
  shows "lattice_of fs = range (λ c. finsum V (λ i. of_int (c i) v fs ! i) {0 ..< length fs})"
proof -
  have "sumlist (map (λ i. of_int (c i) v fs ! i) [0 ..< length fs])
        = finsum V (λ i. of_int (c i) v fs ! i) {0 ..< length fs}" for c
    using  assms by (subst sumlist_map_as_finsum) (fastforce)+
  then show ?thesis
    unfolding lattice_of_def by auto
qed

lemma in_latticeE: assumes "f  lattice_of fs" obtains c where
    "f = sumlist (map (λ i. of_int (c i) v fs ! i) [0 ..< length fs])" 
  using assms unfolding lattice_of_def by auto
    
lemma in_latticeI: assumes "f = sumlist (map (λ i. of_int (c i) v fs ! i) [0 ..< length fs])" 
  shows "f  lattice_of fs" 
  using assms unfolding lattice_of_def by auto

lemma finsum_over_indexes_to_vectors:
  assumes "set vs  carrier_vec n" "l = length vs"
  shows "c. (Vx{0..<l}. of_int (g x) v vs ! x) = (Vvset vs. of_int (c v) v v)"
  using assms proof (induction l arbitrary: vs)
  case (Suc l)
  then obtain vs' v where vs'_def: "vs = vs' @ [v]"
    by (metis Zero_not_Suc length_0_conv rev_exhaust)
  have c: "c. (Vi{0..<l}. of_int (g i) v vs' ! i) = (Vvset vs'. of_int (c v) v v)"
    using Suc vs'_def by (auto)
  then obtain c 
    where c_def: "(Vx{0..<l}. of_int (g x) v vs' ! x) = (Vvset vs'. of_int (c v) v v)"
    by blast
  have "(Vx{0..<Suc l}. of_int (g x) v vs ! x) 
        = of_int (g l) v vs ! l + (Vx{0..<l}. of_int (g x) v vs ! x)"
     using Suc by (subst finsum_insert[symmetric]) (fastforce intro!: finsum_cong')+
  also have "vs = vs' @ [v]"
    using vs'_def by simp
  also have "(Vx{0..<l}. of_int (g x) v (vs' @ [v]) ! x) = (Vx{0..<l}. of_int (g x) v vs' ! x)"
    using Suc vs'_def by (intro finsum_cong') (auto simp add: in_mono append_Cons_nth_left)
  also note c_def
  also have "(vs' @ [v]) ! l = v"
    using Suc vs'_def by auto
  also have "d'. of_int (g l) v v + (Vvset vs'. of_int (c v) v v) = (Vvset vs. of_int (d' v) v v)"
  proof (cases "v  set vs'")
    case True
    then have I: "set vs' = insert v (set vs' - {v})"
      by blast
    define c' where "c' x = (if x = v then c x + g l else c x)" for x
    have "of_int (g l) v v + (Vvset vs'. of_int (c v) v v)
          = of_int (g l) v v + (of_int (c v) v v + (Vvset vs' - {v}. of_int (c v) v v))"
      using Suc vs'_def by (subst I, subst finsum_insert) fastforce+
    also have " = of_int (g l) v v + of_int (c v) v v + (Vvset vs' - {v}. of_int (c v) v v)"
      using Suc vs'_def by (subst a_assoc) (auto intro!: finsum_closed)
    also have "of_int (g l) v v + of_int (c v) v v = of_int (c' v)  v v"
      unfolding c'_def by (auto simp add: add_smult_distrib_vec)
    also have "(Vvset vs' - {v}. of_int (c v) v v) = (Vvset vs' - {v}. of_int (c' v) v v)"
      using Suc vs'_def unfolding c'_def by (intro finsum_cong') (auto)
    also have "of_int (c' v) v v + (Vvset vs' - {v}. of_int (c' v) v v)
               = (Vvinsert v (set vs'). of_int (c' v) v v)"
      using Suc vs'_def by (subst finsum_insert[symmetric]) (auto)
    finally show ?thesis
      using vs'_def by force
  next
    case False
    define c' where "c' x = (if x = v then g l else c x)" for x
    have "of_int (g l) v v + (Vvset vs'. of_int (c v) v v)
          = of_int (c' v) v v + (Vvset vs'. of_int (c v) v v)"
      unfolding c'_def by simp
    also have "(Vvset vs'. of_int (c v) v v) = (Vvset vs'. of_int (c' v) v v)"
      unfolding c'_def using Suc False vs'_def by (auto intro!: finsum_cong')
    also have "of_int (c' v) v v + (Vvset vs'. of_int (c' v) v v)
               = (Vvinsert v (set vs'). of_int (c' v) v v)"
      using False Suc vs'_def by (subst finsum_insert[symmetric]) (auto)
    also have "(Vvset vs'. of_int (c' v) v v) = (Vvset vs'. of_int (c v) v v)"
      unfolding c'_def using False Suc vs'_def by (auto intro!: finsum_cong')
    finally show ?thesis
      using vs'_def by auto
  qed
  finally show ?case
    unfolding vs'_def by blast
qed (auto)

lemma lattice_of_altdef:
  assumes "set vs  carrier_vec n"
  shows "lattice_of vs = range (λc. Vvset vs. of_int (c v) v v)"
proof -
  have "v  lattice_of vs" if "v  range (λc. Vvset vs. of_int (c v) v v)" for v
  proof -
    obtain c where v: "v = (Vvset vs. of_int (c v) v v)"
      using v  range (λc. Vvset vs. of_int (c v) v v) by (auto)
    define c' where "c' i = (if find_index vs (vs ! i) = i then c (vs ! i) else 0)" for i
    have "v = (Vvset vs. of_int (c' (find_index vs v)) v vs ! (find_index vs v))"
      unfolding v
      using assms by (auto intro!: finsum_cong' simp add: c'_def find_index_in_set in_mono)
    also have " = (Vifind_index vs ` (set vs). of_int (c' i) v vs ! i)"
      using assms find_index_in_set find_index_inj by (subst finsum_reindex) fastforce+
    also have " = (Viset [0..<length vs]. of_int (c' i) v vs ! i)"
    proof -
      have "i  find_index vs ` set vs" if "i < length vs" "find_index vs (vs ! i) = i" for i
        using that by (metis imageI nth_mem)
      then show ?thesis
        unfolding c'_def using find_index_leq_length assms 
        by (intro add.finprod_mono_neutral_cong_left) (auto simp add: in_mono find_index_leq_length)
    qed
    also have " = sumlist (map (λi. of_int (c' i) v vs ! i) [0..<length vs])"
      using assms by (subst sumlist_map_as_finsum) (fastforce)+
    finally show ?thesis
      unfolding lattice_of_def by blast
  qed
  moreover have "v  range (λc. Vvset vs. of_int (c v) v v)" if "v  lattice_of vs" for v
  proof -
    obtain c where "v = sumlist (map (λi. of_int (c i) v vs ! i) [0..<length vs])"
      using v  lattice_of vs unfolding lattice_of_def by (auto)
    also have " = (Vx{0..<length vs}. of_int (c x) v vs ! x)"
      using that assms by (subst sumlist_map_as_finsum) fastforce+
    also obtain d where  " = (Vvset vs. of_int (d v) v v)"
      using finsum_over_indexes_to_vectors assms by blast
    finally show ?thesis
      by blast
  qed
  ultimately show ?thesis
    by fastforce
qed

lemma basis_in_latticeI:
  assumes fs: "set fs  carrier_vec n" and "f  set fs" 
  shows "f  lattice_of fs"
proof -
  define c :: "'a vec  int" where "c v = (if v = f then 1 else 0)" for v
  have "f = (Vv{f}. of_int (c v) v v)"
    using assms by (auto simp add: c_def)
  also have " = (Vvset fs. of_int (c v) v v)"
    using assms by (intro add.finprod_mono_neutral_cong_left) (auto simp add: c_def)
  finally show ?thesis
    using assms lattice_of_altdef by blast
qed

lemma lattice_of_eq_set:
  assumes "set fs = set gs" "set fs  carrier_vec n"
  shows "lattice_of fs = lattice_of gs"
  using assms lattice_of_altdef by simp

lemma lattice_of_swap: assumes fs: "set fs  carrier_vec n" 
  and ij: "i < length fs" "j < length fs" "i  j" 
  and gs: "gs = fs[ i := fs ! j, j := fs ! i]" 
shows "lattice_of gs = lattice_of fs"
  using assms mset_swap by (intro lattice_of_eq_set) auto

lemma lattice_of_add: assumes fs: "set fs  carrier_vec n" 
  and ij: "i < length fs" "j < length fs" "i  j" 
  and gs: "gs = fs[ i := fs ! i + of_int l v fs ! j]" 
shows "lattice_of gs = lattice_of fs"
proof -
  {
    fix i j l and fs :: "'a vec list" 
    assume *: "i < j" "j < length fs" and fs: "set fs  carrier_vec n"
    note * = ij(1) *
    let ?gs = "fs[ i := fs ! i + of_int l v fs ! j]"
    let ?len = "[0..<i] @ [i] @ [Suc i..<j] @ [j] @ [Suc j..<length fs]" 
    have "[0 ..< length fs] = [0 ..< j] @ [j] @ [Suc j ..< length fs]" using *
      by (metis append_Cons append_self_conv2 less_Suc_eq_le less_imp_add_positive upt_add_eq_append 
          upt_conv_Cons zero_less_Suc)
    also have "[0 ..< j] = [0 ..< i] @ [i] @ [Suc i ..< j]" using *
      by (metis append_Cons append_self_conv2 less_Suc_eq_le less_imp_add_positive upt_add_eq_append 
          upt_conv_Cons zero_less_Suc)
    finally have len: "[0..<length fs] = ?len" by simp
    from fs have fs: " i. i < length fs  fs ! i  carrier_vec n" unfolding set_conv_nth by auto
    from fs have fsd: " i. i < length fs  dim_vec (fs ! i) = n" by auto
    from fsd[of i] fsd[of j] * have fsd: "dim_vec (fs ! i) = n" "dim_vec (fs ! j) = n" by auto
    {
      fix f
      assume "f  lattice_of fs" 
      from in_latticeE[OF this, unfolded len] obtain c where
        f: "f = sumlist (map (λi. of_int (c i) v fs ! i) ?len)" by auto
      define sc where "sc = (λ xs. sumlist (map (λi. of_int (c i) v fs ! i) xs))"
      define d where "d = (λ k. if k = j then c j - c i * l else c k)"
      define sd where "sd = (λ xs. sumlist (map (λi. of_int (d i) v ?gs ! i) xs))"
      have isc: "set is  {0 ..< length fs}  sc is  carrier_vec n" for "is" 
        unfolding sc_def by (intro sumlist_carrier, auto simp: fs)
      have isd: "set is  {0 ..< length fs}  sd is  carrier_vec n" for "is" 
        unfolding sd_def using * by (intro sumlist_carrier, auto, rename_tac k,
        case_tac "k = i", auto simp: fs)
      let ?a = "sc [0..<i]" let ?b = "sc [i]" let ?c = "sc [Suc i ..< j]" let ?d = "sc [j]" 
      let ?e = "sc [Suc j ..< length fs]" 
      let ?A = "sd [0..<i]" let ?B = "sd [i]" let ?C = "sd [Suc i ..< j]" let ?D = "sd [j]" 
      let ?E = "sd [Suc j ..< length fs]" 
      let ?CC = "carrier_vec n" 
      have ae: "?a  ?CC" "?b  ?CC" "?c  ?CC" "?d  ?CC" "?e  ?CC"  
        using * by (auto intro: isc)
      have AE: "?A  ?CC" "?B  ?CC" "?C  ?CC" "?D  ?CC" "?E  ?CC"  
        using * by (auto intro: isd)
      have sc_sd: "{i,j}  set is  {}  sc is = sd is" for "is" 
        unfolding sc_def sd_def by (rule arg_cong[of _ _ sumlist], rule map_cong, auto simp: d_def,
        rename_tac k, case_tac "i = k", auto)
      have "f = ?a + (?b + (?c + (?d + ?e)))"         
        unfolding f map_append sc_def using fs *
        by ((subst sumlist_append, force, force)+, simp)
      also have " = ?a + ((?b + ?d) + (?c + ?e))" using ae by auto          
      also have " = ?A + ((?b + ?d) + (?C + ?E))" 
        using * by (auto simp: sc_sd)
      also have "?b + ?d = ?B + ?D" unfolding sd_def sc_def d_def sumlist_def
        by (rule eq_vecI, insert * fsd, auto simp: algebra_simps)
      finally have "f = ?A + (?B + (?C + (?D + ?E)))" using AE by auto
      also have " = sumlist (map (λi. of_int (d i) v ?gs ! i) ?len)" 
        unfolding f map_append sd_def using fs *
        by ((subst sumlist_append, force, force)+, simp)
      also have " = sumlist (map (λi. of_int (d i) v ?gs ! i) [0 ..< length ?gs])"
        unfolding len[symmetric] by simp
      finally have "f = sumlist (map (λi. of_int (d i) v ?gs ! i) [0 ..< length ?gs])" .
      from in_latticeI[OF this] have "f  lattice_of ?gs" .
    }
    hence "lattice_of fs  lattice_of ?gs" by blast
  } note main = this 
  {
    fix i j and fs :: "'a vec list" 
    assume *: "i < j" "j < length fs" and fs: "set fs  carrier_vec n"
    let ?gs = "fs[ i := fs ! i + of_int l v fs ! j]"
    define gs where "gs = ?gs" 
    from main[OF * fs, of l, folded gs_def]
    have one: "lattice_of fs  lattice_of gs" .
    have *: "i < j" "j < length gs" "set gs  carrier_vec n" using * fs unfolding gs_def set_conv_nth
      by (auto, rename_tac k, case_tac "k = i", (force intro!: add_carrier_vec)+)
    from fs have fs: " i. i < length fs  fs ! i  carrier_vec n" unfolding set_conv_nth by auto
    from fs have fsd: " i. i < length fs  dim_vec (fs ! i) = n" by auto
    from fsd[of i] fsd[of j] * have fsd: "dim_vec (fs ! i) = n" "dim_vec (fs ! j) = n" by (auto simp: gs_def)
    from main[OF *, of "-l"]
    have "lattice_of gs  lattice_of (gs[i := gs ! i + of_int (- l) v gs ! j])" .
    also have "gs[i := gs ! i + of_int (- l) v gs ! j] = fs" unfolding gs_def
      by (rule nth_equalityI, auto, insert * fsd, rename_tac k, case_tac "k = i", auto)
    ultimately have "lattice_of fs = lattice_of ?gs" using one unfolding gs_def by auto
  } note main = this
  show ?thesis
  proof (cases "i < j")
    case True
    from main[OF this ij(2) fs] show ?thesis unfolding gs by simp
  next
    case False
    with ij have ji: "j < i" by auto
    define hs where "hs = fs[i := fs ! j, j := fs ! i]" 
    define ks where "ks = hs[j := hs ! j + of_int l v hs ! i]" 
    from ij fs have ij': "i < length hs" "set hs  carrier_vec n" unfolding hs_def by auto
    hence ij'': "set ks  carrier_vec n" "i < length ks" "j < length ks" "i  j" 
      using ji unfolding ks_def set_conv_nth by (auto, rename_tac k, case_tac "k = i", 
        force, case_tac "k = j", (force intro!: add_carrier_vec)+)
    from lattice_of_swap[OF fs ij refl] 
    have "lattice_of fs = lattice_of hs" unfolding hs_def by auto
    also have " = lattice_of ks" 
      using main[OF ji ij'] unfolding ks_def .
    also have " = lattice_of (ks[i := ks ! j, j := ks ! i])" 
      by (rule sym, rule lattice_of_swap[OF ij'' refl])
    also have "ks[i := ks ! j, j := ks ! i] = gs" unfolding gs ks_def hs_def
      by (rule nth_equalityI, insert ij, auto, 
       rename_tac k, case_tac "k = i", force, case_tac "k = j", auto)
    finally show ?thesis by simp
  qed
qed

definition "orthogonal_complement W = {x. x  carrier_vec n  (y  W. x  y = 0)}"

lemma orthogonal_complement_subset:
  assumes "A  B"
  shows "orthogonal_complement B  orthogonal_complement A"
unfolding orthogonal_complement_def using assms by auto

end

context vec_space
begin


lemma in_orthogonal_complement_span[simp]:
  assumes [intro]:"S  carrier_vec n"
  shows "orthogonal_complement (span S) = orthogonal_complement S"
proof
  show "orthogonal_complement (span S)  orthogonal_complement S"
    by(fact orthogonal_complement_subset[OF in_own_span[OF assms]])
  {fix x :: "'a vec"
    fix a fix A :: "'a vec set"
    assume x [intro]:"x  carrier_vec n" and f: "finite A" and S:"A  S"
    assume i0:"yS. x  y = 0"
    have "x  lincomb a A = 0"
      unfolding comm_scalar_prod[OF x lincomb_closed[OF subset_trans[OF S assms]]]
    proof(insert S,atomize(full),rule finite_induct[OF f],goal_cases)
      case 1 thus ?case using assms x by force
    next
      case (2 f F)
      { assume i:"insert f F  S"
        hence F:"F  S" and f: "f  S" by auto
        from F f assms
        have [intro]:"F  carrier_vec n"
          and fc[intro]:"f  carrier_vec n"
          and [intro]:"x  F  x  carrier_vec n" for x by auto
        have laf:"lincomb a F  x = 0" using F 2 by auto
        have [simp]:"(uF. (a u v u)  x) = 0"
          by(insert laf[unfolded lincomb_def],atomize(full),subst finsum_scalar_prod_sum) auto
        from f i0 have [simp]:"f  x = 0" by (subst comm_scalar_prod) auto
        from lincomb_closed[OF subset_trans[OF i assms]]
        have "lincomb a (insert f F)  x = 0" unfolding lincomb_def
          apply(subst finsum_scalar_prod_sum,force,force)
          using 2(1,2) smult_scalar_prod_distrib[OF fc x] by auto
      } thus ?case by auto
      qed
  }
  thus "orthogonal_complement S  orthogonal_complement (span S)"
    unfolding orthogonal_complement_def span_def by auto
qed

end

context cof_vec_space
begin

definition lin_indpt_list :: "'a vec list  bool" where
  "lin_indpt_list fs = (set fs  carrier_vec n  distinct fs  lin_indpt (set fs))" 

definition basis_list :: "'a vec list  bool" where
  "basis_list fs = (set fs  carrier_vec n  length fs = n  carrier_vec n  span (set fs))"

lemma upper_triangular_imp_lin_indpt_list:
  assumes A: "A  carrier_mat n n"
    and tri: "upper_triangular A"
    and diag: "0  set (diag_mat A)"
  shows "lin_indpt_list (rows A)"
  using upper_triangular_imp_distinct[OF assms]
  using upper_triangular_imp_lin_indpt_rows[OF assms] A
  unfolding lin_indpt_list_def by (auto simp: rows_def)

lemma basis_list_basis: assumes "basis_list fs" 
  shows "distinct fs" "lin_indpt (set fs)" "basis (set fs)" 
proof -
  from assms[unfolded basis_list_def] 
  have len: "length fs = n" and C: "set fs  carrier_vec n" 
    and span: "carrier_vec n  span (set fs)" by auto
  show b: "basis (set fs)" 
  proof (rule dim_gen_is_basis[OF finite_set C])
    show "card (set fs)  dim" unfolding dim_is_n unfolding len[symmetric] by (rule card_length)
    show "span (set fs) = carrier_vec n" using span C by auto
  qed
  thus "lin_indpt (set fs)" unfolding basis_def by auto  
  show "distinct fs" 
  proof (rule ccontr)
    assume "¬ distinct fs" 
    hence "card (set fs) < length fs" using antisym_conv1 card_distinct card_length by auto
    also have " = dim" unfolding len dim_is_n ..
    finally have "card (set fs) < dim" by auto
    also have "  card (set fs)" using span finite_set[of fs] 
      using b basis_def gen_ge_dim by auto
    finally show False by simp
  qed
qed

lemma basis_list_imp_lin_indpt_list: assumes "basis_list fs" shows "lin_indpt_list fs" 
  using basis_list_basis[OF assms] assms unfolding lin_indpt_list_def basis_list_def by auto

lemma basis_det_nonzero:
  assumes db:"basis (set G)" and len:"length G = n"
  shows "det (mat_of_rows n G)  0"
proof -
  have M_car1:"mat_of_rows n G  carrier_mat n n" using assms by auto
  hence M_car:"(mat_of_rows n G)T  carrier_mat n n" by auto
  have li:"lin_indpt (set G)"
   and inc_2:"set G  carrier_vec n"
   and issp:"carrier_vec n = span (set G)"
   and RG_in_carr:"i. i < length G  G ! i  carrier_vec n"
    using assms[unfolded basis_def] by auto
  hence "basis_list G" unfolding basis_list_def using len by auto
  from basis_list_basis[OF this] have di:"distinct G" by auto
  have "det ((mat_of_rows n G)T)  0" unfolding det_0_iff_vec_prod_zero[OF M_car] 
  proof
    assume "v. v  carrier_vec n  v  0v n  (mat_of_rows n G)T *v v = 0v n"
    then obtain v where v:"v  span (set G)"
                          "v  0v n" "(mat_of_rows n G)T *v v = 0v n"
      unfolding issp by blast
    from finite_in_span[OF finite_set inc_2 v(1)] obtain a
      where aA: "v = lincomb a (set G)" by blast
    from v(1,2)[folded issp] obtain i where i:"v $ i  0" "i < n" by fastforce
    hence inG:"G ! i  set G" using len by auto
    have di2: "distinct [0..<length G]" by auto
    define f where "f = (λl. i  set [0..<length G]. if l = G ! i then v $ i else 0)"
    hence f':"f (G ! i) = (ia[0..<n]. if G ! ia = G ! i then v $ ia else 0)"
      unfolding f_def sum.distinct_set_conv_list[OF di2] unfolding len by metis
    from v have "mat_of_cols n G *v v = 0v n"
      unfolding transpose_mat_of_rows by auto
    with mat_of_cols_mult_as_finsum[OF v(1)[folded issp len] RG_in_carr]
    have f:"lincomb f (set G) = 0v n" unfolding len f_def by auto
    note [simp] = list_trisect[OF i(2)[folded len],unfolded len]
    note x = i(2)[folded len]
    have [simp]:"(x[0..<i]. if G ! x = G ! i then v $ x else 0) = 0"
      by (rule sum_list_0,auto simp: nth_eq_iff_index_eq[OF di less_trans[OF _ x] x])
    have [simp]:"(x[Suc i..<n]. if G ! x = G ! i then v $ x else 0) = 0"
      apply (rule sum_list_0) using nth_eq_iff_index_eq[OF di _ x] len by auto
    from i(1) have "f (G ! i)  0" unfolding f' by auto
  from lin_dep_crit[OF finite_set subset_refl TrueI inG this f]
    have "lin_dep (set G)".
    thus False using li by auto
  qed
  thus det0:"det (mat_of_rows n G)  0" by (unfold det_transpose[OF M_car1])
qed

lemma lin_indpt_list_add_vec: assumes  
      i: "j < length us" "i < length us" "i  j" 
   and indep: "lin_indpt_list  us" 
shows "lin_indpt_list (us [i := us ! i + c v us ! j])" (is "lin_indpt_list ?V")
proof -
  from indep[unfolded lin_indpt_list_def] have us: "set us  carrier_vec n" 
    and dist: "distinct us" and indep: "lin_indpt (set us)" by auto
  let ?E = "set us - {us ! i}" 
  let ?us = "insert (us ! i) ?E"
  let ?v = "us ! i + c v us ! j"     
  from us i have usi: "us ! i  carrier_vec n" "us ! i  ?E" "us ! i  set us" 
    and usj: "us ! j  carrier_vec n" by auto
  from usi usj have v: "?v  carrier_vec n" by auto      
  have fin: "finite ?E" by auto
  have id: "set us = insert (us ! i) (set us - {us ! i})" using i(2) by auto
  from dist i have diff': "us ! i  us ! j" unfolding distinct_conv_nth by auto
  from subset_li_is_li[OF indep] have indepE: "lin_indpt ?E" by auto
  have Vid: "set ?V = insert ?v ?E" using set_update_distinct[OF dist i(2)] by auto
  have E: "?E  carrier_vec n" using us by auto
  have V: "set ?V  carrier_vec n" using us v unfolding Vid by auto
  from dist i have diff: "us ! i  us ! j" unfolding distinct_conv_nth by auto
  have vspan: "?v  span ?E"
  proof
    assume mem: "?v  span ?E" 
    from diff i have "us ! j  ?E" by auto
    hence "us ! j  span ?E" using E by (metis span_mem)
    hence "- c v us ! j  span ?E" using smult_in_span[OF E] by auto
    from span_add1[OF E mem this] have "?v + (- c v us ! j)  span ?E" .
    also have "?v + (- c v us ! j) = us ! i" using usi usj by auto
    finally have mem: "us ! i  span ?E" .
    from in_spanE[OF this] obtain a A where lc: "us ! i = lincomb a A" and A: "finite A" 
      "A  set us - {us ! i}" 
      by auto
    let ?a = "a (us ! i := -1)" let ?A = "insert (us ! i) A" 
    from A have fin: "finite ?A" by auto
    have lc: "lincomb ?a A = us ! i" unfolding lc
      by (rule lincomb_cong, insert A us lc, auto)
    have "lincomb ?a ?A = 0v n" 
      by (subst lincomb_insert2[OF A(1)], insert A us lc usi diff, auto)
    from not_lindepD[OF indep _ _ _ this] A usi 
    show False by auto
  qed
  hence vmem: "?v  ?E" using span_mem[OF E, of ?v] by auto
  from lin_dep_iff_in_span[OF E indepE v this] vspan 
  have indep1: "lin_indpt (set ?V)" unfolding Vid by auto
  from vmem dist have "distinct ?V" by (metis distinct_list_update)
  with indep1 V show ?thesis unfolding lin_indpt_list_def by auto
qed

lemma scalar_prod_lincomb_orthogonal: assumes ortho: "orthogonal gs" and gs: "set gs  carrier_vec n"
  shows "k  length gs  sumlist (map (λ i. g i v gs ! i) [0 ..< k])  sumlist (map (λ i. h i v gs ! i) [0 ..< k])
  = sum_list (map (λ i. g i * h i * (gs ! i  gs ! i)) [0 ..< k])"
proof (induct k)
  case (Suc k)
  note ortho = orthogonalD[OF ortho]
  let ?m = "length gs" 
  from gs Suc(2) have gsi[simp]: " i. i  k  gs ! i  carrier_vec n" by auto
  from Suc have kn: "k  ?m" and k: "k < ?m" by auto
  let ?v1 = "sumlist (map (λi. g i v gs ! i) [0..<k])" 
  let ?v2 = "(g k v gs ! k)" 
  let ?w1 = "sumlist (map (λi. h i v gs ! i) [0..<k])" 
  let ?w2 = "(h k v gs ! k)" 
  from Suc have id: "[0 ..< Suc k] = [0 ..< k] @ [k]" by simp
  have id: "sumlist (map (λi. g i v gs ! i) [0..<Suc k]) = ?v1 + ?v2"
     "sumlist (map (λi. h i v gs ! i) [0..<Suc k]) = ?w1 + ?w2"
    unfolding id map_append
    by (subst sumlist_append, insert Suc(2), auto)+
  have v1: "?v1  carrier_vec n" by (rule sumlist_carrier, insert Suc(2), auto)
  have v2: "?v2  carrier_vec n" by (insert Suc(2), auto)
  have w1: "?w1  carrier_vec n" by (rule sumlist_carrier, insert Suc(2), auto)
  have w2: "?w2  carrier_vec n" by (insert Suc(2), auto)
  have gsk: "gs ! k  carrier_vec n" by simp
  have v12: "?v1 + ?v2  carrier_vec n" using v1 v2 by auto
  have w12: "?w1 + ?w2  carrier_vec n" using w1 w2 by auto
  have 0: " g h. i < k  (g v gs ! i)  (h v gs ! k) = 0" for i
    by (subst scalar_prod_smult_distrib[OF _ gsk], (insert k, auto)[1],
    subst smult_scalar_prod_distrib[OF _ gsk], (insert k, auto)[1], insert ortho[of i k] k, auto)
  have 1: "?v1  ?w2 = 0" 
    by (subst scalar_prod_left_sum_distrib[OF _ w2], (insert Suc(2), auto)[1], rule sum_list_neutral, 
        insert 0, auto)   
  have 2: "?v2  ?w1 = 0" unfolding comm_scalar_prod[OF v2 w1]
    apply (subst scalar_prod_left_sum_distrib[OF _ v2])
     apply ((insert gs, force)[1])
    apply (rule sum_list_neutral)
    by (insert 0, auto)
  show ?case unfolding id
    unfolding scalar_prod_add_distrib[OF v12 w1 w2]
      add_scalar_prod_distrib[OF v1 v2 w1]
      add_scalar_prod_distrib[OF v1 v2 w2]
      scalar_prod_smult_distrib[OF w2 gsk]
      smult_scalar_prod_distrib[OF gsk gsk]
    unfolding Suc(1)[OF kn]
    by (simp add: 1 2 comm_scalar_prod[OF v2 w1])
qed auto
end


locale gram_schmidt = cof_vec_space n f_ty
  for n :: nat and f_ty :: "'a :: {trivial_conjugatable_linordered_field} itself"
begin

definition Gramian_matrix where
  "Gramian_matrix G k = (let M = mat k n (λ (i,j). (G ! i) $ j) in M * MT)"

lemma Gramian_matrix_alt_def: "k  length G  
  Gramian_matrix G k = (let M = mat_of_rows n (take k G) in M * MT)"
  unfolding Gramian_matrix_def Let_def
  by (rule arg_cong[of _ _ "λ x. x * xT"], unfold mat_of_rows_def, intro eq_matI, auto)

definition Gramian_determinant where
  "Gramian_determinant G k = det (Gramian_matrix G k)"

lemma Gramian_determinant_0 [simp]: "Gramian_determinant G 0 = 1"
  unfolding Gramian_determinant_def Gramian_matrix_def Let_def
  by (simp add: times_mat_def)

lemma orthogonal_imp_lin_indpt_list: 
  assumes ortho: "orthogonal gs" and gs: "set gs  carrier_vec n"
  shows "lin_indpt_list gs" 
proof -
  from corthogonal_distinct[of gs] ortho have dist: "distinct gs" by simp
  show ?thesis unfolding lin_indpt_list_def
  proof (intro conjI gs dist finite_lin_indpt2 finite_set)
    fix lc
    assume 0: "lincomb lc (set gs) = 0v n" (is "?lc = _") 
    have lc: "?lc  carrier_vec n" by (rule lincomb_closed[OF gs])
    let ?m = "length gs" 
    from 0 have "0 = ?lc  ?lc" by simp
    also have "?lc = lincomb_list (λi. lc (gs ! i)) gs" 
      unfolding lincomb_as_lincomb_list_distinct[OF gs dist] ..
    also have " = sumlist (map (λi. lc (gs ! i) v gs ! i) [0..< ?m])" 
      unfolding lincomb_list_def by auto 
    also have "   = (i[0..<?m]. (lc (gs ! i) * lc (gs ! i)) * sq_norm (gs ! i))" (is "_ = sum_list ?sum")
      unfolding scalar_prod_lincomb_orthogonal[OF ortho gs le_refl]
      by (auto simp: sq_norm_vec_as_cscalar_prod power2_eq_square)
    finally have sum_0: "sum_list ?sum = 0" ..
    have nonneg: " x. x  set ?sum  x  0" 
      using zero_le_square[of "lc (gs ! i)" for i] sq_norm_vec_ge_0[of "gs ! i" for i] by auto  
    {
      fix x
      assume x: "x  set gs" 
      then obtain i where i: "i < ?m" and x: "x = gs ! i" unfolding set_conv_nth 
        by auto
      hence "lc x * lc x * sq_norm x  set ?sum" by auto
      with sum_list_nonneg_eq_0_iff[of ?sum, OF nonneg] sum_0 
      have "lc x = 0  sq_norm x = 0" by auto
      with orthogonalD[OF ortho, OF i i, folded x]
      have "lc x = 0" by (auto simp: sq_norm_vec_as_cscalar_prod)
    }
    thus "vset gs. lc v = 0" by auto
  qed
qed

lemma orthocompl_span:
  assumes " x. x  S  v  x = 0" "S  carrier_vec n" and [intro]: "v  carrier_vec n"
  and "y  span S" 
  shows "v  y = 0"
proof -
  {fix a A
   assume "y = lincomb a A" "finite A" "A  S"
   note assms = assms this
   hence [intro!]:"lincomb a A  carrier_vec n" "(λv. a v v v)  A  carrier_vec n" by auto
   have "xA. (a x v x)  v = 0" proof fix x assume "x  A" note assms = assms this
     hence x:"x  S" by auto
     with assms have [intro]:"x  carrier_vec n" by auto
     from assms(1)[OF x] have "x  v = 0" by(subst comm_scalar_prod) force+
     thus "(a x v x)  v = 0"
       apply(subst smult_scalar_prod_distrib) by force+
   qed
   hence "v  lincomb a A = 0" apply(subst comm_scalar_prod) apply force+ unfolding lincomb_def
     apply(subst finsum_scalar_prod_sum) by force+
  }
  thus ?thesis using y  span S unfolding span_def by auto
qed

lemma orthogonal_sumlist:
  assumes ortho: " x. x  set S  v  x = 0" and S: "set S  carrier_vec n" and v: "v  carrier_vec n"
  shows "v  sumlist S = 0"
  by (rule orthocompl_span[OF ortho S v sumlist_in_span[OF S span_mem[OF S]]])

lemma oc_projection_alt_def:
  assumes carr:"(W::'a vec set)  carrier_vec n" "x  carrier_vec n"
      and alt1:"y1  W" "x - y1  orthogonal_complement W"
      and alt2:"y2  W" "x - y2  orthogonal_complement W"
  shows  "y1 = y2"
proof -
  have carr:"y1  carrier_vec n" "y2  carrier_vec n" "x  carrier_vec n" "- y1  carrier_vec n" 
    "0v n  carrier_vec n"
    using alt1 alt2 carr by auto
  hence "y1 - y2  carrier_vec n" by auto
  note carr = this carr
  from alt1 have "yaW  (x - y1)  ya = 0" for ya
    unfolding orthogonal_complement_def by blast
  hence "(x - y1)  y2 = 0" "(x - y1)  y1 = 0" using alt2 alt1 by auto
  hence eq1:"y1  y2 = x  y2" "y1  y1 = x  y1" using carr minus_scalar_prod_distrib by force+
  from this(1) have eq2:"y2  y1 = x  y2" using carr comm_scalar_prod by force
  from alt2 have "yaW  (x - y2)  ya = 0" for ya
    unfolding orthogonal_complement_def by blast
  hence "(x - y2)  y1 = 0" "(x - y2)  y2 = 0" using alt2 alt1 by auto
  hence eq3:"y2  y2 = x  y2" "y2  y1 = x  y1" using carr minus_scalar_prod_distrib by force+
  with eq2 have eq4:"x  y1 = x  y2" by auto
  have "(y1 - y2)2 = 0" unfolding sq_norm_vec_as_cscalar_prod cscalar_prod_is_scalar_prod using carr
    apply(subst minus_scalar_prod_distrib) apply force+
    apply(subst (0 0) scalar_prod_minus_distrib) apply force+
    unfolding eq1 eq2 eq3 eq4 by auto
  with sq_norm_vec_eq_0[of "(y1 - y2)"] carr have "y1 - y2 = 0v n" by fastforce
  hence "y1 - y2 + y2 = y2" using carr by fastforce
  also have "y1 - y2 + y2 = y1" using carr by auto
  finally show "y1 = y2" .
qed

definition
  "is_oc_projection w S v = (w  carrier_vec n  v - w  span S  ( u. u  S  w  u = 0))"

lemma is_oc_projection_sq_norm: assumes "is_oc_projection w S v"
  and S: "S  carrier_vec n" 
  and v: "v  carrier_vec n" 
shows "sq_norm w  sq_norm v" 
proof -
  from assms[unfolded is_oc_projection_def]
  have w: "w  carrier_vec n" 
    and vw: "v - w  span S" and ortho: " u. u  S  w  u = 0" by auto
  have "sq_norm v = sq_norm ((v - w) + w)" using v w 
    by (intro arg_cong[of _ _ sq_norm_vec], auto)
  also have " = ((v - w) + w)  ((v - w) + w)" unfolding sq_norm_vec_as_cscalar_prod
    by simp
  also have " = (v - w)  ((v - w) + w) + w  ((v - w) + w)" 
    by (rule add_scalar_prod_distrib, insert v w, auto)
  also have " = ((v - w)  (v - w) + (v - w)  w) + (w  (v - w) + w  w)" 
    by (subst (1 2) scalar_prod_add_distrib, insert v w, auto)
  also have " = sq_norm (v - w) + 2 * (w  (v - w)) + sq_norm w" 
    unfolding sq_norm_vec_as_cscalar_prod using v w by (auto simp: comm_scalar_prod[of w _ "v - w"])
  also have "  2 * (w  (v - w)) + sq_norm w" using sq_norm_vec_ge_0[of "v - w"] by auto
  also have "w  (v - w) = 0" using orthocompl_span[OF ortho S w vw] by auto
  finally show ?thesis by auto
qed

definition oc_projection where
"oc_projection S fi  (SOME v. is_oc_projection v S fi)"

lemma inv_in_span:
  assumes incarr[intro]:"U  carrier_vec n" and insp:"a  span U"
  shows "- a  span U"
proof -
  from insp[THEN in_spanE] obtain aa A where a:"a = lincomb aa A" "finite A" "A  U" by auto
  with assms have [intro!]:"(λv. aa v v v)  A  carrier_vec n" by auto
  from a(1) have e1:"- a = lincomb (λ x. - 1 * aa x) A" unfolding smult_smult_assoc[symmetric] lincomb_def
    by(subst finsum_smult[symmetric]) force+
  show ?thesis using e1 a span_def by blast
qed

lemma non_span_det_zero:
  assumes len: "length G = n"
  and nonb:"¬ (carrier_vec n  span (set G))"
  and carr:"set G  carrier_vec n"
  shows "det (mat_of_rows n G) = 0" unfolding det_0_iff_vec_prod_zero
proof -
  let ?A = "(mat_of_rows n G)T" let ?B = "1m n"
  from carr have carr_mat:"?A  carrier_mat n n" "?B  carrier_mat n n" "mat_of_rows n G  carrier_mat n n"
    using len mat_of_rows_carrier(1) by auto
  from carr have g_len:" i. i < length G  G ! i  carrier_vec n" by auto
  from nonb obtain v where v:"v  carrier_vec n" "v  span (set G)" by fast
  hence "v  0v n" using span_zero by auto
  obtain B C where gj:"gauss_jordan ?A ?B = (B,C)" by force
  note gj = carr_mat(1,2) gj
  hence B:"B = fst (gauss_jordan ?A ?B)" by auto
  from gauss_jordan[OF gj] have BC:"B  carrier_mat n n" by auto
  from gauss_jordan_transform[OF gj] obtain P where
   P:"PUnits (ring_mat TYPE('a) n ?B)"  "B = P * ?A" by fast
  hence PC:"P  carrier_mat n n" unfolding Units_def by (simp add: ring_mat_simps)
  from mat_inverse[OF PC] P obtain PI where "mat_inverse P = Some PI" by fast
  from mat_inverse(2)[OF PC this]
  have PI:"P * PI = 1m n" "PI * P = 1m n" "PI  carrier_mat n n" by auto
  have "B  1m n" proof
    assume "B = ?B"
    hence "?A * P = ?B" unfolding P
      using PC P(2) carr_mat(1) mat_mult_left_right_inverse by blast
    hence "?A * P *v v = v" using v by auto
    hence "?A *v (P *v v) = v" unfolding assoc_mult_mat_vec[OF carr_mat(1) PC v(1)].
    hence v_eq:"mat_of_cols n G *v (P *v v) = v"
      unfolding transpose_mat_of_rows by auto
    have pvc:"P *v v  carrier_vec (length G)" using PC v len by auto
    from mat_of_cols_mult_as_finsum[OF pvc g_len,unfolded v_eq] obtain a where
      "v = lincomb a (set G)" by auto
    hence "v  span (set G)" by (intro in_spanI[OF _ finite_set subset_refl])
    thus False using v by auto
  qed
  with det_non_zero_imp_unit[OF carr_mat(1)] show ?thesis
    unfolding gauss_jordan_check_invertable[OF carr_mat(1,2)] B det_transpose[OF carr_mat(3)]
    by metis
qed

lemma span_basis_det_zero_iff:
assumes "length G = n" "set G  carrier_vec n"
shows "carrier_vec n  span (set G)  det (mat_of_rows n G)  0" (is ?q1)
      "carrier_vec n  span (set G)  basis (set G)" (is ?q2)
      "det (mat_of_rows n G)  0  basis (set G)" (is ?q3)
proof -
  have dc:"det (mat_of_rows n G)  0  carrier_vec n  span (set G)"
    using assms non_span_det_zero by auto
  have cb:"carrier_vec n  span (set G)  basis (set G)" using assms basis_list_basis 
    by (auto simp: basis_list_def)
  have bd:"basis (set G)  det (mat_of_rows n G)  0" using assms basis_det_nonzero by auto
  show ?q1 ?q2 ?q3 using dc cb bd by metis+
qed

lemma lin_indpt_list_nonzero:
  assumes "lin_indpt_list G" 
  shows "0v n  set G"
proof-
  from assms[unfolded lin_indpt_list_def] have "lin_indpt (set G)" by auto
  from vs_zero_lin_dep[OF _ this] assms[unfolded lin_indpt_list_def] show zero: "0v n  set G" by auto
qed

lemma is_oc_projection_eq:
  assumes ispr:"is_oc_projection a S v" "is_oc_projection b S v" 
    and carr: "S  carrier_vec n" "v  carrier_vec n"
  shows "a = b"
proof -
  from carr have c2:"span S  carrier_vec n" "v  carrier_vec n" by auto
  have a:"v - (v - a) = a" using carr ispr by auto
  have b:"v - (v - b) = b" using carr ispr by auto
  have "(v - a) = (v - b)" 
    apply(rule oc_projection_alt_def[OF c2])
    using ispr a b unfolding in_orthogonal_complement_span[OF carr(1)]
    unfolding orthogonal_complement_def is_oc_projection_def by auto
  hence "v - (v - a) = v - (v - b)" by metis
  thus ?thesis unfolding a b.
qed



fun adjuster_wit :: "'a list  'a vec  'a vec list  'a list × 'a vec"
  where "adjuster_wit wits w [] = (wits, 0v n)"
  |  "adjuster_wit wits w (u#us) = (let a = (w  u)/ sq_norm u in 
            case adjuster_wit (a # wits) w us of (wit, v)
          (wit, -a v u + v))"

fun sub2_wit where
    "sub2_wit us [] = ([], [])"
  | "sub2_wit us (w # ws) =
     (case adjuster_wit [] w us of (wit,aw)  let u = aw + w in
      case sub2_wit (u # us) ws of (wits, vvs)  (wit # wits, u # vvs))"  
    
definition main :: "'a vec list  'a list list × 'a vec list" where 
  "main us = sub2_wit [] us"
end


locale gram_schmidt_fs = 
  fixes n :: nat and fs :: "'a :: {trivial_conjugatable_linordered_field} vec list"
begin

sublocale gram_schmidt n "TYPE('a)" .

fun gso and μ where
  "gso i = fs ! i + sumlist (map (λ j. - μ i j v gso j) [0 ..< i])" 
| "μ i j = (if j < i then (fs ! i  gso j)/ sq_norm (gso j) else if i = j then 1 else 0)" 
    
declare gso.simps[simp del]
declare μ.simps[simp del]


lemma gso_carrier'[intro]:
  assumes " i. i  j  fs ! i  carrier_vec n"
  shows "gso j  carrier_vec n"
using assms proof(induct j rule:nat_less_induct[rule_format])
  case (1 j)
  then show ?case unfolding gso.simps[of j] by (auto intro!:sumlist_carrier add_carrier_vec)
qed

lemma adjuster_wit: assumes res: "adjuster_wit wits w us = (wits',a)"
  and w: "w  carrier_vec n"
    and us: " i. i  j  fs ! i  carrier_vec n"
    and us_gs: "us = map gso (rev [0 ..< j])" 
    and wits: "wits = map (μ i) [j ..< i]" 
    and j: "j  n" "j  i" 
    and wi: "w = fs ! i" 
  shows "adjuster n w us = a  a  carrier_vec n  wits' = map (μ i) [0 ..< i] 
      (a = sumlist (map (λj. - μ i j v gso j) [0..<j]))"
  using res us us_gs wits j
proof (induct us arbitrary: wits wits' a j)
  case (Cons u us wits wits' a j)
  note us_gs = Cons(4)
  note wits = Cons(5)
  note jn = Cons(6-7)
  from us_gs obtain jj where j: "j = Suc jj" by (cases j, auto)
  from jn j have jj: "jj  n" "jj < n" "jj  i" "jj < i" by auto
  have zj: "[0 ..< j] = [0 ..< jj] @ [jj]" unfolding j by simp
  have jjn: "[jj ..< i] = jj # [j ..< i]" using jj unfolding j by (metis upt_conv_Cons)
  from us_gs[unfolded zj] have ugs: "u = gso jj" and us: "us = map gso (rev [0..<jj])" by auto
  let ?w = "w  u / (u  u)" 
  have muij: "?w = μ i jj" unfolding μ.simps[of i jj] ugs wi sq_norm_vec_as_cscalar_prod using jj by auto
  have wwits: "?w # wits = map (μ i) [jj..<i]" unfolding jjn wits muij by simp
  obtain wwits b where rec: "adjuster_wit (?w # wits) w us = (wwits,b)" by force
  from Cons(1)[OF this Cons(3) us wwits jj(1,3),unfolded j] have IH: 
     "adjuster n w us = b" "wwits = map (μ i) [0..<i]"
     "b = sumlist (map (λj. - μ i j v gso j) [0..<jj])"
      and b: "b  carrier_vec n" by auto
  from Cons(2)[simplified, unfolded Let_def rec split sq_norm_vec_as_cscalar_prod 
      cscalar_prod_is_scalar_prod]
  have id: "wits' = wwits" and a: "a = - ?w v u + b" by auto
  have 1: "adjuster n w (u # us) = a" unfolding a IH(1)[symmetric] by auto     
  from id IH(2) have wits': "wits' =  map (μ i) [0..<i]" by simp
  have carr:"set (map (λj. - μ i j v gso j) [0..<j])  carrier_vec n"
            "set (map (λj. - μ i j v gso j) [0..<jj])  carrier_vec n" and u:"u  carrier_vec n" 
    using Cons j by (auto intro!:gso_carrier')
  from u b a have ac: "a  carrier_vec n" "dim_vec (-?w v u) = n" "dim_vec b = n" "dim_vec u = n" by auto
  show ?case
    apply (intro conjI[OF 1] ac exI conjI wits')
    unfolding carr a IH zj muij ugs[symmetric] map_append
    apply (subst sumlist_append)
    using Cons.prems j apply force
    using b u ugs IH(3) by auto
qed auto

lemma sub2_wit:
  assumes "set us  carrier_vec n" "set ws  carrier_vec n" "length us + length ws = m" 
    and "ws = map (λ i. fs ! i) [i ..< m]"
    and "us = map gso (rev [0 ..< i])" 
    and us: " j. j < m  fs ! j  carrier_vec n"
    and mn: "m  n" 
  shows "sub2_wit us ws = (wits,vvs)  gram_schmidt_sub2 n us ws = vvs 
     vvs = map gso [i ..< m]  wits = map (λ i. map (μ i) [0..<i]) [i ..< m]"
  using assms(1-6)
proof (induct ws arbitrary: us vvs i wits)
  case (Cons w ws us vs)  
  note us = Cons(3) note wws = Cons(4)
  note wsf' = Cons(6)
  note us_gs = Cons(7)
  from wsf' have "i < m" "i  m" by (cases "i < m", auto)+
  hence i_m: "[i ..< m] = i # [Suc i ..< m]" by (metis upt_conv_Cons)
  from i < m mn have "i < n" "i  n" "i  m" by auto
  hence i_n: "[i ..< n] = i # [Suc i ..< n]" by (metis upt_conv_Cons)
  from wsf' i_m have wsf: "ws = map (λ i. fs ! i) [Suc i ..< m]" 
    and fiw: "fs !  i = w" by auto
  from wws have w: "w  carrier_vec n" and ws: "set ws  carrier_vec n" by auto
  have list: "map (μ i) [i ..< i] = []" by auto
  let ?a = "adjuster_wit [] w us" 
  obtain wit a where a: "?a = (wit,a)" by force
  obtain wits' vv where gs: "sub2_wit ((a + w) # us) ws = (wits',vv)" by force      
  from adjuster_wit[OF a w Cons(8) us_gs list[symmetric] i  n _ fiw[symmetric]] us wws i < m
  have awus: "set ((a + w) # us)  carrier_vec n"  
     and aa: "adjuster n w us = a" "a  carrier_vec n" 
     and aaa: "a = sumlist (map (λj. - μ i j v gso j) [0..<i])"  
     and wit: "wit = map (μ i) [0..<i]" 
    by auto
  have aw_gs: "a + w = gso i" unfolding gso.simps[of i] fiw aaa[symmetric] using aa(2) w by auto
  with us_gs have us_gs': "(a + w) # us = map gso (rev [0..<Suc i])" by auto
  from Cons(1)[OF gs awus ws _ wsf us_gs' Cons(8)] Cons(5) 
  have IH: "gram_schmidt_sub2 n ((a + w) # us) ws = vv"  
    and vv: "vv = map gso [Suc i..<m]" 
    and wits': "wits' = map (λi. map (μ i) [0..<i]) [Suc i ..< m]" by auto
  from gs a aa IH Cons(5) 
  have gs_vs: "gram_schmidt_sub2 n us (w # ws) = vs" and vs: "vs = (a + w) # vv" using Cons(2)
    by (auto simp add: Let_def snd_def split:prod.splits)
  from Cons(2)[unfolded sub2_wit.simps a split Let_def gs] have wits: "wits = wit # wits'" by auto
  from vs vv aw_gs have vs: "vs = map gso [i ..< m]" unfolding i_m by auto
  with gs_vs show ?case unfolding wits wit wits' by (auto simp: i_m)
qed auto
  
lemma partial_connect: fixes vs
  assumes "length fs = m" "k  m" "m  n" "set us  carrier_vec n" "snd (main us) = vs" 
  "us = take k fs" "set fs  carrier_vec n"
shows "gram_schmidt n us = vs" 
    "vs = map gso [0..<k]" 
proof -
  have [simp]: "map ((!) fs) [0..<k] = take k fs" using assms(1,2) by (intro nth_equalityI, auto)
  have carr: "j < m  fs ! j  carrier_vec n" for j using assms by auto
  note assms(5)[unfolded main_def]
  have "gram_schmidt_sub2 n [] (take k fs) = vvs  vvs = map gso [0..<k]  wits = map (λi. map (μ i) [0..<i]) [0..<k]"
    if "vvs = snd (sub2_wit [] (take k fs))" "wits = fst (sub2_wit [] (take k fs))" for vvs wits
    using assms that by (intro sub2_wit) (auto)
  with assms main_def
  show "gram_schmidt n us = vs" "vs = map gso [0..<k]" unfolding gram_schmidt_code
    by (auto simp add: main_def case_prod_beta')
qed

lemma adjuster_wit_small:
  "(adjuster_wit v a xs) = (x1,x2)
   (fst (adjuster_wit v a xs) = x1  x2 = adjuster n a xs)"
proof(induct xs arbitrary: a v x1 x2)
  case (Cons a xs)
  then show ?case
    by (auto simp:Let_def sq_norm_vec_as_cscalar_prod split:prod.splits) 
qed auto

lemma sub2: "rev xs @ snd (sub2_wit xs us) = rev (gram_schmidt_sub n xs us)"
proof -
  have "sub2_wit xs us = (x1, x2)  rev xs @ x2 = rev (gram_schmidt_sub n xs us)"
    for x1 x2 xs us
    apply(induct us arbitrary: xs x1 x2)
    by (auto simp:Let_def rev_unsimp adjuster_wit_small split:prod.splits simp del:rev.simps)
  thus ?thesis 
    apply (cases us)
    by (auto simp:Let_def rev_unsimp adjuster_wit_small split:prod.splits simp del:rev.simps)
qed

lemma gso_connect: "snd (main us) = gram_schmidt n us" unfolding main_def gram_schmidt_def
  using sub2[of Nil us] by auto

definition weakly_reduced :: "'a  nat  bool" 
  (* for k = n, this is reduced according to "Modern Computer Algebra" *)
  where "weakly_reduced α k = ( i. Suc i < k  
    sq_norm (gso i)  α * sq_norm (gso (Suc i)))" 
  
definition reduced :: "'a  nat  bool" 
  (* this is reduced according to LLL original paper *)
  where "reduced α k = (weakly_reduced α k  
    ( i j. i < k  j < i  abs (μ i j)  1/2))"


end (* gram_schmidt_fs *)


locale gram_schmidt_fs_Rn = gram_schmidt_fs +
  assumes fs_carrier: "set fs  carrier_vec n"
begin

abbreviation (input) m where "m  length fs"

definition M where "M k = mat k k (λ (i,j). μ i j)"

lemma f_carrier[simp]: "i < m  fs ! i  carrier_vec n" 
  using fs_carrier unfolding set_conv_nth by force

lemma gso_carrier[simp]: "i < m  gso i  carrier_vec n" 
  using gso_carrier' f_carrier by auto

lemma gso_dim[simp]: "i < m  dim_vec (gso i) = n" by auto
lemma f_dim[simp]: "i < m  dim_vec (fs ! i) = n" by auto

lemma fs0_gso0: "0 < m  fs ! 0 = gso 0" 
  unfolding gso.simps[of 0] using f_dim[of 0] 
  by (cases fs, auto simp add: upt_rec)

lemma fs_by_gso_def : 
assumes i: "i < m"
shows "fs ! i = gso i + M.sumlist (map (λja. μ i ja v gso ja) [0..<i])" (is "_ = _ + ?sum")
proof -
  {
    fix f
    have a: "M.sumlist (map (λja. f ja v gso ja) [0..<i])  carrier_vec n" 
      using gso_carrier i by (intro M.sumlist_carrier, auto)
    hence "dim_vec (M.sumlist (map (λja. f ja v gso ja) [