Theory Partial_Inj

section ‹ Partial Injections ›

theory Partial_Inj
  imports Partial_Fun
begin

typedef ('a, 'b) pinj = "{f :: ('a, 'b) pfun. pfun_inj f}" 
  morphisms pfun_of_pinj pinj_of_pfun 
  by (auto intro: pfun_inj_empty)

lemma pinj_eq_pfun: "f = g  pfun_of_pinj f = pfun_of_pinj g"
  by (simp add: pfun_of_pinj_inject)

lemma pfun_inj_pinj [simp]: "pfun_inj (pfun_of_pinj f)"
  using pfun_of_pinj by auto

type_notation pinj (infixr "" 1)

setup_lifting type_definition_pinj

lift_definition pinv :: "'a  'b  'b  'a" is pfun_inv
  by (simp add: pfun_inj_inv)

unbundle lattice_syntax

instantiation pinj :: (type, type) bot
begin
  lift_definition bot_pinj :: "('a, 'b) pinj" is ""
    by simp
instance ..
end

abbreviation pinj_empty :: "('a, 'b) pinj" ("{}ρ") where "{}ρ  "

lift_definition pinj_app :: "('a, 'b) pinj  'a  'b" ("_'(_')ρ" [999,0] 999) 
is "pfun_app" .

text ‹ Adding a maplet to a partial injection requires that we remove any other maplet that points
  to the value @{term v}, to preserve injectivity. ›

lift_definition pinj_upd :: "('a, 'b) pinj  'a  'b  ('a, 'b) pinj"
is "λ f k v. pfun_upd (f p (- {v})) k v"
  by (simp add: pfun_inj_rres pfun_inj_upd)

lift_definition pidom :: "'a  'b  'a set" is pdom .

lift_definition piran :: "'a  'b  'b set" is pran .

lift_definition pinj_dres :: "'a set  ('a, 'b) pinj  ('a, 'b) pinj" (infixr "ρ" 85) is pdom_res
  by (simp add: pfun_inj_dres)

lift_definition pinj_rres :: "('a, 'b) pinj  'b set  ('a, 'b) pinj" (infixl "ρ" 86) is pran_res
  by (simp add: pfun_inj_rres)

lift_definition pinj_comp :: "'b  'c  'a  'b  'a  'c" (infixl "ρ" 55) is "(∘p)"
  by (simp add: pfun_inj_comp)

syntax
  "_PinjUpd"  :: "[('a, 'b) pinj, maplets] => ('a, 'b) pinj" ("_'(_')ρ" [900,0]900)
  "_Pinj"     :: "maplets => ('a, 'b) pinj"            ("(1{_}ρ)")

translations
  "_PinjUpd m (_Maplets xy ms)"  == "_PinjUpd (_PinjUpd m xy) ms"
  "_PinjUpd m (_maplet  x y)"    == "CONST pinj_upd m x y"
  "_Pinj ms"                     => "_PinjUpd (CONST pempty) ms"
  "_Pinj (_Maplets ms1 ms2)"     <= "_PinjUpd (_Pinj ms1) ms2"
  "_Pinj ms"                     <= "_PinjUpd (CONST pempty) ms"

lemma pinj_app_upd [simp]: "(f(k  v)ρ)(x)ρ = (if (k = x) then v else (f ρ (-{v})) (x)ρ)"
  by (transfer, simp)

lemma pinj_eq_iff: "f = g  (pidom(f) = pidom(g)  ( xpidom(f). f(x)ρ = g(x)ρ))"
  by (transfer, simp add: pfun_eq_iff)

lemma pinv_pempty [simp]: "pinv {}ρ = {}ρ"
  by (transfer, simp)

lemma pinv_pinj_upd [simp]: "pinv (f(x  y)ρ) = (pinv ((-{x}) ρ f))(y  x)ρ"
  by (transfer, subst pfun_inv_upd, simp_all add: pfun_inj_dres pfun_inj_rres  pfun_inv_rres pdres_rres_commute, simp add: pfun_inv_dres)

lemma pinv_pinv: "pinv (pinv f) = f"
  by (transfer, simp add: pfun_inj_inv_inv)

lemma pinv_pcomp: "pinv (f ρ g) = pinv g ρ pinv f"
  by (transfer, simp add: pfun_eq_graph pfun_graph_pfun_inv pfun_graph_comp pfun_inj_comp converse_relcomp)

lemmas pidom_empty [simp] = pdom_zero[Transfer.transferred]
lemma piran_zero [simp]: "piran {}ρ = {}" by (transfer, simp)

lemmas pinj_dres_empty [simp] = pdom_res_zero[Transfer.transferred]
lemmas pinj_rres_empty [simp] = pran_res_zero[Transfer.transferred]

lemmas pidom_res_empty [simp] = pdom_res_empty[Transfer.transferred]
lemmas piran_res_empty [simp] = pran_res_empty[Transfer.transferred]

lemma pidom_res_upd: "A ρ f(k  v)ρ = (if k  A then (A ρ f)(k  v)ρ else A ρ (f ρ (- {v})))"
  by (transfer, simp, metis pdom_res_swap)

lemma piran_res_upd: "f(x  v)ρ ρ A = (if v  A then (f ρ A)(x  v)ρ else ((- {x}) ρ f) ρ A)"
  by (transfer, simp add: inf.commute)
     (metis (no_types, opaque_lifting) ComplI Compl_Un double_compl insert_absorb insert_is_Un pdom_res_swap pran_res_twice)

lemma pinj_upd_with_dres_rres: "((-{x}) ρ f ρ (-{y}))(x  y)ρ = f(x  y)ρ"
  by (transfer, simp add: pdom_res_swap)

lemma pidres_twice: "A ρ B ρ f = (A  B) ρ f"
  by (transfer, metis pdom_res_twice)

lemma pidres_commute: "A ρ B ρ f = B ρ A ρ f"
  by (metis (no_types, opaque_lifting) inf_commute pidres_twice)

lemma pidres_rres_commute: "A ρ (P ρ B) = (A ρ P) ρ B"
  by (transfer, simp, metis (mono_tags, opaque_lifting) pdres_rres_commute)

lemma pirres_twice: "f ρ A ρ B = f ρ (A  B)"
  by (transfer, metis (no_types, opaque_lifting) pran_res_twice)

lemma pirres_commute: "f ρ A ρ B = f ρ B ρ A"
  by (metis inf_commute pirres_twice)

lemma pidom_upd: "pidom (f(k  v)ρ) = insert k (pidom (f ρ (- {v})))"
  by (transfer, simp)

(* FIXME: Properly integrate using a proof strategy for coercion to partial injections *)

lemma f_pinv_f_apply: "x  pran (pfun_of_pinj f)  (pfun_of_pinj f)(pfun_of_pinj (pinv f) (x)p)p = x"
  by (transfer, simp add: f_pfun_inv_f_apply)

fun pinj_of_alist :: "('a × 'b) list  'a  'b" where
"pinj_of_alist [] = {}ρ" |
"pinj_of_alist (p # ps) = (pinj_of_alist ps)(fst p  snd p)ρ" 

lemma pinj_empty_alist [code]: "{}ρ = pinj_of_alist []"
  by simp

lemma pinj_upd_alist [code]: "(pinj_of_alist xs)(k  v)ρ = pinj_of_alist ((k, v) # xs)"
  by simp

context begin

text ‹ Injective associative lists ›

definition ialist :: "('a × 'b) list  bool" where
"ialist xs = (distinct (map fst xs)  distinct (map snd xs))"

text ‹ Remove pairs where either the key or value appeared in a previous pair ›

qualified fun clearjunk :: "('a × 'b) list  ('a × 'b) list" where
"clearjunk [] = []" |
"clearjunk (p#ps) = p # filter (λ (k', v'). k'  fst p  v'  snd p) (clearjunk ps)"

lemma ialist_clearjunk: "ialist (clearjunk xs)"
  by (induct xs rule:clearjunk.induct, auto simp add: ialist_def, (meson distinct_map_filter)+)

lemma ialist_clearjunk_fp: "ialist xs  clearjunk xs = xs"
  by (induct xs, auto simp add: ialist_def filter_id_conv rev_image_eqI)

lemma clearjunk_idem [simp]: "clearjunk (clearjunk xs) = clearjunk xs"
  using ialist_clearjunk ialist_clearjunk_fp by blast

lemma pinj_of_alist_ndres: "k  fst ` set xs  (-{k}) ρ (pinj_of_alist xs) = pinj_of_alist xs"
  by (induct xs, auto simp add: pidom_res_upd)

lemma pinj_of_alist_nrres: "v  snd ` set xs  (pinj_of_alist xs) ρ (- {v}) = pinj_of_alist xs"
  by (induct xs, auto simp add: piran_res_upd)

lemma pidom_ialist: "ialist xs  pidom (pinj_of_alist xs) = set (map fst xs)"
  by (induct xs, auto simp add: ialist_def pidom_upd)
     (metis (no_types, lifting) fst_conv image_eqI pinj_of_alist_nrres)+

lemma pinj_of_alist_filter_as_dres_rres:
  "ialist xs  pinj_of_alist (filter (λ(k', v'). k'  fst p  v'  snd p) xs) = (-{fst p}) ρ pinj_of_alist xs ρ (-{snd p})"
  by (induct xs rule: pinj_of_alist.induct)
     (auto simp add: ialist_def piran_res_upd pinj_of_alist_ndres pidom_res_upd
     ,metis (no_types, lifting) pinj_of_alist_nrres pirres_commute)
  
lemma pinj_of_alist_clearjunk: "pinj_of_alist (clearjunk xs) = pinj_of_alist xs"
  by (induct xs rule:clearjunk.induct, simp add: pinj_eq_iff)
     (simp add: ialist_clearjunk pinj_of_alist_filter_as_dres_rres pinj_upd_with_dres_rres)

lemma pinv_pinj_of_ialist:
  "ialist xs  pinv (pinj_of_alist xs) = pinj_of_alist (map (λ (x, y). (y, x)) xs)"
  by (induct xs rule: pinj_of_alist.induct, auto simp add: ialist_def simp add: pinj_of_alist_ndres)


lemma pfun_of_ialist: "ialist xs  pfun_of_pinj (pinj_of_alist xs) = pfun_of_alist xs"
  by (induct xs rule: pinj_of_alist.induct, auto simp add: bot_pinj.rep_eq ialist_def pinj_upd.rep_eq )
     (metis pinj_of_alist_nrres pinj_rres.rep_eq)


declare clearjunk.simps [simp del]

end

lemma pinv_pinj_of_alist [code]: "pinv (pinj_of_alist xs) = pinj_of_alist (map (λ (x, y). (y, x)) (Partial_Inj.clearjunk xs))"
  by (metis ialist_clearjunk pinj_of_alist_clearjunk pinv_pinj_of_ialist)

lemma pfun_of_pinj_of_alist [code]: 
  "pfun_of_pinj (pinj_of_alist xs) = pfun_of_alist (Partial_Inj.clearjunk xs)"
  by (metis ialist_clearjunk pfun_of_ialist pinj_of_alist_clearjunk)

declare pinj_of_alist.simps [simp del]

end