Theory BTree_Height
theory BTree_Height
  imports BTree
begin
section "Maximum and minimum height"
text "Textbooks usually provide some proofs relating the maxmimum and minimum height of the BTree
for a given number of nodes. We therefore introduce this counting and show the respective proofs."
subsection "Definition of node/size"
thm BTree.btree.size
  
value "size (Node [(Leaf, (0::nat)), (Node [(Leaf, 1), (Leaf, 10)] Leaf, 12), (Leaf, 30), (Leaf, 100)] Leaf)"
text "The default size function does not suit our needs as it regards the length of the list in each node.
 We would like to count the number of nodes in the tree only, not regarding the number of keys."
fun nodes::"'a btree ⇒ nat" where
  "nodes Leaf = 0" |
  "nodes (Node ts t) = 1 + (∑t←subtrees ts. nodes t) + (nodes t)"
value "nodes (Node [(Leaf, (0::nat)), (Node [(Leaf, 1), (Leaf, 10)] Leaf, 12), (Leaf, 30), (Leaf, 100)] Leaf)"
subsection "Maximum number of nodes for a given height"
lemma sum_list_replicate: "sum_list (replicate n c) = n*c"
  apply(induction n)
   apply(auto simp add: ring_class.ring_distribs(2))
  done
abbreviation "bound k h ≡ ((k+1)^h - 1)"
lemma nodes_height_upper_bound:
  "⟦order k t; bal t⟧ ⟹ nodes t * (2*k) ≤ bound (2*k) (height t)"
proof(induction t rule: nodes.induct)
  case (2 ts t)
  let ?sub_height = "((2 * k + 1) ^ height t - 1)"
  have "sum_list (map nodes (subtrees ts)) * (2*k) =
        sum_list (map (λt. nodes t * (2 * k)) (subtrees ts))"
    using sum_list_mult_const by metis
  also have "… ≤ sum_list (map (λx.?sub_height) (subtrees ts))"
    using 2
    using sum_list_mono[of "subtrees ts" "λt. nodes t * (2 * k)" "λx. bound (2 * k) (height t)"]
    by (metis bal.simps(2) order.simps(2))
  also have "… = sum_list (replicate (length ts) ?sub_height)"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "… = (length ts)*(?sub_height)"
    using sum_list_replicate by simp
  also have "… ≤ (2*k)*(?sub_height)"
    using "2.prems"(1)
    by simp
  finally have "sum_list (map nodes (subtrees ts))*(2*k) ≤ ?sub_height*(2*k)"
    by simp
  moreover have "(nodes t)*(2*k) ≤ ?sub_height"
    using 2 by simp
  ultimately have "(nodes (Node ts t))*(2*k) ≤
         2*k
        + ?sub_height * (2*k)
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have "… =  2*k + (2*k)*((2 * k + 1) ^ height t) - 2*k + (2 * k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have "… = (2*k)*((2 * k + 1) ^ height t) + (2 * k + 1) ^ height t - 1"
    by simp
  also have "… = (2*k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?case
    by (metis "2.prems"(2) height_bal_tree)
qed simp
text "To verify our lower bound is sharp, we compare it to the height of artificially constructed
full trees."
fun full_node::"nat ⇒ 'a ⇒ nat ⇒ 'a btree" where
  "full_node k c 0 = Leaf"|
  "full_node k c (Suc n) = (Node (replicate (2*k) ((full_node k c n),c)) (full_node k c n))"
value "let k = (2::nat) in map (λx. nodes x * 2*k) (map (full_node k (1::nat)) [0,1,2,3,4])"
value "let k = (2::nat) in map (λx. ((2*k+(1::nat))^(x)-1)) [0,1,2,3,4]"
lemma compow_comp_id: "c > 0 ⟹ f ∘ f = f ⟹ (f ^^ c) = f"
  apply(induction c)
   apply auto
  by fastforce
lemma compow_id_point: "f x = x ⟹ (f ^^ c) x = x"
  apply(induction c)
   apply auto
  done
lemma height_full_node: "height (full_node k a h) = h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: set_replicate_conv_if)
  done
lemma bal_full_node: "bal (full_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done
lemma order_full_node: "order k (full_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done
lemma full_btrees_sharp: "nodes (full_node k a h) * (2*k) = bound (2*k) h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: height_full_node algebra_simps sum_list_replicate)
  done
lemma upper_bound_sharp_node:
  "t = full_node k a h ⟹ height t = h ∧ order k t ∧ bal t ∧ bound (2*k) h = nodes t * (2*k)"
  by (simp add: bal_full_node height_full_node order_full_node full_btrees_sharp)
subsection "Maximum height for a given number of nodes"
lemma nodes_height_lower_bound:
  "⟦order k t; bal t⟧ ⟹ bound k (height t) ≤ nodes t * k"
proof(induction t rule: nodes.induct)
  case (2 ts t)
  let ?sub_height = "((k + 1) ^ height t - 1)"
  have "k*(?sub_height) ≤ (length ts)*(?sub_height)"
    using "2.prems"(1)
    by simp
  also have "… = sum_list (replicate (length ts) ?sub_height)"
    using sum_list_replicate by simp
  also have "… = sum_list (map (λx.?sub_height) (subtrees ts))"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "… ≤ sum_list (map (λt. nodes t * k) (subtrees ts))"
    using 2
    using sum_list_mono[of "subtrees ts" "λx. bound k (height t)" "λt. nodes t * k"]
    by (metis bal.simps(2) order.simps(2))
  also have "… = sum_list (map nodes (subtrees ts)) * k"
    using sum_list_mult_const[of nodes k "subtrees ts"] by auto
  finally have "sum_list (map nodes (subtrees ts))*k ≥ ?sub_height*k"
    by simp
  moreover have "(nodes t)*k ≥ ?sub_height"
    using 2 by simp
  ultimately have "(nodes (Node ts t))*k ≥
        k
        + ?sub_height * k
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have
    "k + ?sub_height * k + ?sub_height =
     k + k*((k + 1) ^ height t) - k + (k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have "… = k*((k + 1) ^ height t) + (k + 1) ^ height t - 1"
    by simp
  also have "… = (k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?case
    by (metis "2.prems"(2) height_bal_tree)
qed simp
text "To verify our upper bound is sharp, we compare it to the height of artificially constructed
minimally filled (=slim) trees."
fun slim_node::"nat ⇒ 'a ⇒ nat ⇒ 'a btree" where
  "slim_node k c 0 = Leaf"|
  "slim_node k c (Suc n) = (Node (replicate k ((slim_node k c n),c)) (slim_node k c n))"
value "let k = (2::nat) in map (λx. nodes x * k) (map (slim_node k (1::nat)) [0,1,2,3,4])"
value "let k = (2::nat) in map (λx. ((k+1::nat)^(x)-1)) [0,1,2,3,4]"
lemma height_slim_node: "height (slim_node k a h) = h"
  apply(induction k a h rule: full_node.induct)
   apply (auto simp add: set_replicate_conv_if)
  done
lemma bal_slim_node: "bal (slim_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done
lemma order_slim_node: "order k (slim_node k a h)"
  apply(induction k a h rule: full_node.induct)
   apply auto
  done
lemma slim_nodes_sharp: "nodes (slim_node k a h) * k = bound k h"
  apply(induction k a h rule: slim_node.induct)
   apply (auto simp add: height_slim_node algebra_simps sum_list_replicate compow_id_point)
  done
lemma lower_bound_sharp_node:
  "t = slim_node k a h ⟹ height t = h ∧ order k t ∧ bal t ∧ bound k h = nodes t * k"
  by (simp add: bal_slim_node height_slim_node order_slim_node slim_nodes_sharp)
text "Since BTrees have special roots, we need to show the overall nodes seperately"
lemma nodes_root_height_lower_bound:
  assumes "root_order k t"
    and "bal t"
  shows "2*((k+1)^(height t - 1) - 1) + (of_bool (t ≠ Leaf))*k  ≤ nodes t * k"
proof (cases t)
  case (Node ts t)
  let ?sub_height = "((k + 1) ^ height t - 1)"
  from Node have "?sub_height ≤ length ts * ?sub_height"
    using assms
    by (simp add: Suc_leI)
  also have "… = sum_list (replicate (length ts) ?sub_height)"
    using sum_list_replicate
    by simp
  also have "… = sum_list (map (λx. ?sub_height) (subtrees ts))"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "… ≤ sum_list (map (λt. nodes t * k) (subtrees ts))"
    using Node
      sum_list_mono[of "subtrees ts" "λx. (k+1)^(height t) - 1" "λx. nodes x * k"]
      nodes_height_lower_bound assms
    by fastforce
  also have "… = sum_list (map nodes (subtrees ts)) * k"
    using sum_list_mult_const[of nodes k "subtrees ts"] by simp
  finally have "sum_list (map nodes (subtrees ts))*k ≥ ?sub_height"
    by simp
  moreover have "(nodes t)*k ≥ ?sub_height"
    using Node assms nodes_height_lower_bound
    by auto
  ultimately have "(nodes (Node ts t))*k ≥
        ?sub_height
        + ?sub_height + k"
    unfolding nodes.simps add_mult_distrib
    by linarith
  then show ?thesis
    using Node assms(2) height_bal_tree by fastforce
qed simp
lemma nodes_root_height_upper_bound:
  assumes "root_order k t"
    and "bal t"
  shows "nodes t * (2*k) ≤ (2*k+1)^(height t) - 1"
proof(cases t)
  case (Node ts t)
  let ?sub_height = "((2 * k + 1) ^ height t - 1)"
  have "sum_list (map nodes (subtrees ts)) * (2*k) =
        sum_list (map (λt. nodes t * (2 * k)) (subtrees ts))"
    using sum_list_mult_const by metis
  also have "… ≤ sum_list (map (λx.?sub_height) (subtrees ts))"
    using Node
      sum_list_mono[of "subtrees ts" "λx. nodes x * (2*k)"  "λx. (2*k+1)^(height t) - 1"]
      nodes_height_upper_bound assms
    by fastforce
  also have "… = sum_list (replicate (length ts) ?sub_height)"
    using map_replicate_const[of ?sub_height "subtrees ts"] length_map
    by simp
  also have "… = (length ts)*(?sub_height)"
    using sum_list_replicate by simp
  also have "… ≤ (2*k)*?sub_height"
    using assms Node
    by simp
  finally have "sum_list (map nodes (subtrees ts))*(2*k) ≤ ?sub_height*(2*k)"
    by simp
  moreover have "(nodes t)*(2*k) ≤ ?sub_height"
    using Node assms nodes_height_upper_bound
    by auto
  ultimately have "(nodes (Node ts t))*(2*k) ≤
         2*k
        + ?sub_height * (2*k)
        + ?sub_height"
    unfolding nodes.simps add_mult_distrib
    by linarith
  also have "… =  2*k + (2*k)*((2 * k + 1) ^ height t) - 2*k + (2 * k + 1) ^ height t - 1"
    by (simp add: diff_mult_distrib2 mult.assoc mult.commute)
  also have "… = (2*k)*((2 * k + 1) ^ height t) + (2 * k + 1) ^ height t - 1"
    by simp
  also have "… = (2*k+1)^(Suc(height t)) - 1"
    by simp
  finally show ?thesis
    by (metis Node assms(2) height_bal_tree)
qed simp
lemma root_order_imp_divmuleq: "root_order k t ⟹ (nodes t * k) div k = nodes t"
  using root_order.elims(2) by fastforce
lemma nodes_root_height_lower_bound_simp:
  assumes "root_order k t"
    and "bal t"
    and "k > 0"
  shows "(2*((k+1)^(height t - 1) - 1)) div k + (of_bool (t ≠ Leaf)) ≤ nodes t"
proof (cases t)
  case Node
  have "(2*((k+1)^(height t - 1) - 1)) div k + (of_bool (t ≠ Leaf)) =
(2*((k+1)^(height t - 1) - 1) + (of_bool (t ≠ Leaf))*k) div k"
    using Node assms
    using div_plus_div_distrib_dvd_left[of k k "(2 * Suc k ^ (height t - Suc 0) - Suc (Suc 0))"]
    by (auto simp add: algebra_simps simp del: height_btree.simps)
  also have "… ≤ (nodes t * k) div k"
    using nodes_root_height_lower_bound[OF assms(1,2)] div_le_mono
    by blast
  also have "… = nodes t"
    using root_order_imp_divmuleq[OF assms(1)]
    by simp
  finally show ?thesis .
qed simp
lemma nodes_root_height_upper_bound_simp:
  assumes "root_order k t"
    and "bal t"
  shows "nodes t ≤ ((2*k+1)^(height t) - 1) div (2*k)"
proof -
  have "nodes t = (nodes t * (2*k)) div (2*k)"
    using root_order_imp_divmuleq[OF assms(1)]
    by simp
  also have "… ≤ ((2*k+1)^(height t) - 1) div (2*k)"
    using div_le_mono nodes_root_height_upper_bound[OF assms] by blast
  finally show ?thesis .
qed
definition "full_tree = full_node"
fun slim_tree where
  "slim_tree k c 0 = Leaf" |
  "slim_tree k c (Suc h) = Node [(slim_node k c h, c)] (slim_node k c h)"
lemma lower_bound_sharp:
  "k > 0 ⟹ t = slim_tree k a h ⟹ height t = h ∧ root_order k t ∧ bal t ∧ nodes t * k = 2*((k+1)^(height t - 1) - 1) + (of_bool (t ≠ Leaf))*k"
  apply (cases h)
  using slim_nodes_sharp[of k a]
   apply (auto simp add: algebra_simps bal_slim_node height_slim_node order_slim_node)
  done
lemma upper_bound_sharp:
  "k > 0 ⟹ t = full_tree k a h ⟹ height t = h ∧ root_order k t ∧ bal t ∧ ((2*k+1)^(height t) - 1) = nodes t * (2*k)"
  unfolding full_tree_def
  using order_impl_root_order[of k t]
  by (simp add: bal_full_node height_full_node order_full_node full_btrees_sharp)
end