Theory Gauss_Jordan_IArray_Impl
section ‹Code Generation for Basic Matrix Operations›
text ‹In this theory we provide efficient implementations
for the elementary row-transformations. These are necessary since the default
implementations would construct a whole new matrix in every step.›
theory Gauss_Jordan_IArray_Impl
imports
Polynomial_Interpolation.Missing_Unsorted
Matrix_IArray_Impl
Gauss_Jordan_Elimination
begin
lift_definition mat_swaprows_impl :: "nat ⇒ nat ⇒ 'a mat_impl ⇒ 'a mat_impl" is
"λ i j (nr,nc,A). if i < nr ∧ j < nr then
let Ai = IArray.sub A i;
Aj = IArray.sub A j;
Arows = IArray.list_of A;
A' = IArray.IArray (Arows [i := Aj, j := Ai])
in (nr,nc,A')
else (nr,nc,A)"
by (auto split: if_splits)
lemma [code]: "mat_swaprows k l (mat_impl A) = (let nr = dim_row_impl A in
if l < nr ∧ k < nr then
mat_impl (mat_swaprows_impl k l A) else Code.abort (STR ''index out of bounds in mat_swaprows'')
(λ _. mat_swaprows k l (mat_impl A)))" (is "?l = ?r")
proof (cases "l < dim_row_impl A ∧ k < dim_row_impl A")
case True
hence id: "?r = mat_impl (mat_swaprows_impl k l A)" by simp
show ?thesis unfolding id unfolding mat_swaprows_def
proof (rule eq_matI, goal_cases)
case (1 i j)
thus ?case using True
proof (transfer, goal_cases)
case (1 i k l A j)
obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
from 1[unfolded A]
have nr: "length (IArray.list_of rows) = nr"
and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
and ij: "i < nr" "j < nc" and ij': "(i < nr ∧ j < nc) = True"
and l: "l < nr" "k < nr" by auto
show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
using ij nr nc l
by (cases "k = i"; cases "l = i", auto)
qed
qed ((transfer, auto)+)
qed auto
lift_definition mat_multrow_gen_impl :: "('a ⇒ 'a ⇒ 'a) ⇒ nat ⇒ 'a ⇒ 'a mat_impl ⇒ 'a mat_impl" is
"λ mul k a (nr,nc,A). let Ak = IArray.sub A k; Arows = IArray.list_of A;
Ak' = IArray.IArray (map (mul a) (IArray.list_of Ak));
A' = IArray.IArray (Arows [k := Ak'])
in (nr,nc,A')"
proof (auto, goal_cases)
case (1 mul k a nc b row)
show ?case
proof (cases b)
case (IArray rows)
with 1 have "row ∈ set rows ∨ k < length rows ∧ row = IArray (map (mul a) (IArray.list_of (rows ! k)))"
by (cases "k < length rows", auto simp: set_list_update dest: in_set_takeD in_set_dropD)
with 1 IArray show ?thesis by (cases, auto)
qed
qed
lemma [code]: "mat_multrow_gen mul k a (mat_impl A) = mat_impl (mat_multrow_gen_impl mul k a A)"
unfolding mat_multrow_gen_def
proof (rule eq_matI, goal_cases)
case (1 i j)
thus ?case
proof (transfer, goal_cases)
case (1 i mul k a A j)
obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
from 1[unfolded A]
have nr: "length (IArray.list_of rows) = nr"
and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
and ij: "i < nr" "j < nc" and ij': "(i < nr ∧ j < nc) = True" by auto
have len: "j < length (IArray.list_of (IArray.list_of rows ! i))"
using ij nc nr by (cases rows, auto)
show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
using ij nr nc
by (cases "k = i", auto simp: len)
qed
qed ((transfer, auto)+)
lift_definition mat_addrow_gen_impl
:: "('a ⇒ 'a ⇒ 'a) ⇒ ('a ⇒ 'a ⇒ 'a) ⇒ 'a ⇒ nat ⇒ nat ⇒ 'a mat_impl ⇒ 'a mat_impl" is
"λ ad mul a k l (nr,nc,A). if l < nr then let Ak = IArray.sub A k; Al = IArray.sub A l;
Ak' = IArray.of_fun (λ i. ad (mul a (Al !! i)) (Ak !! i)) (min (IArray.length Ak) (IArray.length Al));
A' = IArray.of_fun (λ i. if i = k then Ak' else A !! i) (IArray.length A)
in (nr,nc,A') else (nr,nc,A)"
proof (goal_cases)
case (1 ad mul a k l pp)
obtain nr nc A where pp: "pp = (nr,nc,A)" by (cases pp)
obtain rows where A: "A = IArray rows" by (cases A)
from 1[unfolded pp A, simplified]
have nr: "length rows = nr" and nc: "⋀ r. r∈set rows ⟹ length (IArray.list_of r) = nc" by auto
show ?case
proof (cases "l < nr")
case False
thus ?thesis unfolding pp A prod.simps using nr nc by auto
next
case True
thus ?thesis unfolding pp A prod.simps Let_def using nr nc
by (auto simp: set_list_update dest: in_set_takeD in_set_dropD)
qed
qed
lemma mat_addrow_gen_impl[code]: "mat_addrow_gen ad mul a k l (mat_impl A) = (if l < dim_row_impl A then
mat_impl (mat_addrow_gen_impl ad mul a k l A) else Code.abort (STR ''index out of bounds in mat_addrow'')
(λ _. mat_addrow_gen ad mul a k l (mat_impl A)))" (is "?l = ?r")
proof (cases "l < dim_row_impl A")
case True
hence id: "?r = mat_impl (mat_addrow_gen_impl ad mul a k l A)" by simp
show ?thesis unfolding id unfolding mat_addrow_gen_def
proof (rule eq_matI, goal_cases)
case (1 i j)
thus ?case using True
proof (transfer, goal_cases)
case (1 i ad mul a k l A j)
obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
from 1[unfolded A Let_def]
have nr: "length (IArray.list_of rows) = nr"
and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
and ij: "i < nr" "j < nc" and ij': "(i < nr ∧ j < nc) = True"
and l: "l < nr" by auto
have len: "j < length (IArray.list_of (IArray.list_of rows ! i))"
"j < length (IArray.list_of (IArray.list_of rows ! l))"
using ij nc nr l by (cases rows, auto)+
show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
using ij nr nc l
by (cases "k = i", auto simp: len)
qed next
qed ((transfer, auto simp:Let_def)+)
qed simp
lemma gauss_jordan_main_code[code]:
"gauss_jordan_main A B i j = (let nr = dim_row A; nc = dim_col A in
if i < nr ∧ j < nc then let aij = A $$ (i,j) in if aij = 0 then
(case [ i' . i' <- [Suc i ..< nr], A $$ (i',j) ≠ 0]
of [] ⇒ gauss_jordan_main A B i (Suc j)
| (i' # _) ⇒ gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j)
else if aij = 1 then let v = (λ i. A $$ (i,j)) in
gauss_jordan_main
(eliminate_entries v A i j) (eliminate_entries v B i j) (Suc i) (Suc j)
else let iaij = inverse aij; A' = multrow i iaij A; B' = multrow i iaij B;
v = (λ i. A' $$ (i,j)) in gauss_jordan_main
(eliminate_entries v A' i j) (eliminate_entries v B' i j) (Suc i) (Suc j)
else (A,B))" (is "?l = ?r")
proof -
note simps = gauss_jordan_main.simps[of A B i j] Let_def
let ?nr = "dim_row A"
let ?nc = "dim_col A"
let ?A' = "multrow i (inverse (A $$ (i,j))) A"
let ?B' = "multrow i (inverse (A $$ (i,j))) B"
show ?thesis
proof (cases "i < ?nr ∧ j < ?nc ∧ A $$ (i,j) ≠ 0 ∧ A $$ (i,j) ≠ 1")
case False
thus ?thesis unfolding simps by (auto split: if_splits)
next
case True
from True have id: "?A' $$ (i,j) = 1" by auto
from True have "?l = gauss_jordan_main ?A' ?B' i j" unfolding simps by (simp add: Let_def)
also have "… = ?r" unfolding Let_def gauss_jordan_main.simps[of ?A' ?B' i j] id
using True by simp
finally show ?thesis .
qed
qed
end