Theory Storjohann

section ‹Storjohann's basis reduction algorithm (abstract version)›

text ‹This theory contains the soundness proofs of Storjohann's basis
  reduction algorithms, both for the normal and the improved-swap-order variant.

  The implementation of Storjohann's version of LLL uses modular operations throughout.
  It is an abstract implementation that is already quite close to what the actual implementation will be.
   In particular, the swap operation here is derived from the computation lemma for the swap
   operation in the old, integer-only formalization of LLL.›

theory Storjohann
  imports 
    Storjohann_Mod_Operation
    LLL_Basis_Reduction.LLL_Number_Bounds
    Sqrt_Babylonian.NthRoot_Impl
begin

subsection ‹Definition of algorithm›

text ‹In the definition of the algorithm, the first-flag determines, whether only the first vector
  of the reduced basis should be computed, i.e., a short vector. Then the modulus can be slightly
  decreased in comparison to the required modulus for computing the whole reduced matrix.›

fun max_list_rats_with_index :: "(int * int * nat) list  (int * int * nat)" where
  "max_list_rats_with_index [x] = x" |
  "max_list_rats_with_index ((n1,d1,i1) # (n2,d2,i2) # xs) 
     = max_list_rats_with_index ((if n1 * d2  n2 * d1 then (n2,d2,i2) else (n1,d1,i1)) # xs)"

context LLL
begin

definition "log_base = (10 :: int)" 

definition bound_number :: "bool  nat" where
  "bound_number first = (if first  m  0 then 1 else m)" 

definition compute_mod_of_max_gso_norm :: "bool  rat  int" where
  "compute_mod_of_max_gso_norm first mn = log_base ^ (log_ceiling log_base (max 2 (
     root_rat_ceiling 2 (mn * (rat_of_nat (bound_number first) + 3)) + 1)))"

definition g_bnd_mode :: "bool  rat  int vec list  bool" where 
  "g_bnd_mode first b fs = (if first  m  0 then sq_norm (gso fs 0)  b else g_bnd b fs)" 

definition d_of where "d_of dmu i = (if i = 0 then 1 :: int else dmu $$ (i - 1, i - 1))"

definition compute_max_gso_norm :: "bool  int mat  rat × nat" where
  "compute_max_gso_norm first dmu = (if m = 0 then (0,0) else 
      case max_list_rats_with_index (map (λ i. (d_of dmu (Suc i), d_of dmu i, i)) [0 ..< (if first then 1 else m)])
      of (num, denom, i)  (of_int num / of_int denom, i))"


context
  fixes p :: int ― ‹the modulus›
    and first :: bool ― ‹only compute first vector of reduced basis›
begin

definition basis_reduction_mod_add_row :: 
  "int vec list  int mat  nat  nat  (int vec list × int mat)"  where
  "basis_reduction_mod_add_row mfs dmu i j = 
    (let c = round_num_denom (dmu $$ (i,j)) (d_of dmu (Suc j)) in
      (if c = 0 then (mfs, dmu) 
        else (mfs[ i := (map_vec (λ x. x symmod p)) (mfs ! i - c v mfs ! j)], 
             mat m m (λ(i',j'). (if (i' = i  j'  j) 
                then (if j'=j then (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                      else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                            symmod (p * d_of dmu j' * d_of dmu (Suc j')))
                else (dmu $$ (i',j')))))))"

fun basis_reduction_mod_add_rows_loop where
  "basis_reduction_mod_add_rows_loop mfs dmu i 0 = (mfs, dmu)"
| "basis_reduction_mod_add_rows_loop mfs dmu i (Suc j) = (
     let (mfs', dmu') = basis_reduction_mod_add_row mfs dmu i j
      in basis_reduction_mod_add_rows_loop mfs' dmu' i j)" 

definition basis_reduction_mod_swap_dmu_mod :: "int mat  nat  int mat" where
  "basis_reduction_mod_swap_dmu_mod dmu k = mat m m (λ(i, j). (
    if j < i  (j = k  j = k - 1) then 
        dmu $$ (i, j) symmod (p * d_of dmu j * d_of dmu (Suc j))
    else dmu $$ (i, j)))"

definition basis_reduction_mod_swap where
  "basis_reduction_mod_swap mfs dmu k = 
     (mfs[k := mfs ! (k - 1), k - 1 := mfs ! k],
      basis_reduction_mod_swap_dmu_mod (mat m m (λ(i,j). (
      if j < i then
        if i = k - 1 then 
           dmu $$ (k, j)
        else if i = k  j  k - 1 then 
             dmu $$ (k - 1, j)
        else if i > k  j = k then
           ((d_of dmu (Suc k)) * dmu $$ (i, k - 1) - dmu $$ (k, k - 1) * dmu $$ (i, j)) 
              div (d_of dmu k)
        else if i > k  j = k - 1 then
           (dmu $$ (k, k - 1) * dmu $$ (i, j) + dmu $$ (i, k) * (d_of dmu (k-1)))
              div (d_of dmu k)
        else dmu $$ (i, j)
      else if i = j then 
        if i = k - 1 then 
          ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
            div (d_of dmu k)
        else (d_of dmu (Suc i))
      else dmu $$ (i, j))
    )) k)" 

fun basis_reduction_adjust_mod where
  "basis_reduction_adjust_mod mfs dmu = 
    (let (b,g_idx) = compute_max_gso_norm first dmu;
         p' = compute_mod_of_max_gso_norm first b
        in if p' < p then 
           let mfs' = map (map_vec (λx. x symmod p')) mfs;
               d_vec = vec (Suc m) (λ i. d_of dmu i);
               dmu' = mat m m (λ (i,j). if j < i then dmu $$ (i,j) 
                 symmod (p' * d_vec $ j * d_vec $ (Suc j)) else
                 dmu $$ (i,j))
             in (p', mfs', dmu', g_idx)
           else (p, mfs, dmu, g_idx))" 

definition basis_reduction_adjust_swap_add_step where
  "basis_reduction_adjust_swap_add_step mfs dmu g_idx i = (
    let i1 = i - 1; 
        (mfs1, dmu1) = basis_reduction_mod_add_row mfs dmu i i1;
        (mfs2, dmu2) = basis_reduction_mod_swap mfs1 dmu1 i
      in if i1 = g_idx then basis_reduction_adjust_mod mfs2 dmu2
         else (p, mfs2, dmu2, g_idx))"


definition basis_reduction_mod_step where
  "basis_reduction_mod_step mfs dmu g_idx i (j :: int) = (if i = 0 then (p, mfs, dmu, g_idx, Suc i, j)
     else let di = d_of dmu i;
              (num, denom) = quotient_of α
      in if di * di * denom  num * d_of dmu (i - 1) * d_of dmu (Suc i) then
          (p, mfs, dmu, g_idx, Suc i, j)
      else let (p', mfs', dmu', g_idx') = basis_reduction_adjust_swap_add_step mfs dmu g_idx i
          in (p', mfs', dmu', g_idx', i - 1, j + 1))" 

primrec basis_reduction_mod_add_rows_outer_loop where
  "basis_reduction_mod_add_rows_outer_loop mfs dmu 0 = (mfs, dmu)" |
  "basis_reduction_mod_add_rows_outer_loop mfs dmu (Suc i) = 
    (let (mfs', dmu') = basis_reduction_mod_add_rows_outer_loop mfs dmu i in
      basis_reduction_mod_add_rows_loop mfs' dmu' (Suc i) (Suc i))"
end

text ‹the main loop of the normal Storjohann algorithm›
partial_function (tailrec) basis_reduction_mod_main where
  "basis_reduction_mod_main p first mfs dmu g_idx i (j :: int) = (
    (if i < m 
       then 
         case basis_reduction_mod_step p first mfs dmu g_idx i j
         of (p', mfs', dmu', g_idx', i', j')   
           basis_reduction_mod_main p' first mfs' dmu' g_idx' i' j'
       else
         (p, mfs, dmu)))"

definition compute_max_gso_quot:: "int mat  (int * int * nat)" where
  "compute_max_gso_quot dmu = max_list_rats_with_index 
    (map (λi. ((d_of dmu (i+1)) * (d_of dmu (i+1)), (d_of dmu (i+2)) * (d_of dmu i), Suc i)) [0..<(m-1)])"

text ‹the main loop of Storjohann's algorithm with improved swap order›
partial_function (tailrec) basis_reduction_iso_main where
  "basis_reduction_iso_main p first mfs dmu g_idx (j :: int) = (
    (if m > 1 then
      (let (max_gso_num, max_gso_denum, indx) = compute_max_gso_quot dmu;
        (num, denum) = quotient_of α in
        (if (max_gso_num * denum  > num * max_gso_denum) then
            case basis_reduction_adjust_swap_add_step p first mfs dmu g_idx indx of
              (p', mfs', dmu', g_idx') 
          basis_reduction_iso_main p' first mfs' dmu' g_idx' (j + 1) 
         else
           (p, mfs, dmu)))
     else (p, mfs, dmu)))"

definition compute_initial_mfs where
  "compute_initial_mfs p = map (map_vec (λx. x symmod p)) fs_init"

definition compute_initial_dmu where
  "compute_initial_dmu p dmu = mat m m (λ(i',j'). if j' < i' 
        then dmu $$ (i', j') symmod (p * d_of dmu j' * d_of dmu (Suc j')) 
        else dmu $$ (i', j'))"

definition "dmu_initial = (let dmu = dμ_impl fs_init
   in mat m m (λ (i,j). 
   if j  i then dμ_impl fs_init !! i !! j else 0))"

definition "compute_initial_state first = 
  (let dmu = dmu_initial;
       (b, g_idx) = compute_max_gso_norm first dmu;
       p = compute_mod_of_max_gso_norm first b
     in (p, compute_initial_mfs p, compute_initial_dmu p dmu, g_idx))" 

text ‹Storjohann's algorithm›
definition reduce_basis_mod :: "int vec list" where
  "reduce_basis_mod = (
     let first = False;
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_mod_main p0 first mfs0 dmu0 g_idx 0 0;
         (mfs'', dmu'') = basis_reduction_mod_add_rows_outer_loop p' mfs' dmu' (m-1)
      in mfs'')"

text ‹Storjohann's algorithm with improved swap order›
definition reduce_basis_iso :: "int vec list" where
  "reduce_basis_iso = (
     let first = False; 
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_iso_main p0 first mfs0 dmu0 g_idx 0;
         (mfs'', dmu'') = basis_reduction_mod_add_rows_outer_loop p' mfs' dmu' (m-1)
      in mfs'')"

text ‹Storjohann's algorithm for computing a short vector›
definition 
  "short_vector_mod = (
     let first = True;
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_mod_main p0 first mfs0 dmu0 g_idx 0 0
      in hd mfs')"

text ‹Storjohann's algorithm (iso-variant) for computing a short vector›
definition 
  "short_vector_iso = (
     let first = True; 
         (p0, mfs0, dmu0, g_idx) = compute_initial_state first;
         (p', mfs', dmu') = basis_reduction_iso_main p0 first mfs0 dmu0 g_idx 0
      in hd mfs')"
end

subsection ‹Towards soundness of Storjohann's algorithm›

lemma max_list_rats_with_index_in_set: 
  assumes max: "max_list_rats_with_index xs = (nm, dm, im)"
  and len: "length xs  1"
shows "(nm, dm, im)  set xs"
  using assms
proof (induct xs rule: max_list_rats_with_index.induct)
  case (2 n1 d1 i1 n2 d2 i2 xs)
  have "1  length ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs)" by simp
  moreover have "max_list_rats_with_index ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs)
              = (nm, dm, im)" using 2 by simp
  moreover have "(if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) 
        set ((n1, d1, i1) # (n2, d2, i2) # xs)" by simp
  moreover then have "set ((if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) # xs) 
        set ((n1, d1, i1) # (n2, d2, i2) # xs)" by auto
  ultimately show ?case using 2(1) by auto
qed auto

lemma max_list_rats_with_index: assumes " n d i. (n,d,i)  set xs  d > 0" 
  and max: "max_list_rats_with_index xs = (nm, dm, im)" 
  and "(n,d,i)  set xs" 
shows "rat_of_int n / of_int d  of_int nm / of_int dm" 
  using assms
proof (induct xs arbitrary: n d i rule: max_list_rats_with_index.induct)
  case (2 n1 d1 i1 n2 d2 i2 xs n d i)
  let ?r = "rat_of_int" 
  from 2(2) have "d1 > 0" "d2 > 0" by auto
  hence d: "?r d1 > 0" "?r d2 > 0" by auto
  have "(n1 * d2  n2 * d1) = (?r n1 * ?r d2  ?r n2 * ?r d1)" 
    unfolding of_int_mult[symmetric] by presburger
  also have " = (?r n1 / ?r d1  ?r n2 / ?r d2)" using d 
    by (smt divide_strict_right_mono leD le_less_linear mult.commute nonzero_mult_div_cancel_left 
        not_less_iff_gr_or_eq times_divide_eq_right)
  finally have id: "(n1 * d2  n2 * d1) = (?r n1 / ?r d1  ?r n2 / ?r d2)" .
  obtain n' d' i' where new: "(if n1 * d2  n2 * d1 then (n2, d2, i2) else (n1, d1, i1)) = (n',d',i')" 
    by force  
  have nd': "(n',d',i')  {(n1,d1,i1), (n2, d2, i2)}" using new[symmetric] by auto
  from 2(3) have res: "max_list_rats_with_index ((n',d',i') # xs) = (nm, dm, im)" using new by auto
  note 2 = 2[unfolded new]
  show ?case 
  proof (cases "(n,d,i)  set xs")
    case True
    show ?thesis 
      by (rule 2(1)[of n d, OF 2(2) res], insert True nd', force+)
  next
    case False
    with 2(4) have "n = n1  d = d1  n = n2  d = d2" by auto
    hence "?r n / ?r d  ?r n' / ?r d'" using new[unfolded id]
      by (metis linear prod.inject)
    also have "?r n' / ?r d'  ?r nm / ?r dm" 
      by (rule 2(1)[of n' d', OF 2(2) res], insert nd', force+)
    finally show ?thesis .
  qed
qed auto

context LLL
begin

lemma log_base: "log_base  2" unfolding log_base_def by auto

definition LLL_invariant_weak' :: "nat  int vec list  bool" where 
  "LLL_invariant_weak' i fs = ( 
    gs.lin_indpt_list (RAT fs)  
    lattice_of fs = L 
    weakly_reduced fs i 
    i  m  
    length fs = m    
  )" 

lemma LLL_invD_weak: assumes "LLL_invariant_weak' i fs"
  shows 
  "lin_indep fs" 
  "length (RAT fs) = m" 
  "set fs  carrier_vec n"
  " i. i < m  fs ! i  carrier_vec n" 
  " i. i < m  gso fs i  carrier_vec n" 
  "length fs = m"
  "lattice_of fs = L" 
  "weakly_reduced fs i"
  "i  m"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    by (standard) (use assms LLL_invariant_weak'_def gs.lin_indpt_list_def in auto)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier
    by (auto simp add: LLL_invariant_weak'_def gram_schmidt_fs.reduced_def)
qed

lemma LLL_invI_weak: assumes  
  "set fs  carrier_vec n"
  "length fs = m"
  "lattice_of fs = L" 
  "i  m"
  "lin_indep fs" 
  "weakly_reduced fs i" 
shows "LLL_invariant_weak' i fs" 
  unfolding LLL_invariant_weak'_def Let_def using assms by auto

lemma LLL_invw'_imp_w: "LLL_invariant_weak' i fs  LLL_invariant_weak fs" 
  unfolding LLL_invariant_weak'_def LLL_invariant_weak_def by auto
  
lemma basis_reduction_add_row_weak: 
  assumes Linvw: "LLL_invariant_weak' i fs"
  and i: "i < m"  and j: "j < i" 
  and fs': "fs' = fs[ i := fs ! i - c v fs ! j]" 
shows "LLL_invariant_weak' i fs'"
  "g_bnd B fs  g_bnd B fs'" 
proof (atomize(full), goal_cases)
  case 1
  note Linv = LLL_invw'_imp_w[OF Linvw]
  note main = basis_reduction_add_row_main[OF Linv i j fs']
  have bnd: "g_bnd B fs  g_bnd B fs'" using main(6) unfolding g_bnd_def by auto
  note new = LLL_inv_wD[OF main(1)]
  note old = LLL_invD_weak[OF Linvw]
  have red: "weakly_reduced fs' i" using weakly_reduced fs i main(6) i < m
    unfolding gram_schmidt_fs.weakly_reduced_def by auto
  have inv: "LLL_invariant_weak' i fs'" using LLL_inv_wD[OF main(1)] i < m
    by (intro LLL_invI_weak, auto intro: red)
  show ?case using inv red main bnd by auto
qed

lemma LLL_inv_weak_m_impl_i:
  assumes inv: "LLL_invariant_weak' m fs"
  and i: "i  m"
shows "LLL_invariant_weak' i fs"
proof -
  have "weakly_reduced fs i" using LLL_invD_weak(8)[OF inv]
    by (meson assms(2) gram_schmidt_fs.weakly_reduced_def le_trans less_imp_le_nat linorder_not_less)
  then show ?thesis
    using LLL_invI_weak[of fs i, OF LLL_invD_weak(3,6,7)[OF inv] _ LLL_invD_weak(1)[OF inv]] 
      LLL_invD_weak(2,4,5,8-)[OF inv] i by simp
qed
 
definition mod_invariant where 
  "mod_invariant b p first = (b  rat_of_int (p - 1)^2 / (rat_of_nat (bound_number first) + 3)
      ( e. p = log_base ^ e))"  

lemma compute_mod_of_max_gso_norm: assumes mn: "mn  0"
  and m: "m = 0  mn = 0" 
  and p: "p = compute_mod_of_max_gso_norm first mn" 
shows  
  "p > 1" 
  "mod_invariant mn p first" 
proof -
  let ?m = "bound_number first" 
  define p' where "p' = root_rat_ceiling 2 (mn * (rat_of_nat ?m + 3)) + 1" 
  define p'' where "p'' = max 2 p'" 
  define q where "q = real_of_rat (mn * (rat_of_nat ?m + 3))" 
  have *: "-1 < (0 :: real)" by simp
  also have "0  root 2 (real_of_rat (mn * (rat_of_nat ?m + 3)))" using mn by auto
  finally have "p'  0 + 1" unfolding p'_def
    by (intro plus_left_mono, simp)
  hence p': "p' > 0" by auto
  have p'': "p'' > 1" unfolding p''_def by auto
  have pp'': "p  p''" unfolding compute_mod_of_max_gso_norm_def p  p'_def[symmetric] p''_def[symmetric]
    using log_base p'' log_ceiling_sound by auto
  hence pp': "p  p'" unfolding p''_def by auto    
  show "p > 1" using pp'' p'' by auto

  have q0: "q  0" unfolding q_def using mn m by auto
  have "(mn  rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)) 
    = (real_of_rat mn  real_of_rat (rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)))" using of_rat_less_eq by blast
  also have " = (real_of_rat mn  real_of_rat (rat_of_int (p' - 1)^2) / real_of_rat (rat_of_nat ?m + 3))" by (simp add: of_rat_divide)
  also have " = (real_of_rat mn  ((real_of_int (p' - 1))^2) / real_of_rat (rat_of_nat ?m + 3))" 
    by (metis of_rat_of_int_eq of_rat_power)
  also have " = (real_of_rat mn  (real_of_int sqrt q)^2 / real_of_rat (rat_of_nat ?m + 3))" 
    unfolding p'_def sqrt_def q_def by simp
  also have "" 
  proof -
    have "real_of_rat mn  q / real_of_rat (rat_of_nat ?m + 3)" unfolding q_def using m
      by (auto simp: of_rat_mult)
    also have "  (real_of_int sqrt q)^2 / real_of_rat (rat_of_nat ?m + 3)" 
    proof (rule divide_right_mono)
      have "q = (sqrt q)^2" using q0 by simp
      also have "  (real_of_int sqrt q)^2" 
        by (rule power_mono, auto simp: q0)
      finally show "q  (real_of_int sqrt q)^2" .
    qed auto
    finally show ?thesis .
  qed
  finally have "mn  rat_of_int (p' - 1)^2 / (rat_of_nat ?m + 3)" .
  also have "  rat_of_int (p - 1)^2 / (rat_of_nat ?m + 3)"
    unfolding power2_eq_square
    by (intro divide_right_mono mult_mono, insert p' pp', auto) 
  finally have "mn  rat_of_int (p - 1)^2 / (rat_of_nat ?m + 3)" .
  moreover have " e. p = log_base ^ e" unfolding p compute_mod_of_max_gso_norm_def by auto
  ultimately show "mod_invariant mn p first" unfolding mod_invariant_def by auto
qed

lemma g_bnd_mode_cong: assumes " i. i < m  gso fs i = gso fs' i"
  shows "g_bnd_mode first b fs = g_bnd_mode first b fs'"
  using assms unfolding g_bnd_mode_def g_bnd_def by auto

definition LLL_invariant_mod :: "int vec list  int vec list  int mat  int  bool  rat  nat  bool" where 
  "LLL_invariant_mod fs mfs dmu p first b i = ( 
    length fs = m 
    length mfs = m 
    i  m 
    lattice_of fs = L 
    gs.lin_indpt_list (RAT fs) 
    weakly_reduced fs i 
    (map (map_vec (λx. x symmod p)) fs = mfs) 
    (i' < m.  j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j')) 
    (i' < m. j' < m.  fs i' j' = dmu $$ (i',j')) 
    p > 1 
    g_bnd_mode first b fs 
    mod_invariant b p first
)"

lemma LLL_invD_mod: assumes "LLL_invariant_mod fs mfs dmu p first b i"
shows 
  "length mfs = m"
  "i  m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "weakly_reduced fs i"
  "(map (map_vec (λx. x symmod p)) fs = mfs)"
  "(i' < m. j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.  fs i' j' = dmu $$ (i',j'))"
  " i. i < m  fs ! i  carrier_vec n" 
  "set fs  carrier_vec n"
  " i. i < m  gso fs i  carrier_vec n" 
  " i. i < m  mfs ! i  carrier_vec n"
  "set mfs  carrier_vec n"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    using assms LLL_invariant_mod_def gs.lin_indpt_list_def 
    by (meson gram_schmidt_fs_Rn.intro gram_schmidt_fs_lin_indpt.intro gram_schmidt_fs_lin_indpt_axioms.intro)
  have allfs: "i < m. fs ! i  carrier_vec n" using assms gs'.f_carrier 
    by (simp add: LLL.LLL_invariant_mod_def)
  then have setfs: "set fs  carrier_vec n" by (metis LLL_invariant_mod_def assms in_set_conv_nth subsetI)
  have allgso: "(i < m. gso fs i  carrier_vec n)" using assms gs'.gso_carrier
    by (simp add: LLL.LLL_invariant_mod_def)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier allfs allgso 
      LLL_invariant_mod_def gram_schmidt_fs.reduced_def in_set_conv_nth setfs by fastforce
qed

lemma LLL_invI_mod: assumes 
  "length mfs = m"
  "i  m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "weakly_reduced fs i"
  "map (map_vec (λx. x symmod p)) fs = mfs"
  "(i' < m. j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.  fs i' j' = dmu $$ (i',j'))"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
shows "LLL_invariant_mod fs mfs dmu p first b i" 
  unfolding LLL_invariant_mod_def using assms by blast

definition LLL_invariant_mod_weak :: "int vec list  int vec list  int mat  int  bool  rat  bool" where 
  "LLL_invariant_mod_weak fs mfs dmu p first b = ( 
    length fs = m 
    length mfs = m 
    lattice_of fs = L 
    gs.lin_indpt_list (RAT fs) 
    (map (map_vec (λx. x symmod p)) fs = mfs) 
    (i' < m.  j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j')) 
    (i' < m. j' < m.  fs i' j' = dmu $$ (i',j')) 
    p > 1 
    g_bnd_mode first b fs 
    mod_invariant b p first
)"

lemma LLL_invD_modw: assumes "LLL_invariant_mod_weak fs mfs dmu p first b"
shows 
  "length mfs = m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "(map (map_vec (λx. x symmod p)) fs = mfs)"
  "(i' < m. j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.  fs i' j' = dmu $$ (i',j'))"
  " i. i < m  fs ! i  carrier_vec n" 
  "set fs  carrier_vec n"
  " i. i < m  gso fs i  carrier_vec n" 
  " i. i < m  mfs ! i  carrier_vec n"
  "set mfs  carrier_vec n"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
proof (atomize (full), goal_cases)
  case 1
  interpret gs': gram_schmidt_fs_lin_indpt n "RAT fs"
    using assms LLL_invariant_mod_weak_def gs.lin_indpt_list_def 
    by (meson gram_schmidt_fs_Rn.intro gram_schmidt_fs_lin_indpt.intro gram_schmidt_fs_lin_indpt_axioms.intro)
  have allfs: "i < m. fs ! i  carrier_vec n" using assms gs'.f_carrier 
    by (simp add: LLL.LLL_invariant_mod_weak_def)
  then have setfs: "set fs  carrier_vec n" by (metis LLL_invariant_mod_weak_def assms in_set_conv_nth subsetI)
  have allgso: "(i < m. gso fs i  carrier_vec n)" using assms gs'.gso_carrier
    by (simp add: LLL.LLL_invariant_mod_weak_def)
  show ?case
    using assms gs'.fs_carrier gs'.f_carrier gs'.gso_carrier allfs allgso 
      LLL_invariant_mod_weak_def gram_schmidt_fs.reduced_def in_set_conv_nth setfs by fastforce
qed

lemma LLL_invI_modw: assumes 
  "length mfs = m"
  "length fs = m"
  "lattice_of fs = L"
  "gs.lin_indpt_list (RAT fs)"
  "map (map_vec (λx. x symmod p)) fs = mfs"
  "(i' < m. j' < i'. ¦ fs i' j'¦ < p * d fs j' * d fs (Suc j'))"
  "(i' < m. j' < m.  fs i' j' = dmu $$ (i',j'))"
  "p > 1"
  "g_bnd_mode first b fs"
  "mod_invariant b p first"
shows "LLL_invariant_mod_weak fs mfs dmu p first b" 
  unfolding LLL_invariant_mod_weak_def using assms by blast

lemma ddμ:
  assumes i: "i < m"
  shows "d fs (Suc i) =  fs i i"
proof-
  have "μ fs i i = 1" using i by (simp add: gram_schmidt_fs.μ.simps)
  then show ?thesis using dμ_def by simp
qed

lemma d_of_main: assumes "(i' < m.  fs i' i' = dmu $$ (i',i'))"
  and "i  m"
shows "d_of dmu i = d fs i" 
proof (cases "i = 0")
  case False
  with assms have "i - 1 < m" by auto
  from assms(1)[rule_format, OF this] ddμ[OF this, of fs] False
  show ?thesis by (simp add: d_of_def)
next
  case True
  thus ?thesis unfolding d_of_def True d_def by simp
qed

lemma d_of: assumes inv: "LLL_invariant_mod fs mfs dmu p b first j"
  and "i  m" 
shows "d_of dmu i = d fs i" 
  by (rule d_of_main[OF _ assms(2)], insert LLL_invD_mod(9)[OF inv], auto)

lemma d_of_weak: assumes inv: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and "i  m" 
shows "d_of dmu i = d fs i" 
  by (rule d_of_main[OF _ assms(2)], insert LLL_invD_modw(7)[OF inv], auto)

lemma compute_max_gso_norm: assumes dmu: "(i' < m.  fs i' i' = dmu $$ (i',i'))" 
  and Linv: "LLL_invariant_weak fs" 
shows "g_bnd_mode first (fst (compute_max_gso_norm first dmu)) fs" 
  "fst (compute_max_gso_norm first dmu)  0" 
  "m = 0  fst (compute_max_gso_norm first dmu) = 0" 
proof -
  show gbnd: "g_bnd_mode first (fst (compute_max_gso_norm first dmu)) fs" 
  proof (cases "first  m  0")
    case False
    have "?thesis = (g_bnd (fst (compute_max_gso_norm first dmu)) fs)" unfolding g_bnd_mode_def using False by auto
    also have  unfolding g_bnd_def
    proof (intro allI impI)
      fix i
      assume i: "i < m" 
      have id: "(if first then 1 else m) = m" using False i by auto
      define list where "list = map (λ i. (d_of dmu (Suc i), d_of dmu i, i)) [0 ..< m ]" 
      obtain num denom j where ml: "max_list_rats_with_index list = (num, denom, j)" 
        by (metis prod_cases3)
      have dpos: "d fs i > 0" using LLL_d_pos[OF Linv, of i]  i by auto
      have pos: "(n, d, i)  set list  0 < d" for n d i 
        using LLL_d_pos[OF Linv] unfolding list_def using d_of_main[OF dmu] by auto
      from i have "list ! i  set list" using i unfolding list_def by auto
      also have "list ! i = (d_of dmu (Suc i), d_of dmu i, i)" unfolding list_def using i by auto
      also have " = (d fs (Suc i), d fs i, i)" using d_of_main[OF dmu] i by auto
      finally have "(d fs (Suc i), d fs i, i)  set list" . 
      from max_list_rats_with_index[OF pos ml this] 
      have "of_int (d fs (Suc i)) / of_int (d fs i)  fst (compute_max_gso_norm first dmu)" 
        unfolding compute_max_gso_norm_def list_def[symmetric] ml id split using i by auto
      also have "of_int (d fs (Suc i)) / of_int (d fs i) = sq_norm (gso fs i)" 
        using LLL_d_Suc[OF Linv i] dpos by auto
      finally show "sq_norm (gso fs i)  fst (compute_max_gso_norm first dmu)" .
    qed
    finally show ?thesis .
  next
    case True
    thus ?thesis unfolding g_bnd_mode_def compute_max_gso_norm_def using d_of_main[OF dmu] 
      LLL_d_Suc[OF Linv, of 0] LLL_d_pos[OF Linv, of 0] LLL_d_pos[OF Linv, of 1] by auto
  qed
  show "fst (compute_max_gso_norm first dmu)  0" 
  proof (cases "m = 0")
    case True
    thus ?thesis unfolding compute_max_gso_norm_def by simp
  next
    case False
    hence 0: "0 < m" by simp
    have "0  sq_norm (gso fs 0)" by blast
    also have "  fst (compute_max_gso_norm first dmu)" 
      using gbnd[unfolded g_bnd_mode_def g_bnd_def] using 0 by metis
    finally show ?thesis .
  qed
qed (auto simp: LLL.compute_max_gso_norm_def)


lemma increase_i_mod:
  assumes Linv: "LLL_invariant_mod fs mfs dmu p first b i"
  and i: "i < m" 
  and red_i: "i  0  sq_norm (gso fs (i - 1))  α * sq_norm (gso fs i)"
shows "LLL_invariant_mod fs mfs dmu p first b (Suc i)" "LLL_measure i fs > LLL_measure (Suc i) fs" 
proof -
  note inv = LLL_invD_mod[OF Linv]
  from inv have red: "weakly_reduced fs i"  by (auto)
  from red red_i i have red: "weakly_reduced fs (Suc i)" 
    unfolding gram_schmidt_fs.weakly_reduced_def
    by (intro allI impI, rename_tac ii, case_tac "Suc ii = i", auto)
  show "LLL_invariant_mod fs mfs dmu p first b (Suc i)"
    by (intro LLL_invI_mod, insert inv red i, auto)
  show "LLL_measure i fs > LLL_measure (Suc i) fs" unfolding LLL_measure_def using i by auto
qed

lemma basis_reduction_mod_add_row_main:
  assumes Linvmw: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and i: "i < m"  and j: "j < i" 
  and c: "c = round (μ fs i j)"
  and mfs': "mfs' = mfs[ i := (map_vec (λ x. x symmod p)) (mfs ! i - c v mfs ! j)]"
  and dmu': "dmu' = mat m m (λ(i',j'). (if (i' = i  j'  j) 
        then (if j'=j then (dmu $$ (i,j') - c * dmu $$ (j,j')) 
              else (dmu $$ (i,j') - c * dmu $$ (j,j')) 
                    symmod (p * (d_of dmu j') * (d_of dmu (Suc j'))))
        else (dmu $$ (i',j'))))"
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu' p first b 
        LLL_measure i fs' = LLL_measure i fs
         (μ_small_row i fs (Suc j)  μ_small_row i fs' j) 
         (k < m. gso fs' k = gso fs k)
         (ii  m. d fs' ii = d fs ii)
         ¦μ fs' i j¦  1 / 2
         (i' j'. i' < i  j'  i'  μ fs' i' j' = μ fs i' j')
         (LLL_invariant_mod fs mfs dmu p first b i  LLL_invariant_mod fs' mfs' dmu' p first b i))"
proof -
  define fs' where "fs' = fs[ i := fs ! i - c v fs ! j]"
  from LLL_invD_modw[OF Linvmw] have gbnd: "g_bnd_mode first b fs" and p1: "p > 1" and pgtz: "p > 0" by auto
  have Linvww: "LLL_invariant_weak fs" using LLL_invD_modw[OF Linvmw] LLL_invariant_weak_def by simp
  have 
    Linvw': "LLL_invariant_weak fs'" and
    01: "c = round (μ fs i j)  μ_small_row i fs (Suc j)  μ_small_row i fs' j" and
    02: "LLL_measure i fs' = LLL_measure i fs" and
    03: " i. i < m  gso fs' i = gso fs i" and
    04: " i' j'. i' < m  j' < m  
      μ fs' i' j' = (if i' = i  j'  j then μ fs i j' - of_int c * μ fs j j' else μ fs i' j')" and
    05: " ii. ii  m  d fs' ii = d fs ii" and 
    06: "¦μ fs' i j¦  1 / 2" and
    061: "(i' j'. i' < i  j'  i'  μ fs i' j' = μ fs' i' j')"
    using basis_reduction_add_row_main[OF Linvww i j fs'_def] c i by auto
  have 07: "lin_indep fs'" and 
    08: "length fs' = m" and 
    09: "lattice_of fs' = L" using LLL_inv_wD Linvw' by auto
  have 091: "fs_int_indpt n fs'" using 07 using Gram_Schmidt_2.fs_int_indpt.intro by simp
  define I where "I = {(i',j'). i' = i  j' < j}"
  have 10: "I  {(i',j'). i' < m  j' < i'}" "(i,j) I" "j'  j. (i,j')  I" using I_def i j by auto
  obtain fs'' where 
    11: "lattice_of fs'' = L" and
    12: "map (map_vec (λ x. x symmod p)) fs'' = map (map_vec (λ x. x symmod p)) fs'" and
    13: "lin_indep fs''" and
    14: "length fs'' = m" and
    15: "( k < m. gso fs'' k = gso fs' k)" and
    16: "( k  m. d fs'' k = d fs' k)" and
    17: "( i' < m.  j' < m.  fs'' i' j' = 
      (if (i',j')  I then  fs' i' j' symmod (p * d fs' j' * d fs' (Suc j')) else  fs' i' j'))"
    using mod_finite_set[OF 07 08 10(1) 09 pgtz] by blast
  have 171: "(i' j'. i' < i  j'  i'  μ fs'' i' j' = μ fs' i' j')"
  proof -
    {
      fix i' j'
      assume i'j': "i' < i" "j'  i'"
      have "rat_of_int ( fs'' i' j') = rat_of_int ( fs' i' j')" using "17" I_def i i'j' by auto
      then have "rat_of_int (int_of_rat (rat_of_int (d fs'' (Suc j')) * μ fs'' i' j')) = 
        rat_of_int (int_of_rat (rat_of_int (d fs' (Suc j')) * μ fs' i' j'))"
        using dμ_def i'j' j by auto
      then have "rat_of_int (d fs'' (Suc j')) * μ fs'' i' j' = 
        rat_of_int (d fs' (Suc j')) * μ fs' i' j'" 
        by (smt "08" "091" "13" "14" d_def dual_order.strict_trans fs_int.d_def 
            fs_int_indpt.fs_int_mu_d_Z fs_int_indpt.intro i i'j'(1) i'j'(2) int_of_rat(2))
      then have "μ fs'' i' j' = μ fs' i' j'" by (smt "16" 
            LLL_d_pos[OF Linvw'] Suc_leI int_of_rat(1)
            dual_order.strict_trans fs'_def i i'j' j 
            le_neq_implies_less nonzero_mult_div_cancel_left of_int_hom.hom_zero)
    }
    then show ?thesis by simp
  qed
  then have 172: "(i' j'. i' < i  j'  i'  μ fs'' i' j' = μ fs i' j')" using 061 by simp (* goal *)
  have 18: "LLL_measure i fs'' = LLL_measure i fs'" using 16 LLL_measure_def logD_def D_def by simp
  have 19: "(k < m. gso fs'' k = gso fs k)" using 03 15 by simp
  have "j'  {j..(m-1)}. j' < m" using j i by auto
  then have 20: "j'  {j..(m-1)}.  fs'' i j' =  fs' i j'" 
    using 10(3) 17 Suc_lessD less_trans_Suc by (meson atLeastAtMost_iff i)
  have 21: "j'  {j..(m-1)}. μ fs'' i j' = μ fs' i j'" 
  proof -
    {
      fix j'
      assume j': "j'  {j..(m-1)}"
      define μ'' :: rat where "μ'' = μ fs'' i j'"
      define μ' :: rat where "μ' = μ fs' i j'"
      have "rat_of_int ( fs'' i j') = rat_of_int ( fs' i j')" using 20 j' by simp
      moreover have "j' < length fs'" using i j' 08 by auto
      ultimately have "rat_of_int (d fs' (Suc j')) * gram_schmidt_fs.μ n (map of_int_hom.vec_hom fs') i j'
        = rat_of_int (d fs'' (Suc j')) * gram_schmidt_fs.μ n (map of_int_hom.vec_hom fs'') i j'"
        using 20 08 091 13 14 fs_int_indpt.dμ_def fs_int.d_def fs_int_indpt.dμ dμ_def d_def i fs_int_indpt.intro j'
        by metis
      then have "rat_of_int (d fs' (Suc j')) * μ'' = rat_of_int (d fs' (Suc j')) * μ'" 
        using 16 i j' μ'_def μ''_def unfolding dμ_def by auto
      moreover have "0 < d fs' (Suc j')" using LLL_d_pos[OF Linvw', of "Suc j'"] i j' by auto
      ultimately have "μ fs'' i j' = μ fs' i j'" using μ'_def μ''_def by simp
    }
    then show ?thesis by simp
  qed
  then have 22: "μ fs'' i j = μ fs' i j" using i j by simp
  then have 23: "¦μ fs'' i j¦  1 / 2" using 06 by simp (* goal *)
  have 24: "LLL_measure i fs'' = LLL_measure i fs" using 02 18 by simp (* goal *)
  have 25: "( k  m. d fs'' k = d fs k)" using 16 05 by simp (* goal *)
  have 26: "( k < m. gso fs'' k = gso fs k)" using 15 03 by simp (* goal *)
  have 27: "μ_small_row i fs (Suc j)  μ_small_row i fs'' j"
    using 21 01 μ_small_row_def i j c by auto (* goal *)
  have 28: "length fs = m" "length mfs = m" using LLL_invD_modw[OF Linvmw] by auto
  have 29: "map (map_vec (λx. x symmod p)) fs = mfs" using assms LLL_invD_modw by simp
  have 30: " i. i < m  fs ! i  carrier_vec n" " i. i < m  mfs ! i  carrier_vec n"
    using LLL_invD_modw[OF Linvmw] by auto
  have 31: " i. i < m  fs' ! i  carrier_vec n" using fs'_def 30(1) 
    using "08" "091" fs_int_indpt.f_carrier by blast
  have 32: " i. i < m  mfs' ! i  carrier_vec n" unfolding mfs' using 30(2) 28(2) 
    by (metis (no_types, lifting) Suc_lessD j less_trans_Suc map_carrier_vec minus_carrier_vec 
        nth_list_update_eq nth_list_update_neq smult_closed)
  have 33: "length mfs' = m" using 28(2) mfs' by simp (* invariant goal *)
  then have 34: "map (map_vec (λx. x symmod p)) fs' = mfs'"
  proof -
    {
      fix i' j'
      have j2: "j < m" using j i by auto
      assume i': "i' < m"
      assume j': "j' < n"
      then have fsij: "(fs ! i' $ j') symmod p = mfs ! i' $ j'" using 30 i' j' 28 29 by fastforce
      have "mfs' ! i $ j' = (mfs ! i $ j'- (c v mfs ! j) $ j') symmod p"
        unfolding mfs' using 30(2) j' 28 j2 
        by (metis (no_types, lifting) carrier_vecD i index_map_vec(1) index_minus_vec(1) 
            index_minus_vec(2) index_smult_vec(2) nth_list_update_eq)
      then have mfs'ij: "mfs' ! i $ j' = (mfs ! i $ j'- c * mfs ! j $ j') symmod p" 
        unfolding mfs' using 30(2) i' j' 28 j2 by fastforce
      have "(fs' ! i' $ j') symmod p = mfs' ! i' $ j'"
      proof(cases "i' = i")
        case True
        show ?thesis using fs'_def mfs' True 28 fsij 
        proof -
          have "fs' ! i' $ j' = (fs ! i' - c v fs ! j) $ j'" using fs'_def True i' j' 28(1) by simp
          also have " = fs ! i' $ j' - (c v fs ! j) $ j'" using i' j' 30(1)
            by (metis Suc_lessD carrier_vecD i index_minus_vec(1) index_smult_vec(2) j less_trans_Suc)
          finally have "fs' ! i' $ j' = fs ! i' $ j' - (c v fs ! j) $ j'" by auto
          then have "(fs' ! i' $ j') symmod p = (fs ! i' $ j' - (c v fs ! j) $ j') symmod p" by auto
          also have " = ((fs ! i' $ j') symmod p - ((c v fs ! j) $ j') symmod p) symmod p"
            by (simp add: sym_mod_diff_eq)
          also have "(c v fs ! j) $ j' = c * (fs ! j $ j')" 
            using i' j' True 28 30(1) j
            by (metis Suc_lessD carrier_vecD index_smult_vec(1) less_trans_Suc)
          also have "((fs ! i' $ j') symmod p - (c * (fs ! j $ j')) symmod p) symmod p = 
            ((fs ! i' $ j') symmod p - c * ((fs ! j $ j') symmod p)) symmod p" 
            using i' j' True 28 30(1) j by (metis sym_mod_diff_right_eq sym_mod_mult_right_eq)
          also have "((fs ! j $ j') symmod p) = mfs ! j $ j'" using 30 i' j' 28 29 j2 by fastforce
          also have "((fs ! i' $ j') symmod p - c * mfs ! j $ j') symmod p = 
            (mfs ! i' $ j' - c * mfs ! j $ j') symmod p" using fsij by simp
          finally show ?thesis using mfs'ij by (simp add: True)
        qed
      next
        case False
        show ?thesis using fs'_def mfs' False 28 fsij by simp
      qed
    }
    then have "i' < m. (map_vec (λx. x symmod p)) (fs' ! i') = mfs' ! i'"
      using 31 32 33 08 by fastforce
    then show ?thesis using 31 32 33 08 by (simp add: map_nth_eq_conv)
  qed
  then have 35: "map (map_vec (λx. x symmod p)) fs'' = mfs'" using 12 by simp (* invariant req. *)
  have 36: "lin_indep fs''"  using 13 by simp (* invariant req. *)
  have Linvw'': "LLL_invariant_weak fs''" using LLL_invariant_weak_def 11 13 14 by simp
  have 39: "(i' < m. j' < i'. ¦ fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j'))" (* invariant req. *)
  proof -
    {
      fix i' j'
      assume i': "i' < m"
      assume j': "j' < i'"
      define pdd where "pdd = (p * d fs'' j' * d fs'' (Suc j'))"
      then have pddgtz: "pdd > 0" 
        using pgtz j' LLL_d_pos[OF Linvw', of "Suc j'"] LLL_d_pos[OF Linvw', of j'] j' i' 16 by simp
      have "¦ fs'' i' j'¦ < p * d fs'' j' * d fs'' (Suc j')"
      proof(cases "i' = i")
        case i'i: True
        then show ?thesis
        proof (cases "j' < j")
          case True
          then have eq'': " fs'' i' j' =  fs' i' j' symmod (p * d fs'' j' * d fs'' (Suc j'))"
            using 16 17 10 I_def True i' j' i'i by simp
          have "0 < pdd" using pddgtz by simp
          then show ?thesis unfolding eq'' unfolding pdd_def[symmetric] using sym_mod_abs by blast
        next
          case fls: False
          then have "(i',j')  I" using I_def i'i by simp
          then have dmufs''fs': " fs'' i' j' =  fs' i' j'" using 17 i' j' by simp
          show ?thesis
          proof (cases "j' = j")
            case True
            define μ'' where "μ'' = μ fs'' i' j'" 
            define d'' where "d'' = d fs'' (Suc j')"
            have pge1: "p  1" using pgtz by simp
            have lh: "¦μ''¦  1 / 2" using 23 True i'i μ''_def by simp
            moreover have eq: " fs'' i' j' = μ'' * d''" using dμ_def i' j' μ''_def d''_def 
              by (smt "14" "36" LLL.d_def Suc_lessD fs_int.d_def fs_int_indpt.dμ fs_int_indpt.intro 
                  int_of_rat(1) less_trans_Suc mult_of_int_commute of_rat_mult of_rat_of_int_eq)
            moreover have Sj': "Suc j'  m" "j'  m" using True j' i i' by auto
            moreover then have gtz: "0 < d''" using LLL_d_pos[OF Linvw''] d''_def by simp
            moreover have "rat_of_int ¦ fs'' i' j'¦ = ¦μ'' * (rat_of_int d'')¦" 
              using eq by (metis of_int_abs of_rat_hom.injectivity of_rat_mult of_rat_of_int_eq)
            moreover then have "¦μ'' * rat_of_int d'' ¦ =  ¦μ''¦ * rat_of_int ¦d''¦"
              by (metis (mono_tags, opaque_lifting) abs_mult of_int_abs)
            moreover have " = ¦μ''¦ * rat_of_int d'' " using gtz by simp
            moreover have " < rat_of_int d''" using lh gtz by simp
            ultimately have "rat_of_int ¦ fs'' i' j'¦ < rat_of_int d''" by simp
            then have "¦ fs'' i' j'¦ <  d fs'' (Suc j')" using d''_def by simp
            then have "¦ fs'' i' j'¦ < p * d fs'' (Suc j')" using pge1
              by (smt mult_less_cancel_right2)
            then show ?thesis using pge1 LLL_d_pos[OF Linvw'' Sj'(2)] gtz unfolding d''_def
              by (smt mult_less_cancel_left2 mult_right_less_imp_less)
          next
            case False
            have "j' < m" using i' j' by simp
            moreover have "j' > j" using False fls by simp
            ultimately have "μ fs' i' j' = μ fs i' j'" using i' 04 i by simp
            then have " fs' i' j' =  fs i' j'" using dμ_def i' j' 05 by simp
            then have " fs'' i' j' =  fs i' j'" using dmufs''fs' by simp
            then show ?thesis using LLL_invD_modw[OF Linvmw] i' j' 25 by simp
          qed
        qed
      next
        case False
        then have "(i',j')  I" using I_def by simp
        then have dmufs''fs': " fs'' i' j' =  fs' i' j'" using 17 i' j' by simp
        have "μ fs' i' j' = μ fs i' j'" using i' 04 j' False by simp
        then have " fs' i' j' =  fs i' j'" using dμ_def i' j' 05 by simp
        moreover then have " fs'' i' j' =  fs i' j'" using dmufs''fs' by simp
        then show ?thesis using LLL_invD_modw[OF Linvmw] i' j' 25 by simp
      qed
    }
    then show ?thesis by simp
  qed
  have 40: "(i' < m. j' < m. i'  i  j' > j   fs' i' j' = dmu $$ (i',j'))"
  proof -
    {
      fix i' j'
      assume i': "i' < m" and j': "j' < m"
      assume assm: "i'  i  j' > j"
      have " fs' i' j' = dmu $$ (i',j')"
      proof (cases "i'  i")
        case True
        then show ?thesis using fs'_def LLL_invD_modw[OF Linvmw] dμ_def i i' j j'
          04 28(1) LLL_invI_weak basis_reduction_add_row_main(8)[OF Linvww] by auto
      next
        case False
        then show ?thesis 
          using 05 LLL_invD_modw[OF Linvmw] dμ_def i j j' 04 assm by simp
      qed
    }
    then show ?thesis by simp
  qed
  have 41: "j'  j.  fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')"
  proof -
    {
      let ?oi = "of_int :: _  rat" 
      fix j'
      assume j': "j'  j"
      define dj' μi μj where "dj' = d fs (Suc j')" and "μi = μ fs i j'" and "μj = μ fs j j'"
      have "?oi ( fs' i j') = ?oi (d fs (Suc j')) * (μ fs i j' - ?oi c * μ fs j j')"
        using j' 04 dμ_def 
        by (smt "05" "08" "091" Suc_leI d_def diff_diff_cancel fs_int.d_def 
            fs_int_indpt.fs_int_mu_d_Z i int_of_rat(2) j less_imp_diff_less less_imp_le_nat)
      also have " = (?oi dj') * (μi - of_int c * μj)" 
        using dj'_def μi_def μj_def by (simp add: of_rat_mult)
      also have " = (rat_of_int dj') * μi - of_int c * (rat_of_int dj') * μj" by algebra
      also have " = rat_of_int ( fs i j') - ?oi c * rat_of_int ( fs j j')" unfolding dj'_def μi_def μj_def
        using i j j' dμ_def
        using "28"(1) LLL.LLL_invD_modw(4) Linvmw d_def fs_int.d_def fs_int_indpt.fs_int_mu_d_Z fs_int_indpt.intro by auto
      also have " = rat_of_int (dmu $$ (i,j')) - ?oi c * rat_of_int (dmu $$ (j,j'))" 
        using LLL_invD_modw(7)[OF Linvmw] dμ_def j' i j by auto
      finally have "?oi ( fs' i j') = rat_of_int (dmu $$ (i,j')) - ?oi c * rat_of_int (dmu $$ (j,j'))" by simp
      then have " fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')"
        using of_int_eq_iff by fastforce
    }
    then show ?thesis by simp
  qed
  have 42: "(i' < m. j' < m.  fs'' i' j' = dmu' $$ (i',j'))"
  proof -
    {
      fix i' j'
      assume i': "i' < m" and j': "j' < m"
      have " fs'' i' j' = dmu' $$ (i',j')" 
      proof (cases "i' = i")
        case i'i: True
        then show ?thesis
        proof (cases "j' > j")
          case True
          then have "(i',j')I" using I_def by simp
          moreover then have " fs' i' j' =  fs i' j'" using "04" "05" True Suc_leI dμ_def i' j' by simp
          moreover have "dmu' $$ (i',j') = dmu $$ (i',j')" using dmu' True i' j' by simp
          ultimately show ?thesis using "17" "40" True i' j' by auto
        next
          case False
          then have j'lej: "j'  j" by simp
          then have eq': " fs' i j' = dmu $$ (i,j') - c * dmu $$ (j,j')" using 41 by simp
          have id: "d_of dmu j' = d fs j'" "d_of dmu (Suc j') = d fs (Suc j')" 
            using d_of_weak[OF Linvmw] j' < m by auto
          show ?thesis
          proof (cases "j'  j")
            case True
            then have j'ltj: "j' < j" using True False by simp
            then have "(i',j')  I" using I_def True i'i by simp
            then have " fs'' i' j' = 
              (dmu $$ (i,j') - c * dmu $$ (j,j')) symmod (p * d fs' j' * d fs' (Suc j'))"
              using 17 i' 41 j'lej by (simp add: j' i'i)
            also have " = (dmu $$ (i,j') - c * dmu $$ (j,j')) symmod (p * d fs j' * d fs (Suc j'))"
              using 05 i j'ltj j by simp
            also have " = dmu' $$ (i,j')" 
              unfolding dmu' index_mat(1)[OF i < m j' < m] split id using j'lej True by auto
            finally show ?thesis using i'i by simp
          next
            case False
            then have j'j: "j' = j" by simp
            then have " fs'' i j' =  fs' i j'" using 20 j' by simp
            also have " = dmu $$ (i,j') - c * dmu $$ (j,j')" using eq' by simp
            also have " = dmu' $$ (i,j')" using dmu' j'j i j' by simp
            finally show ?thesis using i'i by simp
          qed
        qed
      next
        case False
        then have "(i',j')I" using I_def by simp
        moreover then have " fs' i' j' =  fs i' j'" by (simp add: "04" "05" False Suc_leI dμ_def i' j')
        moreover then have "dmu' $$ (i',j') = dmu $$ (i',j')" using dmu' False i' j' by simp
        ultimately show ?thesis using "17" "40" False i' j' by auto
      qed
    }
    then show ?thesis by simp
  qed
  from gbnd 26 have gbnd: "g_bnd_mode first b fs''" using g_bnd_mode_cong[of fs'' fs] by simp
  {
    assume Linv: "LLL_invariant_mod fs mfs dmu p first b i"
    have Linvw: "LLL_invariant_weak' i fs" using Linv LLL_invD_mod LLL_invI_weak by simp
    note Linvww = LLL_invw'_imp_w[OF Linvw]
    have 00: "LLL_invariant_weak' i fs'" using Linvw basis_reduction_add_row_weak[OF Linvw i j fs'_def] by auto
    have 37: "weakly_reduced fs'' i" using 15 LLL_invD_weak(8)[OF 00] gram_schmidt_fs.weakly_reduced_def 
      by (smt Suc_lessD i less_trans_Suc) (* invariant req. *)
    have 38: "LLL_invariant_weak' i fs''"
      using 00 11 14 36 37 i 31 12  LLL_invariant_weak'_def by blast
    have "LLL_invariant_mod fs'' mfs' dmu' p first b i"
      using LLL_invI_mod[OF 33 _ 14 11 13 37 35 39 42 p1 gbnd LLL_invD_mod(17)[OF Linv]] i by simp
  }
  moreover have "LLL_invariant_mod_weak fs'' mfs' dmu' p first b"
    using LLL_invI_modw[OF 33 14 11 13 35 39 42 p1 gbnd LLL_invD_modw(15)[OF Linvmw]] by simp
  ultimately show ?thesis using 27 23 24 25 26 172 by auto
qed

definition D_mod :: "int mat  nat" where "D_mod dmu = nat ( i < m. d_of dmu i)"

definition logD_mod :: "int mat  nat"
  where "logD_mod dmu = (if α = 4/3 then (D_mod dmu) else nat (floor (log (1 / of_rat reduction) (D_mod dmu))))" 
end

locale fs_int'_mod = 
  fixes n m fs_init α i fs mfs dmu p first b 
  assumes LLL_inv_mod: "LLL.LLL_invariant_mod n m fs_init α fs mfs dmu p first b i"

context LLL_with_assms
begin

lemma basis_reduction_swap_weak': assumes Linvw: "LLL_invariant_weak' i fs"
  and i: "i < m"
  and i0: "i  0"
  and mu_F1_i: "¦μ fs i (i-1)¦  1 / 2"
  and norm_ineq: "sq_norm (gso fs (i - 1)) > α * sq_norm (gso fs i)" 
  and fs'_def: "fs' = fs[i := fs ! (i - 1), i - 1 := fs ! i]" 
shows "LLL_invariant_weak' (i - 1) fs'" 
proof -
  note inv = LLL_invD_weak[OF Linvw]
  note invw = LLL_invw'_imp_w[OF Linvw]
  note main = basis_reduction_swap_main[OF invw disjI2[OF mu_F1_i] i i0 norm_ineq fs'_def]
  note inv' = LLL_inv_wD[OF main(1)]
  from weakly_reduced fs i have "weakly_reduced fs (i - 1)" 
    unfolding gram_schmidt_fs.weakly_reduced_def by auto
  also have "weakly_reduced fs (i - 1) = weakly_reduced fs' (i - 1)" 
    unfolding gram_schmidt_fs.weakly_reduced_def 
    by (intro all_cong, insert i0 i main(5), auto)
  finally have red: "weakly_reduced fs' (i - 1)" .
  show "LLL_invariant_weak' (i - 1) fs'" using i
    by (intro LLL_invI_weak red inv', auto)
qed

lemma basis_reduction_add_row_done_weak: 
  assumes Linv: "LLL_invariant_weak' i fs"
  and i: "i < m" 
  and mu_small: "μ_small_row i fs 0" 
shows "μ_small fs i"
proof -
  note inv = LLL_invD_weak[OF Linv]
  from mu_small 
  have mu_small: "μ_small fs i" unfolding μ_small_row_def μ_small_def by auto
  show ?thesis
    using i mu_small LLL_invI_weak[OF inv(3,6,7,9,1)] by auto
qed     

lemma LLL_invariant_mod_to_weak_m_to_i: assumes
  inv: "LLL_invariant_mod fs mfs dmu p first b m"
  and i: "i  m"
shows "LLL_invariant_mod fs mfs dmu p first b i"
  "LLL_invariant_weak' m fs"
  "LLL_invariant_weak' i fs"
proof -
  show "LLL_invariant_mod fs mfs dmu p first b i" 
  proof -
    have "LLL_invariant_weak' m fs" using LLL_invD_mod[OF inv] LLL_invI_weak by simp
    then have "LLL_invariant_weak' i fs" using LLL_inv_weak_m_impl_i i by simp
    then have "weakly_reduced fs i" using i LLL_invD_weak(8) by simp
    then show ?thesis using LLL_invD_mod[OF inv] LLL_invI_mod i by simp
  qed
  then show fsinvwi: "LLL_invariant_weak' i fs" using LLL_invD_mod LLL_invI_weak by simp
  show "LLL_invariant_weak' m fs" using LLL_invD_mod[OF inv] LLL_invI_weak by simp
qed

lemma basis_reduction_mod_swap_main: 
  assumes Linvmw: "LLL_invariant_mod_weak fs mfs dmu p first b"
  and k: "k < m"
  and k0: "k  0"
  and mu_F1_i: "¦μ fs k (k-1)¦  1 / 2"
  and norm_ineq: "sq_norm (gso fs (k - 1)) > α * sq_norm (gso fs k)" 
  and mfs'_def: "mfs' = mfs[k := mfs ! (k - 1), k - 1 := mfs ! k]"
  and dmu'_def: "dmu' = (mat m m (λ(i,j). (
      if j < i then
        if i = k - 1 then 
           dmu $$ (k, j)
        else if i = k  j  k - 1 then 
             dmu $$ (k - 1, j)
        else if i > k  j = k then
           ((d_of dmu (Suc k)) * dmu $$ (i, k - 1) - dmu $$ (k, k - 1) * dmu $$ (i, j)) 
              div (d_of dmu k)
        else if i > k  j = k - 1 then
           (dmu $$ (k, k - 1) * dmu $$ (i, j) + dmu $$ (i, k) * (d_of dmu (k-1)))
              div (d_of dmu k)
        else dmu $$ (i, j)
      else if i = j then 
        if i = k - 1 then 
          ((d_of dmu (Suc k)) * (d_of dmu (k-1)) + dmu $$ (k, k - 1) * dmu $$ (k, k - 1)) 
            div (d_of dmu k)
        else (d_of dmu (Suc i))
      else dmu $$ (i, j))
    ))"
  and dmu'_mod_def: "dmu'_mod = mat m m (λ(i, j). (
        if j < i  (j = k  j = k - 1) then 
          dmu' $$ (i, j) symmod (p * (d_of dmu' j) * (d_of dmu' (Suc j)))
        else dmu' $$ (i, j)))"
shows "(fs'. LLL_invariant_mod_weak fs' mfs' dmu'_mod p first b 
        LLL_measure (k-1) fs' < LLL_measure k fs 
        (LLL_invariant_mod fs mfs dmu p first b k  LLL_invariant_mod fs' mfs' dmu'_mod p first b (k-1)))" 
proof - 
  define fs' where "fs' = fs[k := fs ! (k - 1), k - 1 := fs ! k]"
  have pgtz: "p > 0" and p1: "p > 1" using LLL_invD_modw[OF Linvmw] by auto
  have invw: "LLL_invariant_weak fs" using LLL_invD_modw[OF Linvmw] LLL_invariant_weak_def by simp
  note swap_main = basis_reduction_swap_main(3-)[OF invw disjI2[OF mu_F1_i] k k0 norm_ineq fs'_def]
  note ddμ_swap = d_dμ_swap[OF invw disjI2[OF mu_F1_i] k k0 norm_ineq fs'_def]
  have invw': "LLL_invariant_weak fs'" using fs'_def assms invw basis_reduction_swap_main(1) by simp
  have 02: "LLL_measure k fs > LLL_measure (k - 1) fs'" by fact
  have 03: " i j. i < m  j < i  
           fs' i j = (
        if i = k - 1 then 
            fs k j
        else if i = k  j  k - 1 then 
              fs (k - 1) j
        else if i > k  j = k then
           (d fs (Suc k) *  fs i (k - 1) -  fs k (k - 1) *  fs i j) div d fs k
        else if i > k  j = k - 1 then 
           ( fs k (k - 1) *  fs i j +  fs i k * d fs (k - 1)) div d fs k
        else  fs i j)"
    using ddμ_swap by auto
  have 031: "i. i < k-1  gso fs' i = gso fs i" 
    using swap_main(2) k k0 by auto
  have 032: " ii. ii  m  of_int (d fs' ii) = (if ii = k then 
           sq_norm (gso fs' (k - 1)) / sq_norm (gso fs (k - 1)) * of_int (d fs k)
           else of_int (d fs ii))" 
    by fact 
  have gbnd: "g_bnd_mode first b fs'"
  proof (cases "first  m  0")
    case True
    have "sq_norm (gso fs' 0)  sq_norm (gso fs 0)" 
    proof (cases "k - 1 = 0")
      case False
      thus ?thesis using 031[of 0] by simp
    next
      case *: True
      have k_1: "k - 1 < m" using k by auto
      from * k0 have k1: "k = 1" by simp
      (* this is a copy of what is done in LLL.swap-main, should be made accessible in swap-main *)
      have "sq_norm (gso fs' 0)  abs (sq_norm (gso fs' 0))" by simp
      also have " = abs (sq_norm (gso fs 1) + μ fs 1 0 * μ fs 1 0 * sq_norm (gso fs 0))" 
        by (subst swap_main(3)[OF k_1, unfolded *], auto simp: k1)
      also have "  sq_norm (gso fs 1) + abs (μ fs 1 0) * abs (μ fs 1 0) * sq_norm (gso fs 0)"
        by (simp add: sq_norm_vec_ge_0)
      also have "  sq_norm (gso fs 1) + (1 / 2) * (1 / 2) * sq_norm (gso fs 0)" 
        using mu_F1_i[unfolded k1] 
        by (intro plus_right_mono mult_mono, auto)
      also have " < 1 / α * sq_norm (gso fs 0) + (1 / 2) * (1 / 2) * sq_norm (gso fs 0)" 
        by (intro add_strict_right_mono, insert norm_ineq[unfolded mult.commute[of α],
          THEN mult_imp_less_div_pos[OF α0(1)]] k1, auto)
      also have " = reduction * sq_norm (gso fs 0)" unfolding reduction_def
        using α0 by (simp add: ring_distribs add_divide_distrib)
      also have "<