Theory Parser_Combinator

(* Author: Peter Lammich *)
section ‹Parser Combinators›

theory Parser_Combinator
imports
  "HOL-Library.Monad_Syntax"
  "HOL-Library.Char_ord"
  "HOL-Library.Code_Target_Nat"
  "Certification_Monads.Error_Monad"
  Munta_Error_Monad_Add 
  "Show.Show"
  "HOL-Library.Rewrite"
begin

  (**
    Parser Combinators, based on Sternagel et Thiemann's Parser_Monad, with the following additional features/differences:

    * Setup for the function package, to handle recursion
      parser uses fuel to ensure termination.
    * Uses (unit ⇒ shows) for error messages instead of String (lazy computation of messages, more comfortable due to shows)
    * Everything defined over generic token type
    * Some fancy combinators
        a ∥ b    choice, type a = type b
        --[f]    sequential composition, combine results with f
        --       = --[Pair]
        --*      seq, ignore right result
        *--      seq, ignore left result

    TODO/FIXME

    * Currently, the bind and repeat operation dynamically check whether input is consumed and then fail.
      At least for bind (no input is generated), we could try to encode this information into the parser's type.
      However, interplay with function package is not clear at this point :(
      Possible solution: Fixed-point based recursion combinator and partial_function. We could then do totality proof afterwards.


  *)


subsection ‹Type Definitions›

datatype 'a len_list = LL (ll_fuel: nat) (ll_list: "'a list")
definition "ll_from_list l  LL (length l) l"

lemma [measure_function]: "is_measure ll_fuel" ..

text ‹
  A parser takes a list of tokes and returns either an error message or
  a result together with the remaining tokens.
›

type_synonym
  ('t, 'a) parser = "'t len_list  (unit  shows) + ('a × 't len_list)"

text ‹A \emph{consuming parser} (cparser for short) consumes at least one token of input.›
definition is_cparser :: "('t,'a) parser  bool"
where
  "is_cparser p  (l l' x. p l = Inr (x,l')  ll_fuel l' < ll_fuel l)"

lemma is_cparserI:
  assumes "l l' x. p l = Inr (x, l')  ll_fuel l' < ll_fuel l"
  shows "is_cparser p"
  using assms unfolding is_cparser_def by blast

lemma is_cparserE:
  assumes "is_cparser p"
    and "(l l' x. p l = Inr (x, l')  ll_fuel l' < ll_fuel l)  P"
  shows "P"
  using assms by (auto simp: is_cparser_def)

lemma is_cparser_length[simp, dest]:
  assumes "p l = Inr (x, l')" and "is_cparser p"
  shows "ll_fuel l' < ll_fuel l"
  using assms by (blast elim: is_cparserE)

text ‹Used by fundef congruence rules›
definition "PCONG_INR x ts'  Inr (x,ts')"

lemma PCONG_EQ_D[dest!]:
  assumes "p l = PCONG_INR x l'"
  assumes "is_cparser p"
  shows "ll_fuel l' < ll_fuel l"
  using assms unfolding PCONG_INR_def by auto

named_theorems parser_rules

lemmas parser0_rules = disjI1 disjI2 allI impI conjI

ML structure Parser_Combinator = struct

    val cfg_simproc = Attrib.setup_config_bool @{binding parser_simproc} (K true)

    val cfg_debug = Attrib.setup_config_bool @{binding parser_debug} (K false)

    fun trace_tac ctxt msg = if Config.get ctxt cfg_debug then print_tac ctxt msg else all_tac
    fun trace_tac' ctxt msg i =
      if Config.get ctxt cfg_debug then
        print_tac ctxt (msg ^ " on subgoal " ^ Int.toString i)
      else all_tac

    fun prove_cparser_step_tac ctxt =
      let
        val p_rls = Named_Theorems.get ctxt @{named_theorems parser_rules}
      in
        trace_tac' ctxt "prove_cparser_step" THEN'
        (
          resolve_tac ctxt (@{thms parser0_rules} @ p_rls)
          ORELSE' SOLVED' (asm_simp_tac ctxt)
        )
      end


    fun prove_cparser_tac ctxt =
      trace_tac ctxt "prove_cparser" THEN
      DEPTH_SOLVE (FIRSTGOAL (prove_cparser_step_tac ctxt))

    fun add_cparser_def def_thm context = let
      val ctxt = Context.proof_of context
      val orig_ctxt = ctxt

      val ctxt = Config.put cfg_simproc false ctxt

      val (def_thm, ctxt) = yield_singleton (apfst snd oo Variable.import true) def_thm ctxt

      val lhs = def_thm
        |> Local_Defs.meta_rewrite_rule ctxt
        |> (fst o Logic.dest_equals o Thm.prop_of)

      val T = fastype_of lhs
      val cp_stmt =
           Const (@{const_name is_cparser}, T --> HOLogic.boolT)$lhs
        |> HOLogic.mk_Trueprop

      fun is_cparser_free (@{const Trueprop} $ (Const (@{const_name is_cparser},_) $ Free _)) = true
        | is_cparser_free _ = false

      fun is_goal_ok st =
        Thm.prop_of st |> Logic.strip_imp_prems
        |> map (Logic.strip_assums_concl)
        |> forall is_cparser_free

      val cp_thm =
        cp_stmt |> Thm.cterm_of ctxt
        |> Goal.init
        |> SINGLE (
            unfold_tac ctxt [def_thm] THEN
            trace_tac ctxt "cparser def proof" THEN
            DEPTH_FIRST is_goal_ok (
              FIRSTGOAL (prove_cparser_step_tac ctxt)
            )
          )

      val cp_thm = case cp_thm of
        NONE => error "Could not prove any is_cparser theorem: Empty result sequence"
      | SOME thm =>
          Goal.conclude thm


      val cp_thm =
           singleton (Variable.export ctxt orig_ctxt) cp_thm
        |> Drule.zero_var_indexes

      (*
      val cp_thm =
        Goal.prove ctxt [] [] cp_stmt (fn {context, ...} => tac context)
      |> singleton (Variable.export ctxt orig_ctxt)
      *)

      val context = Named_Theorems.add_thm @{named_theorems parser_rules} cp_thm context
    in
      context
    end
  end


attribute_setup consuming = Scan.succeed (Thm.declaration_attribute (Parser_Combinator.add_cparser_def))

simproc_setup is_cparser_prover ("is_cparser p") = fn _ => fn ctxt => fn ct =>
  if Config.get ctxt Parser_Combinator.cfg_simproc then
    let
      open Parser_Combinator
      val t = Thm.term_of ct
      val stmt = Logic.mk_equals (t,@{term True})

      val _ = if Config.get ctxt cfg_debug then
          (Pretty.block [Pretty.str "is_cparser simproc invoked on: ", Syntax.pretty_term ctxt t, Pretty.fbrk, Syntax.pretty_term ctxt stmt]) |> Pretty.string_of |> tracing
        else ()

      val ctxt = Config.put Parser_Combinator.cfg_simproc false ctxt

      val othm = try (Goal.prove ctxt [] [] stmt) (fn {context=ctxt, ...} =>
        FIRSTGOAL (resolve_tac ctxt @{thms HOL.Eq_TrueI})
        THEN TRY (Parser_Combinator.prove_cparser_tac ctxt)
      )

      val _ =
        if Config.get ctxt cfg_debug andalso is_none othm then
          (Pretty.block [Pretty.str "is_cparser simproc failed on: ", Syntax.pretty_term ctxt t, Pretty.fbrk, Syntax.pretty_term ctxt stmt]) |> Pretty.string_of |> tracing
        else ()

      (*
      val _ = case othm of
        NONE => (Pretty.block [Pretty.str "is_cparser simproc failed on: ", Syntax.pretty_term ctxt t, Pretty.fbrk, Syntax.pretty_term ctxt stmt]) |> Pretty.string_of |> warning
     | SOME _ => ();
      *)
    in
      othm
    end
  else
    NONE

text ‹Wrapping a parser to dynamically assert that it consumes tokens.›
definition ensure_cparser :: "('t,'a) parser  ('t,'a) parser" where
  "ensure_cparser p  λts. do {
    (x, ts')  p ts;
    if (ll_fuel ts' < ll_fuel ts) then Error_Monad.return (x,ts')
    else Error_Monad.error (λ_. shows ''Dynamic parser check failed'')
  }"

lemma ensure_cparser_cparser[parser_rules]: "is_cparser (ensure_cparser p)"
  apply (rule is_cparserI)
  unfolding ensure_cparser_def
  by (auto simp: Error_Monad.bind_def split: sum.splits if_splits)

lemma ensure_cparser_cong[fundef_cong]:
  assumes "l=l'"
  assumes "p l = p' l'"
  shows "ensure_cparser p l = ensure_cparser p' l'"
  using assms by (auto simp: ensure_cparser_def)


abbreviation bnf_eq :: "('t,'a) parser  ('t,'a) parser  prop" (infix "::=" 2) where
  "bnf_eq p1 p2  (l. p1 l = p2 l)"



subsection ‹Monad-Setup for Parsers›

definition return :: "'a  ('t, 'a) parser"
where
  "return x = (λts. Error_Monad.return (x, ts))"

definition error_aux :: "(unit  shows)  ('t, 'a) parser"
where
  "error_aux e = (λ_. Error_Monad.error e)"

abbreviation error where "error s  error_aux (ERR s)"
abbreviation "error_str s  error_aux (ERRS s)"

definition update_error :: "('t,'a) parser  (shows  shows)  ('t,'a) parser"
  where "update_error m f l  m l <+? (λe _. f (e ()))"

definition ensure_parser :: "('t,'a) parser  ('t,'a) parser" where
  "ensure_parser p  λts. do {
    (x, ts')  p ts;
    if (ll_fuel ts'  ll_fuel ts) then Error_Monad.return (x,ts')
    else Error_Monad.error (ERRS ''Dynamic parser check failed'')
  }"

definition bind :: "('t, 'a) parser  ('a  ('t, 'b) parser)  ('t, 'b) parser"
where
  "bind m f  λts. do {
    (x, ts)  ensure_parser m ts;
    ensure_parser (f x) ts
  }"

definition get :: "('t,'t) parser" where
  "get ll  case ll of LL (Suc n) (x#xs)  Error_Monad.return (x,LL n xs) | _  (Error_Monad.error (λ_. shows_string ''Expecting more input''))"

definition get_tokens :: "('t,'t list) parser" where
  "get_tokens  λll. Error_Monad.return (ll_list ll,ll)"

adhoc_overloading
      Monad_Syntax.bind  bind
  and Error_Syntax.update_error  update_error

(* TODO: Specialize to parser type? *)
lemma let_cong' [fundef_cong]:
  "M = N  l=l'  (x. x = N  f x l' = g x l')  Let M f l = Let N g l'"
  unfolding Let_def by blast

lemma if_cong' [fundef_cong]:
  assumes "b = c"
    and "l=l'"
    and "c  x l' = u l'"
    and "¬ c  y l' = v l'"
  shows "(if b then x else y) l = (if c then u else v) l'"
  using assms by simp

lemma split_cong' [fundef_cong]:
  "l=l'  (x y. (x, y) = q  f x y l' = g x y l' )  p = q  case_prod f p l = case_prod g q l'"
  by (auto simp: split_def)

lemma bind_cong [fundef_cong]:
  fixes m1 :: "('t, 'a) parser"
  assumes "m1 ts2 = m2 ts2"
    and " y ts.  m2 ts2 = PCONG_INR y ts; ll_fuel ts  ll_fuel ts2  f1 y ts = f2 y ts"
    and "ts1 = ts2"
  shows "((m1  f1) ts1) = ((m2  f2) ts2)"
  using assms
  unfolding bind_def PCONG_INR_def
  by (auto simp: Error_Monad.bind_def ensure_parser_def split: sum.split prod.split)

lemma is_cparser_bind[parser_rules]:
  assumes "is_cparser p  (x. is_cparser (q x))"
  shows "is_cparser (p  q)"
  apply (rule is_cparserI)
  using assms unfolding is_cparser_def bind_def Error_Monad.bind_def ensure_parser_def
  by (fastforce split: sum.splits if_splits)

lemma return_eq[simp]: "return x l = Inr y  y=(x,l)"
  unfolding return_def by auto


lemma is_cparser_error[parser_rules]: "is_cparser (error_aux e)"
  by (auto simp: error_aux_def intro: is_cparserI)

lemma is_cparser_get[parser_rules]:
  "is_cparser get"
  apply (rule is_cparserI)
  apply (auto simp: get_def split: len_list.splits nat.splits list.splits)
  done

lemma monad_laws[simp]:
  "bind m return = ensure_parser m"
  "bind (return x) f = ensure_parser (f x)"
  "bind (bind m f) g = bind m (λx. bind (f x) g)"
  "ensure_parser (ensure_parser m) = ensure_parser m"
  "bind (ensure_parser m) f = bind m f"
  "bind m (λx. ensure_parser (f x)) = bind m f"
  unfolding bind_def return_def ensure_parser_def Error_Monad.bind_def
  by (auto split: if_splits sum.splits prod.splits intro!: ext)

subsection ‹More Combinators›

definition err_expecting_aux :: "(unit  shows)  ('t::show, 'a) parser"
where
  "err_expecting_aux msg = do { ts  get_tokens; error
    (shows_string ''expecting '' o msg () o shows_string '', but found: '' o shows_quote (shows (take 100 ts)))}"

abbreviation err_expecting :: "shows  ('t::show, 'a) parser"
where
  "err_expecting msg  err_expecting_aux (λ_. msg)"

abbreviation "err_expecting_str msg  err_expecting (shows_string msg)"

definition "eoi  do {
  tks  get_tokens; if tks=[] then return () else err_expecting_str ''end of input'' }"


definition alt :: "(_,'a) parser  (_,'b) parser  (_,'a+'b) parser" where
  "alt p1 p2 l  try   do { (r,l)  p1 l; Error_Monad.return (Inl r, l) }
                 catch (λe1. (try do { (r,l)  p2 l; Error_Monad.return (Inr r, l) }
                 catch (λe2. Error_Monad.error (λ_. e1 () o shows ''⏎  | '' o e2 ()))))"

fun sum_join where
  "sum_join (Inl x) = x" | "sum_join (Inr x) = x"

abbreviation alt' :: "(_,'a) parser  (_,'a) parser  (_,'a) parser" (infixr "" 53)
  where "alt' p q  alt p q  return o sum_join"

abbreviation gseq :: "('t,'a) parser  ('a  'b  'c)  ('t,'b) parser  ('t,'c) parser" ("_--[_]_" [61,0,60] 60)
  where "gseq p f q  p  (λa. q  (λb. return (f a b)))" (* TODO/FIXME: Do-notation and abbreviation generate additional type vars here *)

abbreviation seq :: "('t,'a) parser  ('t,'b) parser  ('t,'a×'b) parser" (infixr "--" 60)
  where "seq p q  p --[Pair] q"

abbreviation seq_ignore_left :: "('t,'a) parser  ('t,'b) parser  ('t,'b) parser" (infixr "*--" 60)
  where "p *-- q  p --[λ_ x. x] q"

abbreviation seq_ignore_right :: "('t,'a) parser  ('t,'b) parser  ('t,'a) parser" (infixr "--*" 60)
  where "p --* q  p --[λx _. x] q"

abbreviation map_result :: "('t,'a) parser  ('a'b)  ('t,'b) parser" (infixr "with" 54)
  where "p with f  p  return o f"

definition "exactly ts  foldr (λt p. do { xget; if x=t then p  return o (#) x else error id}) ts (return [])
     err_expecting (shows_string ''Exactly '' o shows ts)"

declare err_expecting_aux_def[consuming]

lemma alt_is_cparser[parser_rules]:
  "is_cparser p  is_cparser q  is_cparser (alt p q)"
  apply (rule is_cparserI)
  unfolding alt_def
  by (auto simp: Error_Monad.bind_def split: sum.splits)

lemma alt_cong[fundef_cong]:
  " l=l'; p1 l = p1' l'; e. p1' l' = Inl e  p2 l = p2' l'   alt p1 p2 l = alt p1' p2' l'"
  unfolding alt_def by (auto split: sum.splits simp: Error_Monad.bind_def)

lemma [parser_rules]: "ts[]  is_cparser (exactly ts)"
  by (cases ts) (auto simp: exactly_def intro: parser_rules)


abbreviation optional :: "'a  ('t,'a) parser  ('t,'a) parser" where
  "optional dflt p  p  return dflt"

abbreviation maybe :: "('t,'a) parser  ('t,'a option) parser" ("(_?)" [1000] 999)
  where "p?  p with Some  return None"


subsubsection ‹Repeat›

fun repeat :: "('t,'a) parser  ('t,'a list) parser" where
  "repeat p ::= optional [] (ensure_cparser p --[(#)] repeat p)"

abbreviation "repeat1 p  p --[(#)] repeat p"


declare repeat.simps[simp del]

lemma repeat_cong[fundef_cong]:
  assumes "nts.  ll_fuel nts  ll_fuel l'   p (nts) = p' (nts)"
  assumes "l=l'"
  shows "repeat p l = repeat p' l'"
  using assms(1)
  unfolding l=l'
  apply (induction p l' rule: repeat.induct)
  apply (rewrite in "_=" repeat.simps)
  apply (rewrite in "=_" repeat.simps)
  apply (intro alt_cong bind_cong)
  apply (auto simp: ensure_cparser_def PCONG_INR_def)
  done

subsubsection ‹Left and Right Associative Chaining›
text ‹Parse a sequence of A› separated by F›,
  and then fold the sequence with the results of F›,
  either left or right associative.

  Example: Assume we have the input x1 o1 … on-1 xn,
    and the result of parsing the x› with A› and the o› with F›
    are ai and +i.
    Then, chainL1› returns (…((a1 +1 a2) +2 a3) +3 …) ›
    and chainR1› returns a1 +1 (a2 +2 (a3 +3 …)…) ›
context
  fixes A :: "('t,'a) parser"
  fixes F :: "('t,'a'a'a) parser"
begin
  definition chainL1 :: "('t,'a) parser" where
    "chainL1  do {
      x  A;
      xs  repeat (F --[λf b a. f a b] A);
      return (foldl (λa f. f a) x xs)
    }"

  qualified fun fold_shiftr :: "('a  (('a'a'a)×'a) list  'a)" where
    "fold_shiftr a [] = a"
  | "fold_shiftr a ((f,b)#xs) = f a (fold_shiftr b xs)"

  definition chainR1 :: "('t,'a) parser" where
    "chainR1  do {
      x  A;
      xs  repeat (F -- A);
      return (fold_shiftr x xs)
    } "

end


lemma chainL1_cong[fundef_cong]:
  assumes "l2. ll_fuel l2  ll_fuel l'  A l2 = A' l2"
  assumes "l2. ll_fuel l2  ll_fuel l'  F l2 = F' l2"
  assumes "l=l'"
  shows "chainL1 A F l = chainL1 A' F' l'"
  unfolding chainL1_def
  apply (intro bind_cong repeat_cong assms order_refl)
  by auto

lemma chainR1_cong[fundef_cong]:
  assumes "l2. ll_fuel l2  ll_fuel l'  A l2 = A' l2"
  assumes "l2. ll_fuel l2  ll_fuel l'  F l2 = F' l2"
  assumes "l=l'"
  shows "chainR1 A F l = chainR1 A' F' l'"
  unfolding chainR1_def
  apply (intro bind_cong repeat_cong assms order_refl)
  by auto




subsection ‹Lexing Utilities›

definition tk_with_prop' :: "shows  ('t::show  bool)  ('t,'t) parser" where
  [consuming]: "tk_with_prop' errmsg Φ  do {
    xget;
    if Φ x then return x
    else err_expecting errmsg
  }"

abbreviation tk_with_prop :: "('t::show  bool)  ('t,'t) parser" where
  "tk_with_prop  tk_with_prop' id"

definition range :: "'t::{linorder,show}  't  ('t,'t) parser" where
  [consuming]: "range a b  do {
    xget;
    if ax  xb then return x
    else err_expecting (shows_string ''Token in range '' o shows a o shows_string '' - '' o shows b) }"

definition any :: "'t::show list  ('t,'t) parser" where
  [consuming]: "any ts  do { tget; if tset ts then return t else err_expecting (shows_string ''One of '' o shows ts) }"

definition "gen_token ws p  ws *-- p"

lemma [parser_rules]: "is_cparser p  is_cparser (gen_token ws p)"
  unfolding gen_token_def by simp

subsubsection ‹Characters›
abbreviation (input) "char_tab  CHR 0x09"
abbreviation (input) "char_carriage_return  CHR 0x0D"
abbreviation (input) "char_wspace  [CHR '' '', CHR ''⏎'', char_tab, char_carriage_return]"

text ‹Some standard idioms›
definition [consuming]: "lx_lowercase  (range CHR ''a'' CHR ''z'' )"
definition [consuming]: "lx_uppercase  (range CHR ''A'' CHR ''Z'' )"
definition [consuming]: "lx_alpha  (lx_lowercase  lx_uppercase)"
definition [consuming]: "lx_digit  (range CHR ''0'' CHR ''9'' )"
abbreviation "lx_alphanum  lx_alpha  lx_digit"

subsection ‹Code Generator Setup›
declare monad_laws[code_unfold]
lemma bind_return_o_unfold[code_unfold]: "(m  return o f) = do { xm; return (f x)}" by (auto simp: o_def)
declare split[code_unfold] (* TODO: Should this be code_unfold by default? *)


subsection ‹Utilities for Parsing›

text ‹Project out remainder token sequence›
fun show_pres where
  "show_pres (Inr (ll,_)) = Inr ll"
| "show_pres (Inl e) = Inl e"

text ‹Parse complete input, parameterized by parser for trailing whitespace›
definition "parse_all ws p  show_pres o (p --* ws --* eoi) o ll_from_list o String.explode"

definition "parse_all_implode ws p s  parse_all ws p s <+? (λmsg. String.implode (msg () ''''))"

definition "parse_all_implode_nows p s  parse_all_implode (return ()) p s"

end