Theory Gauss-Jordan-Elim-Fun.Gauss_Jordan_Elim_Fun
section ‹Gauss-Jordan elimination algorithm›
theory Gauss_Jordan_Elim_Fun
  imports
    "HOL-Combinatorics.Transposition"
begin
text‹Matrices are functions:›
type_synonym 'a matrix = "nat ⇒ nat ⇒ 'a"
text‹In order to restrict to finite matrices, a matrix is usually combined
with one or two natural numbers indicating the maximal row and column of the
matrix.
Gauss-Jordan elimination is parameterized with a natural number ‹n›. It indicates that the matrix ‹A› has ‹n› rows and columns.
In fact, ‹A› is the augmented matrix with ‹n+1› columns. Column
‹n› is the ``right-hand side'', i.e.\ the constant vector ‹b›. The result is the unit matrix augmented with the solution in column
‹n›; see the correctness theorem below.›
fun gauss_jordan :: "('a::field)matrix ⇒ nat ⇒ ('a)matrix option" where
"gauss_jordan A 0 = Some(A)" |
"gauss_jordan A (Suc m) =
 (case dropWhile (λi. A i m = 0) [0..<Suc m] of
   [] ⇒ None |
   p # _ ⇒
    (let Ap' = (λj. A p j / A p m);
         A' = (λi. if i=p then Ap' else (λj. A i j - A i m * Ap' j))
     in gauss_jordan (Fun.swap p m A') m))"
text‹Some auxiliary functions:›
definition solution :: "('a::field)matrix ⇒ nat ⇒ (nat ⇒ 'a) ⇒ bool" where
"solution A n x = (∀i<n. (∑ j=0..<n. A i j * x j) = A i n)"
definition unit :: "('a::field)matrix ⇒ nat ⇒ nat ⇒ bool" where
"unit A m n =
 (∀i j::nat. m≤j ⟶ j<n ⟶ A i j = (if i=j then 1 else 0))"
lemma solution_swap:
assumes "p1 < n" "p2 < n"
shows "solution (Fun.swap p1 p2 A) n x = solution A n x" (is "?L = ?R")
proof(cases "p1=p2")
  case True thus ?thesis by simp
next
  case False
  show ?thesis
  proof
    assume ?R thus ?L using assms False by(simp add: solution_def Fun.swap_def)
  next
   assume ?L
   show ?R
   proof(auto simp: solution_def)
     fix i assume "i<n"
     show "(∑j = 0..<n. A i j * x j) = A i n"
     proof cases
       assume "i=p1"
       with ‹?L› assms False show ?thesis
         by(fastforce simp add: solution_def Fun.swap_def)
     next
       assume "i≠p1"
       show ?thesis
       proof cases
         assume "i=p2"
         with ‹?L› assms False show ?thesis
           by(fastforce simp add: solution_def Fun.swap_def)
       next
         assume "i≠p2"
         with ‹i≠p1› ‹?L› ‹i<n› assms False show ?thesis
           by(fastforce simp add: solution_def Fun.swap_def)
       qed
     qed
   qed
 qed
qed
lemma solution_upd1:
  "c ≠ 0 ⟹ solution (A(p:=(λj. A p j / c))) n x = solution A n x"
apply(cases "p<n")
 prefer 2
 apply(simp add: solution_def)
apply(clarsimp simp add: solution_def)
apply rule
 apply clarsimp
 apply(case_tac "i=p")
  apply (simp add: sum_divide_distrib[symmetric] eq_divide_eq field_simps)
 apply simp
apply (simp add: sum_divide_distrib[symmetric] eq_divide_eq field_simps)
done
lemma solution_upd_but1: "⟦ ap = A p; ∀i j. i≠p ⟶ a i j = A i j; p<n ⟧ ⟹
 solution (λi. if i=p then ap else (λj. a i j - c i * ap j)) n x =
 solution A n x"
apply(clarsimp simp add: solution_def)
apply rule
 prefer 2
 apply (simp add: field_simps sum_subtractf sum_distrib_left[symmetric])
apply(clarsimp)
apply(case_tac "i=p")
 apply simp
apply (auto simp add: field_simps sum_subtractf sum_distrib_left[symmetric] all_conj_distrib)
done
subsection‹Correctness›
text‹The correctness proof:›
lemma gauss_jordan_lemma: "m≤n ⟹ unit A m n ⟹ gauss_jordan A m = Some B ⟹
  unit B 0 n ∧ solution A n (λj. B j n)"
proof(induct m arbitrary: A B)
  case 0
  { fix a and b c d :: "'a"
    have "(if a then b else c) * d = (if a then b*d else c*d)" by simp
  } with 0 show ?case by(simp add: unit_def solution_def sum.If_cases)
next
  case (Suc m)
  let "?Ap' p" = "(λj. A p j / A p m)"
  let "?A' p" = "(λi. if i=p then ?Ap' p else (λj. A i j - A i m * ?Ap' p j))"
  from ‹gauss_jordan A (Suc m) = Some B›
  obtain p ks where "dropWhile (λi. A i m = 0) [0..<Suc m] = p#ks" and
    rec: "gauss_jordan (Fun.swap p m (?A' p)) m = Some B"
    by (auto split: list.splits)
  from this have p: "p≤m" "A p m ≠ 0"
    apply(simp_all add: dropWhile_eq_Cons_conv del:upt_Suc)
    by (metis set_upt atLeast0AtMost atLeastLessThanSuc_atLeastAtMost atMost_iff in_set_conv_decomp)
  have "m≤n" "m<n" using ‹Suc m ≤ n› by arith+
  have "unit (Fun.swap p m (?A' p)) m n" using Suc.prems(2) p
    unfolding unit_def Fun.swap_def Suc_le_eq by (auto simp: le_less)
  from Suc.hyps[OF ‹m≤n› this rec] ‹m<n› p
  show ?case
    by (simp only: solution_swap) (simp_all add: solution_swap solution_upd_but1 [where A = "A(p := ?Ap' p)"] solution_upd1)
qed
theorem gauss_jordan_correct:
  "gauss_jordan A n = Some B ⟹ solution A n (λj. B j n)"
by(simp add:gauss_jordan_lemma[of n n] unit_def  field_simps)
definition solution2 :: "('a::field)matrix ⇒ nat ⇒ nat ⇒ (nat ⇒ 'a) ⇒ bool"
where "solution2 A m n x = (∀i<m. (∑ j=0..<m. A i j * x j) = A i n)"
definition "usolution A m n x ⟷
  solution2 A m n x ∧ (∀y. solution2 A m n y ⟶ (∀j<m. y j = x j))"
lemma non_null_if_pivot:
  assumes "usolution A m n x" and "q < m" shows "∃p<m. A p q ≠ 0"
proof(rule ccontr)
  assume "¬(∃p<m. A p q ≠ 0)"
  hence 1: "⋀p. p<m ⟹ A p q = 0" by simp
  { fix y assume 2: "∀j. j≠q ⟶ y j = x j"
    { fix i assume "i<m"
      with assms(1) have "A i n = (∑j = 0..<m. A i j * x j)"
        by (auto simp: solution2_def usolution_def)
      with 1[OF ‹i<m›] 2
      have "(∑j = 0..<m. A i j * y j) = A i n"
        by (auto intro!: sum.cong)
    }
    hence "solution2 A m n y" by(simp add: solution2_def)
  }
  hence "solution2 A m n (x(q:=0))" and "solution2 A m n (x(q:=1))" by auto
  with assms(1) zero_neq_one ‹q < m›
  show False
    by (simp add: usolution_def)
       (metis fun_upd_same zero_neq_one)
qed
lemma lem1:
  fixes f :: "'a ⇒ 'b::field"
  shows "(∑x∈A. f x * (a * g x)) = a * (∑x∈A. f x * g x)"
  by (simp add: sum_distrib_left field_simps)
lemma lem2:
  fixes f :: "'a ⇒ 'b::field"
  shows "(∑x∈A. f x * (g x * a)) = a * (∑x∈A. f x * g x)"
  by (simp add: sum_distrib_left field_simps)
subsection‹Complete›
lemma gauss_jordan_complete:
  "m ≤ n ⟹ usolution A m n x ⟹ ∃B. gauss_jordan A m = Some B"
proof(induction m arbitrary: A)
  case 0 show ?case by simp
next
  case (Suc m A)
  from ‹Suc m ≤ n› have "m≤n" and "m<Suc m" by arith+
  from non_null_if_pivot[OF Suc.prems(2) ‹m<Suc m›]
  obtain p' where "p'<Suc m" and "A p' m ≠ 0" by blast
  hence "dropWhile (λi. A i m = 0) [0..<Suc m] ≠ []"
    by (simp add: atLeast0LessThan) (metis lessThan_iff linorder_neqE_nat not_less_eq)
  then obtain p xs where 1: "dropWhile (λi. A i m = 0) [0..<Suc m] = p#xs"
    by (metis list.exhaust)
  from this have "p≤m" "A p m ≠ 0"
    by (simp_all add: dropWhile_eq_Cons_conv del: upt_Suc)
       (metis set_upt atLeast0AtMost atLeastLessThanSuc_atLeastAtMost atMost_iff in_set_conv_decomp)
  then have p: "p < Suc m" "A p m ≠ 0"
    by auto
  let ?Ap' = "(λj. A p j / A p m)"
  let ?A' = "(λi. if i=p then ?Ap' else (λj. A i j - A i m * ?Ap' j))"
  let ?A = "Fun.swap p m ?A'"
  have A: "solution2 A (Suc m) n x" using Suc.prems(2) by(simp add: usolution_def)
  { fix i assume le_m: "p < Suc m" "i < Suc m" "A p m ≠ 0"
    have "(∑j = 0..<m. (A i j - A i m * A p j / A p m) * x j) =
      ((∑j = 0..<Suc m. A i j * x j) - A i m * x m) -
      ((∑j = 0..<Suc m. A p j * x j) - A p m * x m) * A i m / A p m"
      by (simp add: field_simps sum_subtractf sum_divide_distrib
                    sum_distrib_left)
    also have "… = A i n - A p n * A i m / A p m"
      using A le_m
      by (simp add: solution2_def field_simps del: sum.op_ivl_Suc)
    finally have "(∑j = 0..<m. (A i j - A i m * A p j / A p m) * x j) =
      A i n - A p n * A i m / A p m" . }
  then have "solution2 ?A m n x" using p
    by (auto simp add: solution2_def Fun.swap_def field_simps)
  moreover
  { fix y assume a: "solution2 ?A m n y"
    let ?y = "y(m := A p n / A p m - (∑j = 0..<m. A p j * y j) / A p m)"
    have "solution2 A (Suc m) n ?y" unfolding solution2_def
    proof safe
      fix i assume "i < Suc m"
      show "(∑j=0..<Suc m. A i j * ?y j) = A i n"
      proof (cases "i = p")
        assume "i = p" with p show ?thesis by (simp add: field_simps)
      next
        assume "i ≠ p"
        show ?thesis
        proof (cases "i = m")
          assume "i = m"
          with p ‹i ≠ p› have "p < m" by simp
          with a[unfolded solution2_def, THEN spec, of p] p(2)
          have "A p m * (A m m * A p n + A p m * (∑j = 0..<m. y j * A m j)) = A p m * (A m n * A p m + A m m * (∑j = 0..<m. y j * A p j))"
            by (simp add: Fun.swap_def field_simps sum_subtractf lem1 lem2 sum_divide_distrib[symmetric]
                     split: if_splits)
          with ‹A p m ≠ 0› show ?thesis unfolding ‹i = m›
            by simp (simp add: field_simps)
        next
          assume "i ≠ m"
          then have "i < m" using ‹i < Suc m› by simp
          with a[unfolded solution2_def, THEN spec, of i] p(2)
          have "A p m * (A i m * A p n + A p m * (∑j = 0..<m. y j * A i j)) = A p m * (A i n * A p m + A i m * (∑j = 0..<m. y j * A p j))"
            by (simp add: Fun.swap_def split: if_splits)
              (simp add: field_simps sum_subtractf lem1 lem2 sum_divide_distrib [symmetric])
          with ‹A p m ≠ 0› show ?thesis
            by simp (simp add: field_simps)
        qed
      qed
    qed
    with ‹usolution A (Suc m) n x›
    have "∀j<Suc m. ?y j = x j" by (simp add: usolution_def)
    hence "∀j<m. y j = x j"
      by simp (metis less_SucI nat_neq_iff)
  } ultimately have "usolution ?A m n x" 
    by (simp add: usolution_def)
  note * = Suc.IH [OF ‹m ≤ n› this]
  from 1 show ?case
    by auto (use * in blast)
qed
text‹Future work: extend the proof to matrix inversion.›
hide_const (open) unit
end