Theory Masking

theory Masking
  imports Complex_Main Bit_Counting Utils
begin

subsection ‹Masking›

abbreviation σ where "σ  count_bits"
abbreviation τ where "τ  count_carries"

locale masking_lemma =
  fixes δ ν  b C :: nat 
  assumes δ_pos: "δ > 0"
      and ℬ_power2: "is_power2 (int )" 
      and b_power2: "is_power2 (int b)"
      and ℬ_ge_2: "2  " 
      and b_le_ℬ: "b  "
      and C_lower_bound: "C  b * ^(δ+1)^ν"
      and C_upper_bound: "C   * ^(δ+1)^ν"
begin

definition n :: "nat  nat" where "n j = (δ+1)^j"

definition m::"nat  nat" where
"m j = (if j  n ` {1..ν} then  - b else  - 1)"

text mask(b, ℬ, δ, ν)›
definition M::nat where "M = (j=0..n ν. m j * ^j)" 

lemma b_ge_1: "1  b"
  using b_power2 is_power2_ge1 by force
    
lemma b_dvd_ℬ: "b dvd " 
  using b_le_ℬ b_power2 ℬ_power2 is_power2_def le_imp_power_dvd by fastforce   

lemma n_inj_on: "inj_on n A"
  unfolding n_def inj_on_def by (simp add: δ_pos)

lemma direct_g_bound:
  assumes z_bound: "i. z i < b"
      and g_code: "g = (i=1..ν. z i * ^(n i))"
  shows "g < C" 
proof (cases "b > 1")
  case False 
  hence "b = 1" using b_ge_1 by simp
  hence "g = 0" using z_bound g_code by simp
  thus ?thesis using C_lower_bound b_ge_1 ℬ_ge_2 using zero_less_iff_neq_zero by fastforce
next
  case True

  define z' where "z' = (λj. if j  range n then z (inv n j) else 0)"

  have z'_bound: "z' j  b - 1" for j
    unfolding z'_def using z_bound b_ge_1 less_Suc_eq_le by fastforce

  have "g = (i{1..ν}. z' (n i) * ^(n i))"
    using sum.cong z'_def n_inj_on by (simp add: g_code)
  also have "... = (jn`{1..ν}. z' j * ^j)"
    using n_inj_on[of "{1..ν}"] by (simp add: sum.reindex_cong)
  also have "...  (j{0..n ν}. z' j * ^j)"
  proof -
    have "n ` {1..ν}  {0..n ν}" 
      unfolding n_def
      by (metis (no_types, opaque_lifting) One_nat_def Suc_eq_plus1 Suc_le_mono atLeast0AtMost 
          atLeastAtMost_iff image_subset_iff power_increasing zero_le)
    thus ?thesis using sum_mono2 by blast
  qed
  also have "...  (j=0..n ν. (b - 1) * ^j)"
    using z'_bound sum_mono by (metis (no_types, lifting) mult_le_cancel2)
  finally have "g  (b-1) * (j<(n ν)+1. ^j)"
    using sum_distrib_left by (metis Suc_eq_plus1 atMost_atLeast0 lessThan_Suc_atMost)
  
  hence "(-1) * g  (-1) * (j<(n ν)+1. ^j) * (b-1)"
    by (simp add: mult.commute)
  also have "... = (^(n ν + 1) - 1) * (b-1)"
    using aux_geometric_sum ℬ_ge_2
    by (metis le_imp_less_Suc nat_add_left_cancel_less one_add_one plus_1_eq_Suc)
  also have "... < (b-1) *  * ^(n ν)"
    using True ℬ_ge_2 by auto
  also have "...  (-1) * b * ^(n ν)"
    using b_le_ℬ ℬ_ge_2
    by (smt (verit) diff_le_mono2 diff_mult_distrib mult.commute mult_le_mono1)
  finally have "g < b * ^(n ν)" 
    by auto
  thus ?thesis
    using C_lower_bound n_def by simp
qed

lemma direct_tau_zero:
  assumes z_bound: "i. z i < b"
      and g_code: "g = (i=1..ν. z i * ^(n i))"
  shows "τ g M = 0"
proof -
  define z' where "z' = (λj. if j  n ` {1..ν} then z (inv n j) else 0)"

  have "g = (i{1..ν}. z' (n i) * ^(n i))"
    using sum.cong z'_def n_inj_on by (simp add: g_code)
  also have "... = (jn`{1..ν}. z' j * ^j)"
    using n_inj_on[of "{1..ν}"] by (simp add: sum.reindex_cong)
  also have "... = (j{0..n ν}. z' j * ^j)"
  proof -
    have "n ` {1..ν}  {0..n ν}"
      unfolding n_def
      by (metis (no_types, opaque_lifting) One_nat_def Suc_eq_plus1 Suc_le_mono atLeast0AtMost 
          atLeastAtMost_iff image_subset_iff power_increasing zero_le)
    moreover have "j  {0..n ν} - n ` {1..ν}  z' j * ^j = 0" for j
      unfolding z'_def by auto
    ultimately show ?thesis
      using sum.mono_neutral_left by (simp add: sum.subset_diff)
  qed
  finally have g_sum: "g = (j<n ν + 1. z' j * ^j)"
    using Suc_eq_plus1 atLeast0AtMost lessThan_Suc_atMost by presburger
  
  have zm_bound: "z' j + m j < " for j
    unfolding z'_def m_def using b_le_ℬ z_bound ℬ_ge_2
    by (smt (verit, best) add_diff_inverse_nat add_le_cancel_right b_ge_1 dual_order.strict_trans2 
        linorder_not_less zero_less_one)
 
  have tau_digits: "τ (z' j) (m j) = 0" for j
  proof (cases "j  n ` {1.. ν}")
    case False thus ?thesis unfolding z'_def by simp
  next
    obtain k where k_def: "b = 2^k" using b_power2 is_power2_def by auto

    case True
    hence "τ (z' j) (m j) = τ (z (inv n j)) ( - b)"
      using z'_def m_def by simp
    also have "... = τ (z (inv n j) + 0 * b) (0 + ( div b - 1) * b)"
      using b_dvd_ℬ by (simp add: mult.commute right_diff_distrib')
    also have "... = τ (z (inv n j)) 0 + τ 0 ( div b - 1)"
      using count_carries_add_shift_no_overflow k_def z_bound
      by (metis (no_types, lifting) add.right_neutral)
    also have "... = 0" 
      by simp
    finally show ?thesis .
  qed

  obtain k where k_def: " = 2^k"
    using ℬ_power2 is_power2_def by auto
  have k_ge_1: "1  k"
    using k_def ℬ_ge_2
    by (metis Suc_eq_plus1 Suc_leD antisym div_le_dividend div_self nat_1_add_1 numeral_eq_one_iff 
        order_refl power_0 semiring_norm(85))

  have "τ g M = τ (j<n ν + 1. z' j * ^j) (j<n ν + 1. m j * ^j)"
    using g_sum M_def Suc_eq_plus1 atLeast0AtMost lessThan_Suc_atMost by presburger
  also have "... = (j<n ν + 1. τ (z' j) (m j))"
    using count_carries_digitwise_no_overflow[OF k_ge_1] k_def zm_bound by blast
  also have "... = 0"
    using tau_digits by simp
  finally show ?thesis .
qed

lemma reverse_impl:
  assumes tau_zero: "τ g M = 0"
      and g_bound_C: "g < C"
  shows "z. (i. z i < b)  g = (i=1..ν. z i * ^(n i))"
proof -
  have g_bound_ℬ: "g < ^(n ν + 1)"
    using C_upper_bound g_bound_C n_def by simp
 
  have g_digit: "j > n ν  nth_digit g j  = 0" for j
    using nth_digit_def ℬ_ge_2 g_bound_ℬ
    by (metis Suc_eq_plus1 div_less linorder_not_less mod_less nat_1_add_1 not_less_eq_eq 
        order_less_le_trans power_increasing_iff zero_less_two)
  have g_sum: "g = (j<n ν + 1. nth_digit g j  * ^j)"
    using digit_gen_sum_repr[OF g_bound_ℬ] ℬ_ge_2 by auto

  have digit_bound: "j  {0..n ν}  nth_digit g j  + m j < " for j
  proof -
    obtain k where k_def: " = 2^k"
      using ℬ_power2 is_power2_def by auto
    have k_ge_1: "1  k"
      using k_def ℬ_ge_2
      by (metis Suc_eq_plus1 Suc_leD antisym div_le_dividend div_self nat_1_add_1 numeral_eq_one_iff 
          order_refl power_0 semiring_norm(85))
    assume "j  {0..n ν}"
    hence "j < n ν + 1" 
      by auto
    moreover have 0: "m j <   nth_digit g j  < " for j
      unfolding m_def nth_digit_def using ℬ_ge_2 b_ge_1 by fastforce
    ultimately have "τ (j<n ν + 1. nth_digit g j  * ^j) (j<n ν + 1. m j * ^j)  
      τ (nth_digit g j ) (m j)"
      using count_carries_digitwise_specific[OF k_ge_1 _ `j < n ν + 1`] k_def by force
    hence "τ (nth_digit g j ) (m j) = 0"
      using tau_zero g_sum M_def
      by (metis Suc_eq_plus1 atLeast0AtMost le_zero_eq lessThan_Suc_atMost)
    thus ?thesis
      using no_carry_no_overflow k_def 0 by auto
  qed
 
  define z where "z = (λj. nth_digit g (n j) )"
 
  have digit_zero: "j  n ` {1..ν}  nth_digit g j  = 0" for j
  proof (cases "j  {0..n ν}")
    case False thus ?thesis using g_digit by auto
  next
    case True
    assume assm: "j  n ` {1..ν}"
    have "nth_digit g j  + m j < "
      using True digit_bound by simp
    hence "nth_digit g j  + ( - 1) < "
      using m_def assm by auto
    thus ?thesis
      using nth_digit_def by simp
  qed

  have z_bound: "z i < b" for i
  proof (cases "i  {1..ν}")
    case False
    hence "(n i)  n ` {1..ν}"
      using n_inj_on by blast
    hence "nth_digit g (n i)  = 0"
      unfolding z_def using digit_zero by auto
    thus ?thesis 
      using b_ge_1 z_def by auto
  next
    case True
    hence "n i  {0..n ν}"
      unfolding n_def by (simp add: power_increasing)
    hence "z i + m (n i) < "
      using digit_bound z_def by auto
    hence "z i + ( - b) < "
      using m_def by (metis True imageI)
    thus "z i < b" 
      using b_le_ℬ by auto
  qed

  have g_sum_z: "g = (i=1..ν. z i * ^(n i))" 
  proof -
    have "g = (j{0..n ν}. nth_digit g j  * ^j)"
      using g_sum by (metis Suc_eq_plus1 atLeast0AtMost lessThan_Suc_atMost)
    also have "... = (jn`{1..ν}. nth_digit g j  * ^j)"
    proof -
      have "n ` {1..ν}  {0..n ν}"
        unfolding n_def
        by (metis (no_types, opaque_lifting) One_nat_def Suc_eq_plus1 Suc_le_mono atLeast0AtMost 
            atLeastAtMost_iff image_subset_iff power_increasing zero_le)
      moreover have "j  {0..n ν} - n ` {1..ν}  nth_digit g j   * ^j = 0" for j
        using digit_zero by auto
      ultimately show ?thesis 
        using sum.mono_neutral_right by (metis (no_types, lifting) finite_atLeastAtMost)
    qed
    also have "... = (i{1..ν}. nth_digit g (n i)  * ^(n i))"
      using n_inj_on sum.reindex by auto
    finally show ?thesis 
      using z_def by simp
  qed

  show ?thesis
    using z_bound g_sum_z by auto 
qed

lemma masking_lemma:
  "(z. (i. z i < b)  g = (i=1..ν. z i * ^(n i)))  (g < C  τ g M = 0)"
  using reverse_impl direct_tau_zero direct_g_bound by auto

end

end