File ‹UD.ML›
signature UD =
sig
datatype ud_thm_out_type = trivial of thm | nontrivial of thm * thm
val axioms_of_ci :
theory -> Defs.T -> string * typ -> (string option * string) list
val das_of_ci : theory -> Defs.T -> string * typ -> thm list
val unoverload_definition :
binding * mixfix -> string * typ -> theory -> ud_thm_out_type * theory
end;
structure UD : UD =
struct
fun mk_msg_unoverload_definition msg = "ud: " ^ msg;
datatype ud_thm_out_type = trivial of thm | nontrivial of thm * thm
local
fun match_args (Ts, Us) =
if Type.could_matches (Ts, Us)
then
Option.map Envir.subst_type
(
SOME (Type.raw_matches (Ts, Us) Vartab.empty)
handle Type.TYPE_MATCH => NONE
)
else NONE;
in
fun axioms_of_ci thy defs (c, T) =
let
val const_entry = Theory.const_dep thy (c, T);
val Uss = Defs.specifications_of defs (fst const_entry);
in
Uss
|> filter (fn spec => is_some (match_args (#lhs spec, snd const_entry)))
|> map (fn Us => (#def Us, #description Us))
end;
fun das_of_ci thy defs = axioms_of_ci thy defs
#> map #1
#> filter is_some
#> map (the #> try (Thm.axiom thy))
#> filter is_some
#> map (the #> Drule.abs_def);
end;
local
val msg_no_cids =
mk_msg_unoverload_definition "no suitable constant-instance definitions";
val msg_ud_ex =
mk_msg_unoverload_definition "unoverloaded constant already exists";
val msg_multiple_cids =
mk_msg_unoverload_definition "multiple constant-instance definitions";
val msg_extra_type_variables =
mk_msg_unoverload_definition "specification depends on extra type variables";
fun map_sorts ctxt map_sortsT thm =
let
val Ts = thm
|> Thm.full_prop_of
|> (fn t => Term.add_tvars t [])
val instTs = Ts
|> map TVar
|> map (Term.map_atyps (map_sortsT (Proof_Context.theory_of ctxt)))
|> map (Thm.ctyp_of ctxt)
|> curry op~~ Ts
in Drule.instantiate_normalize (TVars.make instTs, Vars.empty) thm end;
fun trivial_ud thy b cid' =
let
val (with_thm, thy') =
let
val b' = Binding.qualify_name true b "with"
val attr = single UD_With.UDWithData.add
in Global_Theory.add_thm ((b', cid'), attr) thy end
val _ =
Proof_Display.print_theorem (Position.thread_data ()) (Proof_Context.init_global thy')
(Thm.derivation_name with_thm, [with_thm])
in (trivial with_thm, thy') end;
fun nontrivial_ud thy defs (b, mixfix) cid T =
let
fun mk_with_thm thm_rhs_ct with_def_thm cid =
let
val with_def_thm_rhs_ct = with_def_thm
|> Thm.cprop_of
|> Thm.dest_equals_rhs
val inst = Thm.match (with_def_thm_rhs_ct, thm_rhs_ct)
val const_with_def = with_def_thm
|> Drule.instantiate_normalize inst
|> Thm.symmetric
val with_thm = Thm.transitive cid const_with_def
in with_thm end;
val ctxt = Proof_Context.init_global thy
val cid' = Thm.unvarify_global thy cid
val cid_rhs_ct = cid' |> Thm.cprop_of |> Thm.dest_equals_rhs
val (cid_rhs_t, ctxt') = cid_rhs_ct
|> Thm.term_of
|> Logic.unoverload_types_term thy
|> Logic.unvarify_types_local_term ctxt
val rhs_consts =
let
val sort_const_cs = Term.add_tvarsT T []
|> map #2
|> map (Sorts.params_of_sort thy)
|> flat
|> distinct op=
|> map #1
val consts = Term.add_consts cid_rhs_t []
val consts_no_cids = consts
|> map (fn const => (const, axioms_of_ci thy defs const))
|> filter (fn (_, cid_opt) => null cid_opt)
|> map fst
val consts_in_sort = consts
|> filter (fn (_, T) => Term.has_tfreesT T)
|> filter (fn (c, _) => member op= sort_const_cs c)
val elim_consts = thy
|> UD_Consts.get_keys
|> map (UD_Consts.const_of_key thy #> the #> dest_Const)
val consts_out = consts_no_cids @ consts_in_sort
|> distinct op=
|> filter_out (member (swap #> Term.could_match_const) elim_consts)
in consts_out end
val (cs, ctxt'') = ctxt'
|> Variable.variant_fixes (map (#1 #> Long_Name.base_name) rhs_consts)
val fv_of_const =
(map Const rhs_consts ~~ map Free (cs ~~ map #2 rhs_consts))
|> AList.lookup op=
|> mk_opt_id I
val rhs_const_ts = map Const rhs_consts
val arg_ts = map fv_of_const rhs_const_ts
val cid_rhs_t' = map_aterms fv_of_const cid_rhs_t
fun declare_const_with thy =
let
val T = map type_of arg_ts ---> type_of cid_rhs_t'
val b' = Binding.qualify_name true b "with"
in Sign.declare_const_global ((b', T), mixfix) thy end
val (lhst, thy') = declare_const_with thy
val lhst' = Term.list_comb (lhst, arg_ts)
val (with_def_thm, thy'') =
let
val b' = Binding.qualify_name true b "with_def"
val def_t = Logic.mk_equals (lhst', cid_rhs_t')
in
Global_Theory.add_defs false [((b', def_t), [])] thy' |>> the_single
handle ERROR c =>
let val msg_match = "Extra type variables on rhs"
in
if String.isSubstring msg_match c
then error msg_extra_type_variables
else error c
end
end
val _ =
Proof_Display.print_theorem (Position.thread_data ()) (Proof_Context.init_global thy'')
(Thm.derivation_name with_def_thm, [with_def_thm])
val (with_thm, thy''') =
let
val b' = Binding.qualify_name true b "with"
val with_thm = cid'
|> mk_with_thm cid_rhs_ct with_def_thm
|> singleton (Proof_Context.export ctxt'' ctxt)
val attr = single UD_With.UDWithData.add
in Global_Theory.add_thm ((b', with_thm), attr) thy'' end
val _ =
Proof_Display.print_theorem (Position.thread_data ()) (Proof_Context.init_global thy''')
(Thm.derivation_name with_thm, [with_thm])
in (nontrivial (with_def_thm, with_thm), thy''') end;
in
fun unoverload_definition (b, mixfix) (c, T) thy =
let
fun get_ud_const lr ud_thms =
let
val get_ud_const_impl = Thm.cprop_of
#> lr
#> Thm.term_of
#> strip_abs_body
#> head_of
#> dest_Const
in map get_ud_const_impl ud_thms end
val ctxt = Proof_Context.init_global thy
val defs = Theory.defs_of thy
val ud_thms = ctxt
|> UD_With.UDWithData.get
|> map (Local_Defs.meta_rewrite_rule ctxt)
val T' = T
|> Type.default_sorts_of_empty_sorts thy
|> Logic.varifyT_mixed_global
val _ =
let
val ud_lhs_consts = get_ud_const Thm.dest_equals_lhs ud_thms
val ex = member (swap #> Term.could_match_const) ud_lhs_consts (c, T')
in not ex orelse error msg_ud_ex end
val cids = das_of_ci thy defs (c, T')
val cid =
if null cids
then error msg_no_cids
else
if length cids = 1
then the_single cids
else error msg_multiple_cids
val cid' =
let
val lhs_const = cid
|> Thm.cprop_of
|> Thm.dest_equals_lhs
|> Thm.term_of
|> head_of
val insts = (lhs_const, Const (c, T'))
|> apply2 (Thm.cterm_of ctxt)
|> Thm.first_order_match
in
cid
|> Drule.instantiate_normalize insts
|> map_sorts ctxt Type.default_sorts_of_empty_sorts
|> Local_Defs.unfold ctxt ud_thms
end
val cid_rhs_t = cid'
|> Thm.cprop_of
|> Thm.dest_equals_rhs
|> Thm.term_of
val trivial_flag =
let val ud_rhs_consts = get_ud_const Thm.dest_equals_rhs ud_thms
in
Term.is_comb cid_rhs_t
andalso cid_rhs_t |> head_of |> is_Const
andalso cid_rhs_t
|> head_of
|> dest_Const
|> member (swap #> Term.could_match_const) ud_rhs_consts
end
val (ud_thm_out, thy') =
if trivial_flag
then trivial_ud thy b cid'
else nontrivial_ud thy defs (b, mixfix) cid' T'
in (ud_thm_out, thy') end;
end;
val msg_not_constant =
mk_msg_unoverload_definition "the input term is not a constant";
fun process_ud ((b_opt, t), mixfix) thy =
let
val t' = Proof_Context.read_term_pattern (Proof_Context.init_global thy) t
val (c, T) = dest_Const t'
handle TERM _ => error msg_not_constant
val b = case b_opt of
SOME b => b
| NONE => Binding.name (Long_Name.base_name c)
in unoverload_definition (b, mixfix) (c, T) thy |> #2 end;
val parse_ud = Scan.option Parse.binding -- Parse.const -- Parse.opt_mixfix';
val _ = Outer_Syntax.command
\<^command_keyword>‹ud›
"unoverloading of constant-instance definitions"
(parse_ud >> (process_ud #> Toplevel.theory));
end;