File ‹dpt_sat_solver.ML›
signature DPT_SAT_SOLVER =
sig
type solver
val empty_solver : unit -> solver
val init : solver -> int list list -> unit
val solve : solver -> unit
val satisfied : solver -> bool
val assignment : solver -> int -> bool option
end
structure DPT_SAT_Solver : DPT_SAT_SOLVER =
struct
structure Vec =
struct
fun limit_for x = Array.maxLen
fun capacity_for n x =
(if (n > Array.maxLen) then
error "capacity_for: Invalid_argument \"exceeds limit\""
else () ;
Int.max (1, n))
fun capacity_atleast m n x =
let val result = Int.min (limit_for x, (Int.max (2 * n, m)))
in if (m > result) then
error "capacity_atleast: Invalid_argument \"exceeds limit\""
else () ;
result end
fun array_make size x cap triv =
let val a = Array.array (cap, triv)
val n = Unsynchronized.ref 0
in while !n < size do
(
Array.update (a, !n, x) ;
n := !n + 1
) ;
a end
fun array_init size f cap triv =
let val a = Array.array (cap, triv)
val n = Unsynchronized.ref 0
in while !n < size do
(
Array.update (a, !n, f (!n)) ;
n := !n + 1
) ;
a end
fun array_extend a size cap triv =
let val a' = Array.array (cap, triv)
in ArraySlice.copy {src = ArraySlice.slice (a, 0, SOME size),
dst = a',
di = 0} ; a' end
fun array_top a size = Array.sub (a, size - 1)
fun array_pop a size triv =
let val last = size - 1
val result = Array.sub (a, last)
in Array.update (a, last, triv) ;
result end
fun array_push a size x = Array.update (a, size, x)
fun array_delete a size i triv =
let val last = size - 1
in ArraySlice.copy {src = ArraySlice.slice (a, i + 1, SOME (last - i)),
dst = a,
di = i} ;
last end
fun array_swap_out a size i triv =
let val last = size - 1
in Array.update (a, i, Array.sub (a, last)) ;
Array.update (a, last, triv) ;
last end
fun array_insert a size i x =
(ArraySlice.copy {src = ArraySlice.slice (a, i, SOME (size - i)),
dst = a,
di = i + 1} ;
Array.update (a, i, x))
fun array_fast_filter p a i size triv =
if i = size then
size
else if p (Array.sub (a, i)) then
array_fast_filter p a (i + 1) size triv
else
(
Array.update (a, i, Array.sub (a, size - 1)) ;
Array.update (a, size - 1, triv) ;
array_fast_filter p a i (size - 1) triv
)
fun array_fast_filter_const p x a i size triv =
if i = size then
size
else if p x (Array.sub (a, i)) then
array_fast_filter_const p x a (i + 1) size triv
else
(
Array.update (a, i, Array.sub (a, size - 1)) ;
Array.update (a, size - 1, triv) ;
array_fast_filter_const p x a i (size - 1) triv
)
fun array_first_not p a i size =
if i = size then size
else if not (p (Array.sub (a, i))) then i
else array_first_not p a (i + 1) size
fun array_first_not2 p x a i size =
if i = size then size
else if not (p x (Array.sub (a, i))) then i
else array_first_not2 p x a (i + 1) size
fun array_first_noti p a i size =
if i = size then size
else if not (p i (Array.sub (a, i))) then i
else array_first_noti p a (i + 1) size
fun array_first_noti2 p x a i size =
if i = size then size
else if not (p x i (Array.sub (a, i))) then i
else array_first_noti2 p x a (i + 1) size
fun array_filter p a i size triv =
let val j = Unsynchronized.ref (array_first_not p a i size)
val k = Unsynchronized.ref (!j + 1)
in
while !k < size do
(
let val x = Array.sub (a, !k)
in if p x then
(
Array.update (a, !j, x) ;
j := !j + 1
) else () ;
k := !k + 1 end
) ;
k := !j ;
while !k < size do
(
Array.update (a, !k, triv) ;
k := !k + 1
) ;
!j
end
fun array_filter_const p x a i size triv =
let val j = Unsynchronized.ref (array_first_not2 p x a i size)
val k = Unsynchronized.ref (!j + 1)
in
while !k < size do
(
let val y = Array.sub (a, !k)
in if p x y then
(
Array.update (a, !j, y) ;
j := !j + 1
) else () end
) ;
k := !j ;
while !k < size do
(
Array.update (a, !k, triv) ;
k := !k + 1
) ;
!j
end
fun array_filteri p a i size triv =
let val j = Unsynchronized.ref (array_first_noti p a i size)
val k = Unsynchronized.ref (!j + 1)
in
while !k < size do
(
let val x = Array.sub (a, !k) in
if p (!k) x then
(
Array.update (a, !j, x) ;
j := !j + 1
) else () ;
k := !k + 1 end
) ;
k := !j ;
while !k < size do
(
Array.update (a, !k, triv) ;
k := !k + 1
) ;
!j
end
fun array_filteri_const p x a i size triv =
let val j = Unsynchronized.ref (array_first_noti2 p x a i size)
val k = Unsynchronized.ref (!j + 1)
in
while !k < size do
(
let val y = Array.sub (a, !k)
in if p x (!k) y then
(
Array.update (a, !j, y) ;
j := !j + 1
) else () end
) ;
k := !j ;
while !k < size do
(
Array.update (a, !k, triv) ;
k := !k + 1
) ;
!j
end
fun subarray_first p a i j =
if i = j then j
else if p (Array.sub (a, i)) then i
else subarray_first p a (i + 1) j
fun subarray_first_const p x a i j =
if i = j then j
else if p x (Array.sub (a, i)) then i
else subarray_first_const p x a (i + 1) j
fun array_first p a i j =
let val result = subarray_first p a i j
in if result = j then error "array_first: Not_found"
else result end
fun array_first_const p x a i j =
let val result = subarray_first_const p x a i j
in if result = j then raise error "array_first: Not_found"
else result end
type 'a t =
{contents : 'a array Unsynchronized.ref,
size : int Unsynchronized.ref,
dummy: 'a}
fun make n x triv =
{contents = Unsynchronized.ref (array_make n x (capacity_for n triv) triv),
size = Unsynchronized.ref n,
dummy = triv} : 'a t
val create = make
fun init n f triv =
{contents = Unsynchronized.ref (array_init n f (capacity_for n triv) triv),
size = Unsynchronized.ref n,
dummy = triv} : 'a t
fun copy (v : 'a t) =
{contents = Unsynchronized.ref (Array.tabulate
(Array.length (!(#contents v)),
(fn i => Array.sub (!(#contents v), i)))),
size = Unsynchronized.ref (!(#size v)),
dummy = (#dummy v)} : 'a t
fun capacity_above n x =
let val result = Int.min (limit_for x, 2 * n) in
if (n >= result) then
error "capacity_above: Invalid_argument \"exceeds limit\""
else () ;
result end
fun ensure_cap (v : 'a t) =
let val cap = Array.length (!(#contents v))
in if !(#size v) = cap then
(#contents v) := array_extend
(!(#contents v)) cap
(capacity_above cap (#dummy v))
(#dummy v)
else () end
fun of_array a triv =
let val n = Array.length a
in {contents = Unsynchronized.ref (array_extend a n (capacity_for n triv) triv),
size = Unsynchronized.ref n,
dummy = triv} : 'a t end
fun of_list l triv =
let val n = length l
val a = Array.array (capacity_for n triv, triv)
fun init i [] = ()
| init i (h :: t) = (Array.update (a, i, h) ;
init (i + 1) t)
in init 0 l ;
{contents = Unsynchronized.ref a,
size = Unsynchronized.ref n,
dummy = triv} : 'a t end
fun to_array (v : 'a t) =
let val a = Array.array (!(#size v), #dummy v)
in ArraySlice.copy {src = ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))),
dst = a, di = 0} ; a end
fun to_list (v : 'a t) =
ArraySlice.foldr (op ::) []
(ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))))
fun length (v : 'a t) = !(#size v)
fun is_empty (v : 'a t) = !(#size v) = 0
fun grow (v : 'a t) n x =
let val cap = Array.length (!(#contents v))
in if n > cap then
(
(#contents v) :=
array_extend (!(#contents v)) (!(#size v))
(capacity_atleast n cap (#dummy v)) (#dummy v)
)
else () ;
ArraySlice.modify (fn _ => x)
(ArraySlice.slice (!(#contents v),
!(#size v),
SOME (n - !(#size v)))) ;
(#size v) := n end
fun grow_init (v : 'a t) n f =
let val cap = Array.length (!(#contents v))
in if n > cap then
(
(#contents v) :=
array_extend (!(#contents v)) (!(#size v))
(capacity_atleast n cap (#dummy v)) (#dummy v)
)
else () ;
ArraySlice.modify f
(ArraySlice.slice (!(#contents v),
!(#size v),
SOME (n - !(#size v)))) ;
(#size v) := n end
fun shrink (v : 'a t) n =
(ArraySlice.modify (fn _ => #dummy v)
(ArraySlice.slice (!(#contents v), n,
SOME (!(#size v) - n))) ;
(#size v) := n)
fun clear (v : 'a t) = shrink v 0
fun get (v : 'a t) i = Array.sub (!(#contents v), i)
fun set (v : 'a t) i x = Array.update (!(#contents v), i, x)
fun set_all (v : 'a t) i j x =
ArraySlice.modify (fn _ => x)
(ArraySlice.slice (!(#contents v),
i, SOME (j - i)))
fun top (v : 'a t) = array_top (!(#contents v)) (!(#size v))
fun pop (v : 'a t) =
let val result = array_pop (!(#contents v)) (!(#size v)) (#dummy v)
in (#size v) := !(#size v) - 1 ;
result end
fun push (v : 'a t) x =
(ensure_cap v ;
array_push (!(#contents v)) (!(#size v)) x ;
(#size v) := !(#size v) + 1)
fun insert (v : 'a t) i x =
(ensure_cap v ;
array_insert (!(#contents v)) (!(#size v)) i x ;
(#size v) := !(#size v) + 1)
fun delete (v : 'a t) i =
(#size v) := array_delete (!(#contents v)) (!(#size v)) i (#dummy v)
fun array_swap a i j =
if i <> j then
let val tmp = Array.sub (a, i)
in Array.update (a, i, Array.sub (a, j)) ;
Array.update (a, j, tmp) end
else ()
fun swap (v : 'a t) i j =
array_swap (!(#contents v)) i j
fun swap_out (v : 'a t) i =
(#size v) := array_swap_out (!(#contents v)) (!(#size v)) i (#dummy v)
fun array_rev a i j =
if i < j then
(array_swap a i (j - 1) ;
array_rev a (i + 1) (j - 1))
else ()
fun rev (v : 'a t) = array_rev (!(#contents v)) 0 (!(#size v))
fun array_iter c a i j =
if i <> j then
(c (Array.sub (a, i)) ;
array_iter c a (i + 1) j)
else ()
fun iter c (v : 'a t) = array_iter c (!(#contents v)) 0 (!(#size v))
fun array_iter_const c x a i j =
if i <> j then
(c x (Array.sub (a, i)) ;
array_iter_const c x a (i + 1) j)
else ()
fun iter_const c x (v : 'a t) = array_iter_const c x (!(#contents v)) 0 (!(#size v))
fun array_iteri c a i j =
if i <> j then
(c i (Array.sub (a, i)) ;
array_iteri c a (i + 1) j)
else ()
fun iteri c (v : 'a t) = array_iteri c (!(#contents v)) 0 (!(#size v))
fun array_iter_const c x a i j =
if i <> j then
(c x (Array.sub (a, i)) ;
array_iter_const c x a (i + 1) j)
else ()
fun iteri_const c x (v : 'a t) = array_iter_const c x (!(#contents v)) 0 (!(#size v))
fun modify f (v : 'a t) = ArraySlice.modify f (ArraySlice.slice (!(#contents v),
0, SOME (!(#size v))))
fun modify_const f x (v : 'a t) = modify (f x) v
fun fold_left f acc (v : 'a t) = ArraySlice.foldl (uncurry f) acc
(ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))))
fun fold_right f acc (v : 'a t) = ArraySlice.foldr (uncurry f) acc
(ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))))
fun uncurry3 f (x, y, z) = f x y z
fun foldi_left f acc (v : 'a t) = ArraySlice.foldli (uncurry3 f) acc
(ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))))
fun foldi_right f acc (v : 'a t) = ArraySlice.foldri (uncurry3 f) acc
(ArraySlice.slice
(!(#contents v), 0, SOME (!(#size v))))
fun first p (v : 'a t) = array_first p (!(#contents v)) 0 (!(#size v))
fun first_const p x (v : 'a t) = array_first_const p x (!(#contents v)) 0 (!(#size v))
fun find p (v : 'a t) = Array.sub (!(#contents v), array_first p (!(#contents v)) 0 (!(#size v)))
fun find_const p x (v : 'a t) = (!(#contents v), array_first_const p x (!(#contents v)) 0 (!(#size v)))
fun memb eq x (v : 'a t) =
ArraySlice.exists (eq x)
(ArraySlice.slice ((!(#contents v)),
0, SOME (!(#size v))))
fun exists p (v : 'a t) = ArraySlice.exists p
(ArraySlice.slice ((!(#contents v)),
0, SOME (!(#size v))))
fun exists_const p x (v : 'a t) = exists (p x) v
fun for_all p (v : 'a t) = ArraySlice.all p
(ArraySlice.slice ((!(#contents v)),
0, SOME (!(#size v))))
fun for_all_const p x (v : 'a t) = for_all (p x) v
fun fast_filter p (v : 'a t) =
(#size v) := array_filter p (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun fast_filter_const p x (v : 'a t) =
(#size v) := array_filter_const p x (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun filter p (v : 'a t) =
(#size v) := array_filter p (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun filter_const p x (v : 'a t) =
(#size v) := array_filter_const p x (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun filteri p (v : 'a t) =
(#size v) := array_filteri p (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun filteri_const p x (v : 'a t) =
(#size v) := array_filteri_const p x (!(#contents v)) 0 (!(#size v)) (#dummy v)
fun merge cmp b n m i a j k =
if n <> m then
if j = k orelse (fn res => res = LESS orelse res = EQUAL)
(cmp (Array.sub (b, n)) (Array.sub (a, j)))then
(Array.update (a, i, Array.sub (b, n)) ;
merge cmp b (n + 1) m (i + 1) a j k)
else
(Array.update (a, i, Array.sub (a, j)) ;
merge cmp b n m (i + 1) a (j + 1) k)
else ()
fun msort cmp a i j tmp =
if i + 1 < j then
let val mid = i + (j - i) div 2
in msort cmp a i mid tmp ;
msort cmp a mid j tmp ;
ArraySlice.copy {src = ArraySlice.slice (a, i, SOME (mid - i)),
dst = tmp,
di = 0} ;
merge cmp tmp 0 (mid - i) i a mid j end
else ()
fun merge_sort cmp a i j =
let val tmp = Array.array ((j - i) div 2, Array.sub (a, i))
in msort cmp a i j tmp end
val array_sort = merge_sort
val array_stable_sort = merge_sort
fun sort cmp (v : 'a t) = array_sort cmp (!(#contents v)) 0 (!(#size v))
fun stable_sort cmp (v : 'a t) = array_stable_sort cmp (!(#contents v)) 0 (!(#size v))
end
type 'a vec = 'a Vec.t
structure PriorityQueue =
struct
val invalid_position = ~1
fun int_add m n x =
if n < Vec.length m then
Vec.set m n x
else
Vec.grow m (n + 1) x
fun real_add m n x =
if n < Vec.length m then
Vec.set m n x
else Vec.grow m (n + 1) x
type t = {priority: real vec,
position: int vec,
heap: int vec}
fun queue_length (x : t) = Vec.length (#heap x)
val root_position = 0
fun parent_position i = (i - 1) div 2
fun left_child_position i = 2 * i + 1
fun right_sibling_position i = i + 1
fun ensure_map_memb q e =
if not (e < Vec.length (#priority q)) then
(
real_add (#priority q) e 0.0 ;
int_add (#position q) e invalid_position
) else ()
fun higher_priority q x y =
let val x_score : real = Vec.get (#priority q) x
val y_score : real = Vec.get (#priority q) y
in x_score > y_score orelse ((x_score - y_score) * (x_score - y_score) <= 0.0 andalso x > y) end
fun create e =
{priority = Vec.make 0 0.0 0.0,
position = Vec.make 0 invalid_position invalid_position,
heap = Vec.make 0 e e} : t
fun is_empty (q : t) = Vec.is_empty (#heap q)
fun memb (h : t) x =
0 <= x andalso x < Vec.length (#position h)
andalso Vec.get (#position h) x <> invalid_position
fun percolate_up (h : t) j =
if j > 0 then
let val je = Vec.get (#heap h) j
val i = parent_position j
val ie = Vec.get (#heap h) i
in if higher_priority h je ie then
(
Vec.set (#heap h) j ie ;
Vec.set (#heap h) i je ;
Vec.set (#position h) je i ;
Vec.set (#position h) ie j ;
percolate_up h i
)
else ()
end
else ()
fun percolate_down (h : t) i =
let val size = Vec.length (#heap h)
val l = left_child_position i
in if l < size then
let val r = right_sibling_position l
val j =
if r < size andalso
(higher_priority h)
(Vec.get (#heap h) r)
(Vec.get (#heap h) l)
then r
else l
val ie = Vec.get (#heap h) i
val je = Vec.get (#heap h) j
in if (higher_priority h) je ie then
(
Vec.set (#heap h) i je ;
Vec.set (#heap h) j ie ;
Vec.set (#position h) je i ;
Vec.set (#position h) ie j ;
percolate_down h j
)
else ()
end
else ()
end
fun enqueue (h : t) e =
(ensure_map_memb h e ;
if Vec.get (#position h) e = invalid_position
then let val i = Vec.length (#heap h)
in Vec.push (#heap h) e ;
Vec.set (#position h) e i ;
percolate_up h i end
else ())
fun head q =
if is_empty q then raise error "queue empty!"
else Vec.get (#heap q) root_position
fun drop q =
if is_empty q then raise error "queue empty!"
else
(
Vec.set (#position q) (head q) invalid_position ;
Vec.swap_out (#heap q) root_position ;
if not (is_empty q) then
let val e = head q
in Vec.set (#position q) e root_position ;
percolate_down q root_position end
else ()
)
fun dequeue q =
let val result = head q
in drop q ; result end
fun priority (q : t) e =
if e < Vec.length (#priority q) then
Vec.get (#priority q) e
else 0.0
fun inc_priority q e (x : real) =
(
ensure_map_memb q e ;
let val old = Vec.get (#priority q) e
in Vec.set (#priority q) e (old + x) ;
if memb q e then
percolate_up q (Vec.get (#position q) e)
else () end
)
fun scale (q : t) x = Vec.modify (fn y => x * y) (#priority q)
end
type priority_queue = PriorityQueue.t
val check_assertions : bool Unsynchronized.ref = Unsynchronized.ref true
fun assert condition err_str =
if !check_assertions then
if not (condition ()) then error err_str else ()
else ()
type clause = {literals : int array, activity : real Unsynchronized.ref}
datatype reason = Unassigned
| Decision
| Propagation of clause
fun new_clause lit_list =
{literals = Array.fromList lit_list, activity = Unsynchronized.ref 0.0}
fun array_to_list a = Array.foldr (op ::) [] a
val (max_float : real) = Real.maxFinite
val (min_float : real) = Real.minNormalPos
val (empty_queue : unit -> priority_queue) = (fn _ => PriorityQueue.create (~1))
val (queue_length : priority_queue -> int) = PriorityQueue.queue_length
val (empty_clause : clause) = new_clause []
val (invalid_level : int) = ~1
type solver =
{
clause : clause vec,
learned : clause vec,
trace : int vec,
assignment_array : bool option array Unsynchronized.ref,
reason_array : reason array Unsynchronized.ref,
level_array : int array Unsynchronized.ref,
conflict_detected : bool Unsynchronized.ref,
decision_level : int Unsynchronized.ref,
position : int array Unsynchronized.ref,
early_level_array : int array Unsynchronized.ref,
cls_increment : real Unsynchronized.ref,
early_reason_array : clause array Unsynchronized.ref,
scale : real Unsynchronized.ref,
conflict_level : int Unsynchronized.ref,
decision_queue : priority_queue Unsynchronized.ref,
lower_conflict_lits : int list Unsynchronized.ref,
next_inference_lit : int Unsynchronized.ref,
num_implication_points : int Unsynchronized.ref,
var_increment : real Unsynchronized.ref,
num_vars : int Unsynchronized.ref,
max_var : int Unsynchronized.ref,
shy_of_max : real Unsynchronized.ref,
conflict_vars : int list Unsynchronized.ref,
pending_literals : int Queue.T Unsynchronized.ref,
num_conflicts : int Unsynchronized.ref,
dispatch_lit : int Unsynchronized.ref,
next_inference_cls : int Unsynchronized.ref,
pending_clauses : int list Queue.T Unsynchronized.ref,
bias : (bool array) Unsynchronized.ref,
num_decisions : int Unsynchronized.ref,
watchers : clause list Unsynchronized.ref array Unsynchronized.ref,
num_inferences : int Unsynchronized.ref,
conflict_limit : int Unsynchronized.ref,
learneds_limit : int Unsynchronized.ref
}
fun empty_solver () : solver =
{
clause = Vec.make 0 empty_clause empty_clause,
learned = Vec.make 0 empty_clause empty_clause,
trace = Vec.of_list [] 0,
assignment_array = Unsynchronized.ref (Array.fromList []),
reason_array = Unsynchronized.ref (Array.fromList []),
level_array = Unsynchronized.ref (Array.fromList []),
conflict_detected = Unsynchronized.ref false,
decision_level = Unsynchronized.ref 0,
position = Unsynchronized.ref (Array.fromList []),
early_level_array = Unsynchronized.ref (Array.fromList []),
cls_increment = Unsynchronized.ref 1.0,
early_reason_array = Unsynchronized.ref (Array.fromList []),
scale = Unsynchronized.ref 1.05,
conflict_level = Unsynchronized.ref 0,
decision_queue = Unsynchronized.ref (empty_queue ()),
lower_conflict_lits = Unsynchronized.ref [],
next_inference_lit = Unsynchronized.ref 0,
num_implication_points = Unsynchronized.ref 0,
var_increment = Unsynchronized.ref 0.0,
num_vars = Unsynchronized.ref 0,
max_var = Unsynchronized.ref 0,
shy_of_max = Unsynchronized.ref (max_float / 1.05),
conflict_vars = Unsynchronized.ref [],
pending_literals = Unsynchronized.ref (Queue.empty),
num_conflicts = Unsynchronized.ref 0,
dispatch_lit = Unsynchronized.ref 0,
next_inference_cls = Unsynchronized.ref 0,
pending_clauses = Unsynchronized.ref (Queue.empty),
bias = Unsynchronized.ref (Array.fromList []),
num_decisions = Unsynchronized.ref 0,
watchers = Unsynchronized.ref (Array.fromList []),
num_inferences = Unsynchronized.ref 0,
conflict_limit = Unsynchronized.ref 100,
learneds_limit = Unsynchronized.ref 0
}
fun assignment (s : solver) lit =
let val var = abs lit
val var_val = Array.sub (!(#assignment_array s), var) in
if lit > 0 then var_val else Option.map not var_val
end
fun is_true (s : solver) lit = assignment s lit = SOME true
fun is_false (s : solver) lit =
let val result = assignment s lit = SOME false in
assert (fn _ => result = is_true s (~lit)) "" ;
result end
fun reason (s : solver) lit = Array.sub (!(#reason_array s), abs lit)
fun is_assigned (s : solver) lit =
let val result = reason s lit <> Unassigned in
assert (fn _ => result = is_true s lit orelse is_false s lit)
"is_assigned violates: result = is_true lit orelse is_false lit" ;
result
end
fun array_forall f a =
let val i = Unsynchronized.ref 0
val result = Unsynchronized.ref true in
while !i < Array.length a andalso !result do
(if f (Array.sub (a, !i)) then ()
else result := false ;
i := !i + 1) ;
!result end
fun array_mem x a =
let val i = Unsynchronized.ref 0
val result = Unsynchronized.ref false in
while !i < Array.length a andalso not (!result) do
(if x = Array.sub (a, !i) then result := true else () ;
i := !i + 1) ;
!result end
fun explanation_valid (s : solver) lit lits =
let fun others_false l = is_false s l orelse l = lit in
array_mem lit lits andalso array_forall others_false lits
end
fun reason_valid (s : solver) lit r =
case r
of Unassigned => false
| Decision => true
| Propagation cls => explanation_valid s lit (#literals cls)
fun assign_lit (s : solver) lit reason' =
(
assert (fn _ => not (!(#conflict_detected s)))
"assign_lit violates: not (!(#conflict_detected s))" ;
assert (fn _ => not (is_assigned s lit)) "assign_lit violates: not (is_assigned lit)" ;
assert (fn _ => reason_valid s lit reason') ("assign_lit violates: reason_valid lit reason'") ;
let val var = abs lit
val var_val = SOME (lit > 0) in
if reason' = Decision then (#decision_level s) := !(#decision_level s) + 1 else () ;
Array.update (!(#reason_array s), var, reason') ;
Array.update (!(#assignment_array s), var, var_val) ;
Array.update (!(#level_array s), var, !(#decision_level s)) ;
Array.update (!(#position s), var, Vec.length (#trace s)) ;
Vec.push (#trace s) lit
end
)
fun early_level (s : solver) lit = Array.sub (!(#early_level_array s), abs lit)
fun early_reason (s : solver) lit = Array.sub (!(#early_reason_array s), abs lit)
fun rescale_clause_activities (s : solver) =
let fun rescale c = #activity c := !(#activity c) * min_float in
Vec.iter rescale (#learned s) ;
(#cls_increment s) := !(#cls_increment s) * min_float
end
fun inc_clause_activity (s : solver) cls =
(
if !(#activity cls) > max_float - !(#cls_increment s) then
rescale_clause_activities s
else () ;
#activity cls := !(#activity cls) + !(#cls_increment s)
)
fun level (s : solver) lit = (Array.sub (!(#level_array s), abs lit))
fun unit_propagation_possible (s : solver) =
!(#next_inference_lit s) < Vec.length (#trace s)
fun invariant (s : solver) =
(assert (fn _ => !(#num_vars s) > 1) "invariant violates: !(#num_vars s) > 1" ;
assert (fn _ => !(#decision_level s) >= 0) "invariant violates: !(#decision_level s) >= 0" ;
assert (fn _ => Vec.length (#trace s) >= !(#decision_level s)) "invariant violates: Vec.length (#trace s) >= !(#decision_level s)" ;
assert (fn _ => Vec.length (#trace s) + queue_length (!(#decision_queue s)) >= !(#num_vars s))
"invariant violates: Vec.length (#trace s) + queue_length (!(#decision_queue s)) >= !(#num_vars s)" ;
assert (fn _ => Vec.length (#trace s) = 0 orelse is_true s (Vec.top (#trace s)))
"invariant violates: Vec.length (#trace s) = 0 orelse is_true (Vec.top (#trace s))" ;
assert (fn _ => !(#conflict_detected s) orelse Vec.length (#trace s) >= !(#next_inference_lit s))
"invariant violates: !(#conflict_detected s) orelse Vec.length (#trace s) >= !(#next_inference_lit s)" ;
assert (fn _ => !(#next_inference_lit s) + 1 >= !(#decision_level s))
"invariant violates: !(#next_inference_lit s) + 1 >= !(#decision_level s)" ;
assert (fn _ => !(#conflict_detected s) orelse not (unit_propagation_possible s) orelse
level s (Vec.get (#trace s) (!(#next_inference_lit s))) = !(#decision_level s))
("invariant violates: !(#conflict_detected s) orelse not (unit_propagation_possible ()) orelse\n" ^
"level (Vec.get (#trace s) (!(#next_inference_lit s))) = !(#decision_level s)") ;
assert (fn _ => !(#num_implication_points s) >= 0) "invariant violates: !(#num_implication_points s) >= 0" ;
assert (fn _ => !(#conflict_detected s) orelse !(#num_implication_points s) = 0)
"invariant violates: !(#conflict_detected s) orelse !(#num_implication_points s) = 0" ;
assert (fn _ => !(#conflict_detected s) orelse !(#lower_conflict_lits s) = [])
"invariant violates: !(#conflict_detected s) orelse !(#lower_conflict_lits s) = []" ;
assert (fn _ => not (!(#conflict_detected s)) orelse !(#conflict_level s) <= !(#decision_level s))
"invariant violates: not (!(#conflict_detected s)) orelse !(#conflict_level s) <= !(#decision_level s)" ;
assert (fn _ => not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse
!(#num_implication_points s) > 0)
("invariant violates: not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse \n" ^
"!(#num_implication_points s) > 0") ;
assert (fn _ => not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse
!(#decision_level s) = !(#conflict_level s))
("invariant violates: not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse\n" ^
"!(#decision_level s) = !(#conflict_level s)") ;
assert (fn _ => not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse
!(#num_implication_points s) = 1)
("invariant violates: not (!(#conflict_detected s)) orelse !(#conflict_level s) = 0 orelse" ^
"!(#num_implication_points s) = 1") ;
assert (fn _ => 0.0 < !(#cls_increment s) andalso !(#cls_increment s) <= !(#shy_of_max s))
("invariant violates: 0.0 < !(#cls_increment s) andalso !(#cls_increment s) <= !(#shy_of_max s)") ;
assert (fn _ => 0.0 < !(#var_increment s) andalso !(#var_increment s) <= !(#shy_of_max s))
"invariant violates: 0.0 < !(#var_increment s) andalso !(#var_increment s) <= !(#shy_of_max s)" ;
true
)
fun remove_early_implication (s : solver) lit =
(
assert (fn _ => not (!(#conflict_detected s)))
"remove_early_implication violates: not (!(#conflict_detected s))" ;
assert (fn _ => early_level s lit > !(#decision_level s))
"remove_early_implication violates: early_level lit > !(#decision_level s)" ;
let val var = abs lit in
Array.update (!(#early_reason_array s), var, empty_clause) ;
Array.update (!(#early_level_array s), var, invalid_level) ;
assert (fn _ => not (!(#conflict_detected s)))
"remove_early_implication violates: not (!(#conflict_detected s))"
end
)
fun rescale_decision_priorities (s : solver) =
(
PriorityQueue.scale (!(#decision_queue s)) min_float ;
(#var_increment s) := !(#var_increment s) * min_float
)
fun inc_decision_priority (s : solver) lit =
let val var = abs lit
val priority = PriorityQueue.priority ((!(#decision_queue s))) var in
if priority > max_float - !(#var_increment s) then
rescale_decision_priorities s
else () ;
PriorityQueue.inc_priority (!(#decision_queue s)) var (!(#var_increment s))
end
fun is_conflict_lit (s : solver) lit = member (op =) (!(#conflict_vars s)) (abs lit)
fun add_conflict_lit (s : solver) lit =
(
assert (fn _ => !(#conflict_detected s)) "add_conflict_lit violates: !(#conflict_detected s)" ;
let val lit_level = level s lit in
assert (fn _ => lit_level <= !(#conflict_level s))
"add_conflict_lit violates: lit_level <= !(#conflict_level s)" ;
if not (is_conflict_lit s lit) then
(
assert (fn _ => is_false s lit) "add_conflict_lit violates: is_false lit" ;
(#conflict_vars s) := insert (op =) (abs lit) (!(#conflict_vars s)) ;
inc_decision_priority s lit ;
if lit_level = !(#conflict_level s) then
((#num_implication_points s) := !(#num_implication_points s) + 1)
else (#lower_conflict_lits s) := lit :: !(#lower_conflict_lits s)
)
else () end
)
fun unsatisfiable (s : solver) = !(#conflict_detected s) andalso !(#conflict_level s) = 0
fun unassign_last_lit (s : solver) =
(assert (fn _ => not (Vec.is_empty (#trace s))) "unassign_last_lit violates: not (null (#trace s))" ;
let val lit = Vec.pop (#trace s)
val var = abs lit
val reason = Array.sub (!(#reason_array s), var)
in assert (fn _ => is_assigned s lit) "unassign_last_lit violates: is_assigned lit" ;
assert (fn _ => reason_valid s lit reason) "unassign_last_lit violates: reason_valid lit reason" ;
assert (fn _ => level s lit = !(#decision_level s)) "unassign_last_lit violates: level lit = !(#decision_level s)" ;
Array.update (!(#reason_array s), var, Unassigned) ;
Array.update (!(#assignment_array s), var, NONE) ;
PriorityQueue.enqueue (!(#decision_queue s)) var ;
if reason = Decision then (#decision_level s) := !(#decision_level s) - 1 else () ;
if early_level s lit <> invalid_level then
(
assert (fn _ => early_level s lit <= level s lit)
"unassign_last_lit violates: early_level lit <= level lit" ;
(#pending_literals s) := Queue.enqueue lit (!(#pending_literals s))
) else () end
)
fun is_inferred (s : solver) lit =
(assert (fn _ => is_true s lit) "is_inferred violates: is_true lit" ;
case reason s lit
of Propagation _ => true
| _ => false)
fun unit_propagating_clause (s : solver) lit =
case reason s lit
of Propagation cls =>
(assert (fn _ => explanation_valid s lit (#literals cls))
"unit_propagating_clause violates: explanation_valid lit (#literals cls)" ; cls)
| _ => raise error "unit_propagating_clause violates: false"
fun del_conflict_lit (s : solver) lit =
(assert (fn _ => !(#conflict_detected s)) "del_conflict_lit violates: !(#conflict_detected s)" ;
assert (fn _ => is_conflict_lit s lit) "del_conflict_lit violates: is_conflict_lit lit" ;
(#conflict_vars s) := filter (fn x => x <> abs lit) (!(#conflict_vars s)))
fun set_explanation (s : solver) lits =
let val lit = Vec.top (#trace s)
in Array.app (add_conflict_lit s) lits ;
del_conflict_lit s lit ;
(#num_implication_points s) := !(#num_implication_points s) - 1 ;
unassign_last_lit s ;
assert (fn _ => !(#conflict_detected s)) "set_explanation violates: !(#conflict_detected s)"
end
fun explain_inferences (s : solver) =
(assert (fn _ => !(#conflict_detected s))
"explain_inferences violates: (#conflict_detected s) ()" ;
assert (fn _ => not (unsatisfiable s))
"explain_inferences violates: not (unsatisfiable ())" ;
while not (is_conflict_lit s (Vec.top (#trace s))) do
unassign_last_lit s ;
if !(#num_implication_points s) > 1 andalso is_inferred s (Vec.top (#trace s)) then
let val cls = unit_propagating_clause s (Vec.top (#trace s)) in
inc_clause_activity s cls ;
set_explanation s (#literals cls) ;
explain_inferences s
end
else ()
)
fun set_conflict
(s : solver) iter fold lits =
(
(#num_conflicts s) := !(#num_conflicts s) + 1 ;
(#conflict_detected s) := true ;
(#conflict_level s) := fold (fn (l, acc) => Int.max (acc, (level s l))) 0 lits ;
iter (add_conflict_lit s) lits ;
if !(#conflict_level s) <> 0 then explain_inferences s else ()
)
fun set_conflict_lits (s : solver) lits =
(
assert (fn _ => not (!(#conflict_detected s)))
"set_conflict_lits violates: not (!(#conflict_detected s))" ;
assert (fn _ => array_forall (is_false s) lits)
"set_conflict_lits violates: array_forall is_false lits" ;
set_conflict s Array.app Array.foldl lits ;
assert (fn _ => !(#conflict_detected s)) "set_conflict_lits violates: !(#conflict_detected s)"
)
fun assign_early_implication (s : solver) lit =
(
assert (fn _ => not (!(#conflict_detected s)))
"assign_early_implication violates: not (!(#conflict_detected s))" ;
let val level = early_level s lit
val reason = early_reason s lit
in
if level <= !(#decision_level s) then
(
assert (fn _ => explanation_valid s lit (#literals reason)) ;
if is_false s lit then
(
set_conflict_lits s (#literals reason) ;
inc_clause_activity s reason
)
else if not (is_assigned s lit) then
(assign_lit s lit (Propagation reason)
)
else ()
)
else
remove_early_implication s lit ;
assert (fn _ => invariant s) "assign_early_implication violates: invariant ()"
end
)
fun backjump_possible (s : solver) =
!(#conflict_detected s) andalso
!(#conflict_level s) > 0 andalso
!(#num_implication_points s) = 1
fun backtrack (s : solver) lev =
(
assert (fn _ => !(#decision_level s) > 0) "backtrack violates: !(#decision_level s) > 0" ;
assert (fn _ => !(#decision_level s) > lev) "backtrack violates: !(#decision_level s) > lev" ;
while !(#decision_level s) > lev do
(
unassign_last_lit s
) ;
let val n = Vec.length (#trace s) in
assert (fn _ => !(#next_inference_lit s) >= n) "backtrack violates: !(#next_inference_lit s) >= n" ;
(#next_inference_lit s) := n ;
(#next_inference_cls s) := 0 ;
(#dispatch_lit s) := Int.min (!(#dispatch_lit s), n)
end
)
fun clear_conflict (s : solver) =
(
assert (fn _ => backjump_possible s) "clear_conflict violates: backjump_possible s" ;
assert (fn _ => !(#decision_level s) = !(#conflict_level s)) "clear_conflict violates: !(#decision_level s) = !(#conflict_level s)" ;
del_conflict_lit s (Vec.top (#trace s)) ;
(#num_implication_points s) := 0 ;
List.app (del_conflict_lit s) (!(#lower_conflict_lits s)) ;
(#lower_conflict_lits s) := [] ;
(#conflict_detected s) := false
)
fun inference_possible (s : solver) =
unit_propagation_possible s orelse
not (Queue.is_empty (!(#pending_literals s))) orelse
not (Queue.is_empty (!(#pending_clauses s)))
fun is_empty xs = null xs
fun is_implication_point (s : solver) lit =
is_conflict_lit s lit andalso level s lit = !(#conflict_level s)
fun literal_map_find (s : solver) arr lit = Array.sub (!arr, lit + !(#max_var s) + 1)
fun remove_known_false (s : solver) lits =
let fun keep lit = (not (level s lit = 0 andalso is_false s lit))
in filter keep lits end
fun array_sort cmp a =
Vec.array_sort cmp a 0 (Array.length a)
fun is_not_false (s : solver) lit =
let val result = assignment s lit <> SOME false
in result end
fun watch_valid (s : solver) cls =
let fun earliest_level i =
if early_level s i <> invalid_level then
(assert (fn _ => early_level s i <= level s i)
"watch_valid violates: early_level i <> invalid_level" ;
early_level s i)
else level s i
fun ok i j =
let val lit = Array.sub (#literals cls, i)
val other_lit = Array.sub (#literals cls, j)
in is_not_false s lit orelse
(is_true s other_lit andalso earliest_level other_lit <= level s lit)
end
in ok 0 1 andalso ok 1 0 end
fun backjump (s : solver) =
(
assert (fn _ => backjump_possible s) "backjump violates: backjump_possible s" ;
assert (fn _ => !(#decision_level s) = !(#conflict_level s))
"backjump violates: !(#decision_level s) = !(#conflict_level s)" ;
let val lit = Vec.top (#trace s) in
assert (fn _ => is_implication_point s lit)
"backjump violates: is_implication_point lit" ;
let
val learn_lits = ~ lit :: !(#lower_conflict_lits s)
val norm_learn_lits = remove_known_false s learn_lits
val cls = new_clause norm_learn_lits
val lits = #literals cls
val n = Array.length lits
in
if n >= 2 then
let fun lit_level_compare l0 l1 =
if level s l1 < level s l0 then LESS
else if level s l1 = level s l0 then EQUAL
else GREATER in
array_sort lit_level_compare lits ;
assert (fn _ => Array.sub (lits, 0) = ~ lit)
"backjump violates: Array.sub (lits, 0) = ~ lit" ;
assert (fn _ => level s (Array.sub (lits, 0)) > level s (Array.sub (lits, 1)))
("backjump violates: level (Array.sub (lits, 0)) > " ^
"level (Array.sub (lits, 1))") ;
Vec.push (#learned s) cls ;
let
val watcher0 = (literal_map_find s) (#watchers s)
(Array.sub (#literals cls, 0))
val watcher1 = (literal_map_find s) (#watchers s)
(Array.sub (#literals cls, 1))
in
watcher0 := cls :: !watcher0 ;
watcher1 := cls :: !watcher1
end ;
clear_conflict s ;
backtrack s (level s (Array.sub (lits, 1)))
end
else
(
assert (fn _ => Array.length lits = 1)
"backjump violates: Array.length lits = 1" ;
clear_conflict s ;
backtrack s 0
) ;
assert (fn _ => Array.sub (lits, 0) = ~ lit)
"backjump violates: Array.sub (lits, 0) = ~ lit" ;
assert (fn _ => not (is_assigned s (Array.sub (lits, 0)))) ;
while not (Queue.is_empty (!(#pending_literals s))) do
let val (pending_literal,
pending_literals) = Queue.dequeue (!(#pending_literals s)) in
(#pending_literals s) := pending_literals ;
assign_early_implication s (pending_literal)
end ;
assert (fn _ => not (is_assigned s (Array.sub (lits, 0))) orelse
early_level s (Array.sub (lits, 0)) <> invalid_level)
("backjump violates: not (is_assigned (Array.sub (lits, 0))) " ^
"orelse early_level (Array.sub (lits, 0)) <> invalid_level") ;
assert (fn _ => not (is_true s (Array.sub (lits, 0))))
"backjump violates: not (is_true (Array.sub (lits, 0)))" ;
if is_false s (Array.sub (lits, 0)) then
(set_conflict_lits s (#literals cls) ;
assert (fn _ => !(#conflict_detected s)) "backjump violates: !(#conflict_detected s)")
else
(
assign_lit s (Array.sub (lits, 0)) (Propagation cls) ;
assert (fn _ => n = 1 orelse watch_valid s cls)
"backjump violates: n = 1 orelse watch_valid cls" ;
assert (fn _ => !(#decision_level s) < !(#conflict_level s))
"backjump violates: (#decision_level s) () < (#conflict_level s) ()" ;
assert (fn _ => not (!(#conflict_detected s)))
"backjump violates: not ((#conflict_detected s) ())" ;
assert (fn _ => inference_possible s)
"backjump violates: inference_possible ()"
) ;
if !(#var_increment s) > !(#shy_of_max s) then rescale_decision_priorities s else () ;
(#var_increment s) := !(#var_increment s) * !(#scale s) ;
if !(#cls_increment s) > !(#shy_of_max s) then rescale_clause_activities s else () ;
(#cls_increment s) := !(#cls_increment s) * !(#scale s) ;
assert (fn _ => invariant s) "backjump violates: invariant ()"
end
end
)
fun is_assigned_var (s : solver) var = Array.sub (!(#assignment_array s), var) <> NONE
fun decision_possible (s : solver) = Vec.length (#trace s) < !(#num_vars s)
fun decide (s : solver) =
(assert (fn _ => not (!(#conflict_detected s))) "decide violates: not (!(#conflict_detected s))" ;
assert (fn _ => not (inference_possible s)) "decide violates: not (inference_possible ())" ;
assert (fn _ => decision_possible s) "decide violates: decision_possible ()" ;
while is_assigned_var s (PriorityQueue.head ((!(#decision_queue s)))) do
PriorityQueue.drop (!(#decision_queue s)) ;
assert (fn _ => not (PriorityQueue.is_empty (!(#decision_queue s))))
"decide violates: not (PriorityQueue.is_empty (#decision_queue s))" ;
let val var = PriorityQueue.dequeue (!(#decision_queue s))
val lit =
if Array.sub (!(#bias s), var) then var
else ~ var
in (#num_decisions s) := !(#num_decisions s) + 1 ;
assign_lit s lit Decision ;
assert (fn _ => inference_possible s) "decide violates: inference_possible ()" ;
assert (fn _ => invariant s) "decide violates: invariant ()" ;
assert (fn _ => !(#decision_level s) > 0) "decide violates: !(#decision_level s) > 0" ;
assert (fn _ => not (!(#conflict_detected s))) "decide violates: not (!(#conflict_detected s))"
end)
fun cls_compare cls0 cls1 =
let val len0 = Array.length (#literals cls0)
val len1 = Array.length (#literals cls1)
in assert (fn _ => len0 >= 2) "cls_compare violates: len0 >= 2" ;
assert (fn _ => len1 >= 2) "cls_compare violates: len1 >= 2" ;
if len0 = 2 then (if len1 = 2 then EQUAL else LESS)
else if len1 = 2 then GREATER
else let val n = if !(#activity cls0) < !(#activity cls1) then LESS
else if !(#activity cls0) > !(#activity cls1) then GREATER
else EQUAL
in if n = EQUAL then
(if len0 < len1 then LESS
else if len0 > len1 then GREATER
else EQUAL)
else n end end
fun filteri f xs =
let fun fltr [] i = []
| fltr (y :: ys) i = (if f i y then [y] else []) @ fltr ys (i + 1)
in fltr xs 0 end
fun forget (s : solver) cls =
let val w0 = (literal_map_find s) (#watchers s) (Array.sub (#literals cls, 0))
val w1 = (literal_map_find s) (#watchers s) (Array.sub (#literals cls, 1))
in w0 := filter (fn x => x <> cls) (!w0) ;
w1 := filter (fn x => x <> cls) (!w1) end
fun forget_some_learned_clauses (s : solver) =
(assert (fn _ => not (!(#conflict_detected s)))
"forget_some_learned_clauses violates: not ((#conflict_detected s) ())" ;
Vec.sort cls_compare (#learned s) ;
let val half = Vec.length (#learned s) div 2
fun keep i cls =
i < half orelse Array.length (#literals cls) = 2
orelse (forget s cls ; false)
in Vec.filteri keep (#learned s) ;
assert (fn _ => invariant s)
"forget_some_learned_clauses violates: invariant ()" ;
assert (fn _ => not (!(#conflict_detected s)))
"forget_some_learned_clauses violates: not (!(#conflict_detected s))"
end)
fun rev_rem_dup comp x acc =
fn [] => acc
| (h::t) =>
if comp h x = EQUAL then rev_rem_dup comp x acc t
else rev_rem_dup comp h (h::acc) t
fun flip f x y = f y x
fun remove_sorted_duplicates comp l =
case sort (uncurry (flip comp)) l
of [] => l
| h :: t => rev_rem_dup comp h [h] t
fun lit_compare (s : solver) lit1 lit2 =
let val a1 = if is_true s lit1 then SOME true else if is_false s lit1 then SOME false else NONE
val a2 = if is_true s lit2 then SOME true else if is_false s lit2 then SOME false else NONE
in if SOME true = a1 then
if SOME true = a2 then
let val level1 = level s lit1
val level2 = level s lit2
in if level1 = level2 then
int_ord (lit1, lit2)
else int_ord (level1, level2)
end
else LESS
else if SOME true = a2 then
GREATER
else if NONE = a1 then
if NONE = a2 then
int_ord (lit1, lit2)
else LESS
else if NONE = a2 then
GREATER
else let val level1 = level s lit1
val level2 = level s lit2
in if level1 = level2 then
int_ord (lit1, lit2)
else int_ord (level2, level1)
end
end
fun level_valid (s : solver) lit =
not (is_assigned s lit) orelse
0 <= level s lit andalso level s lit <= !(#decision_level s)
fun register_early_implication (s : solver) lit lev cls =
(assert (fn _ => not (!(#conflict_detected s))) "register_early_implication violates: not (!(#conflict_detected s))" ;
assert (fn _ => is_true s lit) "register_early_implication violates: is_true lit" ;
assert (fn _ => lev < level s lit) "register_early_implication violates: lev < level lit" ;
assert (fn _ => explanation_valid s lit (#literals cls))
"register_early_implication violates: explanation_valid lit (#literals cls)" ;
let val old_early_level = early_level s lit
in if old_early_level = invalid_level orelse lev < old_early_level then
let val var = abs lit
in Array.update (!(#early_reason_array s), var, cls) ;
Array.update (!(#early_level_array s), var, lev) end
else () ;
assert (fn _ => not (!(#conflict_detected s))) "register_early_implication violates: not (!(#conflict_detected s))" ;
assert (fn _ => is_true s lit) "register_early_implication violates: is_true lit"
end)
fun add_clause (s : solver) lit_list =
(if not (!(#conflict_detected s)) then
(let val norm_lits = remove_known_false s
(remove_sorted_duplicates
(lit_compare s) lit_list)
val cls = new_clause norm_lits
val lits = #literals cls
val len = Array.length lits
in Array.app (inc_decision_priority s) lits ;
case len
of 0 => set_conflict_lits s (#literals empty_clause)
| 1 => let val lit = Array.sub (lits, 0)
in assert (fn _ => is_not_false s lit orelse !(#decision_level s) > 0)
"add_clause violates: is_not_false lit orelse !(#decision_level s) > 0" ;
assert (fn _ => level_valid s lit) "add_clause violates: level_valid lit" ;
if is_false s lit then
(set_conflict_lits s (#literals cls) ;
inc_clause_activity s cls)
else
(
if not (is_assigned s lit) then
(assign_lit s lit (Propagation cls) ;
assert (fn _ => is_true s lit) "add_clause violates: is_true lit")
else () ;
if level s lit > 0 then
register_early_implication s lit 0 cls
else ()
) end
| _ =>
(assert (fn _ => len >= 2) "add_clause violates: len >= 2" ;
let val lit0 = Array.sub (lits, 0)
val lit1 = Array.sub (lits, 1)
in assert (fn _ => is_not_false s lit0 orelse is_false s lit1)
"add_clause violates: is_not_false lit0 orelse is_false lit1" ;
assert (fn _ => level_valid s lit0) "add_clause violates: level_valid lit0" ;
assert (fn _ => level_valid s lit1) "add_clause violates: level_valid lit1" ;
Vec.push (#clause s) cls ;
Array.sub (!(#watchers s), lit0 + !(#max_var s) + 1) := cls :: !(Array.sub (!(#watchers s), lit0 + !(#max_var s) + 1)) ;
Array.sub (!(#watchers s), lit1 + !(#max_var s) + 1) := cls :: !(Array.sub (!(#watchers s), lit1 + !(#max_var s) + 1)) ;
if is_false s lit1 then
if is_false s lit0 then
(set_conflict_lits s (#literals cls) ;
inc_clause_activity s cls)
else
(if not (is_assigned s lit0) then
(assign_lit s lit0 (Propagation cls) ;
assert (fn _ => is_true s lit0) "add_clause violates: is_true lit0" ;
let val lev0 = level s lit0
val lev1 = level s lit1
in if lev0 > lev1 then
register_early_implication s lit0 lev1 cls
else () end)
else ())
else () ;
assert (fn _ => !(#conflict_detected s) orelse watch_valid s cls)
"add_clause violates: !(#conflict_detected s) orelse watch_valid cls"
end
)
end)
else
(assert (fn _ => !(#conflict_detected s)) "add_clause violates: !(#conflict_detected s)" ;
(#pending_clauses s) := Queue.enqueue lit_list (!(#pending_clauses s))
) ;
assert (fn _ => invariant s) "add_clause violates: invariant ()")
fun unit_propagate_or_conflict (s : solver) lit watcher =
(assert (fn _ => not (!(#conflict_detected s)))
"unit_propagate_or_conflict violates: not (!(#conflict_detected s))" ;
assert (fn _ => inference_possible s)
"unit_propagate_or_conflict violates: inference_possible ()" ;
assert (fn _ => is_false s lit) "unit_propagate_or_conflict violates: is_false lit" ;
if !(#next_inference_cls s) < length (!watcher) then
(let val cls = nth (!watcher) (!(#next_inference_cls s))
val lits = (#literals cls)
in assert (fn _ => Array.length lits >= 2)
"unit_propagate_or_conflict violates: Array.length lits >= 2" ;
assert (fn _ => Array.sub (lits, 0) = lit orelse
Array.sub (lits, 1) = lit)
("unit_propagate_or_conflict violates: Array.sub (lits, 0) = lit orelse " ^
"Array.sub (lits, 1) = lit") ;
let val (lit_idx, other_idx) =
if Array.sub (lits, 0) = lit then (0, 1) else (1, 0)
val other_lit = Array.sub (lits, other_idx)
val other_assign = assignment s other_lit
in if other_assign = SOME true then
((#next_inference_cls s) := !(#next_inference_cls s) + 1 ;
unit_propagate_or_conflict s lit watcher)
else (assert (fn _ => other_assign <> SOME true)
"unit_propagate_or_conflict violates: other_assign <> SOME true" ;
let val len = Array.length (lits)
val i = Array.foldli (fn (i, j, ~1) =>
if is_false s j then ~1
else if i >= 2 then i else ~1
| (_, _, x) => x)
~1 (lits)
val i = if i >= 0 then i else len
in if i <> len then
(Array.update (lits, lit_idx, Array.sub (lits, i)) ;
Array.update (lits, i, lit) ;
watcher := List.take (!watcher, !(#next_inference_cls s)) @
List.drop (!watcher, !(#next_inference_cls s) + 1) ;
(literal_map_find s) (#watchers s) (Array.sub (lits, lit_idx)) :=
cls :: !((literal_map_find s) (#watchers s) (Array.sub (lits, lit_idx))) ;
unit_propagate_or_conflict s lit watcher
)
else if other_assign = NONE then
(assign_lit s (Array.sub (lits, other_idx)) (Propagation cls) ;
(#num_inferences s) := !(#num_inferences s) + 1 ;
(#next_inference_cls s) := !(#next_inference_cls s) + 1 ;
unit_propagate_or_conflict s lit watcher)
else
(set_conflict_lits s (#literals cls) ;
inc_clause_activity s cls)
end)
end end)
else
(assert (fn _ => !(#next_inference_cls s) = length (!watcher))
"unit_propagate_or_conflict violates: !(#next_inference_cls s) = length (!watcher)" ;
(#next_inference_lit s) := !(#next_inference_lit s) + 1 ;
(#next_inference_cls s) := 0 ;
if !(#next_inference_lit s) < Vec.length (#trace s) then
let val lit = ~ (Vec.get (#trace s) (!(#next_inference_lit s)))
val watcher = (literal_map_find s) (#watchers s) lit
in unit_propagate_or_conflict s lit watcher end
else ())
)
fun infer (s : solver) =
(assert (fn _ => not (!(#conflict_detected s))) "infer violates: not ((#conflict_detected s) ())" ;
assert (fn _ => inference_possible s) "infer violates: inference_possible ()" ;
if not (Queue.is_empty (!(#pending_literals s))) then
assign_early_implication s (let val (pending_literal,
pending_literals) = Queue.dequeue (!(#pending_literals s))
in ((#pending_literals s) := pending_literals ; pending_literal) end)
else if not (Queue.is_empty (!(#pending_clauses s))) then
(add_clause s (let val (pending_clause,
pending_clauses) = Queue.dequeue (!(#pending_clauses s))
in ((#pending_clauses s) := pending_clauses ; pending_clause) end))
else
(
assert (fn _ => !(#next_inference_lit s) < Vec.length (#trace s)) ;
let val lit = ~ (Vec.get (#trace s) (!(#next_inference_lit s)))
val watcher = (literal_map_find s) (#watchers s) lit
in unit_propagate_or_conflict s lit watcher end
) ;
assert (fn _ => invariant s) "infer violates: invariant ()")
fun num_learneds (s : solver) = Vec.length (#learned s)
fun satisfied (s : solver) =
not (!(#conflict_detected s) orelse decision_possible s
orelse inference_possible s)
fun solved (s : solver) =
satisfied s orelse unsatisfiable s
fun search (s : solver) =
(
while not (solved s) andalso !(#num_conflicts s) < !(#conflict_limit s) do
(
while (not (!(#conflict_detected s)) andalso inference_possible s) do
(
infer s
) ;
if !(#conflict_detected s) then
if not (unsatisfiable s) then
(
backjump s ;
if num_learneds s > !(#learneds_limit s) then
forget_some_learned_clauses s
else ()
)
else ()
else if decision_possible s then
decide s
else
assert (fn _ => satisfied s) "search violates: satisfied ()"
) ;
assert (fn _ => not (!(#conflict_detected s)) orelse unsatisfiable s)
"search violates: not (!(#conflict_detected s)) orelse unsatisfiable ()" ;
assert (fn _ => invariant s) "search violates: invariant ()"
)
fun num_clauses s = Vec.length (#clause s)
fun simplify s =
let fun keep cls =
Array.all (not o (is_true s)) (#literals cls)
orelse (forget s cls; false)
in
Vec.filter keep (#clause s) ;
Vec.filter keep (#learned s)
end
fun restart s =
(assert (fn _ => not (!(#conflict_detected s)))
"restart violates: not (conflict_detected s)" ;
if !(#decision_level s) > 0 then backtrack s 0 else () ;
simplify s ;
assert (fn _ => invariant s) "restart violates: invariant s" ;
assert (fn _ => not (!(#conflict_detected s))) "restart violates: not (conflict_detected s)";
assert (fn _ => !(#decision_level s) = 0) "restart violates: !(#decision_level s) = 0")
fun solve s =
(#learneds_limit s := Int.max (1000, (num_clauses s div 3)) ;
while not (solved s) do
(restart s ;
search s ;
#conflict_limit s := !(#conflict_limit s) * 2 ;
#learneds_limit s := !(#learneds_limit s) * 12 div 10))
val invalid_position = ~1
fun initialize_arrays s max_vars_plus_one =
((#assignment_array s) := Array.array (max_vars_plus_one, NONE) ;
(#level_array s) := Array.array (max_vars_plus_one, invalid_level) ;
(#position s) := Array.array (max_vars_plus_one, invalid_position) ;
(#reason_array s) := Array.array (max_vars_plus_one, Unassigned) ;
(#early_reason_array s) := Array.array (max_vars_plus_one, empty_clause) ;
(#early_level_array s) := Array.array (max_vars_plus_one, invalid_level) ;
(#bias s) := Array.array (max_vars_plus_one, true) ;
(#watchers s) := Array.array (2 * max_vars_plus_one, Unsynchronized.ref []) ;
Array.modify (fn _ => Unsynchronized.ref []) (!(#watchers s)))
fun init s clauses =
(let val vars = sort_distinct int_ord (map abs (flat clauses))
in (#decision_level s) := 0 ;
(#decision_queue s) := empty_queue () ;
List.app (PriorityQueue.enqueue (!(#decision_queue s))) vars ;
(#conflict_detected s) := false ;
(#conflict_level s) := 0 ;
(#conflict_vars s) := [] ;
(#num_implication_points s) := 0 ;
(#lower_conflict_lits s) := [] ;
(#next_inference_lit s) := 0 ;
(#next_inference_cls s) := 0 ;
(#pending_clauses s) := Queue.empty ;
(#pending_literals s) := Queue.empty ;
(#dispatch_lit s) := 0 ;
(#var_increment s) := 1.0 ;
(#cls_increment s) := 1.0 ;
(#num_decisions s) := 0 ;
(#num_conflicts s) := 0 ;
(#num_inferences s) := 0 ;
(#max_var s) := fold (curry Int.max) vars 0 ;
initialize_arrays s (!(#max_var s) + 1) ;
(#num_vars s) := length vars ;
assert (fn _ => invariant s) "init violates at 1: invariant ()" ;
List.app (add_clause s) clauses ;
assert (fn _ => invariant s) "init violates at 2: invariant ()"
end
)
end;
let
open SAT_Solver
open Prop_Logic
fun mk_clauses f =
let fun mk_clause g =
let fun mk_lit (BoolVar i) = i
| mk_lit (Not (BoolVar i)) = ~i
| mk_lit _ = error "mk_clauses: formula is not in CNF!"
in case g
of Or (x, y) => mk_clause x @ mk_clause y
| True => [1, ~1]
| False => []
| _ => [mk_lit g] end
in case f
of And (x, y) => mk_clauses x @ mk_clauses y
| _ => [mk_clause f] end
fun dptsat fm =
let val fm = Prop_Logic.defcnf fm
val s = DPT_SAT_Solver.empty_solver ()
val clauses = [1, ~1, 2, ~2] :: mk_clauses fm
in (DPT_SAT_Solver.init s (clauses) ;
DPT_SAT_Solver.solve s ;
if DPT_SAT_Solver.satisfied s then
SATISFIABLE (DPT_SAT_Solver.assignment s)
else UNSATISFIABLE NONE) end
in
SAT_Solver.add_solver ("dptsat", dptsat)
end