Theory Word_Lib.Word_Lemmas
section "Lemmas with Generic Word Length"
theory Word_Lemmas
  imports
    Type_Syntax
    Signed_Division_Word
    Signed_Words
    More_Word
    Most_significant_bit
    Enumeration_Word
    Aligned
    Bit_Shifts_Infix_Syntax
    Boolean_Inequalities
    Word_EqI
begin
context
  includes bit_operations_syntax
begin
lemma word_max_le_or:
  "max x y ≤ x OR y" for x :: "'a::len word"
  by (simp add: word_bool_le_funs)
lemma word_and_le_min:
  "x AND y ≤ min x y" for x :: "'a::len word"
  by (simp add: word_bool_le_funs)
lemma word_not_le_eq:
  "(NOT x ≤ y) = (NOT y ≤ x)" for x :: "'a::len word"
  by transfer (auto simp: take_bit_not_eq_mask_diff)
lemma word_not_le_not_eq[simp]:
  "(NOT y ≤ NOT x) = (x ≤ y)" for x :: "'a::len word"
  by (subst word_not_le_eq) simp
lemma not_min_eq:
  "NOT (min x y) = max (NOT x) (NOT y)" for x :: "'a::len word"
  unfolding min_def max_def
  by auto
lemma not_max_eq:
  "NOT (max x y) = min (NOT x) (NOT y)" for x :: "'a::len word"
  unfolding min_def max_def
  by auto
lemma ucast_le_ucast_eq:
  fixes x y :: "'a::len word"
  assumes x: "x < 2 ^ n"
  assumes y: "y < 2 ^ n"
  assumes n: "n = LENGTH('b::len)"
  shows "(UCAST('a → 'b) x ≤ UCAST('a → 'b) y) ⟷ (x ≤ y)" (is "?L=_")
proof
  assume ?L then show "x ≤ y"
    by (metis x less_mask_eq le_ucast_ucast_le n ucast_ucast_mask)
qed (use n ucast_mono_le y in auto)
lemma ucast_zero_is_aligned:
  ‹is_aligned w n› if ‹UCAST('a::len → 'b::len) w = 0› ‹n ≤ LENGTH('b)›
proof (rule is_aligned_bitI)
  fix q
  assume ‹q < n›
  moreover have ‹bit (UCAST('a::len → 'b::len) w) q = bit 0 q›
    using that by simp
  with ‹q < n› ‹n ≤ LENGTH('b)› show ‹¬ bit w q›
    by (simp add: bit_simps)
qed
lemma unat_ucast_eq_unat_and_mask:
  "unat (UCAST('b::len → 'a::len) w) = unat (w AND mask LENGTH('a))"
  by (metis take_bit_eq_mask unsigned_take_bit_eq unsigned_ucast_eq)
lemma le_max_word_ucast_id:
  ‹UCAST('b → 'a) (UCAST('a → 'b) x) = x›
    if ‹x ≤ UCAST('b::len → 'a) (- 1)›
    for x :: ‹'a::len word›
proof -
  from that have a1: ‹x ≤ word_of_int (uint (word_of_int (2 ^ LENGTH('b) - 1) :: 'b word))›
    by (simp add: of_int_mask_eq)
  have f2: "((∃i ia. (0::int) ≤ i ∧ ¬ 0 ≤ i + - 1 * ia ∧ i mod ia ≠ i) ∨
            ¬ (0::int) ≤ - 1 + 2 ^ LENGTH('b) ∨ (0::int) ≤ - 1 + 2 ^ LENGTH('b) + - 1 * 2 ^ LENGTH('b) ∨
            (- (1::int) + 2 ^ LENGTH('b)) mod 2 ^ LENGTH('b) =
              - 1 + 2 ^ LENGTH('b)) = ((∃i ia. (0::int) ≤ i ∧ ¬ 0 ≤ i + - 1 * ia ∧ i mod ia ≠ i) ∨
            ¬ (1::int) ≤ 2 ^ LENGTH('b) ∨
            2 ^ LENGTH('b) + - (1::int) * ((- 1 + 2 ^ LENGTH('b)) mod 2 ^ LENGTH('b)) = 1)"
    by force
  have f3: "∀i ia. ¬ (0::int) ≤ i ∨ 0 ≤ i + - 1 * ia ∨ i mod ia = i"
    using mod_pos_pos_trivial by force
  have "(1::int) ≤ 2 ^ LENGTH('b)"
    by simp
  then have "2 ^ LENGTH('b) + - (1::int) * ((- 1 + 2 ^ LENGTH('b)) mod 2 ^ len_of TYPE ('b)) = 1"
    using f3 f2 by blast
  then have f4: "- (1::int) + 2 ^ LENGTH('b) = (- 1 + 2 ^ LENGTH('b)) mod 2 ^ LENGTH('b)"
    by linarith
  have f5: "x ≤ word_of_int (uint (word_of_int (- 1 + 2 ^ LENGTH('b))::'b word))"
    using a1 by force
  have f6: "2 ^ LENGTH('b) + - (1::int) = - 1 + 2 ^ LENGTH('b)"
    by force
  have f7: "- (1::int) * 1 = - 1"
    by auto
  have "∀x0 x1. (x1::int) - x0 = x1 + - 1 * x0"
    by force
  then have §: "x ≤ 2 ^ LENGTH('b) - 1"
    using f7 f6 f5 f4 by (metis uint_word_of_int wi_homs(2) word_arith_wis(8) word_of_int_2p)
  then have ‹uint x ≤ uint (2 ^ LENGTH('b) - (1 :: 'a word))›
    by (simp add: word_le_def)
  then have ‹uint x ≤ 2 ^ LENGTH('b) - 1›
    by (simp add: uint_word_ariths)
      (metis ‹1 ≤ 2 ^ LENGTH('b)› ‹uint x ≤ uint (2 ^ LENGTH('b) - 1)› linorder_not_less lt2p_lem uint_1 uint_minus_simple_alt uint_power_lower word_le_def zle_diff1_eq)
  then show ?thesis
    by (metis § and_mask_eq_iff_le_mask mask_eq_exp_minus_1 ucast_ucast_mask)
qed
lemma uint_shiftr_eq:
  ‹uint (w >> n) = uint w div 2 ^ n›
  by word_eqI
lemma bit_shiftl_word_iff [bit_simps]:
  ‹bit (w << m) n ⟷ m ≤ n ∧ n < LENGTH('a) ∧ bit w (n - m)›
  for w :: ‹'a::len word›
  by (simp add: bit_simps)
lemma bit_shiftr_word_iff:
  ‹bit (w >> m) n ⟷ bit w (m + n)›
  for w :: ‹'a::len word›
  by (simp add: bit_simps)
lemma uint_sshiftr_eq:
  ‹uint (w >>> n) = take_bit LENGTH('a) (sint w div 2 ^  n)›
  for w :: ‹'a::len word›
  by (word_eqI_solve dest: test_bit_lenD)
lemma sshiftr_n1: "-1 >>> n = -1"
  by (simp add: sshiftr_def)
lemma nth_sshiftr:
  "bit (w >>> m) n = (n < size w ∧ (if n + m ≥ size w then bit w (size w - 1) else bit w (n + m)))"
  by (simp add: add.commute bit_sshiftr_iff test_bit_over wsst_TYs(3))
lemma sshiftr_numeral:
  ‹(numeral k >>> numeral n :: 'a::len word) =
    word_of_int (signed_take_bit (LENGTH('a) - 1) (numeral k) >> numeral n)›
  using signed_drop_bit_word_numeral [of n k] by (simp add: sshiftr_def shiftr_def)
lemma sshiftr_div_2n: "sint (w >>> n) = sint w div 2 ^ n"
  by word_eqI (cases ‹n < LENGTH('a)›; fastforce simp: le_diff_conv2)
lemma mask_eq:
  ‹mask n = (1 << n) - (1 :: 'a::len word)›
  by (simp add: mask_eq_exp_minus_1 shiftl_def)
lemma nth_shiftl': "bit (w << m) n ⟷ n < size w ∧ n >= m ∧ bit w (n - m)"
  for w :: "'a::len word"
  by (simp add: bit_simps word_size ac_simps)
lemmas nth_shiftl = nth_shiftl' [unfolded word_size]
lemma nth_shiftr: "bit (w >> m) n = bit w (n + m)"
  for w :: "'a::len word"
  by (simp add: bit_simps ac_simps)
lemma shiftr_div_2n: "uint (shiftr w n) = uint w div 2 ^ n"
  by (fact uint_shiftr_eq)
lemma shiftl_rev: "shiftl w n = word_reverse (shiftr (word_reverse w) n)"
  by word_eqI_solve
lemma rev_shiftl: "word_reverse w << n = word_reverse (w >> n)"
  by (simp add: shiftl_rev)
lemma shiftr_rev: "w >> n = word_reverse (word_reverse w << n)"
  by (simp add: rev_shiftl)
lemma rev_shiftr: "word_reverse w >> n = word_reverse (w << n)"
  by (simp add: shiftr_rev)
lemmas ucast_up =
  rc1 [simplified rev_shiftr [symmetric] revcast_ucast [symmetric]]
lemmas ucast_down =
  rc2 [simplified rev_shiftr revcast_ucast [symmetric]]
lemma shiftl_zero_size: "size x ≤ n ⟹ x << n = 0"
  for x :: "'a::len word"
  by (simp add: shiftl_def word_size)
lemma shiftl_t2n: "shiftl w n = 2 ^ n * w"
  for w :: "'a::len word"
  by (simp add: shiftl_def push_bit_eq_mult)
lemma word_shift_by_2:
  "x * 4 = (x::'a::len word) << 2"
  by (simp add: shiftl_t2n)
lemma word_shift_by_3:
  "x * 8 = (x::'a::len word) << 3"
  by (simp add: shiftl_t2n)
lemma slice_shiftr: "slice n w = ucast (w >> n)"
  by word_eqI (cases ‹n ≤ LENGTH('b)›; fastforce simp: ac_simps dest: bit_imp_le_length)
lemma shiftr_zero_size: "size x ≤ n ⟹ x >> n = 0"
  for x :: "'a :: len word"
  by word_eqI
lemma shiftr_x_0 [simp]: "x >> 0 = x"
  for x :: "'a::len word"
  by (simp add: shiftr_def)
lemma shiftl_x_0 [simp]: "x << 0 = x"
  for x :: "'a::len word"
  by (simp add: shiftl_def)
lemmas shiftl0 = shiftl_x_0
lemma shiftr_1 [simp]: "(1::'a::len word) >> n = (if n = 0 then 1 else 0)"
  by (simp add: shiftr_def)
lemma and_not_mask:
  "w AND NOT (mask n) = (w >> n) << n"
  for w :: ‹'a::len word›
  by word_eqI_solve
lemma and_mask:
  "w AND mask n = (w << (size w - n)) >> (size w - n)"
  for w :: ‹'a::len word›
  by word_eqI_solve
lemma shiftr_div_2n_w: "w >> n = w div (2^n :: 'a :: len word)"
  by (fact shiftr_eq_div)
lemma le_shiftr:
  "u ≤ v ⟹ u >> (n :: nat) ≤ (v :: 'a :: len word) >> n"
  unfolding shiftr_def
  apply transfer
  apply (simp add: take_bit_drop_bit)
  apply (simp add: drop_bit_eq_div zdiv_mono1)
  done
lemma le_shiftr':
  "⟦ u >> n ≤ v >> n ; u >> n ≠ v >> n ⟧ ⟹ (u::'a::len word) ≤ v"
  by (metis le_cases le_shiftr verit_la_disequality)
lemma shiftr_mask_le:
  "n ≤ m ⟹ mask n >> m = (0 :: 'a::len word)"
  by word_eqI
lemma shiftr_mask [simp]:
  ‹mask m >> m = (0::'a::len word)›
  by (rule shiftr_mask_le) simp
lemma le_mask_iff:
  "(w ≤ mask n) = (w >> n = 0)"
  for w :: ‹'a::len word›
  by (simp add: less_eq_mask_iff_take_bit_eq_self shiftr_def take_bit_eq_self_iff_drop_bit_eq_0)
lemma and_mask_eq_iff_shiftr_0:
  "(w AND mask n = w) = (w >> n = 0)"
  for w :: ‹'a::len word›
  by (simp flip: take_bit_eq_mask add: shiftr_def take_bit_eq_self_iff_drop_bit_eq_0)
lemma mask_shiftl_decompose:
  "mask m << n = mask (m + n) AND NOT (mask n :: 'a::len word)"
  by word_eqI_solve
lemma shiftl_over_and_dist:
  fixes a::"'a::len word"
  shows "(a AND b) << c = (a << c) AND (b << c)"
  by (unfold shiftl_def) (fact push_bit_and)
lemma shiftr_over_and_dist:
  fixes a::"'a::len word"
  shows "a AND b >> c = (a >> c) AND (b >> c)"
  by (unfold shiftr_def) (fact drop_bit_and)
lemma sshiftr_over_and_dist:
  fixes a::"'a::len word"
  shows "a AND b >>> c = (a >>> c) AND (b >>> c)"
  by word_eqI
lemma shiftl_over_or_dist:
  fixes a::"'a::len word"
  shows "a OR b << c = (a << c) OR (b << c)"
  by (unfold shiftl_def) (fact push_bit_or)
lemma shiftr_over_or_dist:
  fixes a::"'a::len word"
  shows "a OR b >> c = (a >> c) OR (b >> c)"
  by (unfold shiftr_def) (fact drop_bit_or)
lemma sshiftr_over_or_dist:
  fixes a::"'a::len word"
  shows "a OR b >>> c = (a >>> c) OR (b >>> c)"
  by word_eqI
lemmas shift_over_ao_dists =
  shiftl_over_or_dist shiftr_over_or_dist
  sshiftr_over_or_dist shiftl_over_and_dist
  shiftr_over_and_dist sshiftr_over_and_dist
lemma shiftl_shiftl:
  fixes a::"'a::len word"
  shows "a << b << c = a << (b + c)"
  by (word_eqI_solve simp: add.commute add.left_commute)
lemma shiftr_shiftr:
  fixes a::"'a::len word"
  shows "a >> b >> c = a >> (b + c)"
  by word_eqI (simp add: add.left_commute add.commute)
lemma shiftl_shiftr1:
  fixes a::"'a::len word"
  shows "c ≤ b ⟹ a << b >> c = a AND (mask (size a - b)) << (b - c)"
  by word_eqI (auto simp: ac_simps)
lemma shiftl_shiftr2:
  fixes a::"'a::len word"
  shows "b < c ⟹ a << b >> c = (a >> (c - b)) AND (mask (size a - c))"
  by word_eqI_solve
lemma shiftr_shiftl1:
  fixes a::"'a::len word"
  shows "c ≤ b ⟹ a >> b << c = (a >> (b - c)) AND (NOT (mask c))"
  by word_eqI_solve
lemma shiftr_shiftl2:
  fixes a::"'a::len word"
  shows "b < c ⟹ a >> b << c = (a << (c - b)) AND (NOT (mask c))"
  by word_eqI (auto simp: ac_simps)
lemmas multi_shift_simps =
  shiftl_shiftl shiftr_shiftr
  shiftl_shiftr1 shiftl_shiftr2
  shiftr_shiftl1 shiftr_shiftl2
lemma shiftr_mask2:
  "n ≤ LENGTH('a) ⟹ (mask n >> m :: ('a :: len) word) = mask (n - m)"
  by word_eqI_solve
lemma word_shiftl_add_distrib:
  fixes x :: "'a :: len word"
  shows "(x + y) << n = (x << n) + (y << n)"
  by (simp add: shiftl_t2n ring_distribs)
lemma mask_shift:
  "(x AND NOT (mask y)) >> y = x >> y"
  for x :: ‹'a::len word›
  by word_eqI
lemma shiftr_div_2n':
  "unat (w >> n) = unat w div 2 ^ n"
  by word_eqI
lemma shiftl_shiftr_id:
  "⟦ n < LENGTH('a); x < 2 ^ (LENGTH('a) - n) ⟧ ⟹ x << n >> n = (x::'a::len word)"
  by word_eqI (metis add.commute less_diff_conv)
lemma ucast_shiftl_eq_0:
  fixes w :: "'a :: len word"
  shows "⟦ n ≥ LENGTH('b) ⟧ ⟹ ucast (w << n) = (0 :: 'b :: len word)"
  by (transfer fixing: n) (simp add: take_bit_push_bit)
lemma word_shift_nonzero:
  fixes x:: "'a::len word"
  assumes "x ≤ 2 ^ m"
      and mn: "m + n < LENGTH('a::len)"
      and "x ≠ 0"
    shows "x << n ≠ 0"
proof -
  have "0 < unat x"
    by (simp add: assms unat_gt_0)
  moreover
  have "unat x ≤ 2 ^ m"
    by (simp add: assms word_unat_less_le)
  then have §: "2 ^ n * unat x < 2 ^ LENGTH('a)"
    by (metis add_diff_cancel_right' mn le_add2 nat_le_power_trans order_le_less_trans unat_lt2p unat_power_lower)
  ultimately have "0 < 2 ^ n * unat x mod 2 ^ LENGTH('a)"
    by simp
  with § show ?thesis
    by (metis add.commute add_lessD1 less_zeroE mn mult.commute shiftl_eq_mult unat_0 unat_power_lower unat_word_ariths(2))
qed
lemma word_shiftr_lt:
  fixes w :: "'a::len word"
  shows "unat (w >> n) < (2 ^ (LENGTH('a) - n))"
  by (metis mult.commute add_lessD1 div_mod_decomp nat_power_less_diff shiftr_div_2n' unat_lt2p)
lemma shiftr_less_t2n':
  "⟦ x AND mask (n + m) = x; m < LENGTH('a) ⟧ ⟹ x >> n < 2 ^ m" for x :: "'a :: len word"
  by (metis and_mask_eq_iff_shiftr_0 eq_mask_less shiftr_shiftr)
lemma shiftr_less_t2n:
  "x < 2 ^ (n + m) ⟹ x >> n < 2 ^ m" for x :: "'a :: len word"
  by (meson le_add2 less_mask_eq order_le_less_trans p2_gt_0 shiftr_less_t2n' word_gt_a_gt_0)
lemma shiftr_eq_0: "n ≥ LENGTH('a) ⟹ ((w::'a::len word) >> n) = 0"
  by (metis drop_bit_word_beyond shiftr_def)
lemma shiftl_less_t2n:
  fixes x :: "'a :: len word"
  shows "⟦ x < (2 ^ (m - n)); m < LENGTH('a) ⟧ ⟹ (x << n) < 2 ^ m"
  by (simp add: shiftl_def take_bit_push_bit word_size flip: mask_eq_iff_w2p take_bit_eq_mask)
lemma shiftl_less_t2n':
  "(x::'a::len word) < 2 ^ m ⟹ m+n < LENGTH('a) ⟹ x << n < 2 ^ (m + n)"
  by (simp add: shiftl_less_t2n)
lemma scast_bit_test [simp]:
  "scast ((1 :: 'a::len signed word) << n) = (1 :: 'a word) << n"
  by word_eqI
lemma signed_shift_guard_to_word:
  ‹unat x * 2 ^ y < 2 ^ n ⟷ x = 0 ∨ x < 1 << n >> y›
    if ‹n < LENGTH('a)› ‹0 < n›
    for x :: ‹'a::len word›
proof (cases ‹x = 0›)
  case True
  then show ?thesis
    by simp
next
  case False
  then have ‹unat x ≠ 0›
    by (simp add: unat_eq_0)
  then have ‹unat x ≥ 1›
    by simp
  show ?thesis
  proof (cases ‹y < n›)
    case False
    then have ‹n ≤ y›
      by simp
    then obtain q where ‹y = n + q›
      using le_Suc_ex by blast
    moreover have ‹(2 :: nat) ^ n >> n + q ≤ 1›
      by (simp add: drop_bit_eq_div power_add shiftr_def)
    ultimately show ?thesis
      using ‹x ≠ 0› ‹unat x ≥ 1› ‹n < LENGTH('a)›
      by (simp add: power_add not_less word_le_nat_alt unat_drop_bit_eq shiftr_def shiftl_def)
  next
    case True
    with that have ‹y < LENGTH('a)›
      by simp
    show ?thesis
    proof (cases ‹2 ^ n = unat x * 2 ^ y›)
      case True
      moreover have ‹unat x * 2 ^ y < 2 ^ LENGTH('a)›
        using ‹n < LENGTH('a)› by (simp flip: True)
      moreover have ‹(word_of_nat (2 ^ n) :: 'a word) = word_of_nat (unat x * 2 ^ y)›
        using True by simp
      then have ‹2 ^ n = x * 2 ^ y›
        by simp
      ultimately show ?thesis
        using ‹y < LENGTH('a)›
        by (auto simp: drop_bit_eq_div word_less_nat_alt unat_div unat_word_ariths
          shiftr_def shiftl_def)
    next
      case False
      with ‹y < n› have *: ‹unat x ≠ 2 ^ n div 2 ^ y›
        by (auto simp flip: power_sub power_add)
      have ‹unat x * 2 ^ y < 2 ^ n ⟷ unat x * 2 ^ y ≤ 2 ^ n›
        using False by (simp add: less_le)
      also have ‹… ⟷ unat x ≤ 2 ^ n div 2 ^ y›
        by (simp add: less_eq_div_iff_mult_less_eq)
      also have ‹… ⟷ unat x < 2 ^ n div 2 ^ y›
        using * by (simp add: less_le)
      finally show ?thesis
      using that ‹x ≠ 0› by (simp flip: push_bit_eq_mult drop_bit_eq_div
        add: shiftr_def shiftl_def unat_drop_bit_eq word_less_iff_unsigned [where ?'a = nat])
    qed
  qed
qed
lemma shiftr_not_mask_0:
  "n+m ≥ LENGTH('a :: len) ⟹ ((w::'a::len word) >> n) AND NOT (mask m) = 0"
  by word_eqI
lemma shiftl_mask_is_0[simp]:
  "(x << n) AND mask n = 0"
  for x :: ‹'a::len word›
  by (simp flip: take_bit_eq_mask add: take_bit_push_bit shiftl_def)
lemma rshift_sub_mask_eq:
  "(a >> (size a - b)) AND mask b = a >> (size a - b)"
  for a :: ‹'a::len word›
  by (simp add: and_mask_eq_iff_shiftr_0 shiftr_shiftr shiftr_zero_size)
lemma shiftl_shiftr3:
  "b ≤ c ⟹ a << b >> c = (a >> c - b) AND mask (size a - c)"
  for a :: ‹'a::len word›
  by (cases "b = c") (simp_all add: shiftl_shiftr1 shiftl_shiftr2)
lemma and_mask_shiftr_comm:
  "m ≤ size w ⟹ (w AND mask m) >> n = (w >> n) AND mask (m-n)"
  for w :: ‹'a::len word›
  by (simp add: and_mask shiftr_shiftr) (simp add: word_size shiftl_shiftr3)
lemma and_mask_shiftl_comm:
  "m+n ≤ size w ⟹ (w AND mask m) << n = (w << n) AND mask (m+n)"
  for w :: ‹'a::len word›
  by (simp add: and_mask word_size shiftl_shiftl) (simp add: shiftl_shiftr1)
lemma le_mask_shiftl_le_mask: "s = m + n ⟹ x ≤ mask n ⟹ x << m ≤ mask s"
  for x :: ‹'a::len word›
  by (simp add: le_mask_iff shiftl_shiftr3)
lemma word_and_1_shiftl:
  "x AND (1 << n) = (if bit x n then (1 << n) else 0)" for x :: "'a :: len word"
  by word_eqI_solve
lemmas word_and_1_shiftls'
    = word_and_1_shiftl[where n=0]
      word_and_1_shiftl[where n=1]
      word_and_1_shiftl[where n=2]
lemmas word_and_1_shiftls = word_and_1_shiftls' [simplified]
lemma word_and_mask_shiftl:
  "x AND (mask n << m) = ((x >> m) AND mask n) << m"
  for x :: ‹'a::len word›
  by word_eqI_solve
lemma shift_times_fold:
  "(x :: 'a :: len word) * (2 ^ n) << m = x << (m + n)"
  by (simp add: shiftl_t2n ac_simps power_add)
lemma of_bool_nth:
  "of_bool (bit x v) = (x >> v) AND 1"
  for x :: ‹'a::len word›
  by (simp add: bit_iff_odd_drop_bit word_and_1 shiftr_def)
lemma shiftr_mask_eq:
  "(x >> n) AND mask (size x - n) = x >> n" for x :: "'a :: len word"
  by (word_eqI_solve dest: test_bit_lenD)
lemma shiftr_mask_eq':
  "m = (size x - n) ⟹ (x >> n) AND mask m = x >> n" for x :: "'a :: len word"
  by (simp add: shiftr_mask_eq)
lemma and_eq_0_is_nth:
  fixes x :: "'a :: len word"
  shows "y = 1 << n ⟹ ((x AND y) = 0) = (¬ (bit x n))"
  by (simp add: and_exp_eq_0_iff_not_bit shiftl_def)
lemma word_shift_zero:
  "⟦ x << n = 0; x ≤ 2^m; m + n < LENGTH('a)⟧ ⟹ (x::'a::len word) = 0"
  using word_shift_nonzero by blast
lemma mask_shift_and_negate[simp]:"(w AND mask n << m) AND NOT (mask n << m) = 0"
  for w :: ‹'a::len word›
  by word_eqI
lemma bitfield_op_twice:
  "(x AND NOT (mask n << m) OR ((y AND mask n) << m)) AND NOT (mask n << m) = x AND NOT (mask n << m)"
  for x :: ‹'a::len word›
  by word_eqI_solve
lemma bitfield_op_twice'':
  "⟦NOT a = b << c; ∃x. b = mask x⟧ ⟹ (x AND a OR (y AND b << c)) AND a = x AND a"
  for a b :: ‹'a::len word›
  by word_eqI_solve
lemma shiftr1_unfold: "x div 2 = x >> 1"
  by (simp add: drop_bit_eq_div shiftr_def)
lemma shiftr1_is_div_2: "(x::('a::len) word) >> 1 = x div 2"
  by (simp add: drop_bit_eq_div shiftr_def)
lemma shiftl1_is_mult: "(x << 1) = (x :: 'a::len word) * 2"
  by (metis One_nat_def mult_2 mult_2_right one_add_one
        power_0 power_Suc shiftl_t2n)
lemma shiftr1_lt:"x ≠ 0 ⟹ (x::('a::len) word) >> 1 < x"
  by (simp add: div_less_dividend_word drop_bit_eq_div shiftr_def)
lemma shiftr1_0_or_1:"(x::('a::len) word) >> 1 = 0 ⟹ x = 0 ∨ x = 1"
  by (metis le_mask_iff mask_1 not_less not_less_iff_gr_or_eq word_less_1)
lemma shiftr1_irrelevant_lsb: "bit (x::('a::len) word) 0 ∨ x >> 1 = (x + 1) >> 1"
  by (auto simp: bit_0 shiftr_def drop_bit_Suc ac_simps elim: evenE)
lemma shiftr1_0_imp_only_lsb:"((x::('a::len) word) + 1) >> 1 = 0 ⟹ x = 0 ∨ x + 1 = 0"
  by (metis One_nat_def shiftr1_0_or_1 word_less_1 word_overflow)
lemma shiftr1_irrelevant_lsb': "¬ (bit (x::('a::len) word) 0) ⟹ x >> 1 = (x + 1) >> 1"
  using shiftr1_irrelevant_lsb [of x] by simp
lemma cast_chunk_assemble_id:
  assumes n: "n = LENGTH('a::len)"
    and "m = LENGTH('b::len)"
    and "n * 2 = m"
  shows "UCAST('a → 'b) (UCAST('b → 'a) x::'a word) OR (UCAST('a → 'b) (UCAST('b → 'a) (x >> n)::'a word) << n) = x"
proof -
  have "((ucast ((ucast (x >> n))::'a word))::'b word) = x >> n"
    by (metis and_mask_eq_iff_shiftr_0 assms mult_2_right order_refl shiftr_eq_0 shiftr_shiftr ucast_ucast_mask)
  with n show ?thesis
    by (auto simp: ucast_ucast_mask simp flip: and_not_mask word_ao_dist2)
qed
lemma cast_chunk_scast_assemble_id:
  fixes x:: "'b::len word"
  assumes "n = LENGTH('a::len)"
      and "m = LENGTH('b)"
      and "n * 2 = m"
    shows "UCAST('a → 'b) (SCAST('b → 'a) x) OR (UCAST('a → 'b) (SCAST('b → 'a) (x >> n)) << n) = x"
proof -
  have "SCAST('b → 'a) x = UCAST('b → 'a) x"
    by (metis assms down_cast_same is_up is_up_down le_add2 mult_2_right)
  then show ?thesis
    by (metis assms cast_chunk_assemble_id down_cast_same is_down le_add2 mult_2_right)
qed
lemma unat_shiftr_less_t2n:
  fixes x :: "'a :: len word"
  shows "unat x < 2 ^ (n + m) ⟹ unat (x >> n) < 2 ^ m"
  by (simp add: shiftr_div_2n' power_add mult.commute less_mult_imp_div_less)
lemma ucast_less_shiftl_helper:
  "⟦ LENGTH('b) + 2 < LENGTH('a); 2 ^ (LENGTH('b) + 2) ≤ n⟧
    ⟹ (ucast (x :: 'b::len word) << 2) < (n :: 'a::len word)"
  by (meson add_lessD1 order_less_le_trans shiftl_less_t2n' ucast_less)
lemma NOT_mask_shifted_lenword:
  "NOT (mask len << (LENGTH('a) - len) ::'a::len word) = mask (LENGTH('a) - len)"
  by word_eqI_solve
lemma shiftr_less:
  "(w::'a::len word) < k ⟹ w >> n < k"
  by (metis div_le_dividend le_less_trans shiftr_div_2n' unat_arith_simps(2))
lemma word_and_notzeroD:
  "w AND w' ≠ 0 ⟹ w ≠ 0 ∧ w' ≠ 0"
  by auto
lemma shiftr_le_0:
  "unat (w::'a::len word) < 2 ^ n ⟹ w >> n = (0::'a::len word)"
  by (auto simp: take_bit_word_eq_self_iff word_less_nat_alt shiftr_def
    simp flip: take_bit_eq_self_iff_drop_bit_eq_0 intro: ccontr)
lemma of_nat_shiftl:
  "(of_nat x << n) = (of_nat (x * 2 ^ n) :: ('a::len) word)"
proof -
  have "(of_nat x::'a word) << n = of_nat (2 ^ n) * of_nat x"
    using shiftl_t2n by (metis word_unat_power)
  thus ?thesis by simp
qed
lemma shiftl_1_not_0:
  "n < LENGTH('a) ⟹ (1::'a::len word) << n ≠ 0"
  by (simp add: shiftl_t2n)
lemma bitmagic_zeroLast_leq_or1Last:
  "(a::('a::len) word) AND (mask len << x - len) ≤ a OR mask (y - len)"
  by (meson le_word_or2 order_trans word_and_le2)
lemma zero_base_lsb_imp_set_eq_as_bit_operation:
  fixes base ::"'a::len word"
  assumes valid_prefix: "mask (LENGTH('a) - len) AND base = 0"
  shows "(base = NOT (mask (LENGTH('a) - len)) AND a) ⟷
         (a ∈ {base .. base OR mask (LENGTH('a) - len)})"
proof
  have helper3: "x OR y = x OR y AND NOT x" for x y ::"'a::len word" by (simp add: word_oa_dist2)
  from assms show "base = NOT (mask (LENGTH('a) - len)) AND a ⟹
                    a ∈ {base..base OR mask (LENGTH('a) - len)}"
    by (metis and.commute atLeastAtMost_iff helper3 le_word_or2 or.commute word_and_le1)
next
  assume "a ∈ {base..base OR mask (LENGTH('a) - len)}"
  hence a: "base ≤ a ∧ a ≤ base OR mask (LENGTH('a) - len)" by simp
  show "base = NOT (mask (LENGTH('a) - len)) AND a"
  proof -
    have f2: "∀x⇩0. base AND NOT (mask x⇩0) ≤ a AND NOT (mask x⇩0)"
      using a neg_mask_mono_le by blast
    have f3: "∀x⇩0. a AND NOT (mask x⇩0) ≤ (base OR mask (LENGTH('a) - len)) AND NOT (mask x⇩0)"
      using a neg_mask_mono_le by blast
    have f4: "base = base AND NOT (mask (LENGTH('a) - len))"
      using valid_prefix by (metis mask_eq_0_eq_x word_bw_comms(1))
    hence f5: "∀x⇩6. (base OR x⇩6) AND NOT (mask (LENGTH('a) - len)) =
                      base OR x⇩6 AND NOT (mask (LENGTH('a) - len))"
      using word_ao_dist by (metis)
    have f6: "∀x⇩2 x⇩3. a AND NOT (mask x⇩2) ≤ x⇩3 ∨
                      ¬ (base OR mask (LENGTH('a) - len)) AND NOT (mask x⇩2) ≤ x⇩3"
      using f3 dual_order.trans by auto
    have "base = (base OR mask (LENGTH('a) - len)) AND NOT (mask (LENGTH('a) - len))"
      using f5 by auto
    hence "base = a AND NOT (mask (LENGTH('a) - len))"
      using f2 f4 f6 by (metis eq_iff)
    thus "base = NOT (mask (LENGTH('a) - len)) AND a"
      by (metis word_bw_comms(1))
  qed
qed
lemma of_nat_eq_signed_scast:
  "(of_nat x = (y :: ('a::len) signed word))
   = (of_nat x = (scast y :: 'a word))"
  by (metis scast_of_nat scast_scast_id(2))
lemma word_aligned_add_no_wrap_bounded:
  "⟦ w + 2^n ≤ x; w + 2^n ≠ 0; is_aligned w n ⟧ ⟹ (w::'a::len word) < x"
  by (blast dest: is_aligned_no_overflow le_less_trans word_leq_le_minus_one)
lemma mask_Suc:
  "mask (Suc n) = (2 :: 'a::len word) ^ n + mask n"
  by (simp add: mask_eq_decr_exp)
lemma mask_mono:
  "sz' ≤ sz ⟹ mask sz' ≤ (mask sz :: 'a::len word)"
  by (simp add: le_mask_iff shiftr_mask_le)
lemma aligned_mask_disjoint:
  "⟦ is_aligned (a :: 'a :: len word) n; b ≤ mask n ⟧ ⟹ a AND b = 0"
  by (metis and_zero_eq is_aligned_mask le_mask_imp_and_mask word_bw_lcs(1))
lemma word_and_or_mask_aligned:
  "⟦ is_aligned a n; b ≤ mask n ⟧ ⟹ a + b = a OR b"
  by (simp add: aligned_mask_disjoint word_plus_and_or_coroll)
lemma word_and_or_mask_aligned2:
  ‹is_aligned b n ⟹ a ≤ mask n ⟹ a + b = a OR b›
  using word_and_or_mask_aligned [of b n a] by (simp add: ac_simps)
lemma is_aligned_ucastI:
  "is_aligned w n ⟹ is_aligned (ucast w) n"
  by (simp add: bit_ucast_iff is_aligned_nth)
lemma ucast_le_maskI:
  "a ≤ mask n ⟹ UCAST('a::len → 'b::len) a ≤ mask n"
  by (metis and_mask_eq_iff_le_mask ucast_and_mask)
lemma ucast_add_mask_aligned:
  "⟦ a ≤ mask n; is_aligned b n ⟧ ⟹ UCAST ('a::len → 'b::len) (a + b) = ucast a + ucast b"
  by (metis add.commute is_aligned_ucastI ucast_le_maskI ucast_or_distrib word_and_or_mask_aligned)
lemma ucast_shiftl:
  "LENGTH('b) ≤ LENGTH ('a) ⟹ UCAST ('a::len → 'b::len) x << n = ucast (x << n)"
  by word_eqI_solve
lemma ucast_leq_mask:
  "LENGTH('a) ≤ n ⟹ ucast (x::'a::len word) ≤ mask n"
  by (metis and_mask_eq_iff_le_mask ucast_and_mask ucast_id ucast_mask_drop)
lemma shiftl_inj:
  ‹x = y›
    if ‹x << n = y << n› ‹x ≤ mask (LENGTH('a) - n)› ‹y ≤ mask (LENGTH('a) - n)›
    for x y :: ‹'a::len word›
proof (cases ‹n < LENGTH('a)›)
  case False
  with that show ?thesis
    by simp
next
  case True
  moreover from that have ‹take_bit (LENGTH('a) - n) x = x› ‹take_bit (LENGTH('a) - n) y = y›
    by (simp_all add: less_eq_mask_iff_take_bit_eq_self)
  ultimately show ?thesis
    using ‹x << n = y << n› by (metis diff_less gr_implies_not0 linorder_cases linorder_not_le shiftl_shiftr_id shiftl_x_0 take_bit_word_eq_self_iff)
qed
lemma distinct_word_add_ucast_shift_inj:
  ‹p' = p ∧ off' = off›
  if *: ‹p + (UCAST('a::len → 'b::len) off << n) = p' + (ucast off' << n)›
    and ‹is_aligned p n'› ‹is_aligned p' n'› ‹n' = n + LENGTH('a)› ‹n' < LENGTH('b)›
proof -
  from ‹n' = n + LENGTH('a)›
  have [simp]: ‹n' - n = LENGTH('a)› ‹n + LENGTH('a) = n'›
    by simp_all
  from ‹is_aligned p n'› obtain q
    where p: ‹p = push_bit n' (word_of_nat q)› ‹q < 2 ^ (LENGTH('b) - n')›
    by (rule is_alignedE')
  from ‹is_aligned p' n'› obtain q'
    where p': ‹p' = push_bit n' (word_of_nat q')› ‹q' < 2 ^ (LENGTH('b) - n')›
    by (rule is_alignedE')
  define m :: nat where ‹m = unat off›
  then have off: ‹off = word_of_nat m›
    by simp
  define m' :: nat where ‹m' = unat off'›
  then have off': ‹off' = word_of_nat m'›
    by simp
  have ‹push_bit n' q + take_bit n' (push_bit n m) < 2 ^ LENGTH('b)›
    by (metis id_apply is_aligned_no_wrap''' of_nat_eq_id of_nat_push_bit p(1) p(2) take_bit_nat_eq_self_iff take_bit_nat_less_exp take_bit_push_bit that(2) that(5) unsigned_of_nat)
  moreover have ‹push_bit n' q' + take_bit n' (push_bit n m') < 2 ^ LENGTH('b)›
    by (metis ‹n' - n = LENGTH('a)› id_apply is_aligned_no_wrap''' m'_def of_nat_eq_id of_nat_push_bit off' p'(1) p'(2) take_bit_nat_eq_self_iff take_bit_push_bit that(3) that(5) unsigned_of_nat)
  ultimately have ‹push_bit n' q + take_bit n' (push_bit n m) = push_bit n' q' + take_bit n' (push_bit n m')›
    using * by (simp add: p p' off off' push_bit_of_nat push_bit_take_bit word_of_nat_inj unsigned_of_nat shiftl_def flip: of_nat_add)
  then have ‹int (push_bit n' q + take_bit n' (push_bit n m))
    = int (push_bit n' q' + take_bit n' (push_bit n m'))›
    by simp
  then have ‹concat_bit n' (int (push_bit n m)) (int q)
    = concat_bit n' (int (push_bit n m')) (int q')›
    by (simp add: of_nat_push_bit of_nat_take_bit concat_bit_eq)
  then show ?thesis
    by (simp add: p p' off off' take_bit_of_nat take_bit_push_bit word_of_nat_eq_iff concat_bit_eq_iff)
      (simp add: push_bit_eq_mult)
qed
lemma word_upto_Nil:
  "y < x ⟹ [x .e. y ::'a::len word] = []"
  by (simp add: upto_enum_red not_le word_less_nat_alt)
lemma word_enum_decomp_elem:
  assumes "[x .e. (y ::'a::len word)] = as @ a # bs"
  shows "x ≤ a ∧ a ≤ y"
proof -
  have "set as ⊆ set [x .e. y] ∧ a ∈ set [x .e. y]"
    using assms by (auto dest: arg_cong[where f=set])
  then show ?thesis by auto
qed
lemma word_enum_prefix:
  "[x .e. (y ::'a::len word)] = as @ a # bs ⟹ as = (if x < a then [x .e. a - 1] else [])"
proof (induction as arbitrary: x)
  case Nil
  show ?case
  proof (cases "x < y")
    case True
    then show ?thesis
      using local.Nil word_upto_Cons_eq by force
  next
    case False
    then show ?thesis
      using local.Nil not_less_iff_gr_or_eq word_upto_Nil by fastforce
  qed
next
  case (Cons b as)
  show ?case
  proof (cases x y rule: linorder_cases)
    case less
    with Cons.prems have "b + 1 ≤ a"
      using word_enum_decomp_elem word_upto_Cons_eq by fastforce
    moreover have "x + 1 ≤ a"
      using Cons.prems less word_enum_decomp_elem word_upto_Cons_eq by fastforce
    moreover have "(x + 1 ≤ a) = (x < a)"
      by (metis less word_Suc_le word_not_simps(3))
    ultimately
    show ?thesis
      using Cons less word_l_diffs(2) less_is_non_zero_p1 olen_add_eqv unat_plus_simple
        word_overflow_unat word_upto_Cons_eq by fastforce
  next
    case equal
    then show ?thesis
      using Cons.prems by auto
  next
    case greater
    then show ?thesis
      using Cons
      by (simp add: word_upto_Nil)
  qed
qed
lemma word_enum_decomp_set:
  "[x .e. (y ::'a::len word)] = as @ a # bs ⟹ a ∉ set as"
  by (metis distinct_append distinct_enum_upto' not_distinct_conv_prefix)
lemma word_enum_decomp:
  assumes "[x .e. (y ::'a::len word)] = as @ a # bs"
  shows "x ≤ a ∧ a ≤ y ∧ a ∉ set as ∧ (∀z ∈ set as. x ≤ z ∧ z ≤ y)"
proof -
  from assms
  have "set as ⊆ set [x .e. y] ∧ a ∈ set [x .e. y]"
    by (auto dest: arg_cong[where f=set])
  with word_enum_decomp_set[OF assms]
  show ?thesis by auto
qed
lemma of_nat_unat_le_mask_ucast:
  "⟦of_nat (unat t) = w; t ≤ mask LENGTH('a)⟧ ⟹ t = UCAST('a::len → 'b::len) w"
  by (clarsimp simp: ucast_nat_def ucast_ucast_mask simp flip: and_mask_eq_iff_le_mask)
lemma less_diff_gt0:
  "a < b ⟹ (0 :: 'a :: len word) < b - a"
  by unat_arith
lemma unat_plus_gt:
  "unat ((a :: 'a :: len word) + b) ≤ unat a + unat b"
  by (clarsimp simp: unat_plus_if_size)
lemma const_less:
  "⟦ (a :: 'a :: len word) - 1 < b; a ≠ b ⟧ ⟹ a < b"
  by (metis less_1_simp word_le_less_eq)
lemma add_mult_aligned_neg_mask:
  ‹(x + y * m) AND NOT(mask n) = (x AND NOT(mask n)) + y * m›
  if ‹m AND (2 ^ n - 1) = 0›
  for x y m :: ‹'a::len word›
  by (metis (no_types, opaque_lifting)
            add.assoc add.commute add.right_neutral add_uminus_conv_diff
            mask_eq_decr_exp mask_eqs(2) mask_eqs(6) mult.commute mult_zero_left
            subtract_mask(1) that)
lemma unat_of_nat_minus_1:
  "⟦ n < 2 ^ LENGTH('a); n ≠ 0 ⟧ ⟹ unat ((of_nat n:: 'a :: len word) - 1) = n - 1"
  by (simp add: of_nat_diff unat_eq_of_nat)
lemma word_eq_zeroI:
  "a ≤ a - 1 ⟹ a = 0" for a :: "'a :: len word"
  by (simp add: word_must_wrap)
lemma word_add_format:
  "(-1 :: 'a :: len  word) + b + c = b + (c - 1)"
  by simp
lemma upto_enum_word_nth:
  assumes "i ≤ j" and "k ≤ unat (j - i)"
  shows "[i .e. j] ! k = i + of_nat k"
proof -
  have "unat i + unat (j-i) < 2 ^ LENGTH('a)"
    by (metis add.commute ‹i ≤ j› eq_diff_eq no_olen_add_nat)
  then have "toEnum (unat i + k) = i + word_of_nat k"
    using assms by auto
  moreover have "[j] ! (k + unat i - unat j) = i + word_of_nat k"
    if "¬ k < unat j - unat i"  "unat i ≤ unat j"
    using that assms unat_sub by fastforce
  moreover have "[] ! k = i + word_of_nat k" if "¬ unat i ≤ unat j"
    using that ‹i ≤ j› word_less_eq_iff_unsigned by blast
  ultimately show ?thesis
    by (auto simp: upto_enum_def nth_append)
qed
lemma upto_enum_step_nth:
  "⟦ a ≤ c; n ≤ unat ((c - a) div (b - a)) ⟧
   ⟹ [a, b .e. c] ! n = a + of_nat n * (b - a)"
  by (clarsimp simp: upto_enum_step_def not_less[symmetric] upto_enum_word_nth)
lemma upto_enum_inc_1_len:
  fixes a :: "'a::len word"
  assumes "a < - 1"
  shows "[(0 :: 'a :: len word) .e. 1 + a] = [0 .e. a] @ [1 + a]"
proof -
  have "unat (1+a) = 1 + unat a"
    by (simp add: add_eq_0_iff assms unatSuc word_order.not_eq_extremum)
  with assms show ?thesis
    by (simp add: upto_enum_word)
qed
lemma neg_mask_add:
  "y AND mask n = 0 ⟹ x + y AND NOT(mask n) = (x AND NOT(mask n)) + y"
  for x y :: ‹'a::len word›
  by (clarsimp simp: mask_out_sub_mask mask_eqs(7)[symmetric] mask_twice)
lemma shiftr_shiftl_shiftr[simp]:
  "(x :: 'a :: len word) >> a << a >> a = x >> a"
  by (word_eqI_solve dest: bit_imp_le_length)
lemma add_right_shift:
  fixes x y :: ‹'a::len word›
  assumes "x AND mask n = 0" and "y AND mask n = 0" and "x ≤ x + y"
  shows "(x + y) >> n = (x >> n) + (y >> n)"
proof -
  obtain §: "is_aligned x n" "is_aligned y n" "unat x + unat y < 2 ^ LENGTH('a)"
    using assms is_aligned_mask no_olen_add_nat by blast
  then have "unat x div 2 ^ n + unat y div 2 ^ n < 2 ^ LENGTH('a)"
    by (metis add_le_mono add_lessD1 div_le_dividend le_Suc_ex)
  with § show ?thesis
    by (metis (no_types, lifting) div_add is_aligned_iff_dvd_nat shiftr_div_2n' unat_plus_if' word_unat_eq_iff)
qed
lemma sub_right_shift:
  "⟦ x AND mask n = 0; y AND mask n = 0; y ≤ x ⟧
   ⟹ (x - y) >> n = (x >> n :: 'a :: len word) - (y >> n)"
  by (smt (verit) add_diff_cancel_left' add_right_shift diff_0_right diff_add_cancel mask_eqs(4))
lemma and_and_mask_simple:
  "y AND mask n = mask n ⟹ (x AND y) AND mask n = x AND mask n"
  by (simp add: ac_simps)
lemma and_and_mask_simple_not:
  "y AND mask n = 0 ⟹ (x AND y) AND mask n = 0"
  by (simp add: ac_simps)
lemma word_and_le':
  "b ≤ c ⟹ (a :: 'a :: len word) AND b ≤ c"
  by (metis word_and_le1 order_trans)
lemma word_and_less':
  "b < c ⟹ (a :: 'a :: len word) AND b < c"
  by transfer simp
lemma shiftr_w2p:
  "x < LENGTH('a) ⟹ 2 ^ x = (2 ^ (LENGTH('a) - 1) >> (LENGTH('a) - 1 - x) :: 'a :: len word)"
  by word_eqI_solve
lemma t2p_shiftr:
  "⟦ b ≤ a; a < LENGTH('a) ⟧ ⟹ (2 :: 'a :: len word) ^ a >> b = 2 ^ (a - b)"
  by word_eqI_solve
lemma scast_1[simp]:
  "scast (1 :: 'a :: len signed word) = (1 :: 'a word)"
  by simp
lemma unsigned_uminus1 [simp]:
  ‹(unsigned (-1::'b::len word)::'c::len word) = mask LENGTH('b)›
  by (fact unsigned_minus_1_eq_mask)
lemma ucast_ucast_mask_eq:
  "⟦ UCAST('a::len → 'b::len) x = y; x AND mask LENGTH('b) = x ⟧ ⟹ x = ucast y"
  by (drule sym) (simp flip: take_bit_eq_mask add: unsigned_ucast_eq)
lemma ucast_up_eq:
  "⟦ ucast x = (ucast y::'b::len word); LENGTH('a) ≤ LENGTH ('b) ⟧
   ⟹ ucast x = (ucast y::'a::len word)"
  by (simp add: word_eq_iff bit_simps)
lemma ucast_up_neq:
  "⟦ ucast x ≠ (ucast y::'b::len word); LENGTH('b) ≤ LENGTH ('a) ⟧
   ⟹ ucast x ≠ (ucast y::'a::len word)"
  by (fastforce dest: ucast_up_eq)
lemma mask_AND_less_0:
  "⟦ x AND mask n = 0; m ≤ n ⟧ ⟹ x AND mask m = 0"
  for x :: ‹'a::len word›
  by (metis mask_twice2 word_and_notzeroD)
lemma mask_len_id:
  "(x :: 'a :: len word) AND mask LENGTH('a) = x"
  by simp
lemma scast_ucast_down_same:
  "LENGTH('b) ≤ LENGTH('a) ⟹ SCAST('a → 'b) = UCAST('a::len → 'b::len)"
  by (simp add: down_cast_same is_down)
lemma word_aligned_0_sum:
  "⟦ a + b = 0; is_aligned (a :: 'a :: len word) n; b ≤ mask n; n < LENGTH('a) ⟧
   ⟹ a = 0 ∧ b = 0"
  by (simp add: word_plus_and_or_coroll aligned_mask_disjoint word_or_zero)
lemma mask_eq1_nochoice:
  "⟦ LENGTH('a) > 1; (x :: 'a :: len word) AND 1 = x ⟧ ⟹ x = 0 ∨ x = 1"
  by (metis word_and_1)
lemma shiftr_and_eq_shiftl:
  "(w >> n) AND x = y ⟹ w AND (x << n) = (y << n)" for y :: "'a:: len word"
  by (smt (verit, best) and_not_mask mask_eq_0_eq_x shiftl_mask_is_0 shiftl_over_and_dist word_bw_lcs(1))
lemma add_mask_lower_bits':
  "⟦ len = LENGTH('a); is_aligned (x :: 'a :: len word) n;
     ∀n' ≥ n. n' < len ⟶ ¬ bit p n' ⟧
   ⟹ x + p AND NOT(mask n) = x"
  using add_mask_lower_bits by auto
lemma leq_mask_shift:
  "(x :: 'a :: len word) ≤ mask (low_bits + high_bits) ⟹ (x >> low_bits) ≤ mask high_bits"
  by (simp add: le_mask_iff shiftr_shiftr ac_simps)
lemma ucast_ucast_eq_mask_shift:
  "(x :: 'a :: len word) ≤ mask (low_bits + LENGTH('b))
   ⟹ ucast((ucast (x >> low_bits)) :: 'b :: len word) = x >> low_bits"
  by (simp add: and_mask_eq_iff_le_mask leq_mask_shift ucast_ucast_mask)
lemma const_le_unat:
  "⟦ b < 2 ^ LENGTH('a); of_nat b ≤ a ⟧ ⟹ b ≤ unat (a :: 'a :: len word)"
  by (simp add: word_le_nat_alt unsigned_of_nat take_bit_nat_eq_self)
lemma upt_enum_offset_trivial:
  "⟦ x < 2 ^ LENGTH('a) - 1 ; n ≤ unat x ⟧
   ⟹ ([(0 :: 'a :: len word) .e. x] ! n) = of_nat n"
  by (induct x arbitrary: n) (auto simp: upto_enum_word_nth)
lemma word_le_mask_out_plus_2sz:
  "x ≤ (x AND NOT(mask sz)) + 2 ^ sz - 1"
  for x :: ‹'a::len word›
  by (metis add_diff_eq word_neg_and_le)
lemma ucast_add:
  "ucast (a + (b :: 'a :: len word)) = ucast a + (ucast b :: ('a signed word))"
  by transfer (simp add: take_bit_add)
lemma ucast_minus:
  "ucast (a - (b :: 'a :: len word)) = ucast a - (ucast b :: ('a signed word))"
  by (metis (no_types, opaque_lifting) add_diff_cancel_left' add_diff_eq ucast_add)
lemma scast_ucast_add_one [simp]:
  "scast (ucast (x :: 'a::len word) + (1 :: 'a signed word)) = x + 1"
  by (metis scast_ucast_id ucast_1 ucast_add)
lemma word_and_le_plus_one:
  "a > 0 ⟹ (x :: 'a :: len word) AND (a - 1) < a"
  by (simp add: gt0_iff_gem1 word_and_less')
lemma unat_of_ucast_then_shift_eq_unat_of_shift[simp]:
  "LENGTH('b) ≥ LENGTH('a)
   ⟹ unat ((ucast (x :: 'a :: len word) :: 'b :: len word) >> n) = unat (x >> n)"
  by (simp add: shiftr_div_2n' unat_ucast_up_simp)
lemma unat_of_ucast_then_mask_eq_unat_of_mask[simp]:
  "LENGTH('b) ≥ LENGTH('a)
   ⟹ unat ((ucast (x :: 'a :: len word) :: 'b :: len word) AND mask m) = unat (x AND mask m)"
  by (metis ucast_and_mask unat_ucast_up_simp)
lemma shiftr_less_t2n3:
  "⟦ (2 :: 'a word) ^ (n + m) = 0; m < LENGTH('a) ⟧
   ⟹ (x :: 'a :: len word) >> n < 2 ^ m"
  by (fastforce intro: shiftr_less_t2n' simp: mask_eq_decr_exp power_overflow)
lemma unat_shiftr_le_bound:
  "⟦ 2 ^ (LENGTH('a :: len) - n) - 1 ≤ bnd; 0 < n ⟧
   ⟹ unat ((x :: 'a word) >> n) ≤ bnd"
  by (metis add.commute le_diff_conv less_Suc_eq_le order_less_le_trans plus_1_eq_Suc word_shiftr_lt)
lemma shiftr_eqD:
  "⟦ x >> n = y >> n; is_aligned x n; is_aligned y n ⟧
   ⟹ x = y"
  by (metis is_aligned_shiftr_shiftl)
lemma word_shiftr_shiftl_shiftr_eq_shiftr:
  "a ≥ b ⟹ (x :: 'a :: len word) >> a << b >> b = x >> a"
  by (simp add: mask_shift shiftr_shiftl1 shiftr_shiftr)
lemma of_int_uint_ucast:
   "of_int (uint (x :: 'a::len word)) = (ucast x :: 'b::len word)"
  by (fact Word.of_int_uint)
lemma mod_mask_drop:
  "⟦ m = 2 ^ n; 0 < m; mask n AND msk = mask n ⟧
   ⟹ (x mod m) AND msk = x mod m"
  for x :: ‹'a::len word›
  by (simp add: word_mod_2p_is_mask word_bw_assocs)
lemma mask_eq_ucast_eq:
  "⟦ x AND mask LENGTH('a) = (x :: ('c :: len word));
     LENGTH('a) ≤ LENGTH('b)⟧
    ⟹ ucast (ucast x :: ('a :: len word)) = (ucast x :: ('b :: len word))"
  by (metis ucast_and_mask ucast_id ucast_ucast_mask ucast_up_eq)
lemma of_nat_less_t2n:
  "of_nat i < (2 :: ('a :: len) word) ^ n ⟹ n < LENGTH('a) ∧ unat (of_nat i :: 'a word) < 2 ^ n"
  by (metis order_less_trans p2_gt_0 unat_less_power word_neq_0_conv)
lemma two_power_increasing_less_1:
  "⟦ n ≤ m; m ≤ LENGTH('a) ⟧ ⟹ (2 :: 'a :: len word) ^ n - 1 ≤ 2 ^ m - 1"
  by (meson le_m1_iff_lt order_le_less_trans p2_gt_0 two_power_increasing word_1_le_power word_le_minus_mono_left)
lemma word_sub_mono4:
  "⟦ y + x ≤ z + x; y ≤ y + x; z ≤ z + x ⟧ ⟹ y ≤ z" for y :: "'a :: len word"
  by (simp add: word_add_le_iff2)
lemma eq_or_less_helperD:
  "⟦ n = unat (2 ^ m - 1 :: 'a :: len word) ∨ n < unat (2 ^ m - 1 :: 'a word); m < LENGTH('a) ⟧
   ⟹ n < 2 ^ m"
  by (meson le_less_trans nat_less_le unat_less_power word_power_less_1)
lemma mask_sub:
  "n ≤ m ⟹ mask m - mask n = mask m AND NOT(mask n :: 'a::len word)"
  by (metis (full_types) and_mask_eq_iff_shiftr_0 mask_out_sub_mask shiftr_mask_le word_bw_comms(1))
lemma neg_mask_diff_bound:
  "sz'≤ sz ⟹ (ptr AND NOT(mask sz')) - (ptr AND NOT(mask sz)) ≤ 2 ^ sz - 2 ^ sz'"
  (is "_ ⟹ ?lhs ≤ ?rhs")
  for ptr :: ‹'a::len word›
proof -
  assume lt: "sz' ≤ sz"
  hence "?lhs = ptr AND (mask sz AND NOT(mask sz'))"
    by (metis add_diff_cancel_left' multiple_mask_trivia)
  also have "… ≤ ?rhs" using lt
    by (metis (mono_tags) add_diff_eq diff_eq_eq eq_iff mask_2pm1 mask_sub word_and_le')
  finally show ?thesis by simp
qed
lemma mask_out_eq_0:
  "⟦ idx < 2 ^ sz; sz < LENGTH('a) ⟧ ⟹ (of_nat idx :: 'a :: len word) AND NOT(mask sz) = 0"
  by (simp add: of_nat_power less_mask_eq mask_eq_0_eq_x)
lemma is_aligned_neg_mask_eq':
  "is_aligned ptr sz = (ptr AND NOT(mask sz) = ptr)"
  using is_aligned_mask mask_eq_0_eq_x by blast
lemma neg_mask_mask_unat:
  "sz < LENGTH('a)
   ⟹ unat ((ptr :: 'a :: len word) AND NOT(mask sz)) + unat (ptr AND mask sz) = unat ptr"
  by (metis AND_NOT_mask_plus_AND_mask_eq unat_plus_simple word_and_le2)
lemma unat_pow_le_intro:
  "LENGTH('a) ≤ n ⟹ unat (x :: 'a :: len word) < 2 ^ n"
  by (metis lt2p_lem not_le of_nat_le_iff of_nat_numeral semiring_1_class.of_nat_power uint_nat)
lemma unat_shiftl_less_t2n:
  ‹unat (x << n) < 2 ^ m›
  if ‹unat (x :: 'a :: len word) < 2 ^ (m - n)› ‹m < LENGTH('a)›
proof (cases ‹n ≤ m›)
  case False
  with that show ?thesis
    by (simp add: unsigned_eq_0_iff)
next
  case True
  moreover define q r where ‹q = m - n› and ‹r = LENGTH('a) - n - q›
  ultimately have ‹m - n = q› ‹m = n + q› ‹LENGTH('a) = r + q + n›
    using that by simp_all
  with that show ?thesis
    by (metis le_add2 order_le_less_trans shiftl_less_t2n unat_power_lower word_less_iff_unsigned)
qed
lemma unat_is_aligned_add:
  "⟦ is_aligned p n; unat d < 2 ^ n ⟧
   ⟹ unat (p + d AND mask n) = unat d ∧ unat (p + d AND NOT(mask n)) = unat p"
  by (metis add_diff_cancel_left' add_mask_lower_bits is_aligned_add_helper order_le_less_trans
       subtract_mask(2) unat_power_lower word_less_nat_alt)
lemma unat_shiftr_shiftl_mask_zero:
  "⟦ c + a ≥ LENGTH('a) + b ; c < LENGTH('a) ⟧
   ⟹ unat (((q :: 'a :: len word) >> a << b) AND NOT(mask c)) = 0"
  by (fastforce intro: unat_is_aligned_add[where p=0 and n=c, simplified, THEN conjunct2]
                       unat_shiftl_less_t2n unat_shiftr_less_t2n unat_pow_le_intro)
lemmas of_nat_ucast = ucast_of_nat[symmetric]
lemma shift_then_mask_eq_shift_low_bits:
  "x ≤ mask (low_bits + high_bits) ⟹ (x >> low_bits) AND mask high_bits = x >> low_bits"
  for x :: ‹'a::len word›
  by (simp add: leq_mask_shift le_mask_imp_and_mask)
lemma leq_low_bits_iff_zero:
  "⟦ x ≤ mask (low bits + high bits); x >> low_bits = 0 ⟧ ⟹ (x AND mask low_bits = 0) = (x = 0)"
  for x :: ‹'a::len word›
  using and_mask_eq_iff_shiftr_0 by force
lemma unat_less_iff:
  "⟦ unat (a :: 'a :: len word) = b; c < 2 ^ LENGTH('a) ⟧ ⟹ (a < of_nat c) = (b < c)"
  using unat_ucast_less_no_overflow_simp by blast
lemma is_aligned_no_overflow3:
 "⟦ is_aligned (a :: 'a :: len word) n; n < LENGTH('a); b < 2 ^ n; c ≤ 2 ^ n; b < c ⟧
  ⟹ a + b ≤ a + (c - 1)"
  by (meson is_aligned_no_wrap' le_m1_iff_lt not_le word_less_sub_1 word_plus_mono_right)
lemma mask_add_aligned_right:
  "is_aligned p n ⟹ (q + p) AND mask n = q AND mask n"
  by (simp add: mask_add_aligned add.commute)
lemma leq_high_bits_shiftr_low_bits_leq_bits_mask:
  "x ≤ mask high_bits ⟹ (x :: 'a :: len word) << low_bits ≤ mask (low_bits + high_bits)"
  by (metis le_mask_shiftl_le_mask)
lemma word_two_power_neg_ineq:
  assumes "2 ^ m ≠ (0::'a word)"
  shows "2 ^ n ≤ - (2 ^ m :: 'a :: len word)"
proof (cases "n < LENGTH('a) ∧ m < LENGTH('a)")
  case True
  with assms show ?thesis
  by (metis bit_minus_exp_iff linorder_not_le nat_less_le nth_bounded possible_bit_word)
next
  case False
  with assms show ?thesis
    by (force simp: power_overflow)
qed
lemma unat_shiftl_absorb:
  fixes x :: "'a :: len word"
  shows "⟦ x ≤ 2 ^ p; p + k < LENGTH('a) ⟧ ⟹ unat x * 2 ^ k = unat (x * 2 ^ k)"
  by (smt (verit) add_diff_cancel_right' add_lessD1 le_add2 le_less_trans mult.commute nat_le_power_trans
          unat_lt2p unat_mult_lem unat_power_lower word_le_nat_alt)
lemma word_plus_mono_right_split:
  fixes x :: "'a :: len word"
  assumes "unat (x AND mask sz) + unat z < 2 ^ sz" and "sz < LENGTH('a)"
  shows "x ≤ x + z"
proof -
  have *: "is_aligned (x AND NOT (mask sz)) sz" "word_of_nat (unat z) = z"
          "word_of_nat (unat (x AND mask sz)) = x AND mask sz"
    by auto
  with assms have "x AND mask sz ≤ (x AND mask sz) + z"
    by (metis (mono_tags, lifting) le_unat_uoi of_nat_add order_less_imp_le unat_plus_simple unat_power_lower)
  then have "(x AND NOT(mask sz)) + (x AND mask sz) ≤ (x AND NOT(mask sz)) + ((x AND mask sz) + z)"
    by (metis (no_types, lifting) of_nat_power assms * is_aligned_no_wrap' of_nat_add word_plus_mono_right)
  then show ?thesis
    by (simp add: and_not_eq_minus_and)
qed
lemma mul_not_mask_eq_neg_shiftl:
  "NOT(mask n :: 'a::len word) = -1 << n"
  by (simp add: NOT_mask shiftl_t2n)
lemma shiftr_mul_not_mask_eq_and_not_mask:
  "(x >> n) * NOT(mask n) = - (x AND NOT(mask n))"
  for x :: ‹'a::len word›
  by (metis NOT_mask and_not_mask mult_minus_left mult.commute shiftl_t2n)
lemma mask_eq_n1_shiftr:
  "n ≤ LENGTH('a) ⟹ (mask n :: 'a :: len word) = -1 >> (LENGTH('a) - n)"
  by (metis diff_diff_cancel eq_refl mask_full shiftr_mask2)
lemma is_aligned_mask_out_add_eq:
  "is_aligned p n ⟹ (p + x) AND NOT(mask n) = p + (x AND NOT(mask n))"
  by (simp add: mask_out_add_aligned)
lemmas is_aligned_mask_out_add_eq_sub
    = is_aligned_mask_out_add_eq[where x="a - b" for a b, simplified field_simps]
lemma aligned_bump_down:
  "is_aligned x n ⟹ (x - 1) AND NOT(mask n) = x - 2 ^ n"
  by (drule is_aligned_mask_out_add_eq[where x="-1"]) (simp add: NOT_mask)
lemma unat_2tp_if:
  "unat (2 ^ n :: ('a :: len) word) = (if n < LENGTH ('a) then 2 ^ n else 0)"
  by (simp add: unsigned_eq_0_iff)
lemma mask_of_mask:
  "mask (n::nat) AND mask (m::nat) = (mask (min m n) :: 'a::len word)"
  by word_eqI_solve
lemma unat_signed_ucast_less_ucast:
  "LENGTH('a) ≤ LENGTH('b) ⟹ unat (ucast (x :: 'a :: len word) :: 'b :: len signed word) = unat x"
  by (simp add: unat_ucast_up_simp)
lemma toEnum_of_ucast:
  "LENGTH('b) ≤ LENGTH('a) ⟹
   (toEnum (unat (b::'b :: len word))::'a :: len word) = of_nat (unat b)"
  by (simp add: unat_pow_le_intro)
lemma plus_mask_AND_NOT_mask_eq:
  "x AND NOT(mask n) = x ⟹ (x + mask n) AND NOT(mask n) = x" for x::‹'a::len word›
  by (metis AND_NOT_mask_plus_AND_mask_eq is_aligned_neg_mask2 mask_AND_NOT_mask mask_out_add_aligned word_and_not)
lemmas unat_ucast_mask = unat_ucast_eq_unat_and_mask[where w=a for a]
lemma t2n_mask_eq_if:
  "2 ^ n AND mask m = (if n < m then 2 ^ n else (0 :: 'a::len word))"
  by word_eqI_solve
lemma unat_ucast_le:
  "unat (ucast (x :: 'a :: len word) :: 'b :: len word) ≤ unat x"
  by (simp add: ucast_nat_def word_unat_less_le)
lemma ucast_le_up_down_iff:
  "⟦ LENGTH('a) ≤ LENGTH('b); (x :: 'b :: len word) ≤ ucast (- 1 :: 'a :: len word) ⟧
   ⟹ (ucast x ≤ (y :: 'a word)) = (x ≤ ucast y)"
  using le_max_word_ucast_id ucast_le_ucast by metis
lemma ucast_ucast_mask_shift:
  "a ≤ LENGTH('a) + b
   ⟹ ucast (ucast (p AND mask a >> b) :: 'a :: len word) = p AND mask a >> b"
  by (simp add: mask_mono ucast_ucast_eq_mask_shift word_and_le')
lemma unat_ucast_mask_shift:
  "a ≤ LENGTH('a) + b
   ⟹ unat (ucast (p AND mask a >> b) :: 'a :: len word) = unat (p AND mask a >> b)"
  by (metis linear ucast_ucast_mask_shift unat_ucast_up_simp)
lemma mask_overlap_zero:
  "a ≤ b ⟹ (p AND mask a) AND NOT(mask b) = 0"
  for p :: ‹'a::len word›
  by (metis NOT_mask_AND_mask mask_lower_twice2 max_def)
lemma mask_shifl_overlap_zero:
  "a + c ≤ b ⟹ (p AND mask a << c) AND NOT(mask b) = 0"
  for p :: ‹'a::len word›
  by (metis and_mask_0_iff_le_mask mask_mono mask_shiftl_decompose order_trans shiftl_over_and_dist word_and_le' word_and_le2)
lemma mask_overlap_zero':
  "a ≥ b ⟹ (p AND NOT(mask a)) AND mask b = 0"
  for p :: ‹'a::len word›
  using mask_AND_NOT_mask mask_AND_less_0 by blast
lemma mask_rshift_mult_eq_rshift_lshift:
  "((a :: 'a :: len word) >> b) * (1 << c) = (a >> b << c)"
  by (simp add: shiftl_t2n)
lemma shift_alignment:
  "a ≥ b ⟹ is_aligned (p >> a << a) b"
  using is_aligned_shift is_aligned_weaken by blast
lemma mask_split_sum_twice:
  "a ≥ b ⟹ (p AND NOT(mask a)) + ((p AND mask a) AND NOT(mask b)) + (p AND mask b) = p"
  for p :: ‹'a::len word›
  by (simp add: add.commute multiple_mask_trivia word_bw_comms(1) word_bw_lcs(1) word_plus_and_or_coroll2)
lemma mask_shift_eq_mask_mask:
  "(p AND mask a >> b << b) = (p AND mask a) AND NOT(mask b)"
  for p :: ‹'a::len word›
  by (simp add: and_not_mask)
lemma mask_shift_sum:
  "⟦ a ≥ b; unat n = unat (p AND mask b) ⟧
   ⟹ (p AND NOT(mask a)) + (p AND mask a >> b) * (1 << b) + n = (p :: 'a :: len word)"
  by (metis and_not_mask mask_rshift_mult_eq_rshift_lshift mask_split_sum_twice word_eq_unatI)
lemma is_up_compose:
  "⟦ is_up uc; is_up uc' ⟧ ⟹ is_up (uc' ∘ uc)"
  unfolding is_up_def by (simp add: Word.target_size Word.source_size)
lemma of_int_sint_scast:
  "of_int (sint (x :: 'a :: len word)) = (scast x :: 'b :: len word)"
  by (fact Word.of_int_sint)
lemma scast_of_nat_to_signed [simp]:
  "scast (of_nat x :: 'a :: len word) = (of_nat x :: 'a signed word)"
  by (rule bit_word_eqI) (simp add: bit_simps)
lemma scast_of_nat_signed_to_unsigned_add:
  "scast (of_nat x + of_nat y :: 'a :: len signed word) = (of_nat x + of_nat y :: 'a :: len word)"
  by (metis of_nat_add scast_of_nat)
lemma scast_of_nat_unsigned_to_signed_add:
  "(scast (of_nat x + of_nat y :: 'a :: len word)) = (of_nat x + of_nat y :: 'a :: len signed word)"
  by (metis Abs_fnat_hom_add scast_of_nat_to_signed)
lemma and_mask_cases:
  fixes x :: "'a :: len word"
  assumes len: "n < LENGTH('a)"
  shows "x AND mask n ∈ of_nat ` set [0 ..< 2 ^ n]"
  using and_mask_less' len unat_less_power by (fastforce simp add: image_iff Bex_def)
lemma sint_eq_uint_2pl:
  "⟦ (a :: 'a :: len word) < 2 ^ (LENGTH('a) - 1) ⟧
   ⟹ sint a = uint a"
  by (simp add: not_msb_from_less sint_eq_uint word_2p_lem word_size)
lemma pow_sub_less:
  "⟦ a + b ≤ LENGTH('a); unat (x :: 'a :: len word) = 2 ^ a ⟧
   ⟹ unat (x * 2 ^ b - 1) < 2 ^ (a + b)"
  by (metis eq_or_less_helperD le_eq_less_or_eq power_add unat_eq_of_nat unat_lt2p word_unat_power)
lemma sle_le_2pl:
  "⟦ (b :: 'a :: len word) < 2 ^ (LENGTH('a) - 1); a ≤ b ⟧ ⟹ a <=s b"
  by (simp add: not_msb_from_less word_sle_msb_le)
lemma sless_less_2pl:
  "⟦ (b :: 'a :: len word) < 2 ^ (LENGTH('a) - 1); a < b ⟧ ⟹ a <s b"
  using not_msb_from_less word_sless_msb_less by blast
lemma and_mask2:
  "w << n >> n = w AND mask (size w - n)"
  for w :: ‹'a::len word›
  by (rule bit_word_eqI) (auto simp: bit_simps word_size)
lemma aligned_sub_aligned_simple:
  "⟦ is_aligned a n; is_aligned b n ⟧ ⟹ is_aligned (a - b) n"
  by (simp add: aligned_sub_aligned)
lemma minus_one_shift:
  "- (1 << n) = (-1 << n :: 'a::len word)"
  by (simp add: shiftl_def minus_exp_eq_not_mask)
lemma ucast_eq_mask:
  "(UCAST('a::len → 'b::len) x = UCAST('a → 'b) y) =
   (x AND mask LENGTH('b) = y AND mask LENGTH('b))"
  by transfer (simp flip: take_bit_eq_mask add: ac_simps)
context
  fixes w :: "'a::len word"
begin
private lemma sbintrunc_uint_ucast:
  ‹signed_take_bit n (uint (ucast w :: 'b word)) = signed_take_bit n (uint w)› if ‹Suc n = LENGTH('b::len)›
  by (rule bit_eqI) (use that in ‹auto simp: bit_simps›)
private lemma test_bit_sbintrunc:
  assumes "i < LENGTH('a)"
  shows "bit (word_of_int (signed_take_bit n (uint w)) :: 'a word) i
           = (if n < i then bit w n else bit w i)"
  using assms by (simp add: bit_simps)
private lemma test_bit_sbintrunc_ucast:
  assumes len_a: "i < LENGTH('a)"
  shows "bit (word_of_int (signed_take_bit (LENGTH('b) - 1) (uint (ucast w :: 'b word))) :: 'a word) i
          = (if LENGTH('b::len) ≤ i then bit w (LENGTH('b) - 1) else bit w i)"
  using len_a by (auto simp: sbintrunc_uint_ucast bit_simps)
lemma scast_ucast_high_bits:
  ‹scast (ucast w :: 'b::len word) = w
     ⟷ (∀ i ∈ {LENGTH('b) ..< size w}. bit w i = bit w (LENGTH('b) - 1))›
proof (cases ‹LENGTH('a) ≤ LENGTH('b)›)
  case True
  moreover define m where ‹m = LENGTH('b) - LENGTH('a)›
  ultimately have ‹LENGTH('b) = m + LENGTH('a)›
    by simp
  then show ?thesis
    by (simp add: signed_ucast_eq word_size) word_eqI
next
  case False
  define q where ‹q = LENGTH('b) - 1›
  then have ‹LENGTH('b) = Suc q›
    by simp
  moreover define m where ‹m = Suc LENGTH('a) - LENGTH('b)›
  with False ‹LENGTH('b) = Suc q› have ‹LENGTH('a) = m + q›
    by (simp add: not_le)
  ultimately show ?thesis
    apply (simp add: signed_ucast_eq word_size)
    apply transfer
    apply (simp add: signed_take_bit_take_bit)
    apply (simp add: bit_eq_iff bit_take_bit_iff bit_signed_take_bit_iff min_def)
    by (metis atLeastLessThan_iff linorder_not_le nat_less_le not_less_eq)
qed
lemma scast_ucast_mask_compare:
  "scast (ucast w :: 'b::len word) = w
   ⟷ (w ≤ mask (LENGTH('b) - 1) ∨ NOT(mask (LENGTH('b) - 1)) ≤ w)"
  apply (auto simp: Ball_def le_mask_high_bits neg_mask_le_high_bits scast_ucast_high_bits word_size)
  by (metis decr_length_less_iff nless_le)
lemma ucast_less_shiftl_helper':
  "⟦ LENGTH('b) + (a::nat) < LENGTH('a); 2 ^ (LENGTH('b) + a) ≤ n⟧
   ⟹ (ucast (x :: 'b::len word) << a) < (n :: 'a::len word)"
  by (meson add_lessD1 order_less_le_trans shiftl_less_t2n' ucast_less)
end
lemma ucast_ucast_mask2:
  "is_down (UCAST ('a → 'b)) ⟹
   UCAST ('b::len → 'c::len) (UCAST ('a::len → 'b::len) x) = UCAST ('a → 'c) (x AND mask LENGTH('b))"
  by word_eqI_solve
lemma ucast_NOT:
  "ucast (NOT x) = NOT(ucast x) AND mask (LENGTH('a))" for x::"'a::len word"
  by word_eqI_solve
lemma ucast_NOT_down:
  "is_down UCAST('a::len → 'b::len) ⟹ UCAST('a → 'b) (NOT x) = NOT(UCAST('a → 'b) x)"
  by word_eqI
lemma upto_enum_step_shift:
  assumes "is_aligned p n"
  shows "([p , p + 2 ^ m .e. p + 2 ^ n - 1]) = map ((+) p) [0, 2 ^ m .e. 2 ^ n - 1]"
proof -
  consider "n < LENGTH('a)" | "p = 0" "n ≥ LENGTH('a)"
    by (meson assms is_aligned_get_word_bits)
  then show ?thesis
  proof cases
    case 1
    with assms show ?thesis
      using is_aligned_no_overflow linorder_not_le
      by (force simp: upto_enum_step_def)
  qed (auto simp: map_idI)
qed
lemma upto_enum_step_shift_red:
  "⟦ is_aligned p sz; sz < LENGTH('a); us ≤ sz ⟧
     ⟹ [p :: 'a :: len word, p + 2 ^ us .e. p + 2 ^ sz - 1]
          = map (λx. p + of_nat x * 2 ^ us) [0 ..< 2 ^ (sz - us)]"
  by (simp add: upto_enum_step_red upto_enum_step_shift)
lemma upto_enum_step_subset:
  "set [x, y .e. z] ⊆ {x .. z}"
proof -
  have "⋀w. ⟦x ≤ z; w ≤ (z - x) div (y - x)⟧
          ⟹ x ≤ x + w * (y - x) ∧ x + w * (y - x) ≤ z"
    by (metis add.commute div_to_mult_word_lt eq_diff_eq le_plus' word_plus_mono_right2)
  then show ?thesis
    by (auto simp: upto_enum_step_def linorder_not_less)
qed
lemma ucast_distrib:
  fixes M :: "'a::len word ⇒ 'a::len word ⇒ 'a::len word"
  fixes M' :: "'b::len word ⇒ 'b::len word ⇒ 'b::len word"
  fixes L :: "int ⇒ int ⇒ int"
  assumes lift_M: "⋀x y. uint (M x y) = L (uint x) (uint y)  mod 2 ^ LENGTH('a)"
  assumes lift_M': "⋀x y. uint (M' x y) = L (uint x) (uint y)  mod 2 ^ LENGTH('b)"
  assumes distrib: "⋀x y. (L (x mod (2 ^ LENGTH('b))) (y mod (2 ^ LENGTH('b)))) mod (2 ^ LENGTH('b))
                               = (L x y) mod (2 ^ LENGTH('b))"
  assumes is_down: "is_down (ucast :: 'a word ⇒ 'b word)"
  shows "ucast (M a b) = M' (ucast a) (ucast b)"
  unfolding ucast_eq lift_M
  by (metis lift_M' local.distrib is_down ucast_down_wi uint_word_of_int word_of_int_uint)
lemma ucast_down_add:
    "is_down (ucast:: 'a word ⇒ 'b word) ⟹  ucast ((a :: 'a::len word) + b) = (ucast a + ucast b :: 'b::len word)"
  by (metis (mono_tags, opaque_lifting) of_int_add ucast_down_wi word_of_int_Ex)
lemma ucast_down_minus:
    "is_down (ucast:: 'a word ⇒ 'b word) ⟹  ucast ((a :: 'a::len word) - b) = (ucast a - ucast b :: 'b::len word)"
  by (metis add_diff_cancel_right' diff_add_cancel ucast_down_add)
lemma ucast_down_mult:
    "is_down (ucast:: 'a word ⇒ 'b word) ⟹  ucast ((a :: 'a::len word) * b) = (ucast a * ucast b :: 'b::len word)"
  by (simp add: mod_mult_eq take_bit_eq_mod ucast_distrib uint_word_arith_bintrs(3))
lemma scast_distrib:
  fixes M :: "'a::len word ⇒ 'a::len word ⇒ 'a::len word"
  fixes M' :: "'b::len word ⇒ 'b::len word ⇒ 'b::len word"
  fixes L :: "int ⇒ int ⇒ int"
  assumes lift_M: "⋀x y. uint (M x y) = L (uint x) (uint y)  mod 2 ^ LENGTH('a)"
  assumes lift_M': "⋀x y. uint (M' x y) = L (uint x) (uint y)  mod 2 ^ LENGTH('b)"
  assumes distrib: "⋀x y. (L (x mod (2 ^ LENGTH('b))) (y mod (2 ^ LENGTH('b)))) mod (2 ^ LENGTH('b))
                               = (L x y) mod (2 ^ LENGTH('b))"
  assumes is_down: "is_down (scast :: 'a word ⇒ 'b word)"
  shows "scast (M a b) = M' (scast a) (scast b)"
proof -
  have §: "is_down UCAST('a → 'b)"
    using is_up_down is_down by blast
  then have "UCAST('a → 'b) (M a b) = M' (UCAST('a → 'b) a) (UCAST('a → 'b) b)"
    using lift_M lift_M' local.distrib ucast_distrib by blast
  with § show ?thesis
    using down_cast_same by fastforce
qed
lemma scast_down_add:
    "is_down (scast:: 'a word ⇒ 'b word) ⟹  scast ((a :: 'a::len word) + b) = (scast a + scast b :: 'b::len word)"
  by (metis down_cast_same is_up_down ucast_down_add)
lemma scast_down_minus:
    "is_down (scast:: 'a word ⇒ 'b word) ⟹  scast ((a :: 'a::len word) - b) = (scast a - scast b :: 'b::len word)"
  by (metis down_cast_same is_up_down ucast_down_minus)
lemma scast_down_mult:
    "is_down (scast:: 'a word ⇒ 'b word) ⟹  scast ((a :: 'a::len word) * b) = (scast a * scast b :: 'b::len word)"
  by (metis down_cast_same is_up_down ucast_down_mult)
lemma scast_ucast_1:
  "⟦ is_down (ucast :: 'a word ⇒ 'b word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  by (metis down_cast_same ucast_eq ucast_down_wi)
lemma scast_ucast_3:
  "⟦ is_down (ucast :: 'a word ⇒ 'c word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  by (metis down_cast_same ucast_eq ucast_down_wi)
lemma scast_ucast_4:
  "⟦ is_up (ucast :: 'a word ⇒ 'b word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
         (scast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  by (metis down_cast_same ucast_eq ucast_down_wi)
lemma scast_scast_b:
  "⟦ is_up (scast :: 'a word ⇒ 'b word) ⟧ ⟹
     (scast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (metis scast_eq sint_up_scast)
lemma ucast_scast_1:
  "⟦ is_down (scast :: 'a word ⇒ 'b word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
            (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (metis scast_eq ucast_down_wi)
lemma ucast_scast_3:
  "⟦ is_down (scast :: 'a word ⇒ 'c word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
     (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (metis scast_eq ucast_down_wi)
lemma ucast_scast_4:
  "⟦ is_up (scast :: 'a word ⇒ 'b word); is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
     (ucast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (metis down_cast_same scast_eq sint_up_scast)
lemma ucast_ucast_a:
  "⟦ is_down (ucast :: 'b word ⇒ 'c word) ⟧ ⟹
        (ucast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  by (metis down_cast_same ucast_eq ucast_down_wi)
lemma ucast_ucast_b:
  "⟦ is_up (ucast :: 'a word ⇒ 'b word) ⟧ ⟹
     (ucast (ucast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = ucast a"
  by (metis ucast_up_ucast)
lemma scast_scast_a:
  "⟦ is_down (scast :: 'b word ⇒ 'c word) ⟧ ⟹
            (scast (scast (a :: 'a::len word) :: 'b::len word) :: 'c::len word) = scast a"
  by (metis down_cast_same is_down scast_eq ucast_down_wi)
lemma scast_down_wi [OF refl]:
  "uc = scast ⟹ is_down uc ⟹ uc (word_of_int x) = word_of_int x"
  by (metis down_cast_same is_up_down ucast_down_wi)
lemmas cast_simps =
  is_down is_up
  scast_down_add scast_down_minus scast_down_mult
  ucast_down_add ucast_down_minus ucast_down_mult
  scast_ucast_1 scast_ucast_3 scast_ucast_4
  ucast_scast_1 ucast_scast_3 ucast_scast_4
  ucast_ucast_a ucast_ucast_b
  scast_scast_a scast_scast_b
  ucast_down_wi scast_down_wi
  ucast_of_nat scast_of_nat
  uint_up_ucast sint_up_scast
  up_scast_surj up_ucast_surj
lemma sdiv_word_max:
    "(sint (a :: ('a::len) word) sdiv sint (b :: ('a::len) word) < (2 ^ (size a - 1))) =
          ((a ≠ - (2 ^ (size a - 1)) ∨ (b ≠ -1)))"
    (is "?lhs = (¬ ?a_int_min ∨ ¬ ?b_minus1)")
proof (rule classical)
  assume not_thesis: "¬ ?thesis"
  have not_zero: "b ≠ 0"
    using not_thesis by force
  let ?range = ‹{- (2 ^ (size a - 1))..<2 ^ (size a - 1)} :: int set›
  have result_range: "sint a sdiv sint b ∈ ?range ∪ {2 ^ (size a - 1)}"
    using sdiv_word_min [of a b] sdiv_word_max [of a b] by auto
  have result_range_overflow: "(sint a sdiv sint b = 2 ^ (size a - 1)) = (?a_int_min ∧ ?b_minus1)"
  proof -
    have False
      if "sint a sdiv sint b = 2 ^ (size a - 1)" "¬ (a = - (2 ^ (size a - 1)) ∧ b = - 1)"
    proof (cases "?a_int_min")
      case True
      with that show ?thesis
        by (smt (verit, best) One_nat_def int_sdiv_negated_is_minus1 sint_int_min sint_minus1
            wsst_TYs(3) zero_less_power)
    next
      case False
      with that have "¦sint a¦ < 2 ^ (size a - 1)"
        by (smt (verit, best) One_nat_def signed_word_eqI sint_ge sint_int_min sint_less wsst_TYs(3))
      then show ?thesis
        by (metis atLeastAtMost_iff not_less sdiv_int_range that(1))
    qed
    then  show ?thesis
      by (smt (verit, ccfv_SIG) One_nat_def int_sdiv_simps(3) sint_int_min sint_n1 wsst_TYs(3))
  qed
  then show ?thesis
    using result_range by auto
qed
lemmas sdiv_word_min' = sdiv_word_min [simplified word_size, simplified]
lemmas sdiv_word_max' = sdiv_word_max [simplified word_size, simplified]
lemma signed_arith_ineq_checks_to_eq:
  "((- (2 ^ (size a - 1)) ≤ (sint a + sint b)) ∧ (sint a + sint b ≤ (2 ^ (size a - 1) - 1)))
    = (sint a + sint b = sint (a + b ))"
  "((- (2 ^ (size a - 1)) ≤ (sint a - sint b)) ∧ (sint a - sint b ≤ (2 ^ (size a - 1) - 1)))
    = (sint a - sint b = sint (a - b))"
  "((- (2 ^ (size a - 1)) ≤ (- sint a)) ∧ (- sint a) ≤ (2 ^ (size a - 1) - 1))
    = ((- sint a) = sint (- a))"
  "((- (2 ^ (size a - 1)) ≤ (sint a * sint b)) ∧ (sint a * sint b ≤ (2 ^ (size a - 1) - 1)))
    = (sint a * sint b = sint (a * b))"
  "((- (2 ^ (size a - 1)) ≤ (sint a sdiv sint b)) ∧ (sint a sdiv sint b ≤ (2 ^ (size a - 1) - 1)))
    = (sint a sdiv sint b = sint (a sdiv b))"
  "((- (2 ^ (size a - 1)) ≤ (sint a smod sint b)) ∧ (sint a smod sint b ≤ (2 ^ (size a - 1) - 1)))
    = (sint a smod sint b = sint (a smod b))"
  by (auto simp: sint_word_ariths word_size signed_div_arith signed_mod_arith signed_take_bit_int_eq_self_iff intro: sym dest: sym)
lemma signed_arith_sint:
  "((- (2 ^ (size a - 1)) ≤ (sint a + sint b)) ∧ (sint a + sint b ≤ (2 ^ (size a - 1) - 1)))
    ⟹ sint (a + b) = (sint a + sint b)"
  "((- (2 ^ (size a - 1)) ≤ (sint a - sint b)) ∧ (sint a - sint b ≤ (2 ^ (size a - 1) - 1)))
    ⟹ sint (a - b) = (sint a - sint b)"
  "((- (2 ^ (size a - 1)) ≤ (- sint a)) ∧ (- sint a) ≤ (2 ^ (size a - 1) - 1))
    ⟹ sint (- a) = (- sint a)"
  "((- (2 ^ (size a - 1)) ≤ (sint a * sint b)) ∧ (sint a * sint b ≤ (2 ^ (size a - 1) - 1)))
    ⟹ sint (a * b) = (sint a * sint b)"
  "((- (2 ^ (size a - 1)) ≤ (sint a sdiv sint b)) ∧ (sint a sdiv sint b ≤ (2 ^ (size a - 1) - 1)))
    ⟹ sint (a sdiv b) = (sint a sdiv sint b)"
  "((- (2 ^ (size a - 1)) ≤ (sint a smod sint b)) ∧ (sint a smod sint b ≤ (2 ^ (size a - 1) - 1)))
    ⟹ sint (a smod b) = (sint a smod sint b)"
  by (subst (asm) signed_arith_ineq_checks_to_eq; simp)+
lemma nasty_split_lt:
  ‹x * 2 ^ n + (2 ^ n - 1) ≤ 2 ^ m - 1›
    if ‹x < 2 ^ (m - n)› ‹n ≤ m› ‹m < LENGTH('a::len)›
    for x :: ‹'a::len word›
proof -
  define q where ‹q = m - n›
  with ‹n ≤ m› have ‹m = q + n›
    by simp
  with ‹x < 2 ^ (m - n)› have *: ‹i < q› if ‹bit x i› for i
    using that by simp (metis bit_take_bit_iff take_bit_word_eq_self_iff)
  from ‹m = q + n› have ‹push_bit n x OR mask n ≤ mask m›
    by (auto simp: le_mask_high_bits word_size bit_simps dest!: *)
  then have ‹push_bit n x + mask n ≤ mask m›
    by (simp add: disjunctive_add bit_simps)
  then show ?thesis
    by (simp add: mask_eq_exp_minus_1 push_bit_eq_mult)
qed
end
end