Theory DL_Deep_Model_Poly

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Polynomials representing the Deep Network Model›

theory DL_Deep_Model_Poly
imports DL_Deep_Model Polynomials.More_MPoly_Type Jordan_Normal_Form.Determinant
begin

lemma polyfun_det:
assumes "x. (A x)  carrier_mat n n"
assumes "x i j. i<n  j<n  polyfun N (λx. (A x) $$ (i,j))"
shows "polyfun N (λx. det (A x))"
proof -
  {
    fix p assume "p {p. p permutes {0..<n}}"
    then have "p permutes {0..<n}" by auto
    then have "x. x < n  p x < n" using permutes_in_image by auto
    then have "polyfun N (λx. i = 0..<n. A x $$ (i, p i))"
      using polyfun_Prod[of "{0..<n}" N "λi x. A x $$ (i, p i)"] assms by simp
    then have "polyfun N (λx. signof p * (i = 0..<n. A x $$ (i, p i)))" using polyfun_const polyfun_mult by blast
  }
  moreover have "finite {i. i permutes {0..<n}}" by (simp add: finite_permutations)
  ultimately show ?thesis  unfolding det_def'[OF assms(1)]
    using polyfun_Sum[OF finite {i. i permutes {0..<n}}, of N "λp x. signof p * (i = 0..<n. A x $$ (i, p i))"]
    by blast
qed

lemma polyfun_extract_matrix:
assumes "i<m" "j<n"
shows "polyfun {..<a + (m * n + c)} (λf. extract_matrix (λi. f (i + a)) m n $$ (i,j))"
unfolding index_extract_matrix[OF assms] apply (rule polyfun_single) using two_digit_le[OF assms] by simp

lemma polyfun_mult_mat_vec:
assumes "x. v x  carrier_vec n"
assumes "j. j<n  polyfun N (λx. v x $ j)"
assumes "x. A x  carrier_mat m n"
assumes "i j. i<m  j<n  polyfun N (λx. A x $$ (i,j))"
assumes "j < m"
shows "polyfun N (λx. ((A x) *v (v x)) $ j)"
proof -
  have "x. j < dim_row (A x)" using j < m assms(3) carrier_matD(1) by force
  have "x. n = dim_vec (v x)" using assms(1) carrier_vecD by fastforce
  {
    fix i assume "i  {0..<n}"
    then have "i < n" by auto
    {
      fix x
      have "i < dim_vec (v x)" using assms(1) carrier_vecD i<n by fastforce
      have "j < dim_row (A x)" using j < m assms(3) carrier_matD(1) by force
      have "dim_col (A x) = dim_vec (v x)" by (metis assms(1) assms(3) carrier_matD(2) carrier_vecD)
      then have "row (A x) j $ i = A x $$ (j,i)" "i<n" using j < dim_row (A x) i<n by (simp_all add: i < dim_vec (v x))
    }
    then have "polyfun N (λx. row (A x) j $ i * v x $ i)"
      using polyfun_mult assms(4)[OF j < m] assms(2) by fastforce
  }
  then show ?thesis unfolding index_mult_mat_vec[OF x. j < dim_row (A x)] scalar_prod_def
    using polyfun_Sum[of "{0..<n}" N "λi x. row (A x) j $ i * v x $ i"] finite_atLeastLessThan[of 0 n] x. n = dim_vec (v x)
    by simp
qed

(* The variable a has been inserted here to make the induction work:*)
lemma polyfun_evaluate_net_plus_a:
assumes "map dim_vec inputs = input_sizes m"
assumes "valid_net m"
assumes "j < output_size m"
shows "polyfun {..<a + count_weights s m} (λf. evaluate_net (insert_weights s m (λi. f (i + a))) inputs $ j)"
using assms proof (induction m arbitrary:inputs j a)
  case (Input)
  then show ?case unfolding insert_weights.simps evaluate_net.simps using polyfun_const by metis
next
  case (Conv x m)
  then obtain x1 x2 where "x=(x1,x2)" by fastforce
  show ?case unfolding x=(x1,x2) insert_weights.simps evaluate_net.simps drop_map unfolding list_of_vec_index
  proof (rule polyfun_mult_mat_vec)
    {
      fix f
      have 1:"valid_net' (insert_weights s m (λi. f (i + x1 * x2)))"
        using valid_net (Conv x m) valid_net.simps by (metis
        convnet.distinct(1) convnet.distinct(5) convnet.inject(2) remove_insert_weights)
      have 2:"map dim_vec inputs = input_sizes (insert_weights s m (λi. f (i + x1 * x2)))"
        using input_sizes_remove_weights remove_insert_weights
        by (simp add: Conv.prems(1))
      have "dim_vec (evaluate_net (insert_weights s m (λi. f (i + x1 * x2))) inputs) = output_size m"
       using output_size_correct[OF 1 2] using remove_insert_weights by auto
      then show "evaluate_net (insert_weights s m (λi. f (i + x1 * x2))) inputs  carrier_vec (output_size m)"
        using carrier_vec_def by (metis (full_types) mem_Collect_eq)
    }

    have "map dim_vec inputs = input_sizes m" by (simp add: Conv.prems(1))
    have "valid_net m" using Conv.prems(2) valid_net.cases by fastforce
    show "j. j < output_size m   polyfun {..<a + count_weights s (Conv (x1, x2) m)}
          (λf. evaluate_net (insert_weights s m (λi. f (i + x1 * x2 + a))) inputs $ j)"
      unfolding vec_of_list_index count_weights.simps
      using Conv(1)[OF map dim_vec inputs = input_sizes m valid_net m, of _ "x1 * x2 + a"]
      unfolding semigroup_add_class.add.assoc ab_semigroup_add_class.add.commute[of "x1 * x2" a]
      by blast

    have "output_size m = x2" using Conv.prems(2) x = (x1, x2) valid_net.cases by fastforce
    show "f. extract_matrix (λi. f (i + a)) x1 x2  carrier_mat x1 (output_size m)" unfolding output_size m = x2 using dim_extract_matrix
      using carrier_matI by (metis (no_types, lifting))

    show "i j. i < x1  j < output_size m  polyfun {..<a + count_weights s (Conv (x1, x2) m)} (λf. extract_matrix (λi. f (i + a)) x1 x2 $$ (i, j))"
      unfolding output_size m = x2 count_weights.simps using polyfun_extract_matrix[of _ x1 _ x2 a "count_weights s m"] by blast

    show "j < x1" using Conv.prems(3) x = (x1, x2) by auto
  qed
next
  case (Pool m1 m2 inputs j a)
  have A2:"f. map dim_vec (take (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs) = input_sizes m1"
    by (metis Pool.prems(1)  append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights take_map)
  have B2:"f. map dim_vec (drop (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs) = input_sizes m2"
    using Pool.prems(1) append_eq_conv_conj input_sizes.simps(3) input_sizes_remove_weights remove_insert_weights by (metis drop_map)
  have A3:"valid_net m1" and B3:"valid_net m2" using valid_net (Pool m1 m2) valid_net.simps by blast+
  have "output_size (Pool m1 m2) = output_size m2" unfolding output_size.simps
    using valid_net (Pool m1 m2) "valid_net.cases" by fastforce
  then have A4:"j < output_size m1" and B4:"j < output_size m2" using j < output_size (Pool m1 m2) by simp_all

  let ?net1 = "λf. evaluate_net (insert_weights s m1 (λi. f (i + a)))
    (take (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs)"
  let ?net2 = "λf. evaluate_net (insert_weights s m2 (if s then λi. f (i + a) else (λi. f (i + count_weights s m1 + a))))
    (drop (length (input_sizes (insert_weights s m1 (λi. f (i + a))))) inputs)"
  have length1: "f. output_size m1 = dim_vec (?net1 f)"
    by (metis A2 A3 input_sizes_remove_weights output_size_correct remove_insert_weights)
  then have jlength1:"f. j < dim_vec (?net1 f)" using A4 by metis
  have length2: "f. output_size m2 = dim_vec (?net2 f)"
    by (metis B2 B3 input_sizes_remove_weights output_size_correct remove_insert_weights)
  then have jlength2:"f. j < dim_vec (?net2 f)" using B4 by metis
  have cong1:"xf. (λf. evaluate_net (insert_weights s m1 (λi. f (i + a)))
        (take (length (input_sizes (insert_weights s m1 (λi. xf (i + a))))) inputs) $ j)
         = (λf. ?net1 f $ j)"
    using input_sizes_remove_weights remove_insert_weights by auto
  have cong2:"xf. (λf. evaluate_net (insert_weights s m2 (λi. f (i + (a + (if s then 0 else count_weights s m1)))))
        (drop (length (input_sizes (insert_weights s m1 (λi. xf (i + a))))) inputs) $ j)
         = (λf. ?net2 f $ j)"
    unfolding semigroup_add_class.add.assoc[symmetric] ab_semigroup_add_class.add.commute[of a "if s then 0 else count_weights s m1"]
    using input_sizes_remove_weights remove_insert_weights by auto

  show ?case unfolding insert_weights.simps evaluate_net.simps  count_weights.simps
    unfolding  index_component_mult[OF jlength1 jlength2]
    apply (rule polyfun_mult)
     using Pool.IH(1)[OF A2 A3 A4, of a, unfolded cong1]
     apply (simp add:polyfun_subset[of "{..<a + count_weights s m1}" "{..<a + (if s then max (count_weights s m1) (count_weights s m2) else count_weights s m1 + count_weights s m2)}"]) 
    using Pool.IH(2)[OF B2 B3 B4, of "a + (if s then 0 else count_weights s m1)", unfolded cong2 semigroup_add_class.add.assoc[of a]]
    by (simp add:polyfun_subset[of "{..<a + ((if s then 0 else count_weights s m1) + count_weights s m2)}" "{..<a + (if s then max (count_weights s m1) (count_weights s m2) else count_weights s m1 + count_weights s m2)}"])
qed

lemma polyfun_evaluate_net:
assumes "map dim_vec inputs = input_sizes m"
assumes "valid_net m"
assumes "j < output_size m"
shows "polyfun {..<count_weights s m} (λf. evaluate_net (insert_weights s m f) inputs $ j)"
using polyfun_evaluate_net_plus_a[where a=0, OF assms] by simp

lemma polyfun_tensors_from_net:
assumes "valid_net m"
assumes "is  input_sizes m"
assumes "j < output_size m"
shows "polyfun {..<count_weights s m} (λf. Tensor.lookup (tensors_from_net (insert_weights s m f) $ j) is)"
proof -
  have 1:"f. valid_net' (insert_weights s m f)" by (simp add: assms(1) remove_insert_weights)
  have input_sizes:"f. input_sizes (insert_weights s m f) = input_sizes m"
    unfolding input_sizes_remove_weights by (simp add: remove_insert_weights)
  have 2:"f. is  input_sizes (insert_weights s m f)"
    unfolding input_sizes using assms(2) by blast
  have 3:"f. j < output_size' (insert_weights s m f)"
    by (simp add: assms(3) remove_insert_weights)
  have "f1 f2. base_input (insert_weights s m f1) is = base_input (insert_weights s m f2) is"
    unfolding base_input_def by (simp add: input_sizes)
  then have "xf. (λf. evaluate_net (insert_weights s m f) (base_input (insert_weights s m xf) is) $ j)
    = (λf. evaluate_net (insert_weights s m f) (base_input (insert_weights s m f) is) $ j)"
    by metis
  then show ?thesis unfolding lookup_tensors_from_net[OF 1 2 3]
    using polyfun_evaluate_net[OF base_input_length[OF 2, unfolded input_sizes, symmetric] assms(1) assms(3), of s]
    by simp
qed

lemma polyfun_matricize:
assumes "x. dims (T x) = ds"
assumes "is. is  ds  polyfun N (λx. Tensor.lookup (T x) is)"
assumes "x. dim_row (matricize I (T x)) = nr"
assumes "x. dim_col (matricize I (T x)) = nc"
assumes "i < nr"
assumes "j < nc"
shows "polyfun N (λx. matricize I (T x) $$ (i,j))"
proof -
  let ?weave = "λ x. (weave I
    (digit_encode (nths ds I ) i)
    (digit_encode (nths ds (-I )) j))"
  have 1:"x. matricize I (T x) $$ (i,j) = Tensor.lookup (T x) (?weave x)" unfolding matricize_def
    by (metis (no_types, lifting) assms(1) assms(3) assms(4) assms(5) assms(6) case_prod_conv
    dim_col_mat(1) dim_row_mat(1) index_mat(1) matricize_def)
  have "x. ?weave x  ds"
    using valid_index_weave(1) assms(2) digit_encode_valid_index dim_row_mat(1) matricize_def
    using assms digit_encode_valid_index matricize_def by (metis dim_col_mat(1))
  then have "polyfun N (λx. Tensor.lookup (T x) (?weave x))" using assms(2) by simp
  then show ?thesis unfolding 1 using assms(1) by blast
qed

lemma "(¬ (a::nat) < b) = (a  b)"
by (metis not_le)

lemma polyfun_submatrix:
assumes "x. (A x)  carrier_mat m n"
assumes "x i j. i<m  j<n  polyfun N (λx. (A x) $$ (i,j))"
assumes "i < card {i. i < m  i  I}"
assumes "j < card {j. j < n  j  J}"
assumes "infinite I" "infinite J"
shows "polyfun N (λx. (submatrix (A x) I J) $$ (i,j))"
proof -
  have 1:"x. (submatrix (A x) I J) $$ (i,j) = (A x) $$ (pick I i, pick J j)"
    using submatrix_index by (metis (no_types, lifting) Collect_cong assms(1) assms(3) assms(4) carrier_matD(1) carrier_matD(2))
  have "pick I i < m"  "pick J j < n" using card_le_pick_inf[OF infinite I] card_le_pick_inf[OF infinite J]
    i < card {i. i < m  i  I}[unfolded set_le_in] j < card {j. j < n  j  J}[unfolded set_le_in] not_less by metis+
  then show ?thesis unfolding 1 by (simp add: assms(2))
qed

context deep_model_correct_params_y
begin

definition witness_submatrix where
"witness_submatrix f = submatrix (A' f) rows_with_1 rows_with_1"


lemma polyfun_tensor_deep_model:
assumes "is  input_sizes (deep_model_l rs)"
shows "polyfun {..<weight_space_dim}
  (λf. Tensor.lookup (tensors_from_net (insert_weights shared_weights (deep_model_l rs) f) $ y) is)"
proof -
  have 1:"f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
    using remove_insert_weights by metis
  then have "y < output_size ( deep_model_l rs)" using valid_deep_model y_valid length_output_deep_model by force
  have 0:"{..<weight_space_dim} = set [0..<weight_space_dim]" by auto
  then show ?thesis unfolding weight_space_dim_def using polyfun_tensors_from_net assms(1) valid_deep_model
    y < output_size ( deep_model_l rs ) by metis
qed

lemma input_sizes_deep_model: "input_sizes (deep_model_l rs) = replicate (2 * N_half) (last rs)"
  unfolding N_half_def using input_sizes_deep_model deep
  by (metis (no_types, lifting) Nitpick.size_list_simp(2) One_nat_def Suc_1 Suc_le_lessD diff_Suc_Suc length_tl less_imp_le_nat list.size(3) not_less_eq numeral_3_eq_3 power_eq_if)

lemma polyfun_matrix_deep_model:
assumes "i<(last rs) ^ N_half"
assumes "j<(last rs) ^ N_half"
shows "polyfun {..<weight_space_dim} (λf. A' f $$ (i,j))"
proof -
  have 0:"y < output_size ( deep_model_l rs )" using valid_deep_model y_valid length_output_deep_model by force
  have 1:"f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
    using remove_insert_weights by metis
  have 2:"(f is. is  replicate (2 * N_half) (last rs) 
         polyfun {..<weight_space_dim} (λx. Tensor.lookup (A x) is))"
    unfolding A_def using polyfun_tensor_deep_model[unfolded input_sizes_deep_model] 0 by blast
  show ?thesis
    unfolding A'_def A_def apply (rule polyfun_matricize)
    using dims_tensor_deep_model[OF 1] 2[unfolded A_def]
    using dims_A'_pow[unfolded A'_def A_def] i<(last rs) ^ N_half j<(last rs) ^ N_half
    by auto
qed

lemma polyfun_submatrix_deep_model:
assumes "i < r ^ N_half"
assumes "j < r ^ N_half"
shows "polyfun {..<weight_space_dim} (λf. witness_submatrix f $$ (i,j))"
unfolding witness_submatrix_def
proof (rule polyfun_submatrix)
  have 1:"f. remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
    using remove_insert_weights by metis
  show "f. A' f  carrier_mat ((last rs) ^ N_half) ((last rs) ^ N_half)"
    using "1" dims_A'_pow using weight_space_dim_def by auto
  show "f i j. i < last rs ^ N_half  j < last rs ^ N_half 
        polyfun {..<weight_space_dim} (λf. A' f $$ (i, j))"
    using polyfun_matrix_deep_model weight_space_dim_def by force
  show "i < card {i. i < last rs ^ N_half  i  rows_with_1}"
    using assms(1) card_rows_with_1 dims_Aw'_pow set_le_in by metis
  show "j < card {i. i < last rs ^ N_half  i  rows_with_1}"
    using assms(2) card_rows_with_1 dims_Aw'_pow set_le_in by metis
  show "infinite rows_with_1" "infinite rows_with_1" by (simp_all add: infinite_rows_with_1)
qed

lemma polyfun_det_deep_model:
shows "polyfun {..<weight_space_dim} (λf. det (witness_submatrix f))"
proof (rule polyfun_det)
  fix f
  have "remove_weights (insert_weights shared_weights (deep_model_l rs) f) = deep_model_l rs"
    using remove_insert_weights by metis

  show "witness_submatrix f  carrier_mat (r ^ N_half) (r ^ N_half)"
    unfolding witness_submatrix_def apply (rule carrier_matI) unfolding dim_submatrix[unfolded set_le_in]
    unfolding dims_A'_pow[unfolded weight_space_dim_def] using card_rows_with_1 dims_Aw'_pow by simp_all
  show "i j. i < r ^ N_half  j < r ^ N_half  polyfun {..<weight_space_dim} (λf. witness_submatrix f $$ (i, j))"
    using polyfun_submatrix_deep_model by blast
qed

end

end