File ‹nat_sub.ML›
signature NAT_SUB =
sig
val fheads: term list
val norm_plus1: wfconv
val norm_plus: wfconv
val norm_minus': wfconv
val move_outmost: term -> wfconv
val cancel_terms: wfconv
val norm_minus: wfconv
type monomial = cterm list * int
val reduce_monomial_list: monomial list -> monomial list
val add_polynomial_list: monomial list * monomial list -> monomial list
val norm_minus_ct: cterm -> monomial list
val norm_ring_term: cterm -> term
val get_sub_head_equiv:
Proof.context -> box_id * cterm -> (box_id * (wfterm * thm)) list
val nat_sub_expand_once:
Proof.context -> box_id * wfterm -> (box_id * (wfterm * thm)) list
val nat_sub_expand:
Proof.context -> box_id * cterm -> (box_id * (wfterm * thm)) list
val nat_sub_expand_equiv: proofstep
val nat_sub_expand_unit: proofstep
val add_nat_sub_proofsteps: theory -> theory
end;
structure NatSub : NAT_SUB =
struct
structure WfTerm = Auto2_Setup.WfTerm
val fheads = [@{term "(+) :: (nat => nat => nat)"},
@{term "(-) :: (nat => nat => nat)"},
@{term "(*) :: (nat => nat => nat)"}]
val plus_info = Nat_Util.plus_ac_on_typ @{theory} natT
val is_numc = UtilArith.is_numc
val dest_numc = UtilArith.dest_numc
val is_plus = UtilArith.is_plus
val is_minus = UtilArith.is_minus
val is_times = UtilArith.is_times
val is_zero = UtilArith.is_zero
val is_one = UtilArith.is_one
val wf_rewr_obj_eq = WfTerm.rewr_obj_eq fheads
val nat0 = Nat_Util.nat0
val mult_1 = @{thm mult_1}
val mult_1_right = @{thm mult_1_right}
val mult_0 = @{thm mult_0}
val mult_0_right = @{thm mult_0_right}
val add_0 = @{thm add_0}
val add_0_right = @{thm add_0_right}
val add_assoc = wf_rewr_obj_eq @{thm add_ac(1)}
val add_assoc_sym = wf_rewr_obj_eq (obj_sym @{thm add_ac(1)})
val add_comm = wf_rewr_obj_eq @{thm add_ac(2)}
val swap_add = WfTerm.every_conv [add_assoc, WfTerm.arg_conv add_comm,
add_assoc_sym]
val times_assoc = wf_rewr_obj_eq @{thm mult_ac(1)}
val times_assoc_sym = wf_rewr_obj_eq (obj_sym @{thm mult_ac(1)})
val times_comm = wf_rewr_obj_eq @{thm mult_ac(2)}
val swap_times = WfTerm.every_conv [times_assoc, WfTerm.arg_conv times_comm,
times_assoc_sym]
val distrib_r_th = @{thm nat_distrib(1)}
val distrib_l_th = @{thm nat_distrib(2)}
val nat_fold_wfconv = WfTerm.conv_of Nat_Util.nat_fold_conv
val cancel_0_left = WfTerm.try_conv (wf_rewr_obj_eq @{thm add_0_left})
val append_0_left = wf_rewr_obj_eq (obj_sym @{thm add_0_left})
fun compare_atom (t1, t2) =
if is_numc t1 andalso is_numc t2 then EQUAL
else if is_numc t1 then GREATER
else if is_numc t2 then LESS
else Term_Ord.term_ord (t1, t2)
fun norm_mult_atom wft =
let
val t = WfTerm.term_of wft
val (arg1, arg2) = Util.dest_binop_args t
in
if is_one arg1 then wf_rewr_obj_eq mult_1 wft
else if is_one arg2 then wf_rewr_obj_eq mult_1_right wft
else if is_zero arg1 then wf_rewr_obj_eq mult_0 wft
else if is_zero arg2 then wf_rewr_obj_eq mult_0_right wft
else if is_times arg1 then
case compare_atom (dest_arg arg1, arg2) of
GREATER =>
WfTerm.every_conv [swap_times,
WfTerm.arg1_conv norm_mult_atom] wft
| EQUAL =>
if is_numc (dest_arg arg1) andalso is_numc arg2 then
WfTerm.every_conv [times_assoc,
WfTerm.arg_conv nat_fold_wfconv] wft
else
WfTerm.all_conv wft
| _ => WfTerm.all_conv wft
else
case compare_atom (arg1, arg2) of
GREATER => times_comm wft
| EQUAL =>
if is_numc arg1 andalso is_numc arg2 then
nat_fold_wfconv wft
else
WfTerm.all_conv wft
| _ => WfTerm.all_conv wft
end
fun norm_mult_monomial wft =
let
val t = WfTerm.term_of wft
val (_, arg2) = Util.dest_binop_args t
in
if is_times arg2 then
WfTerm.every_conv [times_assoc_sym,
WfTerm.arg1_conv norm_mult_monomial,
norm_mult_atom] wft
else
norm_mult_atom wft
end
fun dest_monomial t =
if is_times t andalso is_numc (dest_arg t) then
(dest_arg1 t, UtilArith.dest_numc (dest_arg t))
else if is_numc t then
(Nat_Util.mk_nat 1, UtilArith.dest_numc t)
else
(t, 1)
fun norm_monomial wft =
let
val t = WfTerm.term_of wft
in
if is_times t andalso is_numc (dest_arg t) then
WfTerm.all_conv wft
else if is_numc t then
wf_rewr_obj_eq (obj_sym @{thm mult_1}) wft
else
wf_rewr_obj_eq (obj_sym @{thm mult_1_right}) wft
end
fun compare_monomial (t1, t2) =
let
val (arg1, _) = dest_monomial t1
val (arg2, _) = dest_monomial t2
in
if is_one arg1 andalso is_one arg2 then EQUAL
else if is_one arg1 then GREATER
else if is_one arg2 then LESS
else Term_Ord.term_ord (arg1, arg2)
end
fun combine_monomial wft =
WfTerm.every_conv [WfTerm.binop_conv norm_monomial,
wf_rewr_obj_eq (obj_sym distrib_l_th),
WfTerm.arg_conv nat_fold_wfconv] wft
fun norm_plus1 wft =
let
val (a, b) = wft |> WfTerm.term_of |> Util.dest_binop_args
in
if is_zero a then wf_rewr_obj_eq add_0 wft
else if is_zero b then wf_rewr_obj_eq add_0_right wft
else if is_plus a then
case compare_monomial (dest_arg a, b) of
LESS => WfTerm.all_conv wft
| EQUAL =>
if is_numc (dest_arg a) andalso is_numc b then
WfTerm.every_conv [add_assoc,
WfTerm.arg_conv nat_fold_wfconv] wft
else
WfTerm.every_conv [add_assoc,
WfTerm.arg_conv combine_monomial] wft
| GREATER =>
WfTerm.every_conv [swap_add, WfTerm.arg1_conv norm_plus1] wft
else
case compare_monomial (a, b) of
LESS => WfTerm.all_conv wft
| EQUAL => if is_numc a andalso is_numc b then
nat_fold_wfconv wft
else
combine_monomial wft
| GREATER => wf_rewr_obj_eq @{thm add_ac(2)} wft
end
fun norm_plus wft =
let
val (_, b) = wft |> WfTerm.term_of |> Util.dest_binop_args
in
if is_plus b then
WfTerm.every_conv [
add_assoc_sym, WfTerm.arg1_conv norm_plus, norm_plus1] wft
else
norm_plus1 wft
end
fun move_outmost u wft =
if u aconv (WfTerm.term_of wft) then WfTerm.all_conv wft
else if not (is_plus (WfTerm.term_of wft)) then
raise Fail "move_outmost: u not found in wft."
else let
val (a, b) = wft |> WfTerm.term_of |> Util.dest_binop_args
in
if u aconv b then WfTerm.all_conv wft
else if is_plus a then
WfTerm.every_conv [WfTerm.arg1_conv (move_outmost u), swap_add] wft
else if u aconv a then
add_comm wft
else
raise Fail "move_outmost: u not found in wft."
end
fun cancel_terms wft =
let
val (a, b) = wft |> WfTerm.term_of |> Util.dest_binop_args
val (ts1, ts2) = apply2 (Auto2_HOL_Extra_Setup.ACUtil.dest_ac plus_info) (a, b)
fun find_same_arg (t1, t2) =
case compare_monomial (t1, t2) of
EQUAL => if t1 aconv nat0 orelse t2 aconv nat0 then NONE
else SOME (t1, t2)
| _ => NONE
fun prepare_left t1 =
if length ts1 = 1 then WfTerm.arg1_conv append_0_left
else WfTerm.arg1_conv (move_outmost t1)
fun prepare_right t2 =
if length ts2 = 1 then WfTerm.arg_conv append_0_left
else WfTerm.arg_conv (move_outmost t2)
fun apply_th (t1, t2) wft =
let
val (_, n1) = dest_monomial t1
val (_, n2) = dest_monomial t2
val fold_cv =
WfTerm.every_conv [
WfTerm.arg_conv nat_fold_wfconv,
WfTerm.try_conv (wf_rewr_obj_eq mult_1),
WfTerm.try_conv (wf_rewr_obj_eq mult_1_right)]
val le_th = if n1 >= n2 then Nat_Util.nat_le_th n2 n1
else Nat_Util.nat_le_th n1 n2
in
if n1 = n2 then
wf_rewr_obj_eq @{thm nat_sub_combine} wft
else if n1 > n2 then
WfTerm.every_conv [
WfTerm.binop_conv (WfTerm.arg_conv norm_monomial),
wf_rewr_obj_eq (le_th RS @{thm nat_sub_combine2}),
WfTerm.arg1_conv (WfTerm.arg_conv fold_cv),
WfTerm.arg1_conv norm_plus] wft
else
WfTerm.every_conv [
WfTerm.binop_conv (WfTerm.arg_conv norm_monomial),
wf_rewr_obj_eq (le_th RS @{thm nat_sub_combine3}),
WfTerm.arg_conv (WfTerm.arg_conv fold_cv),
WfTerm.arg_conv norm_plus] wft
end
in
case get_first find_same_arg (Util.all_pairs (ts1, ts2)) of
NONE => WfTerm.all_conv wft
| SOME (t1, t2) =>
WfTerm.every_conv [prepare_left t1, prepare_right t2,
apply_th (t1, t2),
WfTerm.arg1_conv cancel_0_left,
WfTerm.arg_conv cancel_0_left,
cancel_terms] wft
end
fun norm_mult_poly_monomial wft =
let
val t = WfTerm.term_of wft
val (arg1, _) = Util.dest_binop_args t
in
if is_plus arg1 then
WfTerm.every_conv [wf_rewr_obj_eq distrib_r_th,
WfTerm.arg1_conv norm_mult_poly_monomial,
WfTerm.arg_conv norm_mult_monomial,
norm_plus] wft
else
norm_mult_monomial wft
end
fun norm_mult_polynomials wft =
let
val t = WfTerm.term_of wft
val (_, arg2) = Util.dest_binop_args t
in
if is_plus arg2 then
WfTerm.every_conv [wf_rewr_obj_eq distrib_l_th,
WfTerm.arg1_conv norm_mult_polynomials,
WfTerm.arg_conv norm_mult_poly_monomial,
norm_plus] wft
else
norm_mult_poly_monomial wft
end
fun norm_minus' wft =
let
val t = WfTerm.term_of wft
in
if is_plus t then
WfTerm.every_conv [WfTerm.binop_conv norm_minus',
wf_rewr_obj_eq @{thm nat_sub1},
WfTerm.binop_conv norm_plus,
cancel_terms] wft
else if is_minus t then
WfTerm.every_conv [WfTerm.binop_conv norm_minus',
wf_rewr_obj_eq @{thm nat_sub2},
WfTerm.binop_conv norm_plus,
cancel_terms] wft
else if is_times t then
WfTerm.every_conv [
WfTerm.binop_conv norm_minus',
wf_rewr_obj_eq @{thm nat_sub3},
WfTerm.binop_conv (WfTerm.binop_conv norm_mult_polynomials),
WfTerm.binop_conv norm_plus,
cancel_terms] wft
else
wf_rewr_obj_eq @{thm nat_sub_norm} wft
end
fun norm_minus wft =
WfTerm.every_conv [
norm_minus',
WfTerm.try_conv (wf_rewr_obj_eq @{thm diff_zero})] wft
type monomial = cterm list * int
fun mk_plus (t1, t2) =
Const (@{const_name plus}, natT --> natT --> natT) $ t1 $ t2
fun list_plus ts =
let
fun list_rev ts =
case ts of
[] => Nat_Util.mk_nat 0
| [t] => t
| t :: ts' => mk_plus (list_rev ts', t)
in
list_rev (rev ts)
end
fun mk_times (t1, t2) =
Const (@{const_name times}, natT --> natT --> natT) $ t1 $ t2
fun list_times ts =
let
fun list_rev ts =
case ts of
[] => Nat_Util.mk_nat 1
| [t] => t
| t :: ts' => mk_times (list_rev ts', t)
in
list_rev (rev ts)
end
fun compare_monomial_list ((l1, _), (l2, _)) =
if null l1 andalso null l2 then EQUAL
else if null l1 then GREATER
else if null l2 then LESS
else Term_Ord.term_ord (list_times (map Thm.term_of l1),
list_times (map Thm.term_of l2))
fun reduce_monomial_list ls =
if null ls then []
else let
val ((l1, c1), rest) = (hd ls, reduce_monomial_list (tl ls))
in
case rest of
[] => [(l1, c1)]
| [(_, 0)] => [(l1, c1)]
| (l2, c2) :: rest' =>
if eq_list (op aconvc) (l1, l2) then
if c1 + c2 = 0 then rest'
else (l1, c1 + c2) :: rest'
else (l1, c1) :: (l2, c2) :: rest'
end
fun mult_monomial ((l1, c1), (l2, c2)) =
((l1 @ l2) |> sort (compare_atom o apply2 Thm.term_of), c1 * c2)
fun mult_polynomial_term (ls1, ls2) =
(Util.all_pairs (ls1, ls2))
|> map mult_monomial
|> sort compare_monomial_list
|> reduce_monomial_list
fun add_polynomial_list (ls1, ls2) =
(ls1 @ ls2) |> sort compare_monomial_list |> reduce_monomial_list
fun negate_polynomial_list ls =
map (fn (ls, n) => (ls, ~n)) ls
fun norm_minus_ct ct =
let
val t = Thm.term_of ct
in
if is_plus t then
add_polynomial_list (norm_minus_ct (Thm.dest_arg1 ct),
norm_minus_ct (Thm.dest_arg ct))
else if is_minus t then
add_polynomial_list (
norm_minus_ct (Thm.dest_arg1 ct),
negate_polynomial_list (norm_minus_ct (Thm.dest_arg ct)))
else if is_times t then
mult_polynomial_term (norm_minus_ct (Thm.dest_arg1 ct),
norm_minus_ct (Thm.dest_arg ct))
else if is_numc t then
[([], dest_numc t)]
else
[([ct], 1)]
end
fun subterms_of ct =
ct |> norm_minus_ct |> map fst |> flat
fun to_monomial (l, c) =
if null l then mk_nat c
else if c = 1 then list_times (map Thm.term_of l)
else mk_times (list_times (map Thm.term_of l), mk_nat c)
fun norm_ring_term ct =
let
val dest_ring = norm_minus_ct ct
val (ps, ms) = filter_split (fn (_, n) => n >= 0) dest_ring
in
if length ms > 0 then
Const (@{const_name minus}, natT --> natT --> natT) $
list_plus (map to_monomial ps) $
list_plus (map to_monomial (negate_polynomial_list ms))
else
list_plus (map to_monomial ps)
end
fun get_sub_head_equiv ctxt (id, ct) =
let
fun process_wf_head_equiv (id', (wft, eq_th)) =
let
val cts = subterms_of (Thm.rhs_of eq_th)
val simps = Auto2_Setup.WellformData.simplify ctxt fheads cts (id', wft)
fun process_simp (id'', (wft', eq_th')) =
(BoxID.merge_boxes ctxt (id', id''),
(wft', Util.transitive_list [eq_th, eq_th']))
in
map process_simp simps
end
in
(id, ct) |> Auto2_Setup.WellformData.get_head_equiv ctxt fheads
|> maps process_wf_head_equiv
|> filter_out (Thm.is_reflexive o snd o snd)
end
fun nat_sub_expand_once ctxt (id, wft) =
let
val ct = WfTerm.cterm_of wft
val subt = subterms_of ct
fun get_equiv cu = get_sub_head_equiv ctxt (id, cu)
fun process_info (id', wf_eq) =
(id', WfTerm.rewrite_on_eqs fheads [wf_eq] wft)
in
map process_info (maps get_equiv subt)
end
fun nat_sub_expand ctxt (id, ct) =
let
val max_ac = Config.get ctxt Auto2_HOL_Extra_Setup.AC_ProofSteps.max_ac
fun ac_equiv_eq_better (id, (_, th)) (id', (_, th')) =
let
val seq1 = subterms_of (Thm.rhs_of th)
val seq2 = subterms_of (Thm.rhs_of th')
in
Util.is_subseq (op aconvc) (seq1, seq2) andalso
BoxID.is_eq_ancestor ctxt id id'
end
fun has_ac_equiv_eq_better infos info' =
exists (fn info => ac_equiv_eq_better info info') infos
fun helper (old, new) =
case new of
[] => old
| (id', (wft, eq_th)) :: rest =>
if length old + length new > max_ac then
old @ take (max_ac - length old) new
else let
val old' = ((id', (wft, eq_th)) :: old)
fun merge_info (id'', (wft', eq_th')) =
(BoxID.merge_boxes ctxt (id', id''),
(wft', Util.transitive_list [eq_th, eq_th']))
val rhs_expand =
(nat_sub_expand_once ctxt (id', wft))
|> Util.max_partial ac_equiv_eq_better
|> map merge_info
|> filter_out (has_ac_equiv_eq_better (old' @ rest))
in
helper (old', rest @ rhs_expand)
end
val ts = subterms_of ct
val start = (Auto2_Setup.WellformData.cterm_to_wfterm ctxt fheads (id, ct))
|> maps (Auto2_Setup.WellformData.simplify ctxt fheads ts)
in
helper ([], start)
end
fun is_nat_sub_form t =
if is_plus t orelse is_minus t orelse is_times t then
fastype_of (dest_arg t) = natT
else false
fun nat_sub_expand_equiv_fn ctxt item1 item2 =
let
val {id = id1, tname = tname1, ...} = item1
val {id = id2, tname = tname2, ...} = item2
val (ct1, ct2) = (the_single tname1, the_single tname2)
val (t1, t2) = (Thm.term_of ct1, Thm.term_of ct2)
val id = BoxID.merge_boxes ctxt (id1, id2)
in
if not (is_nat_sub_form t1) orelse not (is_nat_sub_form t2) then []
else if Term_Ord.term_ord (t2, t1) = LESS then []
else if Auto2_Setup.RewriteTable.is_equiv id ctxt (ct1, ct2) then []
else let
val expand1 = nat_sub_expand ctxt (id, ct1)
val expand2 = nat_sub_expand ctxt (id, ct2)
fun get_equiv ((id1, (wft1, eq_th1)), (id2, (wft2, eq_th2))) =
let
val ct1 = Thm.rhs_of eq_th1
val ct2 = Thm.rhs_of eq_th2
val ts1 = norm_minus_ct ct1
val ts2 = norm_minus_ct ct2
in
if eq_list (eq_pair (eq_list (op aconvc)) (op =)) (ts1, ts2) then
let
val (wft1', eq1) = norm_minus wft1
val (wft2', eq2) = norm_minus wft2
val _ = assert (WfTerm.term_of wft1' aconv WfTerm.term_of wft2')
"nat_sub_expand_equiv_fn"
val id' = BoxID.merge_boxes ctxt (id1, id2)
val eq = Util.transitive_list [
eq_th1, eq1, meta_sym eq2, meta_sym eq_th2]
in
[(id', to_obj_eq eq)]
end
else []
end
in
(maps get_equiv (Util.all_pairs (expand1, expand2)))
|> filter (BoxID.has_incr_id o fst)
|> Util.max_partial (BoxID.id_is_eq_ancestor ctxt)
|> map (fn (id, th) => Auto2_Setup.Update.thm_update_sc 1 (id, th))
end
end
val nat_sub_expand_equiv =
{name = "nat_sub_expand_equiv",
args = [TypedMatch (TY_TERM, @{term_pat "?A::nat"}),
TypedMatch (TY_TERM, @{term_pat "?B::nat"})],
func = TwoStep nat_sub_expand_equiv_fn}
fun nat_sub_expand_unit_fn ctxt item =
let
val {id, tname, ...} = item
val ct = the_single tname
val t = Thm.term_of ct
in
if not (is_nat_sub_form t) then []
else let
val expand = nat_sub_expand ctxt (id, ct)
fun process_expand (id', (wft, eq_th)) =
let
val ct = Thm.rhs_of eq_th
val ts = norm_minus_ct ct
in
if length ts = 0 orelse
(length ts = 1 andalso snd (the_single ts) = 1) then
let
val (_, eq') = norm_minus wft
val eq = Util.transitive_list [eq_th, eq']
in
[(id', to_obj_eq eq)]
end
else []
end
in
(maps process_expand expand)
|> filter (BoxID.has_incr_id o fst)
|> map (fn (id, th) => Auto2_Setup.Update.thm_update_sc 1 (id, th))
end
end
val nat_sub_expand_unit =
{name = "nat_sub_expand_unit",
args = [TypedMatch (TY_TERM, @{term_pat "?A::nat"})],
func = OneStep nat_sub_expand_unit_fn}
val add_nat_sub_proofsteps =
fold add_prfstep [
nat_sub_expand_equiv, nat_sub_expand_unit
]
end
val _ = Theory.setup NatSub.add_nat_sub_proofsteps