Theory KD_Tree

(*
  File:     KD_Tree.thy
  Author:   Martin Rau, TU München
*)

section "Definition of the k›-d Tree"

theory KD_Tree
imports
  Complex_Main
  "HOL-Analysis.Finite_Cartesian_Product"
  "HOL-Analysis.Topology_Euclidean_Space"
begin


text ‹
  A k›-d tree is a space-partitioning data structure for organizing points in a k›-dimensional space.
  In principle the k›-d tree is a binary tree. The leafs hold the k›-dimensional points and the nodes
  contain left and right subtrees as well as a discriminator v› at a particular axis k›.
  Every node divides the space into two parts by splitting along a hyperplane.
  Consider a node n› with associated discriminator v› at axis k›.
  All points in the left subtree must have a value at axis k› that is less than or
  equal to v› and all points in the right subtree must have a value at axis k› that is
  greater than v›.

  Deviations from the papers:

  The chosen tree representation is taken from cite"DBLP:journals/toms/FriedmanBF77" with one minor
  adjustment. Originally the leafs hold buckets of points of size b›. This representation fixes the
  bucket size to b = 1›, a single point per Leaf. This is only a minor adjustment since the paper
  proves that b = 1› is the optimal bucket size for minimizing the running time of the nearest neighbor
  algorithm cite"DBLP:journals/toms/FriedmanBF77", only simplifies building the optimized
  k›-d trees cite"DBLP:journals/toms/FriedmanBF77" and has little influence on the
  search algorithm cite"DBLP:journals/cacm/Bentley75".
›

type_synonym 'k point = "(real, 'k) vec"

lemma dist_point_def:
  fixes p0 :: "('k::finite) point"
  shows "dist p0 p1 = sqrt (k  UNIV. (p0$k - p1$k)2)"
  unfolding dist_vec_def L2_set_def dist_real_def by simp

datatype 'k kdt =
  Leaf "'k point"
| Node 'k real "'k kdt" "'k kdt"


subsection ‹Definition of the k›-d Tree Invariant and Related Functions›

fun set_kdt :: "'k kdt  ('k point) set" where
  "set_kdt (Leaf p) = { p }"
| "set_kdt (Node _ _ l r) = set_kdt l  set_kdt r"

definition spread :: "('k::finite)  'k point set  real" where
  "spread k P = (if P = {} then 0 else let V = (λp. p$k) ` P in Max V - Min V)"

definition widest_spread_axis :: "('k::finite)  'k set  'k point set  bool" where
  "widest_spread_axis k K ps  (k'  K. spread k' ps  spread k ps)"

fun invar :: "('k::finite) kdt  bool" where
  "invar (Leaf p)  True"
| "invar (Node k v l r)  (p  set_kdt l. p$k  v)  (p  set_kdt r. v < p$k) 
    widest_spread_axis k UNIV (set_kdt l  set_kdt r)  invar l  invar r"

fun size_kdt :: "'k kdt  nat" where
  "size_kdt (Leaf _) = 1"
| "size_kdt (Node _ _ l r) = size_kdt l + size_kdt r"

fun height :: "'k kdt  nat" where
  "height (Leaf _) = 0"
| "height (Node _ _ l r) = max (height l) (height r) + 1"

fun min_height :: "'k kdt  nat" where
  "min_height (Leaf _) = 0"
| "min_height (Node _ _ l r) = min (min_height l) (min_height r) + 1"

definition balanced :: "'k kdt  bool" where
  "balanced kdt  height kdt - min_height kdt  1"

fun complete :: "'k kdt  bool" where
  "complete (Leaf _) = True"
| "complete (Node _ _ l r)  complete l  complete r  height l = height r"


lemma invar_l:
  "invar (Node k v l r)  invar l"
  by simp

lemma invar_r:
  "invar (Node k v l r)  invar r"
  by simp

lemma invar_l_le_k:
  "invar (Node k v l r)  p  set_kdt l. p$k  v"
  by simp

lemma invar_r_ge_k:
  "invar (Node k v l r)  p  set_kdt r. v < p$k"
  by simp

lemma invar_set:
  "set_kdt (Node k v l r) = set_kdt l  set_kdt r"
  by simp


subsection ‹Lemmas adapted from HOL-Library.Tree› to k›-d Tree›

lemma size_ge0[simp]:
  "0 < size_kdt kdt"
  by (induction kdt) auto

lemma eq_size_1[simp]:
  "size_kdt kdt = 1  (p. kdt = Leaf p)"
  apply (induction kdt)
  apply (auto)
  using size_ge0 nat_less_le apply blast+
  done

lemma eq_1_size[simp]:
  "1 = size_kdt kdt  (p. kdt = Leaf p)"
  using eq_size_1 by metis

lemma neq_Leaf_iff:
  "(p. kdt = Leaf p) = (k v l r. kdt = Node k v l r)"
  by (cases kdt) auto

lemma eq_height_0[simp]:
  "height kdt = 0  (p. kdt = Leaf p)"
  by (cases kdt) auto

lemma eq_0_height[simp]:
  "0 = height kdt  (p. kdt = Leaf p)"
  by (cases kdt) auto

lemma eq_min_height_0[simp]:
  "min_height kdt = 0  (p. kdt = Leaf p)"
  by (cases kdt) auto

lemma eq_0_min_height[simp]:
  "0 = min_height kdt  (p. kdt = Leaf p)"
  by (cases kdt) auto

lemma size_height:
  "size_kdt kdt  2 ^ height kdt"
proof(induction kdt)
  case (Node k v l r)
  show ?case
  proof (cases "height l  height r")
    case True
    have "size_kdt (Node k v l r) = size_kdt l + size_kdt r" by simp
    also have "  2 ^ height l + 2 ^ height r" using Node.IH by arith
    also have "  2 ^ height r + 2 ^ height r" using True by simp
    also have " = 2 ^ height (Node k v l r)"
      using True by (auto simp: max_def mult_2)
    finally show ?thesis .
  next
    case False
    have "size_kdt (Node k v l r) = size_kdt l + size_kdt r" by simp
    also have "  2 ^ height l + 2 ^ height r" using Node.IH by arith
    also have "  2 ^ height l + 2 ^ height l" using False by simp
    finally show ?thesis using False by (auto simp: max_def mult_2)
  qed
qed simp

lemma min_height_le_height:
  "min_height kdt  height kdt"
  by (induction kdt) auto

lemma min_height_size:
  "2 ^ min_height kdt  size_kdt kdt"
proof(induction kdt)
  case (Node k v l r)
  have "(2::nat) ^ min_height (Node k v l r)  2 ^ min_height l + 2 ^ min_height r"
    by (simp add: min_def)
  also have "  size_kdt (Node k v l r)" using Node.IH by simp
  finally show ?case .
qed simp

lemma complete_iff_height:
  "complete kdt  (min_height kdt = height kdt)"
  apply (induction kdt)
  apply simp
  apply (simp add: min_def max_def)
  by (metis le_antisym le_trans min_height_le_height)

lemma size_if_complete:
  "complete kdt  size_kdt kdt = 2 ^ height kdt"
  by (induction kdt) auto

lemma complete_if_size_height:
  "size_kdt kdt = 2 ^ height kdt  complete kdt"
proof (induction "height kdt" arbitrary: kdt)
  case 0 thus ?case by auto
next
  case (Suc h)
  hence "p. kdt = Leaf p"
    by auto
  then obtain k v l r where [simp]: "kdt = Node k v l r"
    using neq_Leaf_iff by metis
  have 1: "height l  h" and 2: "height r  h" using Suc(2) by(auto)
  have 3: "¬ height l < h"
  proof
    assume 0: "height l < h"
    have "size_kdt kdt = size_kdt l + size_kdt r" by simp
    also have "  2 ^ height l + 2 ^ height r"
      using size_height[of l] size_height[of r] by arith
    also have "  < 2 ^ h + 2 ^ height r" using 0 by (simp)
    also have "   2 ^ h + 2 ^ h" using 2 by (simp)
    also have " = 2 ^ (Suc h)" by (simp)
    also have " = size_kdt kdt" using Suc(2,3) by simp
    finally have "size_kdt kdt < size_kdt kdt" .
    thus False by (simp)
  qed
  have 4: "¬ height r < h"
  proof
    assume 0: "height r < h"
    have "size_kdt kdt = size_kdt l + size_kdt r" by simp
    also have "  2 ^ height l + 2 ^ height r"
      using size_height[of l] size_height[of r] by arith
    also have "  < 2 ^ height l + 2 ^ h" using 0 by (simp)
    also have "   2 ^ h + 2 ^ h" using 1 by (simp)
    also have " = 2 ^ (Suc h)" by (simp)
    also have " = size_kdt kdt" using Suc(2,3) by simp
    finally have "size_kdt kdt < size_kdt kdt" .
    thus False by (simp)
  qed
  from 1 2 3 4 have *: "height l = h" "height r = h" by linarith+
  hence "size_kdt l = 2 ^ height l" "size_kdt r = 2 ^ height r"
    using Suc(3) size_height[of l] size_height[of r] by (auto)
  with * Suc(1) show ?case by simp
qed

lemma complete_if_size_min_height:
  "size_kdt kdt = 2 ^ min_height kdt  complete kdt"
proof (induct "min_height kdt" arbitrary: kdt)
  case 0 thus ?case by auto
next
  case (Suc h)
  hence "p. kdt = Leaf p"
    by auto
  then obtain k v l r where [simp]: "kdt = Node k v l r"
    using neq_Leaf_iff by metis
  have 1: "h  min_height l" and 2: "h  min_height r" using Suc(2) by (auto)
  have 3: "¬ h < min_height l"
  proof
    assume 0: "h < min_height l"
    have "size_kdt kdt = size_kdt l + size_kdt r" by simp
    also note min_height_size[of l]
    also(xtrans) note min_height_size[of r]
    also(xtrans) have "(2::nat) ^ min_height l > 2 ^ h"
        using 0 by (simp add: diff_less_mono)
    also(xtrans) have "(2::nat) ^ min_height r  2 ^ h" using 2 by simp
    also(xtrans) have "(2::nat) ^ h + 2 ^ h = 2 ^ (Suc h)" by (simp)
    also have " = size_kdt kdt" using Suc(2,3) by simp
    finally show False by (simp add: diff_le_mono)
  qed
  have 4: "¬ h < min_height r"
  proof
    assume 0: "h < min_height r"
    have "size_kdt kdt = size_kdt l + size_kdt r" by simp
    also note min_height_size[of l]
    also(xtrans) note min_height_size[of r]
    also(xtrans) have "(2::nat) ^ min_height r > 2 ^ h"
        using 0 by (simp add: diff_less_mono)
    also(xtrans) have "(2::nat) ^ min_height l  2 ^ h" using 1 by simp
    also(xtrans) have "(2::nat) ^ h + 2 ^ h = 2 ^ (Suc h)" by (simp)
    also have " = size_kdt kdt" using Suc(2,3) by simp
    finally show False by (simp add: diff_le_mono)
  qed
  from 1 2 3 4 have *: "min_height l = h" "min_height r = h" by linarith+
  hence "size_kdt l = 2 ^ min_height l" "size_kdt r = 2 ^ min_height r"
    using Suc(3) min_height_size[of l] min_height_size[of r] by (auto)
  with * Suc(1) show ?case
    by (simp add: complete_iff_height)
qed

lemma complete_iff_size:
  "complete kdt  size_kdt kdt = 2 ^ height kdt"
  using complete_if_size_height size_if_complete by blast

lemma size_height_if_incomplete:
  "¬ complete kdt  size_kdt kdt < 2 ^ height kdt"
  by (meson antisym_conv complete_iff_size not_le size_height)

lemma min_height_size_if_incomplete:
  "¬ complete kdt  2 ^ min_height kdt < size_kdt kdt"
  by (metis complete_if_size_min_height le_less min_height_size)

lemma balanced_subtreeL:
  "balanced (Node k v l r)  balanced l"
  by (simp add: balanced_def)

lemma balanced_subtreeR:
  "balanced (Node k v l r)  balanced r"
  by (simp add: balanced_def)

lemma balanced_optimal:
  assumes "balanced kdt" "size_kdt kdt  size_kdt kdt'"
  shows "height kdt  height kdt'"
proof (cases "complete kdt")
  case True
  have "(2::nat) ^ height kdt  2 ^ height kdt'"
  proof -
    have "2 ^ height kdt = size_kdt kdt"
      using True by (simp add: complete_iff_height size_if_complete)
    also have "  size_kdt kdt'" using assms(2) by simp
    also have "  2 ^ height kdt'" by (rule size_height)
    finally show ?thesis .
  qed
  thus ?thesis by (simp)
next
  case False
  have "(2::nat) ^ min_height kdt < 2 ^ height kdt'"
  proof -
    have "(2::nat) ^ min_height kdt < size_kdt kdt"
      by(rule min_height_size_if_incomplete[OF False])
    also have "  size_kdt kdt'" using assms(2) by simp
    also have "  2 ^ height kdt'"  by(rule size_height)
    finally have "(2::nat) ^ min_height kdt < (2::nat) ^ height kdt'" .
    thus ?thesis .
  qed
  hence *: "min_height kdt < height kdt'" by simp
  have "min_height kdt + 1 = height kdt"
    using min_height_le_height[of kdt] assms(1) False
    by (simp add: complete_iff_height balanced_def)
  with * show ?thesis by arith
qed


subsection ‹Lemmas adapted from HOL-Library.Tree_Real› to k›-d Tree›

lemma size_height_log:
  "log 2 (size_kdt kdt)  height kdt"
  by (simp add: log2_of_power_le size_height)

lemma min_height_size_log:
  "min_height kdt  log 2 (size_kdt kdt)"
  by (simp add: le_log2_of_power min_height_size)

lemma size_log_if_complete:
  "complete kdt  height kdt = log 2 (size_kdt kdt)"
  using complete_iff_size log2_of_power_eq by blast

lemma min_height_size_log_if_incomplete:
  "¬ complete kdt  min_height kdt < log 2 (size_kdt kdt)"
  by (simp add: less_log2_of_power min_height_size_if_incomplete)

lemma min_height_balanced:
  assumes "balanced kdt"
  shows "min_height kdt = nat(floor(log 2 (size_kdt kdt)))"
proof cases
  assume *: "complete kdt"
  hence "size_kdt kdt = 2 ^ min_height kdt"
    by (simp add: complete_iff_height size_if_complete)
  from log2_of_power_eq[OF this] show ?thesis by linarith
next
  assume *: "¬ complete kdt"
  hence "height kdt = min_height kdt + 1"
    using assms min_height_le_height[of kdt]
    by(auto simp add: balanced_def complete_iff_height)
  hence "size_kdt kdt < 2 ^ (min_height kdt + 1)"
    by (metis * size_height_if_incomplete)
  hence "log 2 (size_kdt kdt) < min_height kdt + 1"
    using log2_of_power_less size_ge0 by blast
  thus ?thesis using min_height_size_log[of kdt] by linarith
qed

lemma height_balanced:
  assumes "balanced kdt"
  shows "height kdt = nat(ceiling(log 2 (size_kdt kdt)))"
proof cases
  assume *: "complete kdt"
  hence "size_kdt kdt = 2 ^ height kdt"
    by (simp add: size_if_complete)
  from log2_of_power_eq[OF this] show ?thesis
    by linarith
next
  assume *: "¬ complete kdt"
  hence **: "height kdt = min_height kdt + 1"
    using assms min_height_le_height[of kdt]
    by(auto simp add: balanced_def complete_iff_height)
  hence "size_kdt kdt  2 ^ (min_height kdt + 1)" by (metis size_height)
  from  log2_of_power_le[OF this size_ge0] min_height_size_log_if_incomplete[OF *] **
  show ?thesis by linarith
qed

lemma balanced_Node_if_wbal1:
  assumes "balanced l" "balanced r" "size_kdt l = size_kdt r + 1"
  shows "balanced (Node k v l r)"
proof -
  from assms(3) have [simp]: "size_kdt l = size_kdt r + 1" by simp
  have "nat log 2 (1 + size_kdt r)  nat log 2 (size_kdt r)"
    by(rule nat_mono[OF ceiling_mono]) simp
  hence 1: "height(Node k v l r) = nat log 2 (1 + size_kdt r) + 1"
    using height_balanced[OF assms(1)] height_balanced[OF assms(2)]
    by (simp del: nat_ceiling_le_eq add: max_def)
  have "nat log 2 (1 + size_kdt r)  nat log 2 (size_kdt r)"
    by(rule nat_mono[OF floor_mono]) simp
  hence 2: "min_height(Node k v l r) = nat log 2 (size_kdt r) + 1"
    using min_height_balanced[OF assms(1)] min_height_balanced[OF assms(2)]
    by (simp)
  have "size_kdt r  1" by (simp add: Suc_leI)
  then obtain i where i: "2 ^ i  size_kdt r" "size_kdt r < 2 ^ (i + 1)"
    using ex_power_ivl1[of 2 "size_kdt r"] by auto
  hence i1: "2 ^ i < size_kdt r + 1" "size_kdt r + 1  2 ^ (i + 1)" by auto
  from 1 2 floor_log_nat_eq_if[OF i] ceiling_log_nat_eq_if[OF i1]
  show ?thesis by(simp add:balanced_def)
qed

lemma balanced_sym:
  "balanced (Node k v l r)  balanced (Node k' v' r l)"
  by (auto simp: balanced_def)

lemma balanced_Node_if_wbal2:
  assumes "balanced l" "balanced r" "abs(int(size_kdt l) - int(size_kdt r))  1"
  shows "balanced (Node k v l r)"
proof -
  have "size_kdt l = size_kdt r  (size_kdt l = size_kdt r + 1  size_kdt r = size_kdt l + 1)" (is "?A  ?B")
    using assms(3) by linarith
  thus ?thesis
  proof
    assume "?A"
    thus ?thesis using assms(1,2)
      apply(simp add: balanced_def min_def max_def)
      by (metis assms(1,2) balanced_optimal le_antisym le_less)
  next
    assume "?B"
    thus ?thesis
      by (meson assms(1,2) balanced_sym balanced_Node_if_wbal1)
  qed
qed

end