File ‹rewrite_test.ML›
local
val init = @{context}
fun add_prim_id (prim_id, id) ctxt =
let
val (prim_id', ctxt') = BoxID.add_prim_id id ctxt
val _ = assert (prim_id = prim_id') "add_prim_id"
in
ctxt'
end
fun to_term_info (id, th) = (id, Util.rhs_of th)
val eq_id_list = eq_set (op =)
val eq_info_list = eq_set (eq_pair (op =) (op aconv))
fun print_term_info ctxt (id, t) =
"(" ^ (BoxID.string_of_box_id id) ^ ", " ^ (Syntax.string_of_term ctxt t) ^ ")"
fun print_term_infos ctxt lst =
commas (map (print_term_info ctxt) lst)
val string_of_box_id_list = Util.string_of_list BoxID.string_of_box_id
fun done str _ = let val _ = tracing ("Finished " ^ str) in () end
fun assert_eq_info (info, info') txt ctxt =
if eq_info_list (info, map to_term_info info') then ctxt
else let
val _ = tracing ("Expected: " ^ print_term_infos ctxt info ^ "\n" ^
"Actual: " ^ print_infos ctxt info')
in
raise Fail txt
end
fun assert_rew_info str info_str ctxt =
let
val t = Syntax.read_term ctxt str
val ct = Thm.cterm_of ctxt t
val info = map (apsnd (Syntax.read_term ctxt)) info_str
in
assert_eq_info (info, Auto2_Setup.RewriteTable.get_rewrite_info ctxt ct)
"assert_rew_info" ctxt
end
fun assert_srew_info str info_str ctxt =
let
val t = Syntax.read_term ctxt str
val ct = Thm.cterm_of ctxt t
val info = map (apsnd (Syntax.read_term ctxt)) info_str
in
assert_eq_info (info, Auto2_Setup.RewriteTable.get_subterm_rewrite_info ctxt ct)
"assert_srew_info" ctxt
end
fun assert_head_rep (id, str) info_str ctxt =
let
val t = Syntax.read_term ctxt str
val ct = Thm.cterm_of ctxt t
val info = map (apsnd (Syntax.read_term ctxt)) info_str
val th = Thm.reflexive ct
in
assert_eq_info (info, Auto2_Setup.RewriteTable.get_head_rep_with_id_th ctxt (id, th))
"assert_head_rep" ctxt
end
fun assert_equiv_gen exp_equiv id (str1, str2) ctxt =
let
val (t1, t2) = (Syntax.read_term ctxt str1, Syntax.read_term ctxt str2)
in
if exp_equiv = Auto2_Setup.RewriteTable.is_equiv_t id ctxt (t1, t2) then ctxt
else let
val _ = tracing ("id: " ^ (BoxID.string_of_box_id id))
val _ = trace_tlist ctxt "Input:" [t1, t2]
in
raise Fail (if exp_equiv then "assert_equiv" else "assert_not_equiv")
end
end
val assert_equiv = assert_equiv_gen true
val assert_not_equiv = assert_equiv_gen false
fun assert_simpl_info str info_str ctxt =
let
val t = Syntax.read_term ctxt str
val ct = Thm.cterm_of ctxt t
val info = map (apsnd (Syntax.read_term ctxt)) info_str
in
assert_eq_info (info, Auto2_Setup.RewriteTable.simplify_info ctxt ct)
"assert_simpl_info" ctxt
end
fun assert_simpl id (str, str') ctxt =
let
val (t, t') = (Syntax.read_term ctxt str, Syntax.read_term ctxt str')
val ct = Thm.cterm_of ctxt t
val res = Util.rhs_of (Auto2_Setup.RewriteTable.simplify id ctxt ct)
in
if res aconv t' then ctxt
else let
val _ = trace_t ctxt "Input:" t
val _ = trace_t ctxt "Expected:" t'
val _ = trace_t ctxt "Actual:" res
in
raise Fail "assert_simpl"
end
end
fun assert_head_equivs (id, str) info_str ctxt =
let
val cu = Thm.cterm_of ctxt (Syntax.read_term ctxt str)
val info = map (apsnd (Syntax.read_term ctxt)) info_str
val res = (Auto2_Setup.RewriteTable.get_head_equiv ctxt cu)
|> BoxID.merge_box_with_info ctxt id
in
assert_eq_info (info, res) "assert_head_equivs" ctxt
end
fun index_reps ctxt =
let
fun index_reps_t t =
fold Auto2_Setup.RewriteTable.update_subsimp
(Auto2_Setup.RewriteTable.get_subterm_rewrite_info ctxt t)
in
fold index_reps_t (Auto2_Setup.RewriteTable.get_all_terms ctxt) ctxt
end
fun declare_term str ctxt =
Proof_Context.augment (Syntax.read_term ctxt str) ctxt
fun add_equiv id (str1, str2) ctxt =
let
val (t1, t2) = (Syntax.read_term ctxt str1, Syntax.read_term ctxt str2)
val thy = Proof_Context.theory_of ctxt
val eq_th = Util.assume_meta_eq thy (t1, t2)
in
ctxt |> Auto2_Setup.RewriteTable.add_equiv (id, eq_th)
end
fun add_term str ctxt =
ctxt |> Auto2_Setup.RewriteTable.add_term ([], Thm.cterm_of ctxt (Syntax.read_term ctxt str)) |> snd
fun assert_eq_edges (str1, str2) ids ctxt =
let
val (t1, t2) = (Syntax.read_term ctxt str1, Syntax.read_term ctxt str2)
val ids' = (Auto2_Setup.RewriteTable.equiv_neighs ctxt t1)
|> map to_term_info
|> filter (fn (_, t2') => t2 aconv t2') |> map fst
in
if eq_id_list (ids, ids') then ctxt
else let
val _ = trace_tlist ctxt "Input:" [t1, t2]
val _ = tracing ("Expected: " ^ (string_of_box_id_list ids) ^ "\n" ^
"Actual: " ^ (string_of_box_id_list ids'))
in
raise Fail "assert_eq_edges"
end
end
fun update_simp id (str1, str2) ctxt =
let
val (t1, t2) = (Syntax.read_term ctxt str1, Syntax.read_term ctxt str2)
val thy = Proof_Context.theory_of ctxt
val eq_th = Util.assume_meta_eq thy (t1, t2)
in
ctxt |> Auto2_Setup.RewriteTable.update_simp (id, eq_th)
end
fun assert_rewrite id (str, str') ctxt =
let
val (t, t') = (Syntax.read_term ctxt str, Syntax.read_term ctxt str')
val res = Util.rhs_of (Auto2_Setup.RewriteTable.get_rewrite id ctxt (Thm.cterm_of ctxt t))
in
if res aconv t' then ctxt
else let
val _ = trace_t ctxt "Input:" t
val _ = trace_t ctxt "Expected:" t'
val _ = trace_t ctxt "Actual:" res
in
raise Fail "assert_rewrite"
end
end
fun add_rewrite id (str1, str2) ctxt =
let
val (t1, t2) = (Syntax.read_term ctxt str1, Syntax.read_term ctxt str2)
val thy = Proof_Context.theory_of ctxt
val th = assume_eq thy (t1, t2)
in
ctxt |> Auto2_Setup.RewriteTable.add_rewrite (id, th) |> snd
end
in
val test_equiv_basic =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1]), (4, [2])]
|> declare_term "[a, b, c, d]"
|> fold add_term ["a", "b", "c", "d"]
|> add_equiv [2] ("a", "b") |> assert_eq_edges ("a", "b") [[2]]
|> add_equiv [3] ("b", "c") |> assert_eq_edges ("b", "c") [[3]]
|> add_equiv [] ("a", "b") |> assert_eq_edges ("a", "b") [[]]
|> add_equiv [4] ("a", "d") |> assert_eq_edges ("a", "d") [[4]]
|> add_equiv [3] ("a", "d") |> assert_eq_edges ("a", "d") [[3], [4]]
|> add_equiv [2] ("a", "d") |> assert_eq_edges ("a", "d") [[2], [3]]
|> add_equiv [] ("a", "d") |> assert_eq_edges ("a", "d") [[]]
|> done "test_equiv_basic"
val test_simp_basic =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1]), (4, [2])]
|> declare_term "[a, b, c, d, e]"
|> fold add_term ["a", "b", "c", "d", "e"]
|> update_simp [2] ("e", "d") |> assert_rewrite [2, 3] ("e", "d")
|> update_simp [3] ("e", "c") |> assert_rewrite [3] ("e", "c")
|> update_simp [3] ("e", "b") |> assert_rewrite [2, 3] ("e", "b")
|> update_simp [4] ("e", "a") |> assert_rewrite [2, 3] ("e", "b")
|> update_simp [2] ("e", "b")
|> update_simp [2, 3] ("e", "a")
|> assert_rew_info "e" [([], "e"), ([4], "a"), ([3], "b"),
([2], "b"), ([2, 3], "a")]
|> done "test_simp_basic"
val test_head_rep =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "f a b"]
|> add_term "f b d"
|> update_simp [2] ("d", "c")
|> update_simp [3] ("b", "a")
|> assert_srew_info "f b d" [([], "f b d"), ([2], "f b c"),
([3], "f a d"), ([2, 3], "f a c")]
|> index_reps
|> assert_head_rep ([], "f a c") [([2, 3], "f b d")]
|> assert_head_rep ([], "f b c") [([2], "f b d")]
|> assert_head_rep ([], "f a d") [([3], "f b d")]
|> assert_head_rep ([], "f b d") [([], "f b d")]
|> assert_head_rep ([3], "f b c") [([2, 3], "f b d")]
|> done "test_head_rep"
val test_head_rep3 =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "f a b"]
|> add_term "f b c" |> add_term "f a d"
|> update_simp [2] ("d", "c")
|> update_simp [3] ("b", "a")
|> index_reps
|> assert_head_rep ([], "f a c") [([2], "f a d"), ([3], "f b c")]
|> assert_head_rep ([2], "f a c") [([2], "f a d"), ([2, 3], "f b c")]
|> assert_head_rep ([3], "f a c") [([3], "f b c"), ([2, 3], "f a d")]
|> done "test_head_rep3"
val test_add_term =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "f a b"]
|> add_rewrite [2] ("d", "c")
|> add_rewrite [3] ("b", "a")
|> index_reps
|> add_term "f b c" |> add_term "f a d"
|> assert_head_rep ([], "f a c") [([2], "f a d"), ([3], "f b c")]
|> assert_eq_edges ("f a d", "f b c") [[2, 3]]
|> add_term "f a c" |> add_term "f b d"
|> assert_head_rep ([], "f a c") [([], "f a c"), ([2], "f a d"),
([3], "f b c"), ([2, 3], "f b d")]
|> assert_equiv [2, 3] ("f a c", "f b d")
|> assert_not_equiv [2] ("f a c", "f b d")
|> assert_not_equiv [3] ("f a c", "f b d")
|> done "test_add_term"
val test_add_rewrite =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "f a b"]
|> add_term "f a d"
|> add_rewrite [2] ("d", "c")
|> assert_head_rep ([], "f a c") [([2], "f a d")]
|> add_term "f b d"
|> assert_head_rep ([], "f b c") [([2], "f b d")]
|> add_rewrite [3] ("a", "b")
|> assert_eq_edges ("f a d", "f b d") [[3]]
|> assert_rew_info "f b d" [([], "f b d"), ([2], "f b c"),
([3], "f a d"), ([2, 3], "f a c")]
|> add_term "f b c"
|> assert_head_rep ([], "f a c") [([2], "f a d"), ([3], "f b c"),
([2, 3], "f b d")]
|> done "test_add_rewrite"
val test_simplify =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "[f a b, e]"]
|> add_rewrite [1] ("f b d", "e")
|> add_rewrite [2] ("d", "c")
|> add_rewrite [3] ("b", "a")
|> assert_simpl_info "f a c" [([], "f a c"), ([2, 3], "e")]
|> assert_simpl_info "f a d" [([], "f a d"), ([2], "f a c"),
([3], "e")]
|> assert_simpl_info "f b c" [([], "f b c"), ([3], "f a c"),
([2], "e")]
|> assert_simpl_info "f b d" [([], "f b d"), ([1], "e")]
|> assert_simpl [2] ("f a c", "f a c")
|> assert_simpl [2, 3] ("f a c", "e")
|> assert_simpl [2] ("f a d", "f a c")
|> assert_simpl [3] ("f a d", "e")
|> assert_simpl [2, 3] ("f a d", "e")
|> add_rewrite [] ("f b d", "e")
|> assert_simpl_info "f b d" [([], "e")]
|> add_term "f a c"
|> assert_equiv [2, 3] ("f a c", "f b d")
|> done "test_simplify"
val test_simplify2 =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1]), (4, [2])]
|> declare_term "[a, b, c, d]"
|> add_rewrite [3] ("b", "c")
|> add_rewrite [4] ("a", "b")
|> add_rewrite [2] ("c", "d")
|> assert_simpl_info "c" [([], "c"), ([3], "b"), ([3, 4], "a")]
|> assert_simpl_info "d" [([], "d"), ([2], "c"), ([2, 3], "b"),
([3, 4], "a")]
|> done "test_simplify2"
val test_simplify3 =
init |> fold add_prim_id [(1, []), (2, [1])]
|> declare_term "[a, b, c, d, e, g]"
|> add_rewrite [1] ("e", "c") |> add_rewrite [1] ("g", "d")
|> add_rewrite [2] ("e", "b") |> add_rewrite [2] ("g", "a")
|> add_rewrite [] ("e", "g")
|> assert_simpl_info "e" [([], "e"), ([1], "c"), ([2], "a")]
|> assert_simpl_info "g" [([], "e"), ([1], "c"), ([2], "a")]
|> done "test_simplify3"
val test_head_equivs =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "[f a b, e]"]
|> add_rewrite [1] ("f b d", "e")
|> add_rewrite [2] ("d", "c")
|> add_rewrite [3] ("b", "a")
|> assert_head_equivs ([], "c") [([], "c"), ([2], "d")]
|> assert_head_equivs ([], "f b d") [([], "f b d"), ([1], "e")]
|> assert_head_equivs ([2], "f b c") [([2], "f b c"), ([2], "e"), ([2], "f b d")]
|> assert_head_equivs ([], "f b c") [([], "f b c"), ([2], "e"), ([2], "f b d")]
|> assert_head_equivs ([], "f a c") [([], "f a c"), ([2, 3], "e"), ([2, 3], "f b d")]
|> add_term "f a c"
|> assert_head_equivs ([], "f b d") [([], "f b d"), ([1], "e"), ([2, 3], "f a c")]
|> assert_head_equivs ([], "f a c") [([], "f a c"), ([2, 3], "e"), ([2, 3], "f b d")]
|> assert_head_equivs ([], "f b c") [([], "f b c"), ([2], "e"), ([2], "f b d"), ([3], "f a c")]
|> done "test_head_equivs"
val test_head_equivs2 =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1]), (4, [2])]
|> declare_term "[a, b, c, d]"
|> add_rewrite [3] ("b", "c")
|> add_rewrite [4] ("a", "b")
|> assert_head_equivs ([], "b") [([], "b"), ([3], "c"), ([4], "a")]
|> add_rewrite [2] ("c", "d")
|> assert_head_equivs ([], "b") [([], "b"), ([3], "c"), ([4], "a"),
([2, 3], "d")]
|> done "test_head_equivs2"
val test_head_equivs3 =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "[f a b, e]"]
|> add_rewrite [] ("b", "a")
|> add_rewrite [2] ("f b d", "e")
|> add_rewrite [3] ("f a d", "e")
|> assert_head_equivs ([], "f a d") [([], "f b d"), ([2], "e"), ([3], "e"), ([], "f a d")]
|> assert_head_equivs ([], "f b d") [([], "f a d"), ([2], "e"), ([3], "e"), ([], "f b d")]
|> done "test_head_equivs3"
val test_head_equivs4 =
init |> fold add_prim_id [(1, []), (2, [1]), (3, [1])]
|> fold declare_term ["[a, b, c, d]", "[f a b, e]"]
|> add_rewrite [2] ("b", "a")
|> add_rewrite [] ("f b c", "e")
|> add_rewrite [3] ("f a c", "e")
|> assert_head_equivs ([], "f b c") [([], "f b c"), ([], "e"), ([3], "f a c"), ([2], "f a c")]
|> assert_head_equivs ([], "f a c") [([], "f a c"), ([2], "e"), ([3], "e"), ([3], "f b c"), ([2], "f b c")]
|> done "test_head_equivs4"
val test_simplify_nat =
init |> declare_term "[a::nat, b, c]"
|> add_rewrite [] ("a + b", "b + a")
|> add_rewrite [] ("a", "c")
|> assert_simpl [] ("b + c", "a + b")
|> assert_equiv [] ("b + c", "a + b")
|> done "test_simplify_nat"
end;