Theory Efficient_Discrete_Sqrt
theory Efficient_Discrete_Sqrt
imports
  Complex_Main
  "HOL-Computational_Algebra.Computational_Algebra"
  "HOL-Library.Discrete_Functions"
  "HOL-Library.Tree"
  "HOL-Library.IArray"
begin
section ‹Efficient Algorithms for the Square Root on ‹ℕ››
subsection ‹A Discrete Variant of Heron's Algorithm›
text ‹
  An algorithm for calculating the discrete square root, taken from 
  Cohen~\<^cite>‹"cohen2010algebraic"›. This algorithm is essentially a discretised variant of
  Heron's method or Newton's method specialised to the square root function.
›
lemma sqrt_eq_floor_sqrt: "floor_sqrt n = nat ⌊sqrt n⌋"
proof -
  have "real ((nat ⌊sqrt n⌋)⇧2) = (real (nat ⌊sqrt n⌋))⇧2"
    by simp
  also have "… ≤ sqrt (real n) ^ 2"
    by (intro power_mono) auto
  also have "… = real n" by simp
  finally have "(nat ⌊sqrt n⌋)⇧2 ≤ n"
    by (simp only: of_nat_le_iff)
  moreover have "n < (Suc (nat ⌊sqrt n⌋))⇧2" proof -
    have "(1 + ⌊sqrt n⌋)⇧2 > n"
      using floor_correct[of "sqrt n"] real_le_rsqrt[of "1 + ⌊sqrt n⌋" n]
        of_int_less_iff[of n "(1 + ⌊sqrt n⌋)⇧2"] not_le
      by fastforce
    then show ?thesis
      using le_nat_floor[of "Suc (nat ⌊sqrt n⌋)" "sqrt n"]
        of_nat_le_iff[of "(Suc (nat ⌊sqrt n⌋))⇧2" n] real_le_rsqrt[of _ n] not_le
      by fastforce
  qed
  ultimately show ?thesis using floor_sqrt_unique by fast
qed
fun newton_sqrt_aux :: "nat ⇒ nat ⇒ nat" where
  "newton_sqrt_aux x n =
     (let y = (x + n div x) div 2
      in if y < x then newton_sqrt_aux y n else x)"
declare newton_sqrt_aux.simps [simp del]
lemma newton_sqrt_aux_simps:
  "(x + n div x) div 2 < x ⟹ newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
  "(x + n div x) div 2 ≥ x ⟹ newton_sqrt_aux x n = x"
  by (subst newton_sqrt_aux.simps; simp add: Let_def)+
lemma heron_step_real: "⟦t > 0; n ≥ 0⟧ ⟹ (t + n/t) / 2 ≥ sqrt n"
  using arith_geo_mean_sqrt[of t "n/t"] by simp
lemma heron_step_div_eq_floored:
  "(t::nat) > 0 ⟹ (t + (n::nat) div t) div 2 = nat ⌊(t + n/t) / 2⌋"
proof -
  assume "t > 0"
  then have "⌊(t + n/t) / 2⌋ = ⌊(t*t + n) / (2*t)⌋"
    by (simp add: mult_divide_mult_cancel_right[of t "t + n/t" 2, symmetric]
        algebra_simps)
  also have "… = (t*t + n) div (2*t)"
    using floor_divide_of_nat_eq by blast
  also have "… = (t*t + n) div t div 2"
    by (simp add: div_mult2_eq ac_simps)
  also have "… = (t + n div t) div 2"
    by (simp add: ‹0 < t› power2_eq_square)
  finally show ?thesis by simp
qed
lemma heron_step: "t > 0 ⟹ (t + n div t) div 2 ≥ floor_sqrt n"
proof -
  assume "t > 0"
  have "floor_sqrt n = nat ⌊sqrt n⌋" by (rule sqrt_eq_floor_sqrt)
  also have "… ≤ nat ⌊(t + n/t) / 2⌋"
    using heron_step_real[of t n] ‹t > 0› by linarith
  also have "… = (t + n div t) div 2"
    using heron_step_div_eq_floored[OF ‹t > 0›] by simp
  finally show ?thesis .
qed
lemma newton_sqrt_aux_correct:
  assumes "x ≥ floor_sqrt n"
  shows   "newton_sqrt_aux x n = floor_sqrt n"
  using assms
proof (induction x n rule: newton_sqrt_aux.induct)
  case (1 x n)
  show ?case
  proof (cases "x = floor_sqrt n")
    case True
    then have "(x ^ 2) div x ≤ n div x" by (intro div_le_mono) simp_all
    also have "(x ^ 2) div x = x" by (simp add: power2_eq_square)
    finally have "(x + n div x) div 2 ≥ x" by linarith
    with True show ?thesis by (auto simp: newton_sqrt_aux_simps)
  next
    case False
    with "1.prems" have x_gt_sqrt: "x > floor_sqrt n" by auto
    with le_floor_sqrt_iff[of x n] have "n < x ^ 2" by simp
    have "x * (n div x) ≤ n" using mult_div_mod_eq[of x n] by linarith
    also have "… < x ^ 2" using le_floor_sqrt_iff[of x n] and x_gt_sqrt by simp
    also have "… = x * x" by (simp add: power2_eq_square)
    finally have "n div x < x" by (subst (asm) mult_less_cancel1) auto
    then have step_decreasing: "(x + n div x) div 2 < x" by linarith
    with x_gt_sqrt have step_ge_sqrt: "(x + n div x) div 2 ≥ floor_sqrt n"
      by (simp add: heron_step)
    from step_decreasing have "newton_sqrt_aux x n = newton_sqrt_aux ((x + n div x) div 2) n"
      by (simp add: newton_sqrt_aux_simps)
    also have "… = floor_sqrt n"
      by (intro "1.IH" step_decreasing step_ge_sqrt) simp_all
    finally show ?thesis .
  qed
qed
definition newton_sqrt :: "nat ⇒ nat" where
  "newton_sqrt n = newton_sqrt_aux n n"
declare floor_sqrt_code [code del]
theorem Discrete_sqrt_eq_newton_sqrt [code]: "floor_sqrt n = newton_sqrt n"
  unfolding newton_sqrt_def by (simp add: newton_sqrt_aux_correct floor_sqrt_le)
subsection ‹Square Testing›
text ‹
  Next, we implement an algorithm to determine whether a given natural number is a perfect square,
  as described by Cohen~\<^cite>‹"cohen2010algebraic"›. Essentially, the number first determines whether
  the number is a square. Essentially
›
definition q11 :: "nat set"
  where "q11 = {0, 1, 3, 4, 5, 9}"
definition q63 :: "nat set"
  where "q63 = {0, 1, 4, 7, 9, 16, 28, 18, 22, 25, 36, 58, 46, 49, 37, 43}"
definition q64 :: "nat set"
  where "q64 = {0, 1, 4, 9, 16, 17, 25, 36, 33, 49, 41, 57}"
definition q65 :: "nat set"
  where "q65 = {0, 1, 4, 10, 14, 9, 16, 26, 30, 25, 29, 40, 56, 36, 49, 61, 35, 51, 39, 55, 64}"
definition q11_array where
  "q11_array = IArray [True,True,False,True,True,True,False,False,False,True,False]"
definition q63_array where
  "q63_array = IArray [True,True,False,False,True,False,False,True,False,True,False,False,
     False,False,False,False,True,False,True,False,False,False,True,False,False,True,False,
     False,True,False,False,False,False,False,False,False,True,True,False,False,False,False,
     False,True,False,False,True,False,False,True,False,False,False,False,False,False,False,
     False,True,False,False,False,False,False]"
definition q64_array where
  "q64_array = IArray [True,True,False,False,True,False,False,False,False,True,False,False,
     False,False,False,False,True,True,False,False,False,False,False,False,False,True,False,
     False,False,False,False,False,False,True,False,False,True,False,False,False,False,True,
     False,False,False,False,False,False,False,True,False,False,False,False,False,False,
     False,True,False,False,False,False,False,False, False]"
definition q65_array where
  "q65_array = IArray [True,True,False,False,True,False,False,False,False,True,True,False,
     False,False,True,False,True,False,False,False,False,False,False,False,False,True,True,
     False,False,True,True,False,False,False,False,True,True,False,False,True,True,False,
     False,False,False,False,False,False,False,True,False,True,False,False,False,True,True
     ,False,False,False,False,True,False,False,True,False]"
lemma sub_q11_array: "i ∈ {..<11} ⟹ IArray.sub q11_array i ⟷ i ∈ q11"
  by (simp add: lessThan_nat_numeral lessThan_Suc q11_def q11_array_def, elim disjE; simp)
lemma sub_q63_array: "i ∈ {..<63} ⟹ IArray.sub q63_array i ⟷ i ∈ q63"
  by (simp add: lessThan_nat_numeral lessThan_Suc q63_def q63_array_def, elim disjE; simp)
lemma sub_q64_array: "i ∈ {..<64} ⟹ IArray.sub q64_array i ⟷ i ∈ q64"
  by (simp add: lessThan_nat_numeral lessThan_Suc q64_def q64_array_def, elim disjE; simp)
lemma sub_q65_array: "i ∈ {..<65} ⟹ IArray.sub q65_array i ⟷ i ∈ q65"
  by (simp add: lessThan_nat_numeral lessThan_Suc q65_def q65_array_def, elim disjE; simp)
lemma in_q11_code: "x mod 11 ∈ q11 ⟷ IArray.sub q11_array (x mod 11)"
  by (subst sub_q11_array) auto
lemma in_q63_code: "x mod 63 ∈ q63 ⟷ IArray.sub q63_array (x mod 63)"
  by (subst sub_q63_array) auto
lemma in_q64_code: "x mod 64 ∈ q64 ⟷ IArray.sub q64_array (x mod 64)"
  by (subst sub_q64_array) auto
lemma in_q65_code: "x mod 65 ∈ q65 ⟷ IArray.sub q65_array (x mod 65)"
  by (subst sub_q65_array) auto
definition square_test :: "nat ⇒ bool" where
  "square_test n =
    (n mod 64 ∈ q64 ∧ (let r = n mod 45045 in
      r mod 63 ∈ q63 ∧ r mod 65 ∈ q65 ∧ r mod 11 ∈ q11 ∧ n = (floor_sqrt n)⇧2))"
lemma square_test_code [code]:
  "square_test n =
    (IArray.sub q64_array (n mod 64) ∧ (let r = n mod 45045 in
           IArray.sub q63_array (r mod 63) ∧ 
           IArray.sub q65_array (r mod 65) ∧
           IArray.sub q11_array (r mod 11) ∧ n = (floor_sqrt n)⇧2))"
    using in_q11_code [symmetric] in_q63_code [symmetric] 
          in_q64_code [symmetric] in_q65_code [symmetric]
  by (simp add: Let_def square_test_def)
lemma square_mod_lower: "m > 0 ⟹ (q⇧2 :: nat) mod m = a ⟹ ∃q' < m. q'⇧2 mod m = a"
  using mod_less_divisor mod_mod_trivial power_mod by blast
lemma q11_upto_def: "q11 = (λk. k⇧2 mod 11) ` {..<11}"
  by (simp add: q11_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q11_infinite_def: "q11 = (λk. k⇧2 mod 11) ` {0..}"
  unfolding q11_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 11 xa "xa⇧2 mod 11"]
      ex_nat_less_eq[of 11 "λx. xa⇧2 mod 11 = x⇧2 mod 11"]
    by auto
qed
lemma q63_upto_def: "q63 = (λk. k⇧2 mod 63) ` {..<63}"
  by (simp add: q63_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q63_infinite_def: "q63 = (λk. k⇧2 mod 63) ` {0..}"
  unfolding q63_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 63 xa "xa⇧2 mod 63"]
      ex_nat_less_eq[of 63 "λx. xa⇧2 mod 63 = x⇧2 mod 63"]
    by auto
qed
lemma q64_upto_def: "q64 = (λk. k⇧2 mod 64) ` {..<64}"
  by (simp add: q64_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q64_infinite_def: "q64 = (λk. k⇧2 mod 64) ` {0..}"
  unfolding q64_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 64 xa "xa⇧2 mod 64"]
      ex_nat_less_eq[of 64 "λx. xa⇧2 mod 64 = x⇧2 mod 64"]
    by auto
qed
lemma q65_upto_def: "q65 = (λk. k⇧2 mod 65) ` {..<65}"
  by (simp add: q65_def lessThan_nat_numeral lessThan_Suc insert_commute)
lemma q65_infinite_def: "q65 = (λk. k⇧2 mod 65) ` {0..}"
  unfolding q65_upto_def image_def proof (auto, goal_cases)
  case (1 xa)
  show ?case
    using square_mod_lower[of 65 xa "xa⇧2 mod 65"]
      ex_nat_less_eq[of 65 "λx. xa⇧2 mod 65 = x⇧2 mod 65"]
    by auto
qed
lemma square_mod_existence:
  fixes n k :: nat
  assumes "∃q. q⇧2 = n"
  shows "∃q. n mod k = q⇧2 mod k"
  using assms by auto
theorem square_test_correct: "square_test n ⟷ is_square n"
proof cases
  assume "is_square n"
  hence  rhs: "∃q. q⇧2 = n" by (auto elim: is_nth_powerE)
  note sq_mod = square_mod_existence[OF this]
  have q64_member: "n mod 64 ∈ q64" using sq_mod[of 64]
    unfolding q64_infinite_def image_def by simp
  let ?r = "n mod 45045"
  have "11 dvd (45045::nat)" "63 dvd (45045::nat)" "65 dvd (45045::nat)" by force+
  then have mod_45045: "?r mod 11 = n mod 11" "?r mod 63 = n mod 63" "?r mod 65 = n mod 65"
    using mod_mod_cancel[of _ 45045 n] by presburger+
  then have "?r mod 11 ∈ q11" "?r mod 63 ∈ q63" "?r mod 65 ∈ q65"
    using sq_mod[of 11] sq_mod[of 63] sq_mod[of 65]
    unfolding q11_infinite_def q63_infinite_def q65_infinite_def image_def mod_45045
    by fast+
  then show ?thesis unfolding square_test_def Let_def using q64_member rhs by auto
next
  assume not_rhs: "¬is_square n"
  hence "∄q. q⇧2 = n" by auto
  then have "(floor_sqrt n)⇧2 ≠ n" by simp
  then show ?thesis unfolding square_test_def by (auto simp: is_nth_power_def)
qed
definition get_nat_sqrt :: "nat ⇒ nat option" 
  where "get_nat_sqrt n = (if is_square n then Some (floor_sqrt n) else None)"
lemma get_nat_sqrt_code [code]:
  "get_nat_sqrt n = 
    (if IArray.sub q64_array (n mod 64) ∧ (let r = n mod 45045 in
           IArray.sub q63_array (r mod 63) ∧ 
           IArray.sub q65_array (r mod 65) ∧
           IArray.sub q11_array (r mod 11)) then
       (let x = floor_sqrt n in if x⇧2 = n then Some x else None) else None)"
  unfolding get_nat_sqrt_def square_test_correct [symmetric] square_test_def
  using in_q11_code [symmetric] in_q63_code [symmetric] 
        in_q64_code [symmetric] in_q65_code [symmetric]
  by (auto split: if_splits simp: Let_def )
end