Theory Multiset_More

(*  Title:       More about Multisets
    Author:      Mathias Fleury <mathias.fleury at mpi-inf.mpg.de>, 2015
    Author:      Jasmin Blanchette <blanchette at in.tum.de>, 2014, 2015
    Author:      Anders Schlichtkrull <andschl at dtu.dk>, 2017
    Author:      Dmitriy Traytel <traytel at in.tum.de>, 2014
    Maintainer:  Mathias Fleury <mathias.fleury at mpi-inf.mpg.de>
*)

section ‹More about Multisets›

theory Multiset_More
  imports
    "HOL-Library.Multiset_Order"
    "HOL-Library.Sublist"
begin

text ‹
Isabelle's theory of finite multisets is not as developed as other areas, such as lists and sets.
The present theory introduces some missing concepts and lemmas. Some of it is expected to move to
Isabelle's library.
›


subsection ‹Basic Setup›

declare
  diff_single_trivial [simp]
  in_image_mset [iff]
  image_mset.compositionality [simp]

  (*To have the same rules as the set counter-part*)
  mset_subset_eqD[dest, intro?] (*@{thm subsetD}*)

  Multiset.in_multiset_in_set[simp]
  inter_add_left1[simp]
  inter_add_left2[simp]
  inter_add_right1[simp]
  inter_add_right2[simp]

  sum_mset_sum_list[simp]


subsection ‹Lemmas about Intersection, Union and Pointwise Inclusion›

lemma subset_mset_imp_subset_add_mset: "A ⊆# B  A ⊆# add_mset x B"
  by (auto simp add: subseteq_mset_def le_SucI)

lemma subset_add_mset_notin_subset_mset: A ⊆# add_mset b B  b ∉# A  A ⊆# B
  by (simp add: subset_mset.le_iff_sup)

lemma subset_msetE [elim!]: "A ⊂# B; A ⊆# B; ¬ B ⊆# A  R  R"
  by (simp add: subset_mset.less_le_not_le)

lemma Diff_triv_mset: "M ∩# N = {#}  M - N = M"
  by (metis diff_intersect_left_idem diff_zero)

lemma diff_intersect_sym_diff: "(A - B) ∩# (B - A) = {#}"
  by (rule multiset_eqI) simp

lemma subseq_mset_subseteq_mset: "subseq xs ys  mset xs ⊆# mset ys"
proof (induct xs arbitrary: ys)
  case (Cons x xs)
  note Outer_Cons = this
  then show ?case
  proof (induct ys)
    case (Cons y ys)
    have "subseq xs ys"
      by (metis Cons.prems(2) subseq_Cons' subseq_Cons2_iff)
    then show ?case
      using Cons by (metis mset.simps(2) mset_subset_eq_add_mset_cancel subseq_Cons2_iff
          subset_mset_imp_subset_add_mset)
  qed simp
qed simp

lemma finite_mset_set_inter:
  finite A  finite B  mset_set (A  B) = mset_set A ∩# mset_set B
  apply (induction A rule: finite_induct)
  subgoal by auto
  subgoal for a A
    by (cases a  B; cases a ∈# mset_set B)
      (use multi_member_split[of a mset_set B] in
        auto simp: mset_set.insert_remove)
  done


subsection ‹Lemmas about Filter and Image›

lemma count_image_mset_ge_count: "count (image_mset f A) (f b)  count A b"
  by (induction A) auto

lemma count_image_mset_inj:
  assumes inj f
  shows count (image_mset f M) (f x) = count M x
  by (induct M) (use assms in auto simp: inj_on_def)

lemma count_image_mset_le_count_inj_on:
  "inj_on f (set_mset M)  count (image_mset f M) y  count M (inv_into (set_mset M) f y)"
proof (induct M)
  case (add x M)
  note ih = this(1) and inj_xM = this(2)

  have inj_M: "inj_on f (set_mset M)"
    using inj_xM by simp

  show ?case
  proof (cases "x ∈# M")
    case x_in_M: True
    show ?thesis
    proof (cases "y = f x")
      case y_eq_fx: True
      show ?thesis
        using x_in_M ih[OF inj_M] unfolding y_eq_fx by (simp add: inj_M insert_absorb)
    next
      case y_ne_fx: False
      show ?thesis
        using x_in_M ih[OF inj_M] y_ne_fx insert_absorb by fastforce
    qed
  next
    case x_ni_M: False
    show ?thesis
    proof (cases "y = f x")
      case y_eq_fx: True
      have "f x ∉# image_mset f M"
        using x_ni_M inj_xM by force
      thus ?thesis
        unfolding y_eq_fx
        by (metis (no_types) inj_xM count_add_mset count_greater_eq_Suc_zero_iff count_inI
          image_mset_add_mset inv_into_f_f union_single_eq_member)
    next
      case y_ne_fx: False
      show ?thesis
      proof (rule ccontr)
        assume neg_conj: "¬ count (image_mset f (add_mset x M)) y
           count (add_mset x M) (inv_into (set_mset (add_mset x M)) f y)"

        have cnt_y: "count (add_mset (f x) (image_mset f M)) y = count (image_mset f M) y"
          using y_ne_fx by simp

        have "inv_into (set_mset M) f y ∈# add_mset x M 
          inv_into (set_mset (add_mset x M)) f (f (inv_into (set_mset M) f y)) =
          inv_into (set_mset M) f y"
          by (meson inj_xM inv_into_f_f)
        hence "0 < count (image_mset f (add_mset x M)) y 
          count M (inv_into (set_mset M) f y) = 0  x = inv_into (set_mset M) f y"
          using neg_conj cnt_y ih[OF inj_M]
          by (metis (no_types) count_add_mset count_greater_zero_iff count_inI f_inv_into_f
            image_mset_add_mset set_image_mset)
        thus False
          using neg_conj cnt_y x_ni_M ih[OF inj_M]
          by (metis (no_types) count_greater_zero_iff count_inI eq_iff image_mset_add_mset
            less_imp_le)
      qed
    qed
  qed
qed simp

lemma mset_filter_compl: "mset (filter p xs) + mset (filter (Not  p) xs) = mset xs"
  by (induction xs) (auto simp: ac_simps)

text ‹Near duplicate of @{thm [source] filter_eq_replicate_mset}: @{thm filter_eq_replicate_mset}.›

lemma filter_mset_eq: "filter_mset ((=) L) A = replicate_mset (count A L) L"
  by (auto simp: multiset_eq_iff)

lemma filter_mset_cong[fundef_cong]:
  assumes "M = M'" "a. a ∈# M  P a = Q a"
  shows "filter_mset P M = filter_mset Q M"
proof -
  have "M - filter_mset Q M = filter_mset (λa. ¬Q a) M"
    by (metis multiset_partition add_diff_cancel_left')
  then show ?thesis
    by (auto simp: filter_mset_eq_conv assms)
qed

lemma image_mset_filter_swap: "image_mset f {# x ∈# M. P (f x)#} = {# x ∈# image_mset f M. P x#}"
  by (induction M) auto

lemma image_mset_cong2:
  "(x. x ∈# M  f x = g x)  M = N  image_mset f M = image_mset g N"
  by (hypsubst, rule image_mset_cong)

lemma filter_mset_empty_conv: (filter_mset P M = {#}) = (L∈#M. ¬ P L)
  by (induction M) auto

lemma multiset_filter_mono2: filter_mset P A ⊆# filter_mset Q A  (a∈#A. P a  Q a)
  by (induction A) (auto intro: subset_mset.trans)

lemma image_filter_cong:
  assumes C. C ∈# M  P C  f C = g C
  shows {#f C. C ∈# {#C ∈# M. P C#}#} = {#g C | C∈# M. P C#}
  using assms by (induction M) auto

lemma image_mset_filter_swap2: {#C ∈# {#P x. x ∈# D#}. Q C #} = {#P x. x ∈# {#C| C ∈# D. Q (P C)#}#}
  by (simp add: image_mset_filter_swap)

declare image_mset_cong2 [cong]

lemma filter_mset_empty_if_finite_and_filter_set_empty:
  assumes
    "{x  X. P x} = {}" and
    "finite X"
  shows "{#x ∈# mset_set X. P x#} = {#}"
proof -
  have empty_empty: "Y. set_mset Y = {}  Y = {#}"
    by auto
  from assms have "set_mset {#x ∈# mset_set X. P x#} = {}"
    by auto
  then show ?thesis
    by (rule empty_empty)
qed


subsection ‹Lemmas about Sum›

lemma sum_image_mset_sum_map[simp]: "sum_mset (image_mset f (mset xs)) = sum_list (map f xs)"
  by (metis mset_map sum_mset_sum_list)

lemma sum_image_mset_mono:
  fixes f :: "'a  'b::canonically_ordered_monoid_add"
  assumes sub: "A ⊆# B"
  shows "(m ∈# A. f m)  (m ∈# B. f m)"
  by (metis image_mset_union le_iff_add sub subset_mset.add_diff_inverse sum_mset.union)

lemma sum_image_mset_mono_mem:
  "n ∈# M  f n  (m ∈# M. f m)" for f :: "'a  'b::canonically_ordered_monoid_add"
  using le_iff_add multi_member_split by fastforce

lemma count_sum_mset_if_1_0: count M a = (x∈#M. if x = a then 1 else 0)
  by (induction M) auto

lemma sum_mset_dvd:
  fixes k :: "'a::comm_semiring_1_cancel"
  assumes "m ∈# M. k dvd f m"
  shows "k dvd (m ∈# M. f m)"
  using assms by (induct M) auto

lemma sum_mset_distrib_div_if_dvd:
  fixes k :: "'a::unique_euclidean_semiring"
  assumes "m ∈# M. k dvd f m"
  shows "(m ∈# M. f m) div k = (m ∈# M. f m div k)"
  using assms by (induct M) (auto simp: div_plus_div_distrib_dvd_left)


subsection ‹Lemmas about Remove›

lemma set_mset_minus_replicate_mset[simp]:
  "n  count A a  set_mset (A - replicate_mset n a) = set_mset A - {a}"
  "n < count A a  set_mset (A - replicate_mset n a) = set_mset A"
  unfolding set_mset_def by (auto split: if_split simp: not_in_iff)

abbreviation removeAll_mset :: "'a  'a multiset  'a multiset" where
  "removeAll_mset C M  M - replicate_mset (count M C) C"

lemma mset_removeAll[simp, code]: "removeAll_mset C (mset L) = mset (removeAll C L)"
  by (induction L) (auto simp: ac_simps multiset_eq_iff split: if_split_asm)

lemma removeAll_mset_filter_mset: "removeAll_mset C M = filter_mset ((≠) C) M"
  by (induction M) (auto simp: ac_simps multiset_eq_iff)

abbreviation remove1_mset :: "'a  'a multiset  'a multiset" where
  "remove1_mset C M  M - {#C#}"

lemma removeAll_subseteq_remove1_mset: "removeAll_mset x M ⊆# remove1_mset x M"
  by (auto simp: subseteq_mset_def)

lemma in_remove1_mset_neq:
  assumes ab: "a  b"
  shows "a ∈# remove1_mset b C  a ∈# C"
  by (metis assms diff_single_trivial in_diffD insert_DiffM insert_noteq_member)

lemma size_mset_removeAll_mset_le_iff: "size (removeAll_mset x M) < size M  x ∈# M"
  by (auto intro: count_inI mset_subset_size simp: subset_mset_def multiset_eq_iff)

lemma size_remove1_mset_If: size (remove1_mset x M) = size M - (if x ∈# M then 1 else 0)
  by (auto simp: size_Diff_subset_Int)

lemma size_mset_remove1_mset_le_iff: "size (remove1_mset x M) < size M  x ∈# M"
  using less_irrefl
  by (fastforce intro!: mset_subset_size elim: in_countE simp: subset_mset_def multiset_eq_iff)

lemma remove_1_mset_id_iff_notin: "remove1_mset a M = M  a ∉# M"
  by (meson diff_single_trivial multi_drop_mem_not_eq)

lemma id_remove_1_mset_iff_notin: "M = remove1_mset a M  a ∉# M"
  using remove_1_mset_id_iff_notin by metis

lemma remove1_mset_eqE:
  "remove1_mset L x1 = M 
    (L ∈# x1  x1 = M + {#L#}  P) 
    (L ∉# x1  x1 = M  P) 
  P"
  by (cases "L ∈# x1") auto

lemma image_filter_ne_mset[simp]:
  "image_mset f {#x ∈# M. f x  y#} = removeAll_mset y (image_mset f M)"
  by (induction M) simp_all

lemma image_mset_remove1_mset_if:
  "image_mset f (remove1_mset a M) =
   (if a ∈# M then remove1_mset (f a) (image_mset f M) else image_mset f M)"
  by (auto simp: image_mset_Diff)

lemma filter_mset_neq: "{#x ∈# M. x  y#} = removeAll_mset y M"
  by (metis add_diff_cancel_left' filter_eq_replicate_mset multiset_partition)

lemma filter_mset_neq_cond: "{#x ∈# M. P x  x  y#} = removeAll_mset y {# x∈#M. P x#}"
  by (metis filter_filter_mset filter_mset_neq)

lemma remove1_mset_add_mset_If:
  "remove1_mset L (add_mset L' C) = (if L = L' then C else remove1_mset L C + {#L'#})"
  by (auto simp: multiset_eq_iff)

lemma minus_remove1_mset_if:
  "A - remove1_mset b B = (if b ∈# B  b ∈# A  count A b  count B b then {#b#} + (A - B) else A - B)"
  by (auto simp: multiset_eq_iff count_greater_zero_iff[symmetric]
    simp del: count_greater_zero_iff)

lemma add_mset_eq_add_mset_ne:
  "a  b  add_mset a A = add_mset b B  a ∈# B  b ∈# A  A = add_mset b (B - {#a#})"
  by (metis (no_types, lifting) diff_single_eq_union diff_union_swap multi_self_add_other_not_self
    remove_1_mset_id_iff_notin union_single_eq_diff)

lemma add_mset_eq_add_mset: add_mset a M = add_mset b M' 
  (a = b  M = M')  (a  b  b ∈# M  add_mset a (M - {#b#}) = M')
  by (metis add_mset_eq_add_mset_ne add_mset_remove_trivial union_single_eq_member)

(* TODO move to Multiset: could replace add_mset_remove_trivial_eq? *)
lemma add_mset_remove_trivial_iff: N = add_mset a (N - {#b#})  a ∈# N  a = b
  by (metis add_left_cancel add_mset_remove_trivial insert_DiffM2 single_eq_single
      size_mset_remove1_mset_le_iff union_single_eq_member)

lemma trivial_add_mset_remove_iff: add_mset a (N - {#b#}) = N  a ∈# N  a = b
  by (subst eq_commute) (fact add_mset_remove_trivial_iff)

lemma remove1_single_empty_iff[simp]: remove1_mset L {#L'#} = {#}  L = L'
  using add_mset_remove_trivial_iff by fastforce

lemma add_mset_less_imp_less_remove1_mset:
  assumes xM_lt_N: "add_mset x M < N"
  shows "M < remove1_mset x N"
proof -
  have "M < N"
    using assms le_multiset_right_total mset_le_trans by blast
  then show ?thesis
    by (metis add_less_cancel_right add_mset_add_single diff_single_trivial insert_DiffM2 xM_lt_N)
qed

lemma remove_diff_multiset[simp]: x13 ∉# A  A - add_mset x13 B = A - B
  by (metis diff_intersect_left_idem inter_add_right1)

lemma removeAll_notin: a ∉# A  removeAll_mset a A = A
  using count_inI by force

lemma mset_drop_upto: mset (drop a N) = {#N!i. i ∈# mset_set {a..<length N}#}
proof (induction N arbitrary: a)
  case Nil
  then show ?case by simp
next
  case (Cons c N)
  have upt: {0..<Suc (length N)} = insert 0 {1..<Suc (length N)}
    by auto
  then have H: mset_set {0..<Suc (length N)} = add_mset 0 (mset_set {1..<Suc (length N)})
    unfolding upt by auto
  have mset_case_Suc: {#case x of 0  c | Suc x  N ! x . x ∈# mset_set {Suc a..<Suc b}#} =
    {#N ! (x-1) . x ∈# mset_set {Suc a..<Suc b}#} for a b
    by (rule image_mset_cong) (auto split: nat.splits)
  have Suc_Suc: {Suc a..<Suc b} = Suc ` {a..<b} for a b
    by auto
  then have mset_set_Suc_Suc: mset_set {Suc a..<Suc b} = {#Suc n. n ∈# mset_set {a..<b}#} for a b
    unfolding Suc_Suc by (subst image_mset_mset_set[symmetric]) auto
  have *: {#N ! (x-Suc 0) . x ∈# mset_set {Suc a..<Suc b}#} = {#N ! x . x ∈# mset_set {a..<b}#}
    for a b
    by (auto simp add: mset_set_Suc_Suc)
  show ?case
    apply (cases a)
    using Cons[of 0] Cons by (auto simp: nth_Cons drop_Cons H mset_case_Suc *)
qed


subsection ‹Lemmas about Replicate›

lemma replicate_mset_minus_replicate_mset_same[simp]:
  "replicate_mset m x - replicate_mset n x = replicate_mset (m - n) x"
  by (induct m arbitrary: n, simp, metis left_diff_repeat_mset_distrib' repeat_mset_replicate_mset)

lemma replicate_mset_subset_iff_lt[simp]: "replicate_mset m x ⊂# replicate_mset n x  m < n"
  by (induct n m rule: diff_induct) (auto intro: subset_mset.gr_zeroI)

lemma replicate_mset_subseteq_iff_le[simp]: "replicate_mset m x ⊆# replicate_mset n x  m  n"
  by (induct n m rule: diff_induct) auto

lemma replicate_mset_lt_iff_lt[simp]: "replicate_mset m x < replicate_mset n x  m < n"
  by (induct n m rule: diff_induct) (auto intro: subset_mset.gr_zeroI gr_zeroI)

lemma replicate_mset_le_iff_le[simp]: "replicate_mset m x  replicate_mset n x  m  n"
  by (induct n m rule: diff_induct) auto

lemma replicate_mset_eq_iff[simp]:
  "replicate_mset m x = replicate_mset n y  m = n  (m  0  x = y)"
  by (cases m; cases n; simp)
    (metis in_replicate_mset insert_noteq_member size_replicate_mset union_single_eq_diff)

lemma replicate_mset_plus: "replicate_mset (a + b) C = replicate_mset a C + replicate_mset b C"
  by (induct a) (auto simp: ac_simps)

lemma mset_replicate_replicate_mset: "mset (replicate n L) = replicate_mset n L"
  by (induction n) auto

lemma set_mset_single_iff_replicate_mset: "set_mset U = {a}  (n > 0. U = replicate_mset n a)"
  by (rule, metis count_greater_zero_iff count_replicate_mset insertI1 multi_count_eq singletonD
    zero_less_iff_neq_zero, force)

lemma ex_replicate_mset_if_all_elems_eq:
  assumes "x ∈# M. x = y"
  shows "n. M = replicate_mset n y"
  using assms by (metis count_replicate_mset mem_Collect_eq multiset_eqI neq0_conv set_mset_def)


subsection ‹Multiset and Set Conversions›

lemma count_mset_set_if: "count (mset_set A) a = (if a  A  finite A then 1 else 0)"
  by auto

lemma mset_set_set_mset_empty_mempty[iff]: "mset_set (set_mset D) = {#}  D = {#}"
  by (simp add: mset_set_empty_iff)

lemma count_mset_set_le_one: "count (mset_set A) x  1"
  by (simp add: count_mset_set_if)

lemma mset_set_set_mset_subseteq[simp]: "mset_set (set_mset A) ⊆# A"
  by (simp add: mset_set_set_mset_msubset)

lemma mset_sorted_list_of_set[simp]: "mset (sorted_list_of_set A) = mset_set A"
  by (metis mset_sorted_list_of_multiset sorted_list_of_mset_set)

lemma sorted_sorted_list_of_multiset[simp]:
  "sorted (sorted_list_of_multiset (M :: 'a::linorder multiset))"
  by (metis mset_sorted_list_of_multiset sorted_list_of_multiset_mset sorted_sort)

lemma mset_take_subseteq: "mset (take n xs) ⊆# mset xs"
  apply (induct xs arbitrary: n)
   apply simp
  by (case_tac n) simp_all

lemma sorted_list_of_multiset_eq_Nil[simp]: "sorted_list_of_multiset M = []  M = {#}"
  by (metis mset_sorted_list_of_multiset sorted_list_of_multiset_empty)


subsection ‹Duplicate Removal›

(* TODO: use abbreviation? *)
definition remdups_mset :: "'v multiset  'v multiset" where
  "remdups_mset S = mset_set (set_mset S)"

lemma set_mset_remdups_mset[simp]: set_mset (remdups_mset A) = set_mset A
  unfolding remdups_mset_def by auto

lemma count_remdups_mset_eq_1: "a ∈# remdups_mset A  count (remdups_mset A) a = 1"
  unfolding remdups_mset_def by (auto simp: count_eq_zero_iff intro: count_inI)

lemma remdups_mset_empty[simp]: "remdups_mset {#} = {#}"
  unfolding remdups_mset_def by auto

lemma remdups_mset_singleton[simp]: "remdups_mset {#a#} = {#a#}"
  unfolding remdups_mset_def by auto

lemma remdups_mset_eq_empty[iff]: "remdups_mset D = {#}  D = {#}"
  unfolding remdups_mset_def by blast

lemma remdups_mset_singleton_sum[simp]:
  "remdups_mset (add_mset a A) = (if a ∈# A then remdups_mset A else add_mset a (remdups_mset A))"
  unfolding remdups_mset_def by (simp_all add: insert_absorb)

lemma mset_remdups_remdups_mset[simp]: "mset (remdups D) = remdups_mset (mset D)"
  by (induction D) (auto simp add: ac_simps)

declare mset_remdups_remdups_mset[symmetric, code]

lemma count_remdups_mset_If: count (remdups_mset A) a = (if a ∈# A then 1 else 0)
  unfolding remdups_mset_def by auto

lemma notin_add_mset_remdups_mset:
  a ∉# A  add_mset a (remdups_mset A) = remdups_mset (add_mset a A)
  by auto


subsection ‹Repeat Operation›

lemma repeat_mset_compower: "repeat_mset n A = (((+) A) ^^ n) {#}"
  by (induction n) auto

lemma repeat_mset_prod: "repeat_mset (m * n) A = (((+) (repeat_mset n A)) ^^ m) {#}"
  by (induction m) (auto simp: repeat_mset_distrib)


subsection ‹Cartesian Product›

text ‹Definition of the cartesian products over multisets. The construction mimics of the cartesian
  product on sets and use the same theorem names (adding only the suffix _mset› to Sigma
  and Times). See file @{file ‹~~/src/HOL/Product_Type.thy›}

definition Sigma_mset :: "'a multiset  ('a  'b multiset)  ('a × 'b) multiset" where
  "Sigma_mset A B  # {#{#(a, b). b ∈# B a#}. a ∈# A #}"

abbreviation Times_mset :: "'a multiset  'b multiset  ('a × 'b) multiset" (infixr "×#" 80) where
  "Times_mset A B  Sigma_mset A (λ_. B)"

hide_const (open) Times_mset

text ‹Contrary to the set version @{term SIGMA x:A. B}, we use the non-ASCII symbol ∈#›.›

syntax
  "_Sigma_mset" :: "[pttrn, 'a multiset, 'b multiset] => ('a * 'b) multiset"
  ("(3SIGMAMSET _∈#_./ _)" [0, 0, 10] 10)
translations
  "SIGMAMSET x∈#A. B" == "CONST Sigma_mset A (λx. B)"

text ‹Link between the multiset and the set cartesian product:›

lemma Times_mset_Times: "set_mset (A ×# B) = set_mset A × set_mset B"
  unfolding Sigma_mset_def by auto

lemma Sigma_msetI [intro!]: "a ∈# A; b ∈# B a  (a, b) ∈# Sigma_mset A B"
  by (unfold Sigma_mset_def) auto

lemma Sigma_msetE[elim!]: "c ∈# Sigma_mset A B; x y. x ∈# A; y ∈# B x; c = (x, y)  P  P"
  by (unfold Sigma_mset_def) auto

text ‹Elimination of @{term "(a, b) ∈# A ×# B"} -- introduces no eigenvariables.›

lemma Sigma_msetD1: "(a, b) ∈# Sigma_mset A B  a ∈# A"
  by blast

lemma Sigma_msetD2: "(a, b) ∈# Sigma_mset A B  b ∈# B a"
  by blast

lemma Sigma_msetE2: "(a, b) ∈# Sigma_mset A B; a ∈# A; b ∈# B a  P  P"
  by blast

lemma Sigma_mset_cong:
  "A = B; x. x ∈# B  C x = D x  (SIGMAMSET x ∈# A. C x) = (SIGMAMSET x ∈# B. D x)"
  by (metis (mono_tags, lifting) Sigma_mset_def image_mset_cong)

lemma count_sum_mset: "count (# M) b = (P ∈# M. count P b)"
  by (induction M) auto

lemma Sigma_mset_plus_distrib1[simp]: "Sigma_mset (A + B) C = Sigma_mset A C + Sigma_mset B C"
  unfolding Sigma_mset_def by auto

lemma Sigma_mset_plus_distrib2[simp]:
  "Sigma_mset A (λi. B i + C i) = Sigma_mset A B + Sigma_mset A C"
  unfolding Sigma_mset_def by (induction A) (auto simp: multiset_eq_iff)

lemma Times_mset_single_left: "{#a#} ×# B = image_mset (Pair a) B"
  unfolding Sigma_mset_def by auto

lemma Times_mset_single_right: "A ×# {#b#} = image_mset (λa. Pair a b) A"
  unfolding Sigma_mset_def by (induction A) auto

lemma Times_mset_single_single[simp]: "{#a#} ×# {#b#} = {#(a, b)#}"
  unfolding Sigma_mset_def by simp

lemma count_image_mset_Pair:
  "count (image_mset (Pair a) B) (x, b) = (if x = a then count B b else 0)"
  by (induction B) auto

lemma count_Sigma_mset: "count (Sigma_mset A B) (a, b) = count A a * count (B a) b"
  by (induction A) (auto simp: Sigma_mset_def count_image_mset_Pair)

lemma Sigma_mset_empty1[simp]: "Sigma_mset {#} B = {#}"
  unfolding Sigma_mset_def by auto

lemma Sigma_mset_empty2[simp]: "A ×# {#} = {#}"
  by (auto simp: multiset_eq_iff count_Sigma_mset)

lemma Sigma_mset_mono:
  assumes "A ⊆# C" and "x. x ∈# A  B x ⊆# D x"
  shows "Sigma_mset A B ⊆# Sigma_mset C D"
proof -
  have "count A a * count (B a) b  count C a * count (D a) b" for a b
    using assms unfolding subseteq_mset_def by (metis count_inI eq_iff mult_eq_0_iff mult_le_mono)
  then show ?thesis
    by (auto simp: subseteq_mset_def count_Sigma_mset)
qed

lemma mem_Sigma_mset_iff[iff]: "((a,b) ∈# Sigma_mset A B) = (a ∈# A  b ∈# B a)"
  by blast

lemma mem_Times_mset_iff: "x ∈# A ×# B  fst x ∈# A  snd x ∈# B"
  by (induct x) simp

lemma Sigma_mset_empty_iff: "(SIGMAMSET i∈#I. X i) = {#}  (i∈#I. X i = {#})"
  by (auto simp: Sigma_mset_def)

lemma Times_mset_subset_mset_cancel1: "x ∈# A  (A ×# B ⊆# A ×# C) = (B ⊆# C)"
  by (auto simp: subseteq_mset_def count_Sigma_mset)

lemma Times_mset_subset_mset_cancel2: "x ∈# C  (A ×# C ⊆# B ×# C) = (A ⊆# B)"
  by (auto simp: subseteq_mset_def count_Sigma_mset)

lemma Times_mset_eq_cancel2: "x ∈# C  (A ×# C = B ×# C) = (A = B)"
  by (auto simp: multiset_eq_iff count_Sigma_mset dest!: in_countE)

lemma split_paired_Ball_mset_Sigma_mset[simp]:
  "(z∈#Sigma_mset A B. P z)  (x∈#A. y∈#B x. P (x, y))"
  by blast

lemma split_paired_Bex_mset_Sigma_mset[simp]:
  "(z∈#Sigma_mset A B. P z)  (x∈#A. y∈#B x. P (x, y))"
  by blast

lemma sum_mset_if_eq_constant:
  "(x∈#M. if a = x then (f x) else 0) = (((+) (f a)) ^^ (count M a)) 0"
  by (induction M) (auto simp: ac_simps)

lemma iterate_op_plus: "(((+) k) ^^ m) 0 = k * m"
  by (induction m) auto

lemma untion_image_mset_Pair_distribute:
  "#{#image_mset (Pair x) (C x). x ∈# J - I#} =
   # {#image_mset (Pair x) (C x). x ∈# J#} - #{#image_mset (Pair x) (C x). x ∈# I#}"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    iterate_op_plus diff_mult_distrib2)

lemma Sigma_mset_Un_distrib1: "Sigma_mset (I ∪# J) C = Sigma_mset I C ∪# Sigma_mset J C"
  by (auto simp add: Sigma_mset_def union_mset_def untion_image_mset_Pair_distribute)

lemma Sigma_mset_Un_distrib2: "(SIGMAMSET i∈#I. A i ∪# B i) = Sigma_mset I A ∪# Sigma_mset I B"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def diff_mult_distrib2 iterate_op_plus max_def not_in_iff)

lemma Sigma_mset_Int_distrib1: "Sigma_mset (I ∩# J) C = Sigma_mset I C ∩# Sigma_mset J C"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def iterate_op_plus min_def not_in_iff)

lemma Sigma_mset_Int_distrib2: "(SIGMAMSET i∈#I. A i ∩# B i) = Sigma_mset I A ∩# Sigma_mset I B"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def iterate_op_plus min_def not_in_iff)

lemma Sigma_mset_Diff_distrib1: "Sigma_mset (I - J) C = Sigma_mset I C - Sigma_mset J C"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def iterate_op_plus min_def not_in_iff diff_mult_distrib2)

lemma Sigma_mset_Diff_distrib2: "(SIGMAMSET i∈#I. A i - B i) = Sigma_mset I A - Sigma_mset I B"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def iterate_op_plus min_def not_in_iff diff_mult_distrib)

lemma Sigma_mset_Union: "Sigma_mset (#X) B = (# (image_mset (λA. Sigma_mset A B) X))"
  by (auto simp: multiset_eq_iff count_sum_mset count_image_mset_Pair sum_mset_if_eq_constant
    Sigma_mset_def iterate_op_plus min_def not_in_iff sum_mset_distrib_left)

lemma Times_mset_Un_distrib1: "(A ∪# B) ×# C = A ×# C ∪# B ×# C"
  by (fact Sigma_mset_Un_distrib1)

lemma Times_mset_Int_distrib1: "(A ∩# B) ×# C = A ×# C ∩# B ×# C"
  by (fact Sigma_mset_Int_distrib1)

lemma Times_mset_Diff_distrib1: "(A - B) ×# C = A ×# C - B ×# C"
  by (fact Sigma_mset_Diff_distrib1)

lemma Times_mset_empty[simp]: "A ×# B = {#}  A = {#}  B = {#}"
  by (auto simp: Sigma_mset_empty_iff)

lemma Times_insert_left: "A ×# add_mset x B = A ×# B + image_mset (λa. Pair a x) A"
  unfolding add_mset_add_single[of x B] Sigma_mset_plus_distrib2
  by (simp add: Times_mset_single_right)

lemma Times_insert_right: "add_mset a A ×# B = A ×# B + image_mset (Pair a) B"
  unfolding add_mset_add_single[of a A] Sigma_mset_plus_distrib1
  by (simp add: Times_mset_single_left)

lemma fst_image_mset_times_mset [simp]:
  "image_mset fst (A ×# B) = (if B = {#} then {#} else repeat_mset (size B) A)"
  by (induct B) (auto simp: Times_mset_single_right ac_simps Times_insert_left)

lemma snd_image_mset_times_mset [simp]:
  "image_mset snd (A ×# B) = (if A = {#} then {#} else repeat_mset (size A) B)"
  by (induct B) (auto simp add: Times_mset_single_right Times_insert_left image_mset_const_eq)

lemma product_swap_mset: "image_mset prod.swap (A ×# B) = B ×# A"
  by (induction A) (auto simp add: Times_mset_single_left Times_mset_single_right
      Times_insert_right Times_insert_left)

context
begin

qualified definition product_mset :: "'a multiset  'b multiset  ('a × 'b) multiset" where
  [code_abbrev]: "product_mset A B = A ×# B"

lemma member_product_mset: "x ∈# product_mset A B  x ∈# A ×# B"
  by (simp add: Multiset_More.product_mset_def)

end

lemma count_Sigma_mset_abs_def: "count (Sigma_mset A B) = (λ(a, b)  count A a * count (B a) b)"
  by (auto simp: fun_eq_iff count_Sigma_mset)

lemma Times_mset_image_mset1: "image_mset f A ×# B = image_mset (λ(a, b). (f a, b)) (A ×# B)"
  by (induct B) (auto simp: Times_insert_left)

lemma Times_mset_image_mset2: "A ×# image_mset f B = image_mset (λ(a, b). (a, f b)) (A ×# B)"
  by (induct A) (auto simp: Times_insert_right)

lemma sum_le_singleton: "A  {x}  sum f A = (if x  A then f x else 0)"
  by (auto simp: subset_singleton_iff elim: finite_subset)

lemma Times_mset_assoc: "(A ×# B) ×# C = image_mset (λ(a, b, c). ((a, b), c)) (A ×# B ×# C)"
  by (auto simp: multiset_eq_iff count_Sigma_mset count_image_mset vimage_def Times_mset_Times
      Int_commute count_eq_zero_iff intro!: trans[OF _ sym[OF sum_le_singleton[of _ "(_, _, _)"]]]
      cong: sum.cong if_cong)


subsection ‹Transfer Rules›

lemma plus_multiset_transfer[transfer_rule]:
  "(rel_fun (rel_mset R) (rel_fun (rel_mset R) (rel_mset R))) (+) (+)"
  by (unfold rel_fun_def rel_mset_def)
    (force dest: list_all2_appendI intro: exI[of _ "_ @ _"] conjI[rotated])

lemma minus_multiset_transfer[transfer_rule]:
  assumes [transfer_rule]: "bi_unique R"
  shows "(rel_fun (rel_mset R) (rel_fun (rel_mset R) (rel_mset R))) (-) (-)"
proof (unfold rel_fun_def rel_mset_def, safe)
  fix xs ys xs' ys'
  assume [transfer_rule]: "list_all2 R xs ys" "list_all2 R xs' ys'"
  have "list_all2 R (fold remove1 xs' xs) (fold remove1 ys' ys)"
    by transfer_prover
  moreover have "mset (fold remove1 xs' xs) = mset xs - mset xs'"
    by (induct xs' arbitrary: xs) auto
  moreover have "mset (fold remove1 ys' ys) = mset ys - mset ys'"
    by (induct ys' arbitrary: ys) auto
  ultimately show "xs'' ys''.
    mset xs'' = mset xs - mset xs'  mset ys'' = mset ys - mset ys'  list_all2 R xs'' ys''"
    by blast
qed

declare rel_mset_Zero[transfer_rule]

lemma count_transfer[transfer_rule]:
  assumes "bi_unique R"
  shows "(rel_fun (rel_mset R) (rel_fun R (=))) count count"
unfolding rel_fun_def rel_mset_def proof safe
  fix x y xs ys
  assume "list_all2 R xs ys" "R x y"
  then show "count (mset xs) x = count (mset ys) y"
  proof (induct xs ys rule: list.rel_induct)
    case (Cons x' xs y' ys)
    then show ?case
      using assms unfolding bi_unique_alt_def2 by (auto simp: rel_fun_def)
  qed simp
qed

lemma subseteq_multiset_transfer[transfer_rule]:
  assumes [transfer_rule]: "bi_unique R" "right_total R"
  shows "(rel_fun (rel_mset R) (rel_fun (rel_mset R) (=)))
    (λM N. filter_mset (Domainp R) M ⊆# filter_mset (Domainp R) N) (⊆#)"
proof -
  have count_filter_mset_less:
    "(a. count (filter_mset (Domainp R) M) a  count (filter_mset (Domainp R) N) a) 
     (a  {x. Domainp R x}. count M a  count N a)" for M and N by auto
  show ?thesis unfolding subseteq_mset_def count_filter_mset_less
    by transfer_prover
qed

lemma sum_mset_transfer[transfer_rule]:
  "R 0 0  rel_fun R (rel_fun R R) (+) (+)  (rel_fun (rel_mset R) R) sum_mset sum_mset"
  using sum_list_transfer[of R] unfolding rel_fun_def rel_mset_def by auto

lemma Sigma_mset_transfer[transfer_rule]:
  "(rel_fun (rel_mset R) (rel_fun (rel_fun R (rel_mset S)) (rel_mset (rel_prod R S))))
     Sigma_mset Sigma_mset"
  by (unfold Sigma_mset_def) transfer_prover


subsection ‹Even More about Multisets›

subsubsection ‹Multisets and Functions›

lemma range_image_mset:
  assumes "set_mset Ds  range f"
  shows "Ds  range (image_mset f)"
proof -
  have "D. D ∈# Ds  (C. f C = D)"
    using assms by blast
  then obtain f_i where
    f_p: "D. D ∈# Ds  (f (f_i D) = D)"
    by metis
  define Cs where
    "Cs  image_mset f_i Ds"
  from f_p Cs_def have "image_mset f Cs = Ds"
    by auto
  then show ?thesis
    by blast
qed


subsubsection ‹Multisets and Lists›

lemma length_sorted_list_of_multiset[simp]: "length (sorted_list_of_multiset A) = size A"
  by (metis mset_sorted_list_of_multiset size_mset)

definition list_of_mset :: "'a multiset  'a list" where
  "list_of_mset m = (SOME l. m = mset l)"

lemma list_of_mset_exi: "l. m = mset l"
  using ex_mset by metis

lemma mset_list_of_mset[simp]: "mset (list_of_mset m) = m"
  by (metis (mono_tags, lifting) ex_mset list_of_mset_def someI_ex)

lemma length_list_of_mset[simp]: "length (list_of_mset A) = size A"
  unfolding list_of_mset_def by (metis (mono_tags) ex_mset size_mset someI_ex)

lemma range_mset_map:
  assumes "set_mset Ds  range f"
  shows "Ds  range (λCl. mset (map f Cl))"
proof -
  have "Ds  range (image_mset f)"
    by (simp add: assms range_image_mset)
  then obtain Cs where Cs_p: "image_mset f Cs = Ds"
    by auto
  define Cl where "Cl = list_of_mset Cs"
  then have "mset Cl = Cs"
    by auto
  then have "image_mset f (mset Cl) = Ds"
    using Cs_p by auto
  then have "mset (map f Cl) = Ds"
    by auto
  then show ?thesis
    by auto
qed

lemma list_of_mset_empty[iff]: "list_of_mset m = []  m = {#}"
  by (metis (mono_tags, lifting) ex_mset list_of_mset_def mset_zero_iff_right someI_ex)

lemma in_mset_conv_nth: "(x ∈# mset xs) = (i<length xs. xs ! i = x)"
  by (auto simp: in_set_conv_nth)

lemma in_mset_sum_list:
  assumes "L ∈# LL"
  assumes "LL  set Ci"
  shows "L ∈# sum_list Ci"
  using assms by (induction Ci) auto

lemma in_mset_sum_list2:
  assumes "L ∈# sum_list Ci"
  obtains LL where
    "LL  set Ci"
    "L ∈# LL"
  using assms by (induction Ci) auto

(* TODO: Make [simp]. *)
lemma in_mset_sum_list_iff: "a ∈# sum_list 𝒜  (A  set 𝒜. a ∈# A)"
  by (metis in_mset_sum_list in_mset_sum_list2)

lemma subseteq_list_Union_mset:
  assumes "length Ci = n"
  assumes "length CAi = n"
  assumes "i<n.  Ci ! i ⊆# CAi ! i "
  shows "# (mset Ci) ⊆# # (mset CAi)"
  using assms proof (induction n arbitrary: Ci CAi)
  case 0
  then show ?case by auto
next
  case (Suc n)
  from Suc have "i<n. tl Ci ! i ⊆# tl CAi ! i"
    by (simp add: nth_tl)
  hence "#(mset (tl Ci)) ⊆# #(mset (tl CAi))" using Suc by auto
  moreover
  have "hd Ci ⊆# hd CAi" using Suc
    by (metis hd_conv_nth length_greater_0_conv zero_less_Suc)
  ultimately
  show "#(mset Ci) ⊆# #(mset CAi)"
    using Suc by (cases Ci; cases CAi) (auto intro: subset_mset.add_mono)
qed

lemma same_mset_distinct_iff:
  mset M = mset M'  distinct M  distinct M'
  by (fact mset_eq_imp_distinct_iff)


subsubsection ‹More on Multisets and Functions›

lemma subseteq_mset_size_eql: "X ⊆# Y  size Y = size X  X = Y"
  using mset_subset_size subset_mset_def by fastforce

lemma image_mset_of_subset_list:
  assumes "image_mset η C' = mset lC"
  shows "qC'. map η qC' = lC  mset qC' = C'"
  using assms apply (induction lC arbitrary: C')
  subgoal by simp
  subgoal by (fastforce dest!: msed_map_invR intro: exI[of _ _ # _])
  done

lemma image_mset_of_subset:
  assumes "A ⊆# image_mset η C'"
  shows "A'. image_mset η A' = A  A' ⊆# C'"
proof -
  define C where "C = image_mset η C'"

  define lA where "lA = list_of_mset A"
  define lD where "lD = list_of_mset (C-A)"
  define lC where "lC = lA @ lD"

  have "mset lC = C"
    using C_def assms unfolding lD_def lC_def lA_def by auto
  then have "qC'. map η qC' = lC  mset qC' = C'"
    using assms image_mset_of_subset_list unfolding C_def by metis
  then obtain qC' where qC'_p: "map η qC' = lC  mset qC' = C'"
    by auto
  let ?lA' = "take (length lA) qC'"
  have m: "map η ?lA' = lA"
    using qC'_p lC_def
    by (metis append_eq_conv_conj take_map)
  let ?A' = "mset ?lA'"

  have "image_mset η ?A' = A"
    using m using lA_def
    by (metis (full_types) ex_mset list_of_mset_def mset_map someI_ex)
  moreover have "?A' ⊆# C'"
    using qC'_p unfolding lA_def
    using mset_take_subseteq by blast
  ultimately show ?thesis by blast
qed

lemma all_the_same: "x ∈# X. x = y  card (set_mset X)  Suc 0"
  by (metis card.empty card.insert card_mono finite.intros(1) finite_insert le_SucI singletonI subsetI)

lemma Melem_subseteq_Union_mset[simp]:
  assumes "x ∈# T"
  shows "x ⊆# #T"
  using assms sum_mset.remove by force

lemma Melem_subset_eq_sum_list[simp]:
  assumes "x ∈# mset T"
  shows "x ⊆# sum_list T"
  using assms by (metis mset_subset_eq_add_left sum_mset.remove sum_mset_sum_list)

lemma less_subset_eq_Union_mset[simp]:
  assumes "i < length CAi"
  shows "CAi ! i ⊆# #(mset CAi)"
proof -
  from assms have "CAi ! i ∈# mset CAi"
    by auto
  then show ?thesis
    by auto
qed

lemma less_subset_eq_sum_list[simp]:
  assumes "i < length CAi"
  shows "CAi ! i ⊆# sum_list CAi"
proof -
  from assms have "CAi ! i ∈# mset CAi"
    by auto
  then show ?thesis
    by auto
qed


subsubsection ‹More on Multiset Order›

lemma less_multiset_doubletons:
  assumes
    "y < t  y < s"  
    "x < t  x < s" 
  shows 
    "{#y, x#} < {#t, s#}" 
  unfolding less_multisetDM
proof (intro exI)
  let ?X = "{#t, s#}"
  let ?Y = "{#y, x#}"
  show "?X  {#}  ?X ⊆# {#t, s#}  {#y, x#} = {#t, s#} - ?X + ?Y
     (k. k ∈# ?Y  (a. a ∈# ?X  k < a))"
    using add_eq_conv_diff assms by auto
qed

end