Theory Karatsuba_Sum_Lemmas
section "Auxiliary Sum Lemmas"
theory Karatsuba_Sum_Lemmas
  imports Karatsuba_Preliminaries "Expander_Graphs.Extra_Congruence_Method"
begin
lemma sum_list_eq: "(⋀x. x ∈ set xs ⟹ f x = g x) ⟹ sum_list (map f xs) = sum_list (map g xs)"
  by (rule arg_cong[OF list.map_cong0])
lemma sum_list_split_0: "(∑i ← [0..<Suc n]. f i) = f 0 + (∑i ← [1..<Suc n]. f i)"
  using upt_eq_Cons_conv 
proof -
  have "[0..<Suc n] = 0 # [1..<Suc n]" using upt_eq_Cons_conv by auto
  then show ?thesis by simp
qed
lemma sum_list_index_trafo: "(∑i ← xs. f (g i)) = (∑i ← map g xs. f i)"
  by (induction xs) simp_all
lemma sum_list_index_shift: "(∑i ← [a..<b]. f (i + c)) = (∑i ← [a+c..<b+c]. f i)"
proof -
  have "(∑i ← [a..<b]. f (i + c)) = (∑i ← (map (λj. j + c) [a..<b]). f i)"
    by (intro sum_list_index_trafo)
  also have "map (λj. j + c) [a..<b] = [a+c..<b+c]"
    using map_add_const_upt by simp
  finally show ?thesis .
qed
lemma list_sum_index_shift: "n = j - k ⟹ (∑i ← [k+1..<j+1]. f i) = (∑i ← [k..<j]. f (i + 1))"
  using sum_list_index_trafo[where g = "λl. l + 1" and xs = "[k..<j]" and f = f, symmetric]
  using map_Suc_upt by simp
lemma list_sum_index_shift': "(∑i ← [0..<m]. a (i + c)) = (∑i ← [c..<m+c]. a i)"
  by (induction m arbitrary: a c) auto
lemma list_sum_index_concat: "(∑i ← [0..<m]. a i) + (∑i ← [m..<m+c]. a i) = (∑i ← [0..<m+c]. a i)"
proof -
  have "(∑i ← [0..<m+c]. a i) = (∑i ← [0..<m] @ [m..<m+c]. a i)"
    using upt_add_eq_append[of 0 m c] by simp
  then show ?thesis using sum_list_append by simp
qed
lemma sum_list_linear:
  assumes "⋀a b. f (a + b) = f a + f b"
  assumes "f 0 = 0"
  shows "f (∑i ← xs. g i) = (∑i ← xs. f (g i))"
  using assms
  by (induction xs) simp_all
lemma sum_list_int:
  shows "int (∑i ← xs. g i) = (∑i ← xs. int (g i))"
  by (intro sum_list_linear int_ops(5) int_ops(1))
lemma sum_list_split_Suc:
  assumes "n = Suc n'"
  shows "(∑i ← [0..<n]. f i) = (∑i ← [0..<n']. f i) + f n'"
  using assms by simp
lemma sum_list_estimation_leq:
  assumes "⋀i. i ∈ set xs ⟹ f i ≤ B"
  shows "(∑i ← xs. f i) ≤ length xs * B"
  using assms by (induction xs)(simp, fastforce)
lemma sum_list_estimation_le:
  assumes "⋀i. i ∈ set xs ⟹ f i < B"
  assumes "xs ≠ []"
  shows "(∑i ← xs. f i) < length xs * B"
proof -
  from ‹xs ≠ []› have "length xs > 0" by simp
  from ‹xs ≠ []› obtain x where "x ∈ set xs" by fastforce
  then have "B > 0" using assms(1) by fastforce
  then obtain B' where "B = Suc B'" using not0_implies_Suc by blast
  with assms(1) have "⋀i. i ∈ set xs ⟹ f i ≤ B'" by fastforce
  with sum_list_estimation_leq have "(∑i ← xs. f i) ≤ length xs * B'" by blast
  also have "... < length xs * B" using ‹B = Suc B'› ‹length xs > 0› by simp
  finally show ?thesis .
qed
subsection "@{class semiring_1} Sums"
lemma (in semiring_1) of_bool_mult: "of_bool x * a = (if x then a else 0)"
  by simp
lemma (in semiring_1_cancel) of_bool_disj: "of_bool (x ∨ y) = of_bool x + of_bool y - of_bool x * of_bool y"
  by simp
lemma (in semiring_1) of_bool_disj_excl: "¬ (x ∧ y) ⟹ of_bool (x ∨ y) = of_bool x + of_bool y"
  by simp
lemma (in semiring_1) of_bool_var_swap:
  "(∑i ← xs. of_bool (i = j) * f i) = (∑i ← xs. of_bool (i = j) * f j)"
  by (induction xs) simp_all
lemma "(∑i ← xs. of_bool (i = j) * f i) = count_list xs j * f j"
  by (induction xs) simp_all
lemma (in semiring_1) of_bool_distinct:
  "distinct xs ⟹ (∑i ← xs. of_bool (i = j) * f i j) = of_bool (j ∈ set xs) * f j j"
  by (induction xs) auto
lemma (in semiring_1) of_bool_distinct_in:
  "distinct xs ⟹ j ∈ set xs ⟹ (∑i ← xs. of_bool (i = j) * f i j) = f j j"
  using of_bool_distinct[of xs j f] of_bool_mult by simp
lemma (in linordered_semiring_1) of_bool_sum_leq_1:
  assumes "distinct xs"
  assumes "⋀i j. i ∈ set xs ⟹ j ∈ set xs ⟹ P i ⟹ P j ⟹ i = j"
  shows "(∑l ← xs. of_bool (P l)) ≤ 1"
  using assms
proof (induction xs)
  case Nil
  then show ?case by simp
next
  case (Cons a xs)
  consider "P a" | "¬ P a" by blast
  then show ?case
  proof cases
    case 1
    then have r: "(∑l←a # xs. of_bool (P l)) = 1 + (∑l←xs. of_bool (P l))"
      by simp
    have "of_bool (P l) = 0" if "l ∈ set xs" for l
    proof -
      from that have "a ≠ l" using Cons by auto
      then have "¬ P l" using Cons ‹l ∈ set xs› 1 by force
      then show "of_bool (P l) = 0" by simp
    qed
    then have "(∑l←xs. of_bool (P l)) = (∑l←xs. 0)"
      using list.map_cong0[of xs] by metis
    then show ?thesis using r by simp
  next
    case 2
    then have "(∑l←a # xs. of_bool (P l)) = (∑l←xs. of_bool (P l))"
      by simp
    then show ?thesis using Cons by simp
  qed
qed
instantiation nat :: linordered_semiring_1
begin
  instance ..
end
lemma (in semiring_1) sum_list_mult_sum_list: "(∑i ← xs. f i) * (∑j ← ys. g j) = (∑i ← xs. ∑j ← ys. f i * g j)"
  by (simp add: sum_list_const_mult sum_list_mult_const)
lemma (in semiring_1) semiring_1_sum_list_eq:
"(⋀i. i ∈ set xs ⟹ f i = g i) ⟹ (∑i ← xs. f i) = (∑i ← xs. g i)"
  using arg_cong[OF list.map_cong0] by blast
lemma (in semiring_1) sum_swap:
"(∑i ← xs. (∑j ← ys. f i j)) = (∑j ← ys. (∑i ← xs. f i j))"
proof (induction xs)
  case (Cons a xs)
  have "(∑i ← (a # xs). (∑j ← ys. f i j)) = (∑j ← ys. f a j) + (∑i ← xs. (∑j ← ys. f i j))"
    by simp
  also have "... = (∑j ← ys. f a j) + (∑j ← ys. (∑i ← xs. f i j))"
    using Cons by simp
  also have "... = (∑j ← ys. f a j + (∑i ← xs. f i j))"
    using sum_list_addf[of "λj. f a j" _ ys] by simp
  also have "... = (∑j ← ys. (∑i ← (a # xs). f i j))" by simp
  finally show ?case .
qed simp
lemma (in semiring_1) sum_append:
  "(∑i ← (xs @ ys). f i) = (∑i ← xs. f i) + (∑i ← ys. f i)"
  by (induction xs) (simp_all add: add.assoc)
lemma (in semiring_1) sum_append':
  assumes "zs = xs @ ys"
  shows "(∑i ← zs. f i) = (∑i ← xs. f i) + (∑i ← ys. f i)"
  using assms sum_append by blast
subsubsection "Power Sums"
lemma (in semiring_1) sum_list_of_bool_filter: "(∑i ← xs. of_bool (P i) * f i) = (∑i ← filter P xs. f i)"
  by (induction xs; simp)
lemma upt_filter_less: "filter (λi. i < c) [a..<b] = [a..<min b c]"
  by (induction b; simp)
lemma upt_filter_geq: "filter (λi. i ≥ c) [a..<b] = [max a c..<b]"
  by (induction b; simp)
lemma (in semiring_1) sum_list_of_bool_less: "(∑i ← [a..<b]. of_bool (i < c) * f i) = (∑i ← [a..<min b c]. f i)"
  unfolding sum_list_of_bool_filter upt_filter_less by (rule refl)
lemma (in semiring_1) sum_list_of_bool_geq: "(∑i ← [a..<b]. of_bool (i ≥ c) * f i) = (∑i ← [max a c..<b]. f i)"
  unfolding sum_list_of_bool_filter upt_filter_geq by (rule refl)
lemma (in semiring_1) sum_list_of_bool_range: "(∑i ← [a..<b]. of_bool (i ∈ set [c..<d]) * f i) =
    (∑i ← [max a c..<min b d]. f i)"
proof -
  have "(∑i ← [a..<b]. of_bool (i ∈ set [c..<d]) * f i) =
      (∑i ← [a..<b]. of_bool (i ≥ c) * (of_bool (i < d) * f i))"
    by (intro semiring_1_sum_list_eq; simp)
  then show ?thesis unfolding sum_list_of_bool_geq sum_list_of_bool_less .
qed
lemma (in comm_semiring_1) cauchy_product:
"(∑i ← [0..<n]. f i) * (∑j ← [0..<m]. g j) =
    (∑k ← [0..<n + m - 1]. ∑l ← [k + 1 - m..<min (k + 1) n]. f l * g (k - l))"
proof  -
  have "(∑i ← [0..<n]. f i) * (∑j ← [0..<m]. g j) =
    (∑i ← [0..<n]. ∑j ← [0..<m]. f i * g j)"
    unfolding sum_list_mult_const[symmetric]
    unfolding sum_list_const_mult[symmetric]
    by (rule refl)
  also have "... = (∑i ← [0..<n]. ∑j ← [0..<m]. ∑k ← [0..<n + m - 1]. of_bool (k = i + j) * (f i * g j))"
    by (intro semiring_1_sum_list_eq of_bool_distinct_in[symmetric]; simp)
  also have "... = (∑k ← [0..<n + m - 1]. ∑i ← [0..<n]. ∑j ← [0..<m]. of_bool (k = i + j) * (f i * g j))"
    unfolding sum_swap[where xs = "[0..<m]" and ys = "[0..<n + m - 1]"]
    unfolding sum_swap[where xs = "[0..<n]" and ys = "[0..<n + m - 1]"]
    by (rule refl)
  also have "... = (∑k ← [0..<n + m - 1]. ∑i ← [0..<n]. ∑j ← [0..<m]. of_bool (k ≥ i ∧ j = k - i) * (f i * g j))"
    by (intro semiring_1_sum_list_eq; simp)
  also have "... = (∑k ← [0..<n + m - 1]. ∑i ← [0..<n]. ∑j ← [0..<m]. of_bool (j = k - i) * (of_bool (k ≥ i) * (f i * g j)))"
    by (intro semiring_1_sum_list_eq; simp)
  also have "... = (∑k ← [0..<n + m - 1]. ∑i ← [0..<n]. of_bool (k - i ∈ set [0..<m]) * ((of_bool (k ≥ i) * (f i * g (k - i)))))"
    by (intro semiring_1_sum_list_eq of_bool_distinct distinct_upt)
  also have "... = (∑k ← [0..<n + m - 1]. ∑i ← [0..<n]. of_bool (i ≥ k + 1 - m) * ((of_bool (k + 1 > i) * (f i * g (k - i)))))"
    by (intro semiring_1_sum_list_eq; auto)
  also have "... = (∑k ← [0..<n + m - 1]. ∑l ← [k + 1 - m..<min (k + 1) n]. f l * g (k - l))"
    apply (intro semiring_1_sum_list_eq)
    unfolding sum_list_of_bool_geq sum_list_of_bool_less max_0L min.commute[of n]
    by (rule refl)
  finally show ?thesis .
qed
lemma (in comm_semiring_1) power_sum_product:
  assumes "m > 0"
  assumes "n ≥ m"
  shows
"(∑i←[0..<n]. f i * x ^ i) * (∑j←[0..<m]. g j * x ^ j) =
  (∑k←[0..<m]. (∑i←[0..<Suc k]. f i * g (k - i)) * x ^ k) +
  (∑k←[m..<n]. (∑i←[Suc k - m..<Suc k]. f i * g (k - i)) * x ^ k) +
  (∑k←[n..<n + m - 1]. (∑i←[Suc k - m..<n]. f i * g (k - i)) * x ^ k)"
proof -
  have 1: "[0..<n + m - 1] = [0..<m] @ [m..<n] @ [n..<n + m - 1]"
    using upt_add_eq_append'[of 0 m "n + m - 1"] upt_add_eq_append'[of m n "n + m - 1"] assms by simp
  have "(∑i ← [0..<n]. f i * x ^ i) * (∑j ← [0..<m]. g j * x ^ j) =
      (∑k ← [0..<n + m - 1]. ∑l ← [k + 1 - m..<min (k + 1) n]. (f l * x ^ l) * (g (k - l) * x ^ (k - l)))"
    by (rule cauchy_product)
  also have "... = (∑k ← [0..<n + m - 1]. ∑l ← [k + 1 - m..<min (k + 1) n]. f l * g (k - l) * x ^ k)"
    apply (intro semiring_1_sum_list_eq)
    using mult.commute mult.assoc power_add[symmetric]
    by simp
  also have "... = (∑k ← [0..<n + m - 1]. (∑l ← [k + 1 - m..<min (k + 1) n]. f l * g (k - l)) * x ^ k)"
    by (intro semiring_1_sum_list_eq sum_list_mult_const)
  also have "... = (∑k←[0..<m]. (∑i←[k + 1 - m..<min (k + 1) n]. f i * g (k - i)) * x ^ k) +
      (∑k←[m..<n]. (∑i←[k + 1 - m..<min (k + 1) n]. f i * g (k - i)) * x ^ k) +
      (∑k←[n..<n + m - 1]. (∑i←[k + 1 - m..<min (k + 1) n]. f i * g (k - i)) * x ^ k)"
    unfolding 1 sum_append add.assoc by (rule refl)
  also have "... = (∑k←[0..<m]. (∑i←[0..<Suc k]. f i * g (k - i)) * x ^ k) +
      (∑k←[m..<n]. (∑i←[Suc k - m..<Suc k]. f i * g (k - i)) * x ^ k) +
      (∑k←[n..<n + m - 1]. (∑i←[Suc k - m..<n]. f i * g (k - i)) * x ^ k)"
    using assms by (intro_cong "[cong_tag_2 (+)]" more: semiring_1_sum_list_eq; simp)
  finally show ?thesis .
qed
lemma (in comm_semiring_1) power_sum_product_same_length:
  assumes "n > 0"
  shows "(∑i←[0..<n]. f i * x ^ i) * (∑j←[0..<n]. g j * x ^ j) =
  (∑k←[0..<n]. (∑i←[0..<Suc k]. f i * g (k - i)) * x ^ k) +
  (∑k←[n..<2 * n - 1]. (∑i←[Suc k - n..<n]. f i * g (k - i)) * x ^ k)"
  using power_sum_product[of n n f x g, OF assms order.refl]
  by (simp add: semiring_numeral_class.mult_2)
lemma (in semiring_1) sum_index_transformation:
  shows "(∑i ← xs. f (g i)) = (∑j ← map g xs. f j)"
  by (induction xs) simp_all
lemma (in comm_semiring_1) power_sum_split:
  fixes f :: "nat ⇒ 'a"
  fixes x :: 'a
  fixes c :: nat
  assumes "j ≤ n"
  shows "(∑i ← [0..<n]. f i * x ^ (i * c)) =
      (∑i ← [0..<j]. f i * x ^ (i * c)) +
      x ^ (j * c) * (∑i ← [0..<n - j]. f (j + i) * x ^ (i * c))"
proof -
  have "(λi. i + j) = (+) j" by fastforce
  have "(∑i ← [0..<n]. f i * x ^ (i * c)) =
    (∑i ← [0..<j]. f i * x ^ (i * c)) + (∑i ← [j..<n]. f i * x ^ (i * c))"
    apply (intro sum_append' upt_add_eq_append') using ‹j ≤ n› by auto
  also have "(∑i ← [j..<n]. f i * x ^ (i * c)) =
    (∑i ← map ((+) j) [0..<n - j]. f i * x ^ (i * c))"
    apply (intro_cong "[cong_tag_1 sum_list, cong_tag_2 map]" more: refl)
    using ‹j ≤ n› map_add_upt[of j "n - j"] ‹(λi. i + j) = (+) j› by simp
  also have "... = (∑i ← [0..<n - j]. f (j + i) * x ^ ((j + i) * c))"
    by (intro sum_index_transformation[symmetric])
  also have "... = (∑i ← [0..<n - j]. x ^ (j * c) * (f (j + i) * x ^ (i * c)))"
    apply (intro semiring_1_sum_list_eq)
    using mult.commute mult.assoc by (simp add: power_add add_mult_distrib)
  also have "... = x ^ (j * c) * (∑i ← [0..<n - j]. (f (j + i) * x ^ (i * c)))"
    by (intro sum_list_const_mult)
  finally show ?thesis .
qed
subsection "@{type nat} Sums"
lemma geo_sum_nat:
  assumes "(q :: nat) > 1"
  shows "(q - 1) * (∑i ← [0..<n]. q ^ i) = q ^ n - 1"
proof (induction n)
  case (Suc n)
  have "(q - 1) * (∑i ← [0..<Suc n]. q ^ i) = (q - 1) * (q ^ n + (∑i ← [0..<n]. q ^ i))"
    by simp
  also have "... = (q - 1) * q ^ n + (q - 1) * (∑i ← [0..<n]. q ^ i)"
    using add_mult_distrib mult.commute by metis
  also have "... = (q - 1) * q ^ n + (q ^ n - 1)"
    using Suc.IH by simp
  also have "... = q * q ^ n - 1" using ‹q > 1› by (simp add: diff_mult_distrib)
  finally show ?case by simp
qed simp
lemma geo_sum_bound:
  assumes "(q :: nat) > 1"
  assumes "⋀i. i < n ⟹ f i < q"
  shows "(∑i ← [0..<n]. f i * q ^ i) < q ^ n"
proof -
  from assms have "⋀i. i < n ⟹ f i ≤ (q - 1)" by fastforce
  then have "(∑i ← [0..<n]. f i * q ^ i) ≤ (∑i ← [0..<n]. (q - 1) * q ^ i)"
    apply (intro sum_list_mono mult_le_mono1)
    using assms by simp
  also have "... = (q - 1) * (∑i ← [0..<n]. q ^ i)"
    by (intro sum_list_const_mult)
  also have "... = q ^ n - 1"
    by (intro geo_sum_nat assms)
  also have "... < q ^ n" using ‹q > 1› by simp
  finally show ?thesis .
qed
lemma power_sum_nat_split_div_mod:
  assumes "x > 1"
  assumes "c > 0"
  assumes "⋀i. i < n ⟹ (f i :: nat) < x ^ c"
  assumes "j ≤ n"
  shows "(∑i ← [0..<n]. f i * x ^ (i * c)) div x ^ (j * c)
      = (∑i ← [0..<n - j]. f (j + i) * x ^ (i * c))"
        "(∑i ← [0..<n]. f i * x ^ (i * c)) mod x ^ (j * c)
      = (∑i ← [0..<j]. f i * x ^ (i * c))"
proof -
  define sum where "sum = (∑i ← [0..<n]. f i * x ^ (i * c))"
  then have "sum = (∑i ← [0..<j]. f i * x ^ (i * c)) +
      x ^ (j * c) * (∑i ← [0..<n - j]. f (j + i) * x ^ (i * c))"
    (is "sum = ?sum1 + x ^ (j * c) * ?sum2")
    using power_sum_split ‹j ≤ n› by blast
  have "?sum1 = (∑i ← [0..<j]. f i * (x ^ c) ^ i)"
    apply (intro_cong "[cong_tag_2 (*)]" more: semiring_1_sum_list_eq refl)
    using power_mult mult.commute by metis
  also have "... < (x ^ c) ^ j"
    apply (intro geo_sum_bound)
    subgoal using assms one_less_power by blast
    subgoal using assms by simp
    done
  finally have "?sum1 < x ^ (j * c)" by (simp add: power_mult mult.commute)
  then show "sum mod x ^ (j * c) = ?sum1" "sum div (x ^ (j * c)) = ?sum2" using ‹sum = ?sum1 + x ^ (j * c) * ?sum2›
    using assms(1) by fastforce+
qed
lemma :
  assumes "x > 1"
  assumes "c > 0"
  assumes "⋀i. i < n ⟹ (f i :: nat) < x ^ c"
  assumes "j < n"
  shows "((∑i ← [0..<n]. f i * x ^ (i * c)) div x ^ (j * c)) mod x ^ c = f j"
proof -
  have "(∑i ← [0..<n]. f i * x ^ (i * c)) div x ^ (j * c) =
    (∑i ← [0..<n - j]. f (j + i) * x ^ (i * c))" (is "?sum = _")
    apply (intro power_sum_nat_split_div_mod(1) assms)
    using assms by simp_all
  moreover have "... mod x ^ (1 * c) = (∑i ← [0..<1]. f (j + i) * x ^ (i * c))"
    apply (intro power_sum_nat_split_div_mod(2) assms)
    using assms by simp_all
  ultimately show "?sum mod x ^ c = f j" by simp
qed
lemma power_sum_nat_eq:
  assumes "x > 1"
  assumes "c > 0"
  assumes "⋀i. i < n ⟹ (f i :: nat) < x ^ c"
  assumes "⋀i. i < n ⟹ g i < x ^ c"
  assumes "(∑i ← [0..<n]. f i * x ^ (i * c)) = (∑i ← [0..<n]. g i * x ^ (i * c))"
    (is "?sumf = ?sumg")
  shows "⋀i. i < n ⟹ f i = g i"
proof -
  fix i
  assume "i < n"
  then have "f i = (?sumf div x ^ (i * c)) mod x ^ c"
    apply (intro power_sum_nat_extract_coefficient[symmetric] assms) by assumption
  also have "... = (?sumg div x ^ (i * c)) mod x ^ c"
    using assms by argo
  also have "... = g i"
    apply (intro power_sum_nat_extract_coefficient assms) using ‹i < n› by simp_all
  finally show "f i = g i" .
qed
end