Theory PST_RBT

section ‹Priority Search Trees on top of RBTs›

theory PST_RBT
imports
  "HOL-Data_Structures.Cmp"
  "HOL-Data_Structures.Isin2"
  "HOL-Data_Structures.Lookup2"
  PST_General
begin
  
text ‹
We obtain a priority search map based on red-black trees via the 
general priority search tree augmentation.

This theory has been derived from the standard Isabelle implementation of red 
black trees in @{session "HOL-Data_Structures"}.
›

subsection ‹Definitions›

subsubsection ‹The Code›

datatype tcolor = Red | Black

type_synonym ('k,'p) rbth = "(('k×'p) × (tcolor × ('k × 'p))) tree"

abbreviation R where "R mkp l a r  Node l (a, Red,mkp) r"
abbreviation B where "B mkp l a r  Node l (a, Black,mkp) r"

abbreviation "mkR  mkNode Red"
abbreviation "mkB  mkNode Black"

fun baliL :: "('k,'p::linorder) rbth  'k×'p  ('k,'p) rbth  ('k,'p) rbth" 
  where
  "baliL (R _ (R _ t1 a1 t2) a2 t3) a3 t4 = mkR (mkB t1 a1 t2) a2 (mkB t3 a3 t4)"
| "baliL (R _ t1 a1 (R _ t2 a2 t3)) a3 t4 = mkR (mkB t1 a1 t2) a2 (mkB t3 a3 t4)"
| "baliL t1 a t2 = mkB t1 a t2"

fun baliR :: "('k,'p::linorder) rbth  'k×'p  ('k,'p) rbth  ('k,'p) rbth" 
  where
"baliR t1 a1 (R _ (R _ t2 a2 t3) a3 t4) = mkR (mkB t1 a1 t2) a2 (mkB t3 a3 t4)" |
"baliR t1 a1 (R _ t2 a2 (R _ t3 a3 t4)) = mkR (mkB t1 a1 t2) a2 (mkB t3 a3 t4)" |
"baliR t1 a t2 = mkB t1 a t2"

fun paint :: "tcolor  ('k,'p::linorder) rbth  ('k,'p::linorder) rbth" where
"paint c Leaf = Leaf" |
"paint c (Node l (a, (_,mkp)) r) = Node l (a, (c,mkp)) r"

fun baldL :: "('k,'p::linorder) rbth  'k × 'p  ('k,'p::linorder) rbth 
     ('k,'p::linorder) rbth" 
where
"baldL (R _ t1 x t2) y t3 = mkR (mkB t1 x t2) y t3" |
"baldL bl x (B _ t1 y t2) = baliR bl x (mkR t1 y t2)" |
"baldL bl x (R _ (B _ t1 y t2) z t3) 
  = mkR (mkB bl x t1) y (baliR t2 z (paint Red t3))" |
"baldL t1 x t2 = mkR t1 x t2"

fun baldR :: "('k,'p::linorder) rbth  'k × 'p  ('k,'p::linorder) rbth 
     ('k,'p::linorder) rbth" 
where
"baldR t1 x (R _ t2 y t3) = mkR t1 x (mkB t2 y t3)" |
"baldR (B _ t1 x t2) y t3 = baliL (mkR t1 x t2) y t3" |
"baldR (R _ t1 x (B _ t2 y t3)) z t4 
  = mkR (baliL (paint Red t1) x t2) y (mkB t3 z t4)" |
"baldR t1 x t2 = mkR t1 x t2"

fun combine :: "('k,'p::linorder) rbth  ('k,'p::linorder) rbth 
     ('k,'p::linorder) rbth" 
where
"combine Leaf t = t" |
"combine t Leaf = t" |
"combine (R _ t1 a t2) (R _ t3 c t4) =
  (case combine t2 t3 of
     R _ u2 b u3  (mkR (mkR t1 a u2) b (mkR u3 c t4)) |
     t23  mkR t1 a (mkR t23 c t4))" |
"combine (B _ t1 a t2) (B _ t3 c t4) =
  (case combine t2 t3 of
     R _ t2' b t3'  mkR (mkB t1 a t2') b (mkB t3' c t4) |
     t23  baldL t1 a (mkB t23 c t4))" |
"combine t1 (R _ t2 a t3) = mkR (combine t1 t2) a t3" |
"combine (R _ t1 a t2) t3 = mkR t1 a (combine t2 t3)"

fun color :: "('k,'p) rbth  tcolor" where
"color Leaf = Black" |
"color (Node _ (_, (c,_)) _) = c"


fun upd :: "'a::linorder  'b::linorder  ('a,'b) rbth  ('a,'b) rbth" where
"upd x y Leaf = mkR Leaf (x,y) Leaf" |
"upd x y (B _ l (a,b) r) = (case cmp x a of
  LT  baliL (upd x y l) (a,b) r |
  GT  baliR l (a,b) (upd x y r) |
  EQ  mkB l (x,y) r)" |
"upd x y (R _ l (a,b) r) = (case cmp x a of
  LT  mkR (upd x y l) (a,b) r |
  GT  mkR l (a,b) (upd x y r) |
  EQ  mkR l (x,y) r)"

definition update :: "'a::linorder  'b::linorder  ('a,'b) rbth  ('a,'b) rbth" 
where
"update x y t = paint Black (upd x y t)"


fun del :: "'a::linorder  ('a,'b::linorder)rbth  ('a,'b)rbth" where
"del x Leaf = Leaf" |
"del x (Node l ((a,b), (c,_)) r) = (case cmp x a of
     LT  if l  Leaf  color l = Black
           then baldL (del x l) (a,b) r else mkR (del x l) (a,b) r |
     GT  if r  Leaf color r = Black
           then baldR l (a,b) (del x r) else mkR l (a,b) (del x r) |
  EQ  combine l r)"

definition delete :: "'a::linorder  ('a,'b::linorder) rbth  ('a,'b) rbth" where
"delete x t = paint Black (del x t)"


subsubsection ‹Invariants›

fun bheight :: "('k,'p) rbth  nat" where
"bheight Leaf = 0" |
"bheight (Node l (x, (c,_)) r) = (if c = Black then bheight l + 1 else bheight l)"

fun invc :: "('k,'p) rbth  bool" where
"invc Leaf = True" |
"invc (Node l (a, (c,_)) r) =
  (invc l  invc r  (c = Red  color l = Black  color r = Black))"

fun invc2 :: "('k,'p) rbth  bool" ― ‹Weaker version› where
"invc2 Leaf = True" |
"invc2 (Node l (a, _) r) = (invc l  invc r)"

fun invh :: "('k,'p) rbth  bool" where
"invh Leaf = True" |
"invh (Node l (x, _) r) = (invh l  invh r  bheight l = bheight r)"

definition rbt :: "('k,'p::linorder) rbth  bool" where
"rbt t = (invc t  invh t  invpst t  color t = Black)"


subsection ‹Functional Correctness›

lemma inorder_paint[simp]: "inorder(paint c t) = inorder t"
by(cases t) (auto)

lemma inorder_mkNode[simp]:
  "inorder (mkNode c l a r) = inorder l @ a # inorder r"
by (auto simp: mkNode_def)


lemma inorder_baliL[simp]:
  "inorder(baliL l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baliL.cases) (auto)

lemma inorder_baliR[simp]:
  "inorder(baliR l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baliR.cases) (auto)


lemma inorder_baldL[simp]:
  "inorder(baldL l a r) = inorder l @ a # inorder r"
by (cases "(l,a,r)" rule: baldL.cases) auto

lemma inorder_baldR[simp]:
  "inorder(baldR l a r) = inorder l @ a # inorder r"
by(cases "(l,a,r)" rule: baldR.cases) auto

lemma inorder_combine[simp]:
  "inorder(combine l r) = inorder l @ inorder r"
by (induction l r rule: combine.induct) (auto split: tree.split tcolor.split)

lemma inorder_upd:
  "sorted1(inorder t)  inorder(upd x y t) = upd_list x y (inorder t)"
by(induction x y t rule: upd.induct)
  (auto simp: upd_list_simps)

lemma inorder_update:
  "sorted1(inorder t)  inorder(update x y t) = upd_list x y (inorder t)"
by(simp add: update_def inorder_upd)

lemma inorder_del:
 "sorted1(inorder t)   inorder(del x t) = del_list x (inorder t)"
by(induction x t rule: del.induct)
  (auto simp: del_list_simps)

lemma inorder_delete:
  "sorted1(inorder t)  inorder(delete x t) = del_list x (inorder t)"
by(simp add: delete_def inorder_del)


subsection ‹Invariant Preservation›

lemma color_paint_Black: "color (paint Black t) = Black"
by (cases t) auto

theorem rbt_Leaf: "rbt Leaf"
by (simp add: rbt_def)

lemma invc2I: "invc t  invc2 t"
by (cases t rule: invc.cases) simp+

lemma paint_invc2: "invc2 t  invc2 (paint c t)"
by (cases t) auto

lemma invc_paint_Black: "invc2 t  invc (paint Black t)"
by (cases t) auto

lemma invh_paint: "invh t  invh (paint c t)"
by (cases t) auto

lemma invc_mkRB[simp]:
  "invc (mkR l a r)  invc l  invc r  color l = Black  color r = Black"
  "invc (mkB l a r)  invc l  invc r"
by (simp_all add: mkNode_def)

lemma color_mkNode[simp]: "color (mkNode c l a r) = c"
by (simp_all add: mkNode_def)


subsubsection ‹Update›

lemma invc_baliL:
  "invc2 l; invc r  invc (baliL l a r)"
by (induct l a r rule: baliL.induct) auto

lemma invc_baliR:
  "invc l; invc2 r  invc (baliR l a r)"
by (induct l a r rule: baliR.induct) auto

lemma bheight_mkRB[simp]:
  "bheight (mkR l a r) = bheight l"
  "bheight (mkB l a r) = Suc (bheight l)"
  by (simp_all add: mkNode_def)

lemma bheight_baliL:
  "bheight l = bheight r  bheight (baliL l a r) = Suc (bheight l)"
by (induct l a r rule: baliL.induct) auto

lemma bheight_baliR:
  "bheight l = bheight r  bheight (baliR l a r) = Suc (bheight l)"
by (induct l a r rule: baliR.induct) auto

lemma invh_mkNode[simp]:
  "invh (mkNode c l a r)  invh l  invh r  bheight l = bheight r"
by (simp add: mkNode_def)

lemma invh_baliL:
  " invh l; invh r; bheight l = bheight r   invh (baliL l a r)"
by (induct l a r rule: baliL.induct) auto

lemma invh_baliR:
  " invh l; invh r; bheight l = bheight r   invh (baliR l a r)"
by (induct l a r rule: baliR.induct) auto


lemma invc_upd: assumes "invc t"
  shows "color t = Black  invc (upd x y t)" "invc2 (upd x y t)"
using assms
by (induct x y t rule: upd.induct) 
   (auto simp: invc_baliL invc_baliR invc2I mkNode_def)

lemma invh_upd: assumes "invh t"
  shows "invh (upd x y t)" "bheight (upd x y t) = bheight t"
using assms
by(induct x y t rule: upd.induct)
  (auto simp: invh_baliL invh_baliR bheight_baliL bheight_baliR)


lemma invpst_paint[simp]: "invpst (paint c t) = invpst t"
by (cases "(c,t)" rule: paint.cases) auto

lemma invpst_baliR: "invpst l  invpst r  invpst (baliR l a r)"
by (cases "(l,a,r)" rule: baliR.cases) auto

lemma invpst_baliL: "invpst l  invpst r  invpst (baliL l a r)"
by (cases "(l,a,r)" rule: baliL.cases) auto

lemma invpst_upd: "invpst t  invpst (upd x y t)"
by (induct x y t rule: upd.induct) (auto simp: invpst_baliR invpst_baliL)


theorem rbt_update: "rbt t  rbt (update x y t)"
by (simp add: invc_upd(2) invh_upd(1) color_paint_Black invc_paint_Black 
  invh_paint rbt_def update_def invpst_upd)


subsubsection ‹Delete›

lemma bheight_paint_Red:
  "color t = Black  bheight (paint Red t) = bheight t - 1"
by (cases t) auto

lemma invh_baldL_invc:
  " invh l;  invh r;  bheight l + 1 = bheight r;  invc r 
    invh (baldL l a r)  bheight (baldL l a r) = bheight l + 1"
by (induct l a r rule: baldL.induct)
   (auto simp: invh_baliR invh_paint bheight_baliR bheight_paint_Red)

lemma invh_baldL_Black:
  " invh l;  invh r;  bheight l + 1 = bheight r;  color r = Black 
    invh (baldL l a r)  bheight (baldL l a r) = bheight r"
by (induct l a r rule: baldL.induct) (auto simp add: invh_baliR bheight_baliR)

lemma invc_baldL: "invc2 l; invc r; color r = Black  invc (baldL l a r)"
by (induct l a r rule: baldL.induct) (auto simp: invc_baliR invc2I mkNode_def)

lemma invc2_baldL: " invc2 l; invc r   invc2 (baldL l a r)"
by (induct l a r rule: baldL.induct) 
   (auto simp: invc_baliR paint_invc2 invc2I mkNode_def)

lemma invh_baldR_invc:
  " invh l;  invh r;  bheight l = bheight r + 1;  invc l 
   invh (baldR l a r)  bheight (baldR l a r) = bheight l"
by(induct l a r rule: baldR.induct)
  (auto simp: invh_baliL bheight_baliL invh_paint bheight_paint_Red)

lemma invc_baldR: "invc a; invc2 b; color a = Black  invc (baldR a x b)"
by (induct a x b rule: baldR.induct) (simp_all add: invc_baliL mkNode_def)

lemma invc2_baldR: " invc l; invc2 r  invc2 (baldR l x r)"
by (induct l x r rule: baldR.induct) 
   (auto simp: invc_baliL paint_invc2 invc2I mkNode_def)

lemma invh_combine:
  " invh l; invh r; bheight l = bheight r 
   invh (combine l r)  bheight (combine l r) = bheight l"
by (induct l r rule: combine.induct)
   (auto simp: invh_baldL_Black split: tree.splits tcolor.splits)

lemma invc_combine:
  assumes "invc l" "invc r"
  shows "color l = Black  color r = Black  invc (combine l r)"
         "invc2 (combine l r)"
using assms
by (induct l r rule: combine.induct)
   (auto simp: invc_baldL invc2I mkNode_def split: tree.splits tcolor.splits)

lemma neq_LeafD: "t  Leaf  l x c r. t = Node l (x,c) r"
by(cases t) auto

lemma del_invc_invh: "invh t  invc t  invh (del x t) 
   (color t = Red  bheight (del x t) = bheight t  invc (del x t) 
    color t = Black  bheight (del x t) = bheight t - 1  invc2 (del x t))"
proof (induct x t rule: del.induct)
case (2 x _ y _ c)
  have "x = y  x < y  x > y" by auto
  thus ?case proof (elim disjE)
    assume "x = y"
    with 2 show ?thesis
    by (cases c) (simp_all add: invh_combine invc_combine)
  next
    assume "x < y"
    with 2 show ?thesis
      by(cases c)
        (auto 
          simp: invh_baldL_invc invc_baldL invc2_baldL mkNode_def 
          dest: neq_LeafD)
  next
    assume "y < x"
    with 2 show ?thesis
      by(cases c)
        (auto 
          simp: invh_baldR_invc invc_baldR invc2_baldR mkNode_def 
          dest: neq_LeafD)
  qed
qed auto

lemma invpst_baldR: "invpst l  invpst r  invpst (baldR l a r)"
by (cases "(l,a,r)" rule: baldR.cases) (auto simp: invpst_baliL)

lemma invpst_baldL: "invpst l  invpst r  invpst (baldL l a r)"
by (cases "(l,a,r)" rule: baldL.cases) (auto simp: invpst_baliR)

lemma invpst_combine: "invpst l  invpst r  invpst (combine l r)"
by(induction l r rule: combine.induct)
  (auto split: tree.splits tcolor.splits simp: invpst_baldR invpst_baldL)

lemma invpst_del: "invpst t  invpst (del x t)"
by(induct x t rule: del.induct)
  (auto simp: invpst_baldR invpst_baldL invpst_combine)

theorem rbt_delete: "rbt t  rbt (delete k t)"
apply (clarsimp simp: delete_def rbt_def)
apply (frule (1) del_invc_invh[where x=k])
apply (auto simp: invc_paint_Black invh_paint color_paint_Black invpst_del)
done

lemma rbt_getmin_ismin: 
  "rbt t  tLeaf  is_min2 (pst_getmin t) (set_tree t)"
unfolding rbt_def by (simp add: pst_getmin_ismin)

definition "rbt_is_empty t  t = Leaf"

lemma rbt_is_empty: "rbt_is_empty t  inorder t = []"
by (cases t) (auto simp: rbt_is_empty_def)

definition empty where "empty = Leaf"


subsection ‹Overall Correctness›

interpretation PM: PrioMap_by_Ordered
where empty = empty and lookup = lookup and update = update and delete = delete
and inorder = inorder and inv = "rbt" and is_empty = rbt_is_empty 
and getmin = pst_getmin
apply standard
apply (auto simp: lookup_map_of inorder_update inorder_delete rbt_update 
                  rbt_delete rbt_Leaf rbt_is_empty empty_def 
            dest: rbt_getmin_ismin)
done

end