Theory Euler_Partition

(* Author: Lukas Bulwahn <lukas.bulwahn-at-gmail.com>
   Dedicated to Sandra H. for a wonderful relaxing summer
*)

section ‹Euler's Partition Theorem›

theory Euler_Partition
imports
  Main
  Card_Number_Partitions.Number_Partition
begin

subsection ‹Preliminaries›

subsubsection ‹Additions to Divides Theory›

lemma power_div_nat:
  assumes "c  b"
  assumes "a > 0"
  shows  "(a :: nat) ^ b div a ^ c = a ^ (b - c)"
by (metis assms nonzero_mult_div_cancel_right le_add_diff_inverse2 less_not_refl2 power_add power_not_zero)

subsubsection ‹Additions to Groups-Big Theory›

lemma sum_div:
  assumes "finite A"
  assumes "a. a  A  (b::'b::euclidean_semiring) dvd f a"
  shows "(aA. f a) div b = (aA. (f a) div b)"
using assms
proof (induct)
  case insert from this show ?case by auto (subst div_add; auto intro!: dvd_sum)
qed (auto)

lemma sum_mod:
  assumes "finite A"
  assumes "a. a  A  f a mod b = (0::'b::unique_euclidean_semiring)"
  shows "(aA. f a) mod b = 0"
using assms by induct (auto simp add: mod_add_eq [symmetric])

subsubsection ‹Additions to Finite-Set Theory›

lemma finite_exponents:
  "finite {i. 2 ^ i  (n::nat)}"
proof -
  have "{i::nat. 2 ^ i  n}  {0..n}"
    using dual_order.trans by fastforce
  from finite_subset[OF this] show ?thesis by simp
qed

subsection ‹Binary Encoding of Natural Numbers›

definition bitset :: "nat  nat set"
where
  "bitset n = {i. odd (n div (2 ^ i))}"

lemma in_bitset_bound:
  "b  bitset n  2 ^ b  n"
unfolding bitset_def using not_less by fastforce

lemma in_bitset_bound_weak:
  "b  bitset n  b  n"
by (meson order.trans in_bitset_bound self_le_ge2_pow[OF order_refl])

lemma finite_bitset:
  "finite (bitset n)"
proof -
  have "bitset n  {..n}" by (auto dest: in_bitset_bound_weak)
  from this show ?thesis using finite_subset by auto
qed

lemma bitset_0:
  "bitset 0 = {}"
unfolding bitset_def by auto

lemma bitset_2n: "bitset (2 * n) = Suc ` (bitset n)"
proof (rule set_eqI)
  fix x
  show "(x  bitset (2 * n)) = (x  Suc ` bitset n)"
    unfolding bitset_def by (cases x) auto
qed

lemma bitset_Suc:
  assumes "even n"
  shows "bitset (n + 1) = insert 0 (bitset n)"
proof (rule set_eqI)
  fix x
  from assms show "(x  bitset (n + 1)) = (x  insert 0 (bitset n))"
    unfolding bitset_def by (cases x) (auto simp add: Divides.div_mult2_eq)
qed

lemma bitset_2n1:
  "bitset (2 * n + 1) = insert 0 (Suc ` (bitset n))"
by (subst bitset_Suc) (auto simp add: bitset_2n)

lemma sum_bitset:
  "(ibitset n. 2 ^ i) = n"
proof (induct rule: nat_bit_induct)
  case zero
  show ?case by (auto simp add: bitset_0)
next
  case (even n)
  from this show ?case
    by (simp add: bitset_2n sum.reindex sum_distrib_left[symmetric])
next
  case (odd n)
  have "(ibitset (2 * n + 1). 2 ^ i) = (iinsert 0 (Suc ` bitset n). 2 ^ i)"
    by (simp only: bitset_2n1)
  also have "... = 2 ^ 0 + (iSuc ` bitset n. 2 ^ i)"
    by (subst sum.insert) (auto simp add: finite_bitset)
  also have "... = 2 * n + 1"
    using odd by (simp add: sum.reindex sum_distrib_left[symmetric])
  finally show ?case by simp
qed

lemma binarysum_div:
  assumes "finite B"
  shows "(iB. (2::nat) ^ i) div 2 ^ j = (iB. if i < j then 0 else 2 ^ (i - j))"
  (is "_ = (i_. ?f i)")
proof -
  have split_B: "B = {iB. i < j}  {iB. j  i}" by auto
  have bound: "(i | i  B  i < j. (2::nat) ^ i) < 2 ^ j"
  proof (rule order.strict_trans1)
    show "(i | i  B  i < j. (2::nat) ^ i)  (i<j. 2 ^ i)" by (auto intro: sum_mono2)
    show "... < 2 ^ j" using sum_power2 by (simp add: atLeast0LessThan)
  qed
  from this have zero: "(i | i  B  i < j. (2::nat) ^ i) div (2 ^ j) = 0" by (elim div_less)
  from assms have mod0: "(i | i  B  j  i. (2::nat) ^ i) mod 2 ^ j = 0"
    by (auto intro!: sum_mod simp add: le_imp_power_dvd)
  from assms have "(iB. (2::nat) ^ i) div (2 ^ j) = ((i | i  B  i < j. 2 ^ i) + (i | i  B  j  i. 2 ^ i)) div 2 ^ j"
    by (subst sum.union_disjoint[symmetric]) (auto simp add: split_B[symmetric])
  also have "... = (i | i  B  j  i. 2 ^ i) div 2 ^ j"
    by (simp add: div_add1_eq zero mod0)
  also have "... = (i | i  B  j  i. 2 ^ i div 2 ^ j)"
    using assms by (subst sum_div) (auto simp add: sum_div le_imp_power_dvd)
  also have "... = (i | i  B  j  i. 2 ^ (i - j))"
    by (rule sum.cong[OF refl]) (auto simp add: power_div_nat)
  also have "... = (iB. ?f i)"
    using assms by (subst split_B; subst sum.union_disjoint) auto
  finally show ?thesis .
qed

lemma odd_iff:
  assumes "finite B"
  shows "odd (iB. if i < x then (0::nat) else 2 ^ (i - x)) = (x  B)" (is "odd (i_. ?s i) = _")
proof -
  from assms have even: "even (iB - {x}. ?s i)"
    by (subst dvd_sum) auto
  show ?thesis
  proof
    assume "odd (iB. ?s i)"
    from this even show "x  B" by (cases "x  B") auto
  next
    assume "x  B"
    from assms this have "(iB. ?s i) = 1 + (iB-{x}. ?s i)"
      by (auto simp add: sum.remove)
    from assms this even show "odd (iB. ?s i)" by auto
  qed
qed

lemma bitset_sum:
  assumes "finite B"
  shows "bitset (iB. 2 ^ i) = B"
using assms unfolding bitset_def by (simp add: binarysum_div odd_iff)

subsection ‹Decomposition of a Number into a Power of Two and an Odd Number›

function (sequential) index :: "nat  nat"
where
  "index 0 = 0"
| "index n = (if odd n then 0 else Suc (index (n div 2)))"
by (pat_completeness) auto

termination
proof
  show "wf {(x::nat, y). x < y}" by (simp add: wf)
next
  fix n show "(Suc n div 2, Suc n)  {(x, y). x < y}" by simp
qed

function (sequential) oddpart :: "nat  nat"
where
  "oddpart 0 = 0"
| "oddpart n = (if odd n then n else oddpart (n div 2))"
by pat_completeness auto

termination
proof
  show "wf {(x::nat, y). x < y}" by (simp add: wf)
next
  fix n show "(Suc n div 2, Suc n)  {(x, y). x < y}" by simp
qed

lemma odd_oddpart:
  "odd (oddpart n)  n  0"
by (induct n rule: index.induct) auto

lemma index_oddpart_decomposition:
  "n = 2 ^ (index n) * oddpart n"
proof (induct n rule: index.induct)
  case (2 n)
  from this show "Suc n = 2 ^ index (Suc n) * oddpart (Suc n)"
    by (simp add: mult.assoc)
qed (simp)

lemma oddpart_leq:
  "oddpart n  n"
by (induct n rule: index.induct) (simp, metis div_le_dividend le_Suc_eq le_trans oddpart.simps(2))

lemma index_oddpart_unique:
  assumes "odd (m :: nat)" "odd m'"
  shows "(2 ^ i * m = 2 ^ i' * m')  (i = i'  m = m')"
proof (induct i arbitrary: i')
  case 0
  from assms show ?case by auto
next
  case (Suc _ i')
  from assms this show ?case by (cases i') auto
qed

lemma index_oddpart:
  assumes "odd m"
  shows "index (2 ^ i * m) = i" "oddpart (2 ^ i * m) = m"
using index_oddpart_unique[where i=i and m=m and m'="oddpart (2 ^ i * m)" and i'="index (2 ^ i * m)"]
  assms odd_oddpart index_oddpart_decomposition by force+

subsection ‹Partitions With Only Distinct and Only Odd Parts›

definition odd_of_distinct :: "(nat  nat)  nat  nat"
where
  "odd_of_distinct p = (λi. if odd i then (j | p (2 ^ j * i) = 1. 2 ^ j) else 0)"

definition distinct_of_odd :: "(nat  nat)  nat  nat"
where
  "distinct_of_odd p = (λi. if index i  bitset (p (oddpart i)) then 1 else 0)"

lemma odd:
  "odd_of_distinct p i  0  odd i"
unfolding odd_of_distinct_def by auto

lemma distinct_distinct_of_odd:
  "distinct_of_odd p i  1"
unfolding distinct_of_odd_def by auto

lemma odd_of_distinct:
  assumes "odd_of_distinct p i  0"
  assumes "i. p i  0  i  n"
  shows "1  i  i  n"
proof
  from assms(1) odd have "odd i"
    by simp
  then show "1  i"
    by (auto elim: oddE)
next
  from assms(1) obtain j where "p (2 ^ j * i) > 0"
    by (auto simp add: odd_of_distinct_def split: if_splits) fastforce
  with assms(2) have "i  2 ^ j * i" "2 ^ j * i  n"
    by simp_all
  then show "i  n"
    by (rule order_trans)
qed

lemma distinct_of_odd:
  assumes "i. p i * i  n" "i. p i  0  odd i"
  assumes "distinct_of_odd p i  0"
  shows "1  i  i  n"
proof
  from assms(3) have index: "index i  bitset (p (oddpart i))"
    unfolding distinct_of_odd_def by (auto split: if_split_asm)
  have "i  0"
  proof
    assume zero: "i = 0"
    from assms(2) have "p 0 = 0" by auto
    from index zero this show "False" by (auto simp add: bitset_0)
  qed
  from this show "1  i" by auto
  from assms(1) have leq_n: "p (oddpart i) * oddpart i  n" by auto
  from index have "2 ^ index i  p (oddpart i)" by (rule in_bitset_bound)
  from this leq_n show "i  n"
    by (subst index_oddpart_decomposition[of i]) (meson dual_order.trans eq_imp_le mult_le_mono)
qed

lemma odd_distinct:
  assumes "i. p i  0  odd i"
  shows "odd_of_distinct (distinct_of_odd p) = p"
using assms unfolding odd_of_distinct_def distinct_of_odd_def
by (auto simp add: fun_eq_iff index_oddpart sum_bitset)

lemma distinct_odd:
  assumes "i. p i  0  1  i  i  n" "i. p i  1"
  shows "distinct_of_odd (odd_of_distinct p) = p"
proof -
  from assms have "{i. p i = 1}  {..n}" by auto
  from this have finite: "finite {i. p i = 1}" by (simp add: finite_subset)
  have "x j. x > 0  p (2 ^ j * oddpart x) = 1 
    index (2 ^ j * oddpart x)  index ` {i. p i = 1  oddpart x = oddpart i}"
    by (rule imageI) (auto intro: imageI simp add: index_oddpart odd_oddpart)
  from this have eq: "x. x > 0  {j. p (2 ^ j * oddpart x) = 1} = index ` {i. p i = 1  oddpart x = oddpart i}"
    by (auto simp add: index_oddpart odd_oddpart index_oddpart_decomposition[symmetric])
  from finite have all_finite: "x. x > 0  finite {j. p (2 ^ j * oddpart x) = 1}"
    unfolding eq by auto
  show ?thesis
  proof
    fix x
    from assms(1) have p0: "p 0 = 0" by auto
    show "distinct_of_odd (odd_of_distinct p) x = p x"
    proof (cases "x > 0")
      case False
      from this p0 show ?thesis
        unfolding odd_of_distinct_def distinct_of_odd_def
        by (auto simp add: odd_oddpart bitset_0)
    next
      case True
      from p0 assms(2)[of x] all_finite[OF True] show ?thesis
        unfolding odd_of_distinct_def distinct_of_odd_def
        by (auto simp add: odd_oddpart bitset_0 bitset_sum index_oddpart_decomposition[symmetric])
    qed
  qed
qed

lemma sum_distinct_of_odd:
  assumes "i. p i  0  1  i  i  n"
  assumes "i. p i * i  n"
  assumes "i. p i  0  odd i"
  shows "(in. distinct_of_odd p i * i) = (in. p i * i)"
proof -
  {
    fix m
    assume odd: "odd (m :: nat)"
    have finite: "finite {k. 2 ^ k * m  n  k  bitset (p m)}" by (simp add: finite_bitset)
    have "(i | k. i = 2 ^ k * m  i  n. distinct_of_odd p i * i) =
      (i | k. i = 2 ^ k * m  i  n. if index i  bitset (p (oddpart i)) then i else 0)"
      unfolding distinct_of_odd_def by (auto intro: sum.cong)
    also have "... = (i | k. i = 2 ^ k * m  k  bitset (p m)  i  n. i)"
      using odd by (intro sum.mono_neutral_cong_right) (auto simp add: index_oddpart)
    also have "... = (k | 2 ^ k * m  n  k  bitset (p m). 2 ^ k * m)"
      using odd by (auto intro!: sum.reindex_cong[OF _ _ refl] inj_onI)
    also have "... = (kbitset (p m). 2 ^ k * m)"
      using assms(2)[of m] finite dual_order.trans in_bitset_bound
      by (fastforce intro!: sum.mono_neutral_cong_right)
    also have "... = (kbitset (p m). 2 ^ k) * m"
      by (subst sum_distrib_right) auto
    also have "... = p m * m"
      by (auto simp add: sum_bitset)
    finally have "(i | k. i = 2 ^ k * m  i  n. distinct_of_odd p i * i) = p m * m" .
  } note inner_eq = this

  have set_eq: "{i. 1  i  i  n} = ((λm. {i. k. i = (2 ^ k) * m  i  n}) ` {m. m  n  odd m})"
  proof -
    {
      fix x
      assume "1  x" "x  n"
      from this oddpart_leq[of x] have "oddpart x  n  odd (oddpart x)  (k. 2 ^ index x * oddpart x = 2 ^ k * oddpart x)"
        by (auto simp add: odd_oddpart)
      from this have "mn. odd m  (k. x = 2 ^ k * m)"
        by (auto simp add: index_oddpart_decomposition[symmetric])
    }
    from this show ?thesis by (auto simp add: Suc_leI odd_pos)
  qed
  let ?S = "(λm. {i. k. i = 2 ^ k * m  i  n}) ` {m. m  n  odd m}"
  have no_overlap: "A?S. B?S. A  B  A  B = {}"
    by (auto simp add: index_oddpart_unique)
  have inj: "inj_on (λm. {i. (k. i = 2 ^ k * m)  i  n}) {m. m  n  odd m}"
    unfolding inj_on_def by auto (force simp add: index_oddpart_unique)
  have reindex: "F. (i | 1  i  i  n. F i) = (m | m  n  odd m. (i | k. i = 2 ^ k * m  i  n. F i))"
    unfolding set_eq by (subst sum.Union_disjoint) (auto simp add: no_overlap intro: sum.reindex_cong[OF inj])
  have "(in. distinct_of_odd p i * i) =  (i | 1  i  i  n. distinct_of_odd p i * i)"
    by (auto intro: sum.mono_neutral_right)
  also have "... = (m | m  n  odd m. i | k. i = 2 ^ k * m  i  n. distinct_of_odd p i * i)"
    by (simp only: reindex)
  also have "... = (i | i  n  odd i. p i * i)"
    by (rule sum.cong[OF refl]; subst inner_eq) auto
  also have "... = (in. p i * i)"
    using assms(3) by (auto intro: sum.mono_neutral_left)
  finally show ?thesis .
qed

lemma leq_n:
  assumes "i. 0 < p i  1  i  i  (n::nat)"
  assumes "(in. p i * i) = n"
  shows "p i * i  n"
proof (rule ccontr)
  assume "¬ p i * i  n"
  from this have gr_n: "p i * i > n" by auto
  from this assms(1) have "1  i  i  n" by force
  from this have "(jn. p j * j) = p i * i + (j | j  n  j  i. p j * j)"
    by (subst sum.insert[symmetric]) (auto intro: sum.cong simp del: sum.insert)
  from this gr_n assms(2) show False by simp
qed

lemma distinct_of_odd_in_distinct_partitions:
  assumes "p  {p. p partitions n  (i. p i  0  odd i)}"
  shows "distinct_of_odd p  {p. p partitions n  (i. p i  1)}"
proof
  have "distinct_of_odd p partitions n"
  proof (rule partitionsI)
    fix i assume "distinct_of_odd p i  0"
    from this assms show "1  i  i  n"
    unfolding partitions_def
    by (rule_tac distinct_of_odd) (auto simp add: leq_n)
  next
    from assms show "(in. distinct_of_odd p i * i) = n"
      by (subst  sum_distinct_of_odd) (auto simp add: distinct_distinct_of_odd leq_n elim: partitionsE)
  qed
  moreover have "i. distinct_of_odd p i  1"
    by (intro allI distinct_distinct_of_odd)
  ultimately show "distinct_of_odd p partitions n  (i. distinct_of_odd p i  1)" by simp
qed

lemma odd_of_distinct_in_odd_partitions:
  assumes "p  {p. p partitions n  (i. p i  1)}"
  shows "odd_of_distinct p  {p. p partitions n  (i. p i  0  odd i)}"
proof
  from assms have distinct: "i. p i = 0  p i = 1"
    using le_imp_less_Suc less_Suc_eq_0_disj by fastforce
  from assms have set_eq: "{x. p x = 1} = {x  {..n}. p x = 1}"
     unfolding partitions_def by auto
  from assms have sum: "(in. p i * i) = n"
     unfolding partitions_def by auto
  {
    fix i
    assume i: "odd (i :: nat)"
    have 3: "inj_on index {x. p x = 1  oddpart x = i}"
      unfolding inj_on_def by auto (metis index_oddpart_decomposition)
    {
      fix j assume "p (2 ^ j * i) = 1"
      from this i have "j  index ` {x. p x = 1  oddpart x = i}"
        by (auto simp add: index_oddpart(1, 2) intro!: image_eqI[where x="2 ^ j * i"])
    }
    from i this have "{j. p (2 ^ j * i) = 1} = index ` {x. p x = 1  oddpart x = i}"
      by (auto simp add: index_oddpart_decomposition[symmetric])
    from 3 this have "(j | p (2 ^ j * i) = 1. 2 ^ j) * i = (x | p x = 1  oddpart x = i. 2 ^ index x) * i"
      by (auto intro: sum.reindex_cong[where l = "index"])
    also have "... = (x | p x = 1  oddpart x = i. 2 ^ index x * oddpart x)"
      by (auto simp add: sum_distrib_right)
    also have "... = (x | p x = 1  oddpart x = i. x)"
       by (simp only: index_oddpart_decomposition[symmetric])
    also have "...  (x | p x = 1. x)"
      using set_eq by (intro sum_mono2) auto
    also have "... = (xn. p x * x)"
      using distinct by (subst set_eq) (force intro!: sum.mono_neutral_cong_left)
    also have "... = n" using sum .
    finally have "(j | p (2 ^ j * i) = 1. 2 ^ j) * i  n" .
  }
  from this have less_n: "i. odd_of_distinct p i * i  n"
    unfolding odd_of_distinct_def by auto
  have "odd_of_distinct p partitions n"
  proof (rule partitionsI)
    fix i assume "odd_of_distinct p i  0"
    from this assms show "1  i  i  n"
      by (elim CollectE conjE partitionsE odd_of_distinct) auto
  next
    have "(in. odd_of_distinct p i * i) = (in. distinct_of_odd (odd_of_distinct p) i * i)"
      using assms less_n by (subst sum_distinct_of_odd) (auto elim!: partitionsE odd_of_distinct simp only: odd)
    also have "... = (in. p i * i)" using assms
      by (auto elim!: partitionsE simp only:) (subst distinct_odd, auto)
    also with assms have "... = n" by (auto elim: partitionsE)
    finally show "(in. odd_of_distinct p i * i) = n" .
  qed
  moreover have "i. odd_of_distinct p i  0  odd i"
    by (intro allI impI odd)
  ultimately show "odd_of_distinct p partitions n  (i. odd_of_distinct p i  0  odd i)" by simp
qed

subsection ‹Euler's Partition Theorem›

theorem Euler_partition_theorem:
  "card {p. p partitions n  (i. p i  1)} = card {p. p partitions n  (i. p i  0  odd i)}"
  (is "card ?distinct_partitions = card ?odd_partitions")
proof (rule card_bij_eq)
  from odd_of_distinct_in_odd_partitions show
    "odd_of_distinct ` ?distinct_partitions  ?odd_partitions" by auto
  moreover from distinct_of_odd_in_distinct_partitions show
    "distinct_of_odd ` ?odd_partitions  ?distinct_partitions" by auto
  moreover have "p?distinct_partitions. distinct_of_odd (odd_of_distinct p) = p"
    by auto (subst distinct_odd; auto simp add: partitions_def)
  moreover have "p?odd_partitions. odd_of_distinct (distinct_of_odd p) = p"
    by auto (subst odd_distinct; auto simp add: partitions_def)
  ultimately show "inj_on odd_of_distinct ?distinct_partitions"
    "inj_on distinct_of_odd ?odd_partitions"
    by (intro bij_betw_imp_inj_on bij_betw_byWitness; auto)+
  show "finite ?distinct_partitions" "finite ?odd_partitions"
    by (simp add: finite_partitions)+
qed

end