Theory Matrix_Utils

(***********************************************************************************
 * Copyright (c) University of Exeter, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 ***********************************************************************************)

section ‹Proofs and Definitions that Enrich the Matrix Formalization›

theory
  Matrix_Utils
  imports
    Jordan_Normal_Form.Matrix
    "HOL-Combinatorics.Permutations"
begin

text‹
  This theory provides additional definition and lemmas that are useful when working with vectors 
  and matrices as provided @{theory "Jordan_Normal_Form.Matrix"}. Furthermore, this theory contains
  additional theorems over lists, in particular of properties of @{const "map2"} (and,hence, 
  @{const "zip"}.
›

subsection‹List Properties›



lemma map2_to_map_idx_eq: 
  length xs = length ys  (map2 (*) xs (ys)) = map (λ i. xs!i * ys!i) [0..< length xs]
  using map2_map_map map_nth
  by metis 

lemma map2_to_map_idx: 
  (map2 (*) xs (ys)) = map (λ i. xs!i * ys!i) [0..< min (length xs) (length ys)]
  by (rule nth_equalityI, auto)

lemma map2_mult_commute: 
  map2 (*) (xs::'a::comm_ring list) ys = map2 (*) ys xs
  by (induction xs ys rule:list_induct2', simp_all add: mult.commute)

subsection‹Vector and Matrix Properties›

definition mult_vec_mat :: " 'a Matrix.vec  'a :: semiring_0 Matrix.mat  'a Matrix.vec" (infixl "v*" 70)
  where "v v* A  vec (dim_col A) (λ i. col A i  v)"

lemma dim_mult_vec_mat: dim_vec (v v* A) = dim_col A
  by (auto simp add: mult_vec_mat_def)

lemma index_mult_vec_mat: i < dim_col A  (v v* A) $ i = col A i  v
  by (auto simp add: mult_vec_mat_def)

lemma dim_col_mat_list:  m  set (mat_to_list M). dim_col M = length m
  unfolding mat_to_list_def dim_col_def o_def
  by simp

lemma dim_col_mat_list': mat_to_list M  []  dim_col M = length (hd (mat_to_list M))
  using dim_col_mat_list by fastforce 

lemma scalar_prod_list: 
  ((vec_of_list v)  (vec_of_list w)) = ( i  {0 ..< length w}. v!i *  w!i)
  unfolding scalar_prod_def vec_of_list_def 
  by (simp, metis (no_types, lifting) dim_vec sum.cong vec.abs_eq vec_of_list.abs_eq vec_of_list_index) 

lemma dim_col_mat_of_col_list: dim_col (mat_of_cols_list n As) = length As
  unfolding mat_of_cols_list_def by simp

lemma dim_row_mat_of_col_list: dim_row (mat_of_cols_list n As) = n
  unfolding mat_of_cols_list_def by simp

lemma dim_col_mat_of_row_list: dim_col (mat_of_rows_list n As) = n
  unfolding mat_of_rows_list_def by simp

lemma dim_row_mat_of_row_list: dim_row (mat_of_rows_list n As) = length As
  unfolding mat_of_rows_list_def by simp

lemma vec_of_list_ext: vec_of_list xs = vec_of_list ys  xs = ys
  by (metis list_vec) 

lemma list_of_vec_ext: list_of_vec xs = list_of_vec ys  xs = ys
  by (metis vec_list) 

lemma map_if_lam:
  map (λ i. if i < n then P(i) else Q(i)) [0..<n] = map (λ i. P(i)) [0..<n]
  by simp 

lemma map_if_lam':
  map (λ i. if p  i < n then (P i) else (Q i)) [0..<n] = map (λ i. if p then (P i) else (Q i)) [0..<n]
  by simp 

lemma map_if_lam'':
  map (λi. map (λia. if i < n then P i ia else Q i ia) [0..<m]) [0..<n] 
  =map (λi. map (λia. P i ia) [0..<m]) [0..<n]
  by simp

lemma vec_add_list: 
  assumes length v = length w 
  shows list_of_vec ((vec_of_list v) + (vec_of_list w)) = map2 (+) v w
  unfolding plus_vec_def
  apply(simp add:vec_of_list_index)      
  using assms map2_map_map [of "(+)"]  map_nth
  by metis 

lemma vec_add_list': 
  assumes length v = length w 
  shows ((vec_of_list v) + (vec_of_list w)) = vec_of_list (map2 (+) v w)
  apply(rule list_of_vec_ext, simp only: list_vec)
  using vec_add_list assms by blast 

lemma mat_col_list: 
  assumes i < length As 
    and  a  set As.  a'  set As. length a = length a'  a  []
    and d = length (hd As) 
  shows list_of_vec ( col (mat_of_cols_list d As) i ) =  As!i 
  apply (intro nth_equalityI)
  subgoal using assms by simp 
      (metis dim_row_mat_of_col_list list.set_sel(1) list.size(3) not_less_zero nth_mem)
  subgoal for j using assms unfolding list_of_vec_index length_list_of_vec
    by (auto simp: mat_of_cols_list_def)
  done

lemma mult_vec_mat_col_list: 
  assumes length vs=n
    and  a  set As.  a'  set As. length a = length a'  a  []
    and length (hd As)=d 
    and length As=n 
    and As  [] 
  shows    list_of_vec ((vec_of_list vs) v* (mat_of_cols_list d As)) = map  (λi. ia = 0..<length vs. As ! i ! ia * vs ! ia) [0..<n]
  apply (intro nth_equalityI)
  subgoal using assms by (simp add: mult_vec_mat_def mat_of_cols_list_def)
  subgoal for i using assms unfolding list_of_vec_index mult_vec_mat_def mat_of_cols_list_def
    by (auto simp: scalar_prod_def vec_of_list_index intro!: sum.cong arg_cong2[of _ _ _ _ "(*)"])
      (metis list.set_sel(1) list_of_vec_index list_of_vec_vec map_nth nth_mem)
  done

lemma mult_vec_mat_row_list: 
  assumes length vs=d
    and  a  set As.  a'  set As. length a = length a'  a  []
    and length (hd As)=d 
    and length As=n 
    and As  []                                
  shows list_of_vec ((vec_of_list vs) v* (mat_of_rows_list d As)) = map (λi. ia = 0..<length vs. map (λia. As ! ia ! i) [0..<length As] ! ia * vs ! ia) [0..<d]
  apply (intro nth_equalityI)
  subgoal using assms by (simp add: mult_vec_mat_def mat_of_rows_list_def)
  subgoal for i using assms unfolding list_of_vec_index mult_vec_mat_def mat_of_rows_list_def
    by (auto simp: scalar_prod_def vec_of_list_index intro!: sum.cong arg_cong2[of _ _ _ _ "(*)"])
     (metis list_of_vec_index list_of_vec_vec)
  done
  
lemma mult_vec_mat_row_list': 
  assumes length vs=d
    and  a  set As.  a'  set As. length a = length a'  a  []
    and length (hd As)=d 
    and length As=n 
    and As  []                                
  shows    ((vec_of_list vs) v* (mat_of_rows_list d As)) = vec_of_list (map (λi. ia = 0..<length vs. map (λia. As ! ia ! i) [0..<length As] ! ia * vs ! ia) [0..<d])
  apply(rule list_of_vec_ext, simp only:list_vec)
  using assms mult_vec_mat_row_list by blast

lemma col_of_rows_list:
  assumes d = Min (set (map length As))
    and i < d
  shows   list_of_vec (col (mat_of_rows_list d As) i) = map (λ as. (as!i)) As
  apply (intro nth_equalityI)
  subgoal by (simp add: mat_of_rows_list_def)
  subgoal for j using assms unfolding list_of_vec_index
    by (auto simp: mat_of_rows_list_def)
  done

lemma col_of_rows_list':
  assumes  as  set As. length as = d
    and As  []
  shows (col (mat_of_rows_list d As) i) = vec_of_list (map (λ as. (as!i)) As)
proof (cases "i < d")
  case True
  then show ?thesis 
    apply (subst list_of_vec_ext) 
    by (auto simp add: vec_list vec_of_list_ext assms col_of_rows_list)
next
  case False
  then show ?thesis
  proof -
    have "list_of_vec (col (mat_of_rows_list d As) i) = 
          list_of_vec (vec_of_list (map (λas. as ! i) As))"
    proof (rule nth_equalityI)
      show "length (list_of_vec (col (mat_of_rows_list d As) i)) = 
            length (list_of_vec (vec_of_list (map (λas. as ! i) As)))"
        by (simp add: mat_of_rows_list_def)
    next
      fix j
      assume j_bound: "j < length (list_of_vec (col (mat_of_rows_list d As) i))"
      then have j_lt: "j < length As"
        by (simp add: mat_of_rows_list_def)
       have lhs: "list_of_vec (col (mat_of_rows_list d As) i) ! j = 
                 mat (length As) d (λ(r, c). As ! r ! c) $$ (j, i)"
        unfolding list_of_vec_index mat_of_rows_list_def col_def
        using j_lt by simp
      have rhs: "list_of_vec (vec_of_list (map (λas. as ! i) As)) ! j = 
                 (map (λas. as ! i) As) ! j"
        unfolding list_of_vec_index
        using j_lt  vec_of_list_index by blast 
      have "(map (λas. as ! i) As) ! j = As ! j ! i"
        using j_lt by simp
      moreover have "mat (length As) d (λ(r, c). As ! r ! c) $$ (j, i) = As ! j ! i"
        using j_lt False 
        apply (simp add: mat_def assms index_mat_def)
        apply (subst Abs_mat_inverse)
        apply blast
        apply(simp add: mk_mat_def undef_mat_def)
      using  assms(1) map_nth nth_mem  
      by metis
      ultimately show "list_of_vec (col (mat_of_rows_list d As) i) ! j = 
                       list_of_vec (vec_of_list (map (λas. as ! i) As)) ! j"
        using lhs rhs by simp
    qed
    then show ?thesis
      by (rule list_of_vec_ext)
  qed
qed

lemma list_mat: mat_of_rows_list (dim_col A) (mat_to_list A) = A
  unfolding mat_to_list_def mat_of_rows_list_def
  by(auto)

lemma list_mat_transpose_transpose: (mat_of_rows_list (dim_col xT) (mat_to_list xT))T = x
  using transpose_transpose[of "x", symmetric] list_mat by metis 

lemma mat_list:
   r  set(rs). length r = dimc mat_to_list (mat_of_rows_list dimc rs) = rs
  unfolding mat_of_rows_list_def mat_to_list_def
  by (intro nth_equalityI, auto)

lemma dim_row_list:  dim_row m = length (mat_to_list m)
  by (metis dim_row_mat_of_row_list list_mat) 

lemma dim_col_list:  c  set (mat_to_list m). length c = dim_col m
  by (simp add: mat_to_list_def) 

lemma scalar_prod_sum_list_lv_eq: 
  assumes same_dim: dim_vec (x::'a::comm_ring Matrix.vec) = dim_vec y 
  shows x  y  sum_list (map2 (*) (list_of_vec x) (list_of_vec y))
proof(unfold scalar_prod_def, insert assms,induction "dim_vec x" )
  case 0
  then show ?case by simp
next
  case (Suc xa) note * = this 
  then show ?case 
    apply(simp add:list_of_vec_map sum_def)
    by (simp add: comm_monoid_add_class.sum_def interv_sum_list_conv_sum_set_nat map2_map_map) 
qed

lemma scalar_prod_sum_list_vl_eq:
  assumes same_dim: length (x::'a::comm_ring list) = length y 
  shows (vec_of_list x)  (vec_of_list y)  sum_list (map2 (*) x y)
proof(unfold scalar_prod_def, insert assms,induction "length x" )
  case 0
  then show ?case by simp
next
  case (Suc xa) note * = this 
  then show ?case 
    apply(simp add:list_of_vec_map sum_def)
    using comm_monoid_add_class.sum_def interv_sum_list_conv_sum_set_nat 
    by (metis (mono_tags, lifting) atLeastLessThan_upt map2_to_map_idx_eq map_eq_conv vec_of_list_index)
qed


end