File ‹~~/src/Provers/splitter.ML›
signature SPLITTER_DATA =
sig
  val context       : Proof.context
  val mk_eq         : thm -> thm
  val meta_eq_to_iff: thm 
  val iffD          : thm 
  val disjE         : thm 
  val conjE         : thm 
  val exE           : thm 
  val contrapos     : thm 
  val contrapos2    : thm 
  val notnotD       : thm 
  val safe_tac      : Proof.context -> tactic
end
signature SPLITTER =
sig
  
  val cmap_of_split_thms: thm list -> (string * (typ * term * thm * typ * int) list) list
  val split_posns: (string * (typ * term * thm * typ * int) list) list ->
    theory -> typ list -> term -> (thm * (typ * typ * int list) list * int list * typ * term) list
    
  
  val split_tac       : Proof.context -> thm list -> int -> tactic
  val split_inside_tac: Proof.context -> thm list -> int -> tactic
  val split_asm_tac   : Proof.context -> thm list -> int -> tactic
  val add_split: thm -> Proof.context -> Proof.context
  val add_split_bang: thm -> Proof.context -> Proof.context
  val del_split: thm -> Proof.context -> Proof.context
  val split_modifiers : Method.modifier parser list
end;
functor Splitter(Data: SPLITTER_DATA): SPLITTER =
struct
val Const (const_not, _) $ _ =
  Object_Logic.drop_judgment Data.context
    (#1 (Logic.dest_implies (Thm.prop_of Data.notnotD)));
val Const (const_or , _) $ _ $ _ =
  Object_Logic.drop_judgment Data.context
    (#1 (Logic.dest_implies (Thm.prop_of Data.disjE)));
val const_Trueprop = Object_Logic.judgment_name Data.context;
fun split_format_err () = error "Wrong format for split rule";
fun split_thm_info thm =
  (case Thm.concl_of (Data.mk_eq thm) of
    \<^Const_>‹Pure.eq _ for ‹Var _ $ t› c› =>
      (case strip_comb t of
        (Const p, _) => (p, case c of (Const (s, _) $ _) => s = const_not | _ => false)
      | _ => split_format_err ())
  | _ => split_format_err ());
fun cmap_of_split_thms thms =
let
  val splits = map Data.mk_eq thms
  fun add_thm thm cmap =
    (case Thm.concl_of thm of _ $ (t as _ $ lhs) $ _ =>
       (case strip_comb lhs of (Const(a,aT),args) =>
          let val info = (aT,lhs,thm,fastype_of t,length args)
          in case AList.lookup (op =) cmap a of
               SOME infos => AList.update (op =) (a, info::infos) cmap
             | NONE => (a,[info])::cmap
          end
        | _ => split_format_err())
     | _ => split_format_err())
in
  fold add_thm splits []
end;
val abss = fold (Term.abs o pair "");
fun mk_case_split_tac order =
let
val meta_iffD = Data.meta_eq_to_iff RS Data.iffD;  
val lift = Goal.prove_global \<^theory> ["P", "Q", "R"]
  [Syntax.read_prop_global \<^theory>‹Pure› "!!x :: 'b. Q(x) == R(x) :: 'c"]
  (Syntax.read_prop_global \<^theory>‹Pure› "P(%x. Q(x)) == P(%x. R(x))")
  (fn {context = ctxt, prems} =>
    rewrite_goals_tac ctxt prems THEN resolve_tac ctxt [reflexive_thm] 1)
val _ $ _ $ (_ $ (_ $ abs_lift) $ _) = Thm.prop_of lift;
val trlift = lift RS transitive_thm;
fun mk_cntxt t pos T =
  let
    fun down [] t = (Bound 0, t)
      | down (p :: ps) t =
          let
            val (h, ts) = strip_comb t
            val (ts1, u :: ts2) = chop p ts
            val (u1, u2) = down ps u
          in
            (list_comb (incr_boundvars 1 h,
               map (incr_boundvars 1) ts1 @ u1 ::
               map (incr_boundvars 1) ts2),
             u2)
          end;
    val (u1, u2) = down (rev pos) t
  in (Abs ("", T, u1), u2) end;
fun mk_cntxt_splitthm t tt T =
  let fun repl lev t =
    if Envir.aeconv(incr_boundvars lev tt, t) then Bound lev
    else case t of
        (Abs (v, T2, t)) => Abs (v, T2, repl (lev+1) t)
      | (Bound i) => Bound (if i>=lev then i+1 else i)
      | (t1 $ t2) => (repl lev t1) $ (repl lev t2)
      | t => t
  in Abs("", T, repl 0 t) end;
fun add_lbnos t is = add_loose_bnos (t, 0, is);
fun type_test (T, lbnos, apsns) =
  let val (_, U: typ, _) = nth apsns (foldl1 Int.min lbnos)
  in T = U end;
fun mk_split_pack (thm, T: typ, T', n, ts, apsns, pos, TB, t) =
  if n > length ts then []
  else let val lev = length apsns
           val lbnos = fold add_lbnos (take n ts) []
           val flbnos = filter (fn i => i < lev) lbnos
           val tt = incr_boundvars (~lev) t
       in if null flbnos then
            if T = T' then [(thm,[],pos,TB,tt)] else []
          else if type_test(T,flbnos,apsns) then [(thm, rev apsns,pos,TB,tt)]
               else []
       end;
local
  exception MATCH
in
  fun typ_match thy (tyenv, TU) = Sign.typ_match thy TU tyenv
    handle Type.TYPE_MATCH => raise MATCH;
  fun fomatch thy args =
    let
      fun mtch tyinsts = fn
          (Ts, Var(_,T), t) =>
            typ_match thy (tyinsts, (T, fastype_of1(Ts,t)))
        | (_, Free (a,T), Free (b,U)) =>
            if a=b then typ_match thy (tyinsts,(T,U)) else raise MATCH
        | (_, Const (a,T), Const (b,U)) =>
            if a=b then typ_match thy (tyinsts,(T,U)) else raise MATCH
        | (_, Bound i, Bound j) =>
            if i=j then tyinsts else raise MATCH
        | (Ts, Abs(_,T,t), Abs(_,U,u)) =>
            mtch (typ_match thy (tyinsts,(T,U))) (U::Ts,t,u)
        | (Ts, f$t, g$u) =>
            mtch (mtch tyinsts (Ts,f,g)) (Ts, t, u)
        | _ => raise MATCH
    in (mtch Vartab.empty args; true) handle MATCH => false end;
end;
fun split_posns (cmap : (string * (typ * term * thm * typ * int) list) list) thy Ts t =
  let
    val T' = fastype_of1 (Ts, t);
    fun posns Ts pos apsns (Abs (_, T, t)) =
          let val U = fastype_of1 (T::Ts,t)
          in posns (T::Ts) (0::pos) ((T, U, pos)::apsns) t end
      | posns Ts pos apsns t =
          let
            val (h, ts) = strip_comb t
            fun iter t (i, a) = (i+1, (posns Ts (i::pos) apsns t) @ a);
            val a =
              case h of
                Const(c, cT) =>
                  let fun find [] = []
                        | find ((gcT, pat, thm, T, n)::tups) =
                            let val t2 = list_comb (h, take n ts) in
                              if Sign.typ_instance thy (cT, gcT) andalso fomatch thy (Ts, pat, t2)
                              then mk_split_pack(thm,T,T',n,ts,apsns,pos,type_of1(Ts,t2),t2)
                              else find tups
                            end
                  in find (these (AList.lookup (op =) cmap c)) end
              | _ => []
          in snd (fold iter ts (0, a)) end
  in posns Ts [] [] t end;
fun shorter ((_,ps,pos,_,_), (_,qs,qos,_,_)) =
  prod_ord (int_ord o apply2 length) (order o apply2 length)
    ((ps, pos), (qs, qos));
fun select thy cmap state i =
  let
    val goal = Thm.term_of (Thm.cprem_of state i);
    val Ts = rev (map #2 (Logic.strip_params goal));
    val _ $ t $ _ = Logic.strip_assums_concl goal;
  in (Ts, t, sort shorter (split_posns cmap thy Ts t)) end;
fun exported_split_posns cmap thy Ts t =
  sort shorter (split_posns cmap thy Ts t);
fun inst_lift ctxt Ts t (T, U, pos) state i =
  let
    val (cntxt, u) = mk_cntxt t pos (T --> U);
    val trlift' = Thm.lift_rule (Thm.cprem_of state i)
      (Thm.rename_boundvars abs_lift u trlift);
    val (Var (P, _), _) =
      strip_comb (fst (Logic.dest_equals
        (Logic.strip_assums_concl (Thm.prop_of trlift'))));
  in infer_instantiate ctxt [(P, Thm.cterm_of ctxt (abss Ts cntxt))] trlift' end;
fun inst_split ctxt Ts t tt thm TB state i =
  let
    val thm' = Thm.lift_rule (Thm.cprem_of state i) thm;
    val (Var (P, _), _) =
      strip_comb (fst (Logic.dest_equals
        (Logic.strip_assums_concl (Thm.prop_of thm'))));
    val cntxt = mk_cntxt_splitthm t tt TB;
  in infer_instantiate ctxt [(P, Thm.cterm_of ctxt (abss Ts cntxt))] thm' end;
fun split_tac _ [] i = no_tac
  | split_tac ctxt splits i =
  let val cmap = cmap_of_split_thms splits
      fun lift_tac Ts t p st = compose_tac ctxt (false, inst_lift ctxt Ts t p st i, 2) i st
      fun lift_split_tac state =
            let val (Ts, t, splits) = select (Proof_Context.theory_of ctxt) cmap state i
            in case splits of
                 [] => no_tac state
               | (thm, apsns, pos, TB, tt)::_ =>
                   (case apsns of
                      [] =>
                        compose_tac ctxt (false, inst_split ctxt Ts t tt thm TB state i, 0) i state
                    | p::_ => EVERY [lift_tac Ts t p,
                                     resolve_tac ctxt [reflexive_thm] (i+1),
                                     lift_split_tac] state)
            end
  in COND (has_fewer_prems i) no_tac
          (resolve_tac ctxt [meta_iffD] i THEN lift_split_tac)
  end;
in (split_tac, exported_split_posns) end;  
val (split_tac, split_posns) = mk_case_split_tac int_ord;
val (split_inside_tac, _) = mk_case_split_tac (rev_order o int_ord);
fun split_asm_tac _ [] = K no_tac
  | split_asm_tac ctxt splits =
  let val cname_list = map (fst o fst o split_thm_info) splits;
      fun tac (t,i) =
          let val n = find_index (exists_Const (member (op =) cname_list o #1))
                                 (Logic.strip_assums_hyp t);
              fun first_prem_is_disj (Const (\<^const_name>‹Pure.imp›, _) $ (Const (c, _)
                    $ (Const (s, _) $ _ $ _ )) $ _ ) = c = const_Trueprop andalso s = const_or
              |   first_prem_is_disj (Const(\<^const_name>‹Pure.all›,_)$Abs(_,_,t)) =
                                        first_prem_is_disj t
              |   first_prem_is_disj _ = false;
      
              fun flat_prems_tac i = SUBGOAL (fn (t,i) =>
                           (if first_prem_is_disj t
                            then EVERY[eresolve_tac ctxt [Data.disjE] i, rotate_tac ~1 i,
                                       rotate_tac ~1  (i+1),
                                       flat_prems_tac (i+1)]
                            else all_tac)
                           THEN REPEAT (eresolve_tac ctxt [Data.conjE,Data.exE] i)
                           THEN REPEAT (dresolve_tac ctxt [Data.notnotD]   i)) i;
          in if n<0 then  no_tac  else (DETERM (EVERY'
                [rotate_tac n, eresolve_tac ctxt [Data.contrapos2],
                 split_tac ctxt splits,
                 rotate_tac ~1, eresolve_tac ctxt [Data.contrapos], rotate_tac ~1,
                 flat_prems_tac] i))
          end;
  in SUBGOAL tac
  end;
fun gen_split_tac _ [] = K no_tac
  | gen_split_tac ctxt (split::splits) =
      let val (_,asm) = split_thm_info split
      in (if asm then split_asm_tac else split_tac) ctxt [split] ORELSE'
         gen_split_tac ctxt splits
      end;
fun string_of_typ T =
  if is_Type T then
    (case dest_Type_args T of
      [] => ""
    | Ts => enclose "(" ")" (commas (map string_of_typ Ts))) ^ dest_Type_name T
  else "_";
fun split_name (name, T) asm = "split " ^
  (if asm then "asm " else "") ^ name ^ " :: " ^ string_of_typ T;
fun gen_add_split bang split ctxt =
  let
    val (name, asm) = split_thm_info split
    val split0 = Thm.trim_context split
    fun tac ctxt' =
      let val split' = Thm.transfer' ctxt' split0 in
        (if asm then split_asm_tac ctxt' [split']
         else if bang
              then split_tac ctxt' [split'] THEN_ALL_NEW
                   TRY o (SELECT_GOAL (Data.safe_tac ctxt'))
              else split_tac ctxt' [split'])
      end
  in Simplifier.addloop (ctxt, (split_name name asm, tac)) end;
val add_split = gen_add_split false;
val add_split_bang = gen_add_split true;
fun del_split split ctxt =
  let val (name, asm) = split_thm_info split
  in Simplifier.delloop (ctxt, split_name name asm) end;
val splitN = "split";
fun split_add bang = Simplifier.attrib (gen_add_split bang);
val split_del = Simplifier.attrib del_split;
val add_del = Scan.lift
  (Args.bang >> K (split_add true)
   || Args.del >> K split_del
   || Scan.succeed (split_add false));
val _ = Theory.setup
  (Attrib.setup \<^binding>‹split› add_del "declare case split rule");
val split_modifiers =
 [Args.$$$ splitN -- Args.colon >> K (Method.modifier (split_add false) ⌂),
  Args.$$$ splitN -- Args.bang_colon >> K (Method.modifier (split_add true) ⌂),
  Args.$$$ splitN -- Args.del -- Args.colon >> K (Method.modifier split_del ⌂)];
val _ =
  Theory.setup
    (Method.setup \<^binding>‹split›
      (Attrib.thms >> (fn ths => fn ctxt => SIMPLE_METHOD' (CHANGED_PROP o gen_split_tac ctxt ths)))
      "apply case split rule");
end;