Theory BenOr_Kozen_Reif.BKR_Algorithm
theory BKR_Algorithm
  imports
    Sturm_Tarski.Sturm_Tarski
    More_Matrix
    
begin
section "Setup"
definition retrieve_polys:: "'a list ⇒ nat list ⇒ 'a list"
  where "retrieve_polys qss index_list = (map (nth qss) index_list)"
definition construct_NofI:: "real poly ⇒ real poly list ⇒ rat"
  where "construct_NofI p I =  rat_of_int (changes_R_smods p ((pderiv p)*(prod_list I)))"
definition construct_rhs_vector:: "real poly ⇒ real poly list ⇒ nat list list ⇒ rat vec"
  where "construct_rhs_vector p qs Is = vec_of_list (map (λ I.(construct_NofI p (retrieve_polys qs I))) Is)"
section "Base Case"
definition base_case_info:: "(rat mat × (nat list list × rat list list))"
  where "base_case_info =
    ((mat_of_rows_list 2 [[1,1], [1,-1]]), ([[],[0]], [[1],[-1]]))"
definition base_case_solve_for_lhs:: "real poly ⇒ real poly ⇒ rat vec"
  where "base_case_solve_for_lhs p q = (mult_mat_vec (mat_of_rows_list 2 [[1/2, 1/2], [1/2, -1/2]])  (construct_rhs_vector p [q] [[], [0]]))"
thm "gauss_jordan_compute_inverse"
primrec matr_option:: "nat ⇒ 'a::{one, zero} mat option ⇒ 'a mat"
  where "matr_option dimen None = 1⇩m dimen"
  | "matr_option dimen (Some c) = c" 
definition mat_equal:: "'a:: field mat ⇒ 'a :: field mat ⇒ bool"
  where "mat_equal A B = (dim_row A = dim_row B ∧ dim_col A = dim_col B ∧ (mat_to_list A) = (mat_to_list B))"
definition mat_inverse_var :: "'a :: field mat ⇒ 'a mat option" where 
  "mat_inverse_var A = (if dim_row A = dim_col A then
    let one = 1⇩m (dim_row A) in
    (case gauss_jordan A one of
      (B, C) ⇒ if (mat_equal B one) then Some C else None) else None)"
definition solve_for_lhs:: "real poly ⇒ real poly list ⇒ nat list list ⇒ rat mat ⇒ rat vec"
  where "solve_for_lhs p qs subsets matr =
     mult_mat_vec (matr_option (dim_row matr) (mat_inverse_var matr))  (construct_rhs_vector p qs subsets)"
section "Smashing" 
definition subsets_smash::"nat ⇒ nat list list ⇒ nat list list ⇒ nat list list"
  where "subsets_smash n s1 s2 = concat (map (λl1. map (λ l2. l1 @ (map ((+) n) l2)) s2) s1)"
definition signs_smash::"'a list list ⇒  'a list list ⇒ 'a list list"
  where "signs_smash s1 s2 = concat (map (λl1. map (λ l2. l1 @ l2) s2) s1)"
definition smash_systems:: "real poly ⇒ real poly list ⇒ real poly list ⇒ nat list list ⇒ nat list list ⇒
  rat list list ⇒ rat list list ⇒ rat mat ⇒ rat mat ⇒ 
  real poly list × (rat mat × (nat list list × rat list list))"
  where "smash_systems p qs1 qs2 subsets1 subsets2 signs1 signs2 mat1 mat2 =
    (qs1@qs2, (kronecker_product mat1 mat2, (subsets_smash (length qs1) subsets1 subsets2, signs_smash signs1 signs2)))"
fun combine_systems:: "real poly ⇒ (real poly list × (rat mat × (nat list list × rat list list))) ⇒ (real poly list × (rat mat × (nat list list × rat list list)))
  ⇒ (real poly list × (rat mat × (nat list list × rat list list)))"
  where "combine_systems p (qs1, m1, sub1, sgn1) (qs2, m2, sub2, sgn2) = 
    (smash_systems p qs1 qs2 sub1 sub2 sgn1 sgn2 m1 m2)"
section "Reduction"
definition find_nonzeros_from_input_vec:: "rat vec ⇒ nat list"
  where "find_nonzeros_from_input_vec lhs_vec = filter (λi. lhs_vec $ i ≠ 0) [0..< dim_vec lhs_vec]"
definition take_indices:: "'a list ⇒ nat list ⇒ 'a list"
  where "take_indices subsets indices = map ((!) subsets) indices"
definition take_cols_from_matrix:: "'a mat ⇒ nat list ⇒ 'a mat"
  where "take_cols_from_matrix matr indices_to_keep = 
    mat_of_cols (dim_row matr) ((take_indices (cols matr) indices_to_keep):: 'a vec list)"
definition take_rows_from_matrix:: "'a mat ⇒ nat list ⇒ 'a mat"
  where "take_rows_from_matrix matr indices_to_keep = 
    mat_of_rows (dim_col matr) ((take_indices (rows matr) indices_to_keep):: 'a vec list)"
fun reduce_mat_cols:: "'a mat ⇒ rat vec ⇒ 'a mat"
  where "reduce_mat_cols A lhs_vec = take_cols_from_matrix A (find_nonzeros_from_input_vec lhs_vec)"
definition rows_to_keep:: "('a::field) mat ⇒ nat list" where
  "rows_to_keep A = map snd (pivot_positions (gauss_jordan_single (A⇧T)))"
fun reduction_step:: "rat mat ⇒ rat list list ⇒ nat list list ⇒ rat vec ⇒ rat mat × (nat list list × rat list list)"
  where "reduction_step A signs subsets lhs_vec = 
    (let reduce_cols_A = (reduce_mat_cols A lhs_vec);
         rows_keep = rows_to_keep reduce_cols_A in
    (take_rows_from_matrix  reduce_cols_A rows_keep,
      (take_indices subsets rows_keep,
      take_indices signs (find_nonzeros_from_input_vec lhs_vec))))"
fun reduce_system:: "real poly ⇒ (real poly list × (rat mat × (nat list list × rat list list))) ⇒ (rat mat × (nat list list × rat list list))"
  where "reduce_system p (qs,m,subs,signs) =
    reduction_step m signs subs (solve_for_lhs p qs subs m)" 
section "Overall algorithm "
  
fun calculate_data:: "real poly ⇒ real poly list ⇒  (rat mat × (nat list list × rat list list))"
  where
    "calculate_data p qs = 
  ( let len = length qs in
    if len = 0 then
      (λ(a,b,c).(a,b,map (drop 1) c)) (reduce_system p ([1],base_case_info))
    else if len ≤ 1 then reduce_system p (qs,base_case_info)
    else
    (let q1 = take (len div 2) qs; left = calculate_data p q1;
         q2 = drop (len div 2) qs; right = calculate_data p q2;
         comb = combine_systems p (q1,left) (q2,right) in
         reduce_system p comb
    )
  )"
definition find_consistent_signs_at_roots:: "real poly ⇒ real poly list ⇒ rat list list"
  where [code]:
    "find_consistent_signs_at_roots p qs =
  ( let (M,S,Σ) = calculate_data p qs in Σ )"
lemma find_consistent_signs_at_roots_thm:
  shows "find_consistent_signs_at_roots p qs = snd (snd (calculate_data p qs))"
  by (simp add: case_prod_beta find_consistent_signs_at_roots_def)
end