Theory Check_Monad

(* Title:     Check_Monad
   Author:    Christian Sternagel
   Author:    René Thiemann
*)

section ‹A Special Error Monad for Certification with Informative Error Messages›

theory Check_Monad
imports Error_Monad
begin

text ‹A check is either successful or fails with some error.›
type_synonym
  'e check = "'e + unit"

abbreviation succeed :: "'e check"
where
  "succeed  return ()"

definition check :: "bool  'e  'e check"
where
  "check b e = (if b then succeed else error e)"

lemma isOK_check [simp]:
  "isOK (check b e) = b" by (simp add: check_def)

lemma isOK_check_catch [simp]:
  "isOK (try check b e catch f)  b  isOK (f e)"
  by (auto simp add: catch_def check_def)

definition check_return :: "'a check  'b  'a + 'b"
where
  "check_return chk res = (chk  return res)"

lemma check_return [simp]:
  "check_return chk res = return res'  isOK chk  res' = res"
  unfolding check_return_def by (cases chk) auto

lemma [code_unfold]:
  "check_return chk res = (case chk of Inr _  Inr res | Inl e  Inl e)"
  unfolding check_return_def bind_def ..

abbreviation check_allm :: "('a  'e check)  'a list  'e check"
where
  "check_allm f xs  forallM f xs <+? snd"

abbreviation check_exm :: "('a  'e check)  'a list  ('e list  'e)  'e check"
where
  "check_exm f xs fld  existsM f xs <+? fld"

lemma isOK_check_allm:
  "isOK (check_allm f xs)  (x  set xs. isOK (f x))"
  by simp

abbreviation check_allm_index :: "('a  nat  'e check)  'a list  'e check"
where
  "check_allm_index f xs  forallM_index f xs <+? snd"

abbreviation check_all :: "('a  bool)  'a list  'a check"
where
  "check_all f xs  check_allm (λx. if f x then succeed else error x) xs"

abbreviation check_all_index :: "('a  nat  bool)  'a list  ('a × nat) check"
where
  "check_all_index f xs  check_allm_index (λx i. if f x i then succeed else error (x, i)) xs"

lemma isOK_check_all_index [simp]:
  "isOK (check_all_index f xs)  (i < length xs. f (xs ! i) i)"
  by auto

text ‹The following version allows to modify the index during the check.›
definition
  check_allm_gen_index ::
    "('a  nat  nat)  ('a  nat  'e check)  nat  'a list  'e check"
where
  "check_allm_gen_index g f n xs = snd (foldl (λ(i, m) x. (g x i, m  f x i)) (n, succeed) xs)"

lemma foldl_error:
  "snd (foldl (λ(i, m) x . (g x i, m  f x i)) (n, error e) xs) = error e"
  by (induct xs arbitrary: n) auto

lemma isOK_check_allm_gen_index [simp]:
  assumes "isOK (check_allm_gen_index g f n xs)"
  shows "xset xs. i. isOK (f x i)"
using assms
proof (induct xs arbitrary: n)
  case (Cons x xs)
  show ?case
  proof (cases "isOK (f x n)")
    case True
    then have "i. isOK (f x i)" by auto
    with True Cons show ?thesis
      unfolding check_allm_gen_index_def by (force simp: isOK_iff)
  next
    case False
    then obtain e where "f x n = error e" by (cases "f x n") auto
    with foldl_error [of g f _ e] and Cons show ?thesis
      unfolding check_allm_gen_index_def by auto
  qed
qed simp

lemma check_allm_gen_index [fundef_cong]:
  fixes f :: "'a  nat  'e check"
  assumes "x n. x  set xs  g x n = g' x n"
    and "x n. x  set xs  f x n = f' x n"
  shows "check_allm_gen_index g f n xs = check_allm_gen_index g' f' n xs"
proof -
  { fix n m
    have "foldl (λ(i, m) x. (g x i, m  f x i)) (n, m) xs =
      foldl (λ(i, m) x. (g' x i, m  f' x i)) (n, m) xs"
      using assms by (induct xs arbitrary: n m) auto }
  then show ?thesis unfolding check_allm_gen_index_def by simp
qed

definition check_subseteq :: "'a list  'a list  'a check"
where
  "check_subseteq xs ys = check_all (λx. x  set ys) xs"

lemma isOK_check_subseteq [simp]:
  "isOK (check_subseteq xs ys)  set xs  set ys"
  by (auto simp: check_subseteq_def)

definition check_same_set :: "'a list  'a list  'a check"
where
  "check_same_set xs ys = (check_subseteq xs ys  check_subseteq ys xs)"

lemma isOK_check_same_set [simp]:
  "isOK (check_same_set xs ys)  set xs = set ys"
  unfolding check_same_set_def by auto

definition check_disjoint :: "'a list  'a list  'a check"
where
  "check_disjoint xs ys = check_all (λx. x  set ys) xs"

lemma isOK_check_disjoint [simp]:
  "isOK (check_disjoint xs ys)  set xs  set ys = {}"
  unfolding check_disjoint_def by (auto)

definition check_all_combinations :: "('a  'a  'b check)  'a list  'b check"
where
  "check_all_combinations c xs = check_allm (λx. check_allm (c x) xs) xs"

lemma isOK_check_all_combinations [simp]:
  "isOK (check_all_combinations c xs)  (x  set xs. y  set xs. isOK (c x y))"
  unfolding check_all_combinations_def by simp

fun check_pairwise :: "('a  'a  'b check)  'a list  'b check"
where
  "check_pairwise c [] = succeed" |
  "check_pairwise c (x # xs) = (check_allm (c x) xs  check_pairwise c xs)"

lemma pairwise_aux:
  "(j<length (x # xs). i<j. P ((x # xs) ! i) ((x # xs) ! j))
     = ((j<length xs. P x (xs ! j))  (j<length xs. i<j. P (xs ! i) (xs ! j)))"
  (is "?C = (?A  ?B)")
proof (intro iffI conjI)
  assume *: "?A  ?B"
  show "?C"
  proof (intro allI impI)
    fix i j
    assume "j < length (x # xs)" and "i < j"
    then show "P ((x # xs) ! i) ((x # xs) ! j)"
    proof (induct j)
      case (Suc j)
      then show ?case
        using * by (induct i) simp_all
    qed simp
  qed
qed force+

lemma isOK_check_pairwise [simp]:
  "isOK (check_pairwise c xs)  (j<length xs. i<j. isOK (c (xs ! i) (xs ! j)))"
proof (induct xs)
  case (Cons x xs)
  have "isOK (check_allm (c x) xs) = (j<length xs. isOK (c x (xs ! j)))"
    using all_set_conv_all_nth [of xs "λy. isOK (c x y)"] by simp
  then have "isOK (check_pairwise c (x # xs)) =
    ((j<length xs. isOK (c x (xs ! j)))  (j<length xs. i<j. isOK (c (xs ! i) (xs ! j))))"
    by (simp add: Cons)
  then show ?case using pairwise_aux [of x xs "λx y. isOK (c x y)"] by simp
qed auto

abbreviation check_exists :: "('a  bool)  'a list  ('a list) check"
where
  "check_exists f xs  check_exm (λx. if f x then succeed else error [x]) xs concat"

lemma isOK_choice [simp]:
  "isOK (choice [])  False"
  "isOK (choice (x # xs))  isOK x  isOK (choice xs)"
  by (auto simp: choice.simps isOK_def split: sum.splits)

fun or_ok :: "'a check  'a check  'a check" where
  "or_ok (Inl a) b = b" |
  "or_ok (Inr a) b = Inr a" 

lemma or_is_or: "isOK (or_ok a b) = isOK a  isOK b" using or_ok.elims by blast


end