Theory SatSolverCode

(*    Title:              SatSolverVerification/SatSolverCode.thy
      Author:             Filip Maric
      Maintainer:         Filip Maric <filip at matf.bg.ac.yu>
*)

section‹Functional implementation of a SAT solver with Two Watch literal propagation.›
theory SatSolverCode
imports SatSolverVerification "HOL-Library.Code_Target_Numeral"
begin

(******************************************************************************)
subsection‹Specification›
(******************************************************************************)

lemma [code_unfold]:
  fixes literal :: Literal and clause :: Clause
  shows "literal el clause = List.member clause literal"
  by (auto simp add: member_def)

datatype ExtendedBool = TRUE | FALSE | UNDEF

record State = 
  ― ‹Satisfiability flag: UNDEF, TRUE or FALSE›
"getSATFlag" :: ExtendedBool
  ― ‹Formula› 
"getF"       :: Formula      
  ― ‹Assertion Trail›
"getM"       :: LiteralTrail 
  ― ‹Conflict flag›
"getConflictFlag"   :: bool   ― ‹raised iff M falsifies F›
  ― ‹Conflict clause index› 
"getConflictClause" :: nat    ― ‹corresponding clause from F is false in M›
  ― ‹Unit propagation queue›
"getQ" :: "Literal list"      
  ― ‹Unit propagation graph›
"getReason" :: "Literal  nat option" ― ‹index of a clause that is a reason for propagation of a literal›
  ― ‹Two-watch literal scheme›
  ― ‹clause indices instead of clauses are used›
"getWatch1" :: "nat  Literal option"  ― ‹First watch of a clause›
"getWatch2" :: "nat  Literal option"  ― ‹Second watch of a clause›
"getWatchList" :: "Literal  nat list" ― ‹Watch list of a given literal›
  ― ‹Conflict analysis data structures›
"getC"   :: Clause             ― ‹Conflict analysis clause - always false in M›
"getCl"  :: Literal            ― ‹Last asserted literal in (opposite getC)›
"getCll" :: Literal            ― ‹Second last asserted literal in (opposite getC)›
"getCn"  :: nat                ― ‹Number of literals of (opposite getC) on the (currentLevel M)›

definition
setWatch1 :: "nat  Literal  State  State"
where
"setWatch1 clause literal state =
    state getWatch1 := (getWatch1 state)(clause := Some literal), 
           getWatchList := (getWatchList state)(literal := clause # (getWatchList state literal)) 
         
"
declare setWatch1_def[code_unfold]

definition
setWatch2 :: "nat  Literal  State  State"
where
"setWatch2 clause literal state =
    state getWatch2 := (getWatch2 state)(clause := Some literal),
           getWatchList := (getWatchList state)(literal := clause # (getWatchList state literal)) 
         
"
declare setWatch2_def[code_unfold]


definition
swapWatches :: "nat  State  State"
where
"swapWatches clause state ==
    state getWatch1 := (getWatch1 state)(clause := (getWatch2 state clause)),
           getWatch2 := (getWatch2 state)(clause := (getWatch1 state clause))
         
"
declare swapWatches_def[code_unfold]

primrec getNonWatchedUnfalsifiedLiteral :: "Clause  Literal  Literal  LiteralTrail  Literal option"
where
"getNonWatchedUnfalsifiedLiteral [] w1 w2 M = None" |
"getNonWatchedUnfalsifiedLiteral (literal # clause) w1 w2 M = 
    (if literal  w1  
        literal  w2  
        ¬ (literalFalse literal (elements M)) then
            Some literal
     else
            getNonWatchedUnfalsifiedLiteral clause w1 w2 M
    )
"

definition
setReason :: "Literal  nat  State  State"
where
"setReason literal clause state = 
    state getReason := (getReason state)(literal := Some clause) 
"
declare setReason_def[code_unfold]

primrec notifyWatches_loop::"Literal  nat list  nat list  State  State"
where
"notifyWatches_loop literal [] newWl state = state getWatchList := (getWatchList state)(literal := newWl) " |
"notifyWatches_loop literal (clause # list') newWl state = 
    (let state' = (if Some literal = (getWatch1 state clause) then 
                       (swapWatches clause state) 
                   else 
                       state) in
    case (getWatch1 state' clause) of 
        None  state
    |   Some w1  (
    case (getWatch2 state' clause) of 
        None  state
    |   Some w2  
    (if (literalTrue w1 (elements (getM state'))) then
        notifyWatches_loop literal list' (clause # newWl) state'
     else
        (case (getNonWatchedUnfalsifiedLiteral (nth (getF state') clause) w1 w2 (getM state')) of 
            Some l'  
                notifyWatches_loop literal list' newWl (setWatch2 clause l' state')
          | None  
                (if (literalFalse w1 (elements (getM state'))) then
                    let state'' = (state' getConflictFlag := True, getConflictClause := clause ) in
                    notifyWatches_loop literal list' (clause # newWl) state''
                else
                    let state'' = state' getQ := (if w1 el (getQ state') then 
                                                      (getQ state') 
                                                  else 
                                                      (getQ state') @ [w1] 
                                                  )
                                         in
                   let state''' = (setReason w1 clause state'') in
                   notifyWatches_loop literal list' (clause # newWl) state'''
                )
        )
    )
    )
    )
"

definition
notifyWatches :: "Literal  State  State"
where
"notifyWatches literal state ==
    notifyWatches_loop literal (getWatchList state literal) [] state
"
declare notifyWatches_def[code_unfold]


definition
assertLiteral :: "Literal  bool  State  State"
where
"assertLiteral literal decision state ==
    let state' = (state getM := (getM state) @ [(literal, decision)] ) in
    notifyWatches (opposite literal) state'
"


definition
applyUnitPropagate :: "State  State"
where
"applyUnitPropagate state =
    (let state' = (assertLiteral (hd (getQ state)) False state) in
    state' getQ := tl (getQ state'))
"

partial_function (tailrec)
exhaustiveUnitPropagate :: "State  State"
where
exhaustiveUnitPropagate_unfold[code]:
"exhaustiveUnitPropagate state =
    (if (getConflictFlag state)  (getQ state) = [] then 
        state 
    else 
        exhaustiveUnitPropagate (applyUnitPropagate state)
    )
"

inductive
exhaustiveUnitPropagate_dom :: "State  bool"
where
step: "(¬ getConflictFlag state  getQ state  []
    exhaustiveUnitPropagate_dom (applyUnitPropagate state))
    exhaustiveUnitPropagate_dom state"


definition
addClause :: "Clause  State  State"
where
"addClause clause state =
    (let clause' = (remdups (removeFalseLiterals clause (elements (getM state)))) in 
    (if (clauseTrue clause' (elements (getM state))) then 
        state
    else (if clause'=[] then 
        state getSATFlag := FALSE 
    else (if (length clause' = 1) then 
        let state' = (assertLiteral (hd clause') False state) in
        exhaustiveUnitPropagate state'
    else (if (clauseTautology clause') then 
        state
    else
        let clauseIndex = length (getF state) in
        let state'   = state getF := (getF state) @ [clause'] in
        let state''  = setWatch1 clauseIndex (nth clause' 0) state' in
        let state''' = setWatch2 clauseIndex (nth clause' 1) state'' in
        state'''
   )))
 ))"


definition
initialState :: "State"
where
"initialState =
     getSATFlag = UNDEF,
      getF = [], 
      getM = [], 
      getConflictFlag = False,
      getConflictClause = 0, 
      getQ = [],
      getReason = λ l. None,
      getWatch1 = λ c. None, 
      getWatch2 = λ c. None,
      getWatchList = λ l. [],
      getC = [],
      getCl = (Pos 0), 
      getCll = (Pos 0), 
      getCn = 0
    
"

primrec initialize :: "Formula  State  State"
where
"initialize [] state = state" |
"initialize (clause # formula) state = initialize formula (addClause clause state)"

definition 
findLastAssertedLiteral :: "State  State"
where
"findLastAssertedLiteral state = 
   state  getCl := getLastAssertedLiteral (oppositeLiteralList (getC state)) (elements (getM state)) "

definition
countCurrentLevelLiterals :: "State  State"
where
"countCurrentLevelLiterals state = 
   (let cl = currentLevel (getM state) in 
        state  getCn := length (filter (λ l. elementLevel (opposite l) (getM state) = cl) (getC state)) )"

definition setConflictAnalysisClause :: "Clause  State  State" 
where 
"setConflictAnalysisClause clause state = 
  (let oppM0 = oppositeLiteralList (elements (prefixToLevel 0 (getM state))) in 
   let state' = state (| getC := remdups (list_diff clause oppM0) |) in 
     countCurrentLevelLiterals (findLastAssertedLiteral state')
  )"
 
definition
applyConflict :: "State  State"
where
"applyConflict state = 
   (let conflictClause = (nth (getF state) (getConflictClause state)) in
    setConflictAnalysisClause conflictClause state)"

definition
applyExplain :: "Literal  State  State"
where
"applyExplain literal state =
    (case (getReason state literal) of
        None  
            state
    |   Some reason  
            let res = resolve (getC state) (nth (getF state) reason) (opposite literal) in 
            setConflictAnalysisClause res state
        
    )
"

partial_function (tailrec)
applyExplainUIP :: "State  State"
where
applyExplainUIP_unfold:
"applyExplainUIP state = 
    (if (getCn state = 1) then 
         state
     else
         applyExplainUIP (applyExplain (getCl state) state)
    )
"

inductive
applyExplainUIP_dom :: "State  bool"
where
step:
"(getCn state  1
     applyExplainUIP_dom (applyExplain (getCl state) state))
   applyExplainUIP_dom state
"

definition
applyLearn :: "State  State"
where
"applyLearn state =
        (if getC state = [opposite (getCl state)] then
            state
         else
            let state' = state getF := (getF state) @ [getC state]  in
            let l  = (getCl state) in
            let ll = (getLastAssertedLiteral (removeAll l (oppositeLiteralList (getC state))) (elements (getM state))) in
            let clauseIndex = length (getF state) in
            let state''  = setWatch1 clauseIndex (opposite l) state' in
            let state''' = setWatch2 clauseIndex (opposite ll) state'' in
            state''' getCll := ll 
        )
"

definition
getBackjumpLevel :: "State  nat"
where
"getBackjumpLevel state ==
    (if getC state = [opposite (getCl state)] then 
        0 
     else
        elementLevel (getCll state) (getM state)
     )
"

definition
applyBackjump :: "State  State"
where
"applyBackjump state =
    (let l = (getCl state) in
     let level = getBackjumpLevel state in
     let state' = state getConflictFlag := False, getQ := [], getM := (prefixToLevel level (getM state)) in
     let state'' = (if level > 0 then setReason (opposite l) (length (getF state) - 1) state' else state') in
     assertLiteral (opposite l) False state''
    )
"

axiomatization selectLiteral :: "State  Variable set  Literal"
where
selectLiteral_def:
"Vbl - vars (elements (getM state))  {}  
    var (selectLiteral state Vbl)  (Vbl - vars (elements (getM state)))"

definition
applyDecide :: "State  Variable set  State"
where
"applyDecide state Vbl =
    assertLiteral (selectLiteral state Vbl) True state
"

definition
solve_loop_body :: "State  Variable set  State"
where
"solve_loop_body state Vbl = 
    (let state' = exhaustiveUnitPropagate state in
    (if (getConflictFlag state') then
        (if (currentLevel (getM state')) = 0 then
            state' getSATFlag := FALSE 
         else
            (applyBackjump
            (applyLearn
            (applyExplainUIP 
            (applyConflict
                state'
            )
            )
            )
            )
         )
     else
        (if (vars (elements (getM state'))  Vbl) then
            state' getSATFlag := TRUE 
         else
            applyDecide state' Vbl
        )
    )
    )
"


partial_function (tailrec) 
solve_loop :: "State  Variable set  State"
where
solve_loop_unfold: 
"solve_loop state Vbl = 
    (if (getSATFlag state)  UNDEF then
        state
     else 
        let state' = solve_loop_body state Vbl in
        solve_loop state' Vbl
    )
"

inductive
solve_loop_dom :: "State  Variable set  bool"
where
step:
"(getSATFlag state = UNDEF
     solve_loop_dom (solve_loop_body state Vbl) Vbl)
   solve_loop_dom state Vbl"

definition solve::"Formula  ExtendedBool"
where
"solve F0 = 
    (getSATFlag 
        (solve_loop 
            (initialize F0 initialState) (vars F0)
        )
    )
"

(* 
code_modulename SML
  Nat Numbers
  Int Numbers
  Ring_and_Field Numbers

code_modulename OCaml
  Nat Numbers
  Int Numbers
  Ring_and_Field Numbers

export_code solve in OCaml file "code/solve.ML"
                  in SML file "code/solve.ocaml
                  in Haskell file "code/"
*)

(******************************************************************************)
(*      I N V A R I A N T S                                                   *)
(******************************************************************************)

definition
InvariantWatchListsContainOnlyClausesFromF :: "(Literal  nat list)  Formula  bool"
where
"InvariantWatchListsContainOnlyClausesFromF Wl F = 
    ( (l::Literal) (c::nat). c   set (Wl l)  0  c  c < length F)
"

definition
InvariantWatchListsUniq :: "(Literal  nat list)  bool"
where
"InvariantWatchListsUniq Wl =
    ( l. uniq (Wl l))
"

definition
InvariantWatchListsCharacterization :: "(Literal  nat list)  (nat  Literal option)  (nat  Literal option)  bool"
where
"InvariantWatchListsCharacterization Wl w1 w2 = 
    ( (c::nat) (l::Literal). c  set (Wl l) = (Some l = (w1 c)  Some l = (w2 c)))
"

definition
InvariantWatchesEl :: "Formula  (nat  Literal option)  (nat  Literal option)  bool"
where
"InvariantWatchesEl formula watch1 watch2 == 
     (clause::nat). 0  clause  clause < length formula  
        ( (w1::Literal) (w2::Literal). watch1 clause = Some w1  watch2 clause = Some w2  
             w1 el (nth formula clause)  w2 el (nth formula clause))
"

definition
InvariantWatchesDiffer :: "Formula  (nat  Literal option)  (nat  Literal option)  bool"
where
"InvariantWatchesDiffer formula watch1 watch2 == 
     (clause::nat). 0  clause  clause < length formula  watch1 clause  watch2 clause
"

definition
watchCharacterizationCondition::"Literal  Literal  LiteralTrail  Clause  bool"
where
"watchCharacterizationCondition w1 w2 M clause = 
    (literalFalse w1 (elements M)  
        ( ( l. l el clause  literalTrue l (elements M)  elementLevel l M  elementLevel (opposite w1) M) 
          ( l. l el clause  l  w1  l  w2  
                literalFalse l (elements M)  elementLevel (opposite l) M  elementLevel (opposite w1) M)
          )
    )
"

definition
InvariantWatchCharacterization::"Formula  (nat  Literal option)  (nat  Literal option)  LiteralTrail  bool"
where
"InvariantWatchCharacterization F watch1 watch2 M =
    ( c w1 w2. (0  c  c < length F  Some w1 = watch1 c  Some w2 = watch2 c)  
          watchCharacterizationCondition w1 w2 M (nth F c)  
          watchCharacterizationCondition w2 w1 M (nth F c)
    )
"

definition
InvariantQCharacterization :: "bool  Literal list  Formula  LiteralTrail  bool"
where
"InvariantQCharacterization conflictFlag Q F M ==
   ¬ conflictFlag  ( (l::Literal). l el Q = ( (c::Clause). c el F  isUnitClause c l (elements M)))
"

definition
InvariantUniqQ :: "Literal list  bool"
where
"InvariantUniqQ Q = 
    uniq Q
"

definition
InvariantConflictFlagCharacterization :: "bool  Formula  LiteralTrail  bool"
where
"InvariantConflictFlagCharacterization conflictFlag F M ==
    conflictFlag = formulaFalse F (elements M)
"

definition
InvariantNoDecisionsWhenConflict :: "Formula  LiteralTrail  nat  bool"
where
"InvariantNoDecisionsWhenConflict F M level= 
    ( level'. level' < level  
              ¬ formulaFalse F (elements (prefixToLevel level' M))
    )
"

definition
InvariantNoDecisionsWhenUnit :: "Formula  LiteralTrail  nat  bool"
where
"InvariantNoDecisionsWhenUnit F M level = 
    ( level'. level' < level  
              ¬ ( clause literal. clause el F 
                                   isUnitClause clause literal (elements (prefixToLevel level' M)))
    )
"

definition InvariantEquivalentZL :: "Formula  LiteralTrail  Formula  bool"
where
"InvariantEquivalentZL F M F0 = 
    equivalentFormulae (F @ val2form (elements (prefixToLevel 0 M))) F0
"

definition
InvariantGetReasonIsReason :: "(Literal  nat option)  Formula  LiteralTrail  Literal set  bool"
where
"InvariantGetReasonIsReason GetReason F M Q == 
      literal. (literal el (elements M)  ¬ literal el (decisions M)  elementLevel literal M > 0  
                   ( (reason::nat). (GetReason literal) = Some reason  0  reason  reason < length F  
                         isReason (nth F reason) literal (elements M)
                   )
                 )  
                (currentLevel M > 0  literal  Q  
                   ( (reason::nat). (GetReason literal) = Some reason  0  reason  reason < length F  
                         (isUnitClause (nth F reason) literal (elements M)  clauseFalse (nth F reason) (elements M))
                   )
                 )
"

definition
InvariantConflictClauseCharacterization :: "bool  nat  Formula  LiteralTrail  bool"
where
"InvariantConflictClauseCharacterization conflictFlag conflictClause F M  ==
         conflictFlag  (conflictClause < length F  
                           clauseFalse (nth F conflictClause) (elements M))"

definition
InvariantClCharacterization :: "Literal  Clause  LiteralTrail  bool" 
where
"InvariantClCharacterization Cl C M == 
  isLastAssertedLiteral Cl (oppositeLiteralList C) (elements M)"

definition
InvariantCllCharacterization :: "Literal  Literal  Clause  LiteralTrail  bool" 
where
"InvariantCllCharacterization Cl Cll C M == 
  set C  {opposite Cl}  
      isLastAssertedLiteral Cll (removeAll Cl (oppositeLiteralList C)) (elements M)"

definition
InvariantClCurrentLevel :: "Literal  LiteralTrail  bool"
where
"InvariantClCurrentLevel Cl M == 
  elementLevel Cl M = currentLevel M"

definition
InvariantCnCharacterization :: "nat  Clause  LiteralTrail  bool"
where
"InvariantCnCharacterization Cn C M == 
  Cn = length (filter (λ l. elementLevel (opposite l) M = currentLevel M) (remdups C))
"

definition
InvariantUniqC :: "Clause  bool"
where
"InvariantUniqC clause = uniq clause"

definition
InvariantVarsQ :: "Literal list  Formula  Variable set  bool"
where
"InvariantVarsQ Q F0 Vbl ==
  vars Q  vars F0  Vbl"

(******************************************************************************)

end