Theory Optimal_BST

(* Author: Tobias Nipkow, based on work by Daniel Somogyi *)

section ‹Optimal BSTs: The `Cubic' Algorithm\label{sec:cubic}›

theory Optimal_BST
imports Weighted_Path_Length Monad_Memo_DP.OptBST
begin

subsection ‹Function argmin›

text ‹Function argmin› was moved to Monad_Memo_DP.argmin›.
It iterates over a list and returns the rightmost element that minimizes a given function:

@{thm [display] argmin.simps}

text ‹An optimized version that avoids repeated computation of f x›:›

fun argmin2 :: "('a  ('b::linorder))  'a list  'a * 'b" where
"argmin2 f (x#xs) =
  (let fx = f x
   in if xs = [] then (x, fx)
      else let mfm = argmin2 f xs
           in if fx < snd mfm then (x,fx) else mfm)"

lemma argmin2_argmin: "xs  []  argmin2 f xs = (argmin f xs, f(argmin f xs))"
by (induction xs) (auto simp: Let_def)

lemma argmin_argmin2[code]: "argmin f xs = (if xs = [] then undefined else fst(argmin2 f xs))"
apply(auto simp: argmin2_argmin)
apply (meson argmin.elims list.distinct(1))
done



lemma argmin_in: "xs  []  argmin f xs  set xs"
using argmin_forall[of xs "λx. xset xs"] by blast

lemma argmin_pairs: "xs  [] 
  (argmin f xs,f (argmin f xs)) = argmin snd (map (λx. (x,f x)) xs)"
by (induction f xs rule:argmin.induct) (auto, smt snd_conv)

lemma argmin_map: "xs  []  argmin c (map f xs) = f(argmin (c o f) xs)"
by(induction xs) (simp_all add: Let_def)


subsection ‹The `Cubic' Algorithm›

text ‹We hide the details of the access frequencies a› and b› by working with an abstract
version of function w› definied above (summing a› and b›). Later we interpret w› accordingly.›

locale Optimal_BST =
fixes w :: "int  int  nat"
begin

subsubsection ‹Functions wpl› and min_wpl›

sublocale wpl where w = w .

text ‹Function min_wpl i j› computes the minimal weighted path length of any tree t›
where @{prop"inorder t = [i..j]"}. It simply tries all possible indices between i› and j›
as the root. Thus it implicitly constructs all possible trees.›

declare conj_cong [fundef_cong]
function min_wpl :: "int  int  nat" where
"min_wpl i j =
  (if i > j then 0
   else Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j) ` {i..j}) + w i j)"
by auto
termination by (relation "measure (λ(i,j). nat(j-i+1))") auto
declare min_wpl.simps[simp del]

text ‹Note that for efficiency reasons we have pulled + w i j› out of Min›.
In the lemma below this is reversed because it simplifies the proofs.
Similar optimizations are possible in other functions below.›

lemma min_wpl_simps[simp]:
  "i > j  min_wpl i j = 0"
  "i  j  min_wpl i j =
     Min ((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j})"
by(auto simp add: min_wpl.simps[of i j] Min_add_commute)

lemma upto_split1: 
  " i  j;  j  k   [i..k] = [i..j-1] @ [j..k]"
proof (induction j rule: int_ge_induct)
  case base thus ?case by (simp add: upto_rec1)
next
  case step thus ?case using upto_rec1 upto_rec2 by simp
qed

text‹Function @{const min_wpl} returns a lower bound for all possible BSTs:›

theorem min_wpl_is_optimal:
  "inorder t = [i..j]  min_wpl i j  wpl i j t"
proof(induction i j t rule: wpl.induct)
  case 1
  thus ?case by(simp add: upto.simps split: if_splits)
next
  case (2 i j l k r)
  then show ?case 
  proof cases
    assume "i > j" thus ?thesis by(simp)
  next
    assume [arith]: "¬ i > j"

    note inorder = inorder_upto_split[OF "2.prems"]
        
    let ?M = "(λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j}"
    let ?w = "min_wpl i (k-1) + min_wpl (k+1) j + w i j"
 
    have aux_min:"Min ?M  ?w"
    proof (rule Min_le)
      show "finite ?M" by simp
      show "?w  ?M" using inorder(3,4) by simp
    qed

    have "min_wpl i j = Min ?M" by(simp)
    also have "...  ?w" by (rule aux_min)    
    also have "...  wpl i (k-1) l + wpl (k+1) j r + w i j"
      using inorder(1,2) "2.IH" by simp
    also have "... = wpl i j l,k,r" by simp
    finally show ?thesis .
  qed
qed

text ‹Now we show that the lower bound computed by @{const min_wpl}
is the wpl of an optimal tree that can be computed in the same manner.›

subsubsection ‹Function opt_bst›

text‹This is the functional equivalent of the standard cubic imperative algorithm.
Unless it is memoized, the complexity is again exponential.
The pattern of recursion is the same as for @{const min_wpl} but instead of the minimal weight
it computes a tree with the minimal weight:›

function opt_bst :: "int  int  int tree" where
"opt_bst i j =
  (if i > j then Leaf
   else argmin (wpl i j) [opt_bst i (k-1), k, opt_bst (k+1) j. k  [i..j]])"
by auto
termination by (relation "measure (λ(i,j) . nat(j-i+1))") auto
declare opt_bst.simps[simp del]

corollary opt_bst_simps[simp]:
  "i > j  opt_bst i j = Leaf"
  "i  j  opt_bst i j =
     (argmin (wpl i j) [opt_bst i (k-1), k, opt_bst (k+1) j. k  [i..j]])"
by(auto simp add: opt_bst.simps[of i j])

text ‹As promised, @{const opt_bst} computes a tree with the minimal wpl:›

theorem wpl_opt_bst: "wpl i j (opt_bst i j) = min_wpl i j"
proof(induction i j rule: min_wpl.induct)
  case (1 i j)
  show ?case
  proof cases
    assume "i > j" 
    thus ?thesis by(simp)
  next
    assume [arith]: "¬ i > j"
    let ?ts = "[opt_bst i (k-1), k, opt_bst (k+1) j. k  [i..j]]"
    let ?M = "((λk. min_wpl i (k-1) + min_wpl (k+1) j + w i j) ` {i..j})"
    have 1: "?ts  []" by (auto simp add: upto.simps)
    have "wpl i j (opt_bst i j) = wpl i j (argmin (wpl i j) ?ts)" by simp
    also have " = Min (wpl i j ` (set ?ts))"
      by(rule argmin_Min[OF 1])
    also have " = Min ?M"
    proof (rule arg_cong[where f=Min])
      show "wpl i j ` (set ?ts) = ?M" using "1.IH"
        by (force simp: Bex_def image_iff "1.IH")
    qed
    also have " = min_wpl i j" by simp
    finally show ?thesis .
  qed
qed

corollary opt_bst_is_optimal:
  "inorder t = [i..j]  wpl i j (opt_bst i j)  wpl i j t"
by (simp add: min_wpl_is_optimal wpl_opt_bst)

subsubsection ‹Function opt_bst_wpl›

text ‹Function @{const opt_bst} is simplistic because it computes the wpl
of each tree anew rather than returning it with the tree. That is what opt_bst_wpl› does:›

function opt_bst_wpl :: "int  int  int tree × nat" where
"opt_bst_wpl i j = 
  (if i > j then (Leaf, 0)
   else argmin snd [let (t1,c1) = opt_bst_wpl i (k-1);
                        (t2,c2) = opt_bst_wpl (k+1) j
                     in (t1,k,t2, c1 + c2 + w i j). k  [i..j]])"
by auto
termination
  by (relation "measure (λ(i,j). nat(j-i+1))")(auto)
declare opt_bst_wpl.simps[simp del]

text‹Function @{const opt_bst_wpl} returns an optimal tree and its wpl:›

lemma opt_bst_wpl_eq_pair:
  "opt_bst_wpl i j = (opt_bst i j, wpl i j (opt_bst i j))"
proof(induction i j rule: opt_bst_wpl.induct)
  case (1 i j)
  note [simp] = opt_bst_wpl.simps[of i j]
  show ?case 
  proof cases
    assume "i > j" thus ?thesis using "1.prems" by auto
  next
    assume "¬ i > j"
    thus ?thesis by (simp add: argmin_pairs comp_def "1.IH" cong: list.map_cong_simp)
  qed
qed

corollary opt_bst_wpl_eq_pair': "opt_bst_wpl i j = (opt_bst i j, min_wpl i j)"
by (simp add: opt_bst_wpl_eq_pair wpl_opt_bst)

end (* locale Optimal_BST *)

end