Theory Automatic_Refinement.Tagged_Solver
theory Tagged_Solver
imports Refine_Util
begin
ML ‹
  signature TAGGED_SOLVER = sig
    type solver = thm list * string * string * (Proof.context -> tactic')
    
    val get_solvers: Proof.context -> solver list
    val declare_solver: thm list -> binding -> string 
      -> (Proof.context -> tactic') -> morphism 
      -> Context.generic -> Context.generic
    val lookup_solver: string -> Context.generic -> solver option
    val add_triggers: string -> thm list -> morphism ->
      Context.generic -> Context.generic
    val delete_solver: string -> morphism -> Context.generic -> Context.generic
    val tac_of_solver: Proof.context -> solver -> tactic'
    val get_potential_solvers: Proof.context -> int -> thm -> solver list
    val get_potential_tacs: Proof.context -> int -> thm -> tactic' list
    val solve_greedy_step_tac: Proof.context -> tactic'
    val solve_greedy_tac: Proof.context -> tactic'
    val solve_greedy_keep_tac: Proof.context -> tactic'
    val solve_full_step_tac: Proof.context -> tactic'
    val solve_full_tac: Proof.context -> tactic'
    val solve_full_keep_tac: Proof.context -> tactic'
    val cfg_keep: bool Config.T
    val cfg_trace: bool Config.T
    val cfg_full: bool Config.T
    val cfg_step: bool Config.T
    val solve_tac: Proof.context -> tactic'
    val pretty_solvers: Proof.context -> Pretty.T
  end
  structure Tagged_Solver : TAGGED_SOLVER = struct
    type solver = thm list * string * string * (Proof.context -> tactic')
    structure solvers = Generic_Data (
      type T = solver Item_Net.T * solver Symtab.table
      val empty = (Item_Net.init 
        ((op =) o apply2 #2) 
        (fn p:solver => #1 p |> map Thm.concl_of)
      ,
        Symtab.empty
      )
  
      fun merge ((n1,t1),(n2,t2)) 
        = (Item_Net.merge (n1,n2), Symtab.merge ((op =) o apply2 #2) (t1,t2))
    )
    fun get_solvers ctxt = solvers.get (Context.Proof ctxt) 
      |> #2 |> Symtab.dest |> map #2
    fun lookup_solver n context = let
      val tab = solvers.get context |> #2
    in
      Symtab.lookup tab n
    end
    fun add_triggers n thms phi context = 
      case lookup_solver n context of
        NONE => 
          error ("Undefined solver: " ^ n)
      | SOME (trigs,n,desc,tac) => let
          val thms = map (Morphism.thm phi) thms
          val trigs = thms @ trigs
          val solver = (trigs,n,desc,tac)
        in 
          solvers.map (Item_Net.update solver ## Symtab.update (n, solver)) 
            context
        end
    fun declare_solver thms n desc tac phi context = let
      val thms = map (Morphism.thm phi) thms
      val n = Morphism.binding phi n
      val n = Context.cases Sign.full_name Proof_Context.full_name context n
      val _ = 
        if Symtab.defined (solvers.get context |> #2) n then
          error ("Duplicate solver " ^ n)
        else ()
      val solver = (thms,n,desc,tac)
    in
      solvers.map (Item_Net.update solver ## Symtab.update (n,solver)) context
    end
    fun delete_solver n _ context = 
      case lookup_solver n context of
        NONE => error ("Undefined solver: " ^ n)
      | SOME solver => 
          solvers.map (Item_Net.remove solver ## Symtab.delete (#2 solver))
            context
    val cfg_keep = 
      Attrib.setup_config_bool @{binding tagged_solver_keep} (K false)
    val cfg_trace = 
      Attrib.setup_config_bool @{binding tagged_solver_trace} (K false)
    val cfg_step = 
      Attrib.setup_config_bool @{binding tagged_solver_step} (K false)
    val cfg_full = 
      Attrib.setup_config_bool @{binding tagged_solver_full} (K false)
    
    fun get_potential_solvers ctxt i st = 
      let
        val concl = Logic.concl_of_goal (Thm.prop_of st) i
        val net = solvers.get (Context.Proof ctxt) |> #1
        val solvers = Item_Net.retrieve net concl
      in solvers end
    fun notrace_tac_of_solver ctxt (thms,_,_,tac) = 
      match_tac ctxt thms THEN' tac ctxt
    fun trace_tac_of_solver ctxt (thms,name,_,tac) i st = 
      let
        val _ = tracing ("Trying solver " ^ name)
        val r = match_tac ctxt thms i st
      in
        case Seq.pull r of 
          NONE => (tracing "  No trigger"; Seq.empty)
        | SOME _ => let 
            val r = Seq.maps (tac ctxt i) r
          in
            case Seq.pull r of 
              NONE => (tracing ("  No solution (" ^ name ^ ")"); Seq.empty)
            | SOME _ => (tracing ("  OK (" ^ name ^ ")"); r)
          end
      end
    fun tac_of_solver ctxt = 
      if Config.get ctxt cfg_trace then
        trace_tac_of_solver ctxt
      else
        notrace_tac_of_solver ctxt
    
    fun get_potential_tacs ctxt i st = 
      if i <= Thm.nprems_of st then
        eq_assume_tac :: (
          get_potential_solvers ctxt i st
          |> map (tac_of_solver ctxt)
        )
      else []
    fun solve_greedy_step_tac ctxt i st = 
      (FIRST' (get_potential_tacs ctxt i st)) i st
    fun solve_full_step_tac ctxt i st = 
      (APPEND_LIST' (get_potential_tacs ctxt i st) i st)
    
    fun solve_greedy_tac ctxt i st = let
      val tacs = get_potential_tacs ctxt i st
    in
      (FIRST' tacs THEN_ALL_NEW_FWD solve_greedy_tac ctxt) i st
    end
    
    fun solve_full_tac ctxt i st = let
      val tacs = get_potential_tacs ctxt i st
    in
      (APPEND_LIST' tacs THEN_ALL_NEW_FWD solve_full_tac ctxt) i st
    end
    fun solve_greedy_keep_tac ctxt i st = let
      val tacs = get_potential_tacs ctxt i st
    in
      (FIRST' tacs THEN_ALL_NEW_FWD (TRY o solve_greedy_keep_tac ctxt)) i st
    end
    fun solve_full_keep_tac ctxt i st = let
      val tacs = get_potential_tacs ctxt i st
    in
      (APPEND_LIST' tacs THEN_ALL_NEW_FWD (TRY o solve_full_keep_tac ctxt)) i st
    end
    fun solve_tac ctxt = 
      case (Config.get ctxt cfg_keep, Config.get ctxt cfg_step, 
            Config.get ctxt cfg_full) of
        (_,true,false) => solve_greedy_step_tac ctxt
      | (_,true,true) => solve_full_step_tac ctxt
      | (true,false,false) => solve_greedy_keep_tac ctxt
      | (false,false,false) => solve_greedy_tac ctxt
      | (true,false,true) => solve_full_keep_tac ctxt
      | (false,false,true) => solve_full_tac ctxt
    fun pretty_solvers ctxt = let
      fun pretty_solver (ts,name,desc,_) = Pretty.block (
        Pretty.str (name ^ ": " ^ desc) :: Pretty.fbrk 
        :: Pretty.str ("Triggers: ")
        :: Pretty.commas (map (Thm.pretty_thm ctxt) ts))
      val solvers = get_solvers ctxt
    in
      Pretty.big_list "Solvers:" (map pretty_solver solvers)
    end
  end
›
method_setup tagged_solver = ‹let
  open Refine_Util
  val flags = 
        parse_bool_config "keep" Tagged_Solver.cfg_keep
    ||  parse_bool_config "trace" Tagged_Solver.cfg_trace
    ||  parse_bool_config "full" Tagged_Solver.cfg_full
    ||  parse_bool_config "step" Tagged_Solver.cfg_step
in
  parse_paren_lists flags >> (fn _ => fn ctxt => 
    SIMPLE_METHOD' (Tagged_Solver.solve_tac ctxt)
  )
end
› "Select tactic to solve goal by pattern"
term True
end