#!/usr/bin/python3
import glob
import os
import re
import sys
from hashlib import sha1
from stat import S_IREAD, S_IRGRP, S_IROTH, S_IWRITE
from typing import Union, Collection, Match, Dict, Optional, Tuple, Any, Set

had_errors = False

source_files: set[str] = set()
generated_files: set[str] = set()


def multisubst(mappings: Collection[(Union[re.Pattern, str])], content: str) -> str:
    replacements = []
    patterns = []
    i = 0
    for pat, repl in mappings:
        if isinstance(pat, str):
            pat_str = re.escape(pat)
        else:
            pat_str = pat.pattern
        replacements.append(repl)
        patterns.append(f"(?P<GROUP_{i}>\\b(?:{pat_str})\\b)")
        i += 1

    pattern = re.compile("|".join(patterns))

    def repl_func(m: Match):
        # print(m)
        for name, text in m.groupdict().items():
            if text is None:
                continue
            if text.startswith("GROUP_"):
                continue
            idx = int(name[6:])
            # print(name, idx)
            return replacements[idx]
        assert False

    return pattern.sub(repl_func, content)

def write_to_file(file, content):
    global generated_files
    generated_files.add(file)
    try:
        with open(file, 'rt') as f:
            old_content = f.read()
        if content == old_content:
            print("(Nothing changed, not writing.)")
            return
        os.chmod(file, S_IREAD | S_IWRITE | S_IRGRP | S_IROTH)
        os.remove(file)
    except FileNotFoundError: pass

    with open(file, 'wt', newline='\n') as f:
        f.write(content)
    os.chmod(file, S_IREAD | S_IRGRP | S_IROTH)


def rewrite_laws(outputfile: str, template: str, substitutions: Dict[str, str]):
    global source_files
    print(f"Rewriting {template} -> {outputfile}")
    source_files.add(template)
    with open(template, 'rt') as f:
        content = f.read()
    new_content = multisubst(substitutions.items(), content)

    new_content = f"""(*
 * This is an autogenerated file. Do not edit.
 * The original is {template}. It was converted using instantiate_laws.py.
 *)

""" + new_content

    write_to_file(outputfile, new_content)


def read_instantiation_header(file: str) -> Optional[Tuple[str, Optional[str], Dict[str, str]]]:
    global source_files
    global had_errors
    source_files.add(file)
    with open(file, 'rt') as f:
        content = f.read()

    assert file.startswith("Axioms_")
    basename = file[len("Axioms_"):]
    assert basename.endswith(".thy")
    basename = basename[:-len(".thy")]

    m = re.compile(r"""\(\* AXIOM INSTANTIATION [^\n]*\n(.*?)\*\)""", re.DOTALL).search(content)
    if m is None:
        print(f"*** Could not find AXIOM INSTANTIATION header in {file}.")
        had_errors = True
        lines = []
    else:
        lines = m.group(1).splitlines()
    substitutions = {
        'theory Laws': f'theory Laws_{basename}',
        'imports Laws': f'imports Laws_{basename}',
        'theory Laws_Complement': f'theory Laws_Complement_{basename}',
        'Axioms': f'Axioms_{basename}',
        'Axioms_Complement': f'Axioms_Complement_{basename}'
    }
    # print(substitutions)
    for line in lines:
        if line.strip() == "":
            continue
        if re.match(r"^\s*#", line):
            continue
        m = re.match(r"^\s*(.+?)\s+\\<rightarrow>\s+(.+?)\s*$", line)
        if m is None:
            print(f"*** Invalid AXIOM INSTANTIATION line in {file}: {line}")
            had_errors = True
            continue
        key = m.group(1)
        val = m.group(2)
        if key in substitutions:
            print(f"*** Repeated AXIOM INSTANTIATION key in {file}: {line}")
            had_errors = True
        substitutions[key] = val
    # print(substitutions)
    laws_complement = f"Laws_Complement_{basename}.thy" if os.path.exists(f"Axioms_Complement_{basename}.thy") else None
    return (f"Laws_{basename}.thy", laws_complement, substitutions)

def rewrite_all():
    for f in glob.glob("Axioms_*.thy"):
        if f.startswith("Axioms_Complement"): continue
        lawfile, lawfile_complement, substitutions = read_instantiation_header(f)
        rewrite_laws(lawfile, "Laws.thy", substitutions)
        if lawfile_complement is not None:
            rewrite_laws(lawfile_complement, "Laws_Complement.thy", substitutions)


def create_check_theory():
    global source_files, generated_files

    print("Creating Check_Autogenerated_Files.thy")

    hash_checks = []
    for kind, files in (("Source", source_files), ("Generated", generated_files)):
        for file in sorted(files):
            with open(file, 'rb') as f:
                hash = sha1(f.read()).hexdigest()
            hash_checks.append(f'  check "{kind}" "{file}" "{hash}"')

    hash_checks_concat = ";\n".join(hash_checks)

    content = rf"""(*
 * This is an autogenerated file. Do not edit.
 * It was created using instantiate_laws.py.
 * It checks whether the other autogenerated files are up-to-date.
 *)

theory Check_Autogenerated_Files
  (* These imports are not actually needed, but in jEdit, they will conveniently trigger a re-execution of the checking code below upon changes. *)
  imports Laws_Classical Laws_Quantum Laws_Complement_Quantum
begin

ML \<open>
let
  fun check kind file expected = let
    val content = File.read (Path.append (Resources.master_directory \<^theory>) (Path.basic file))
    val hash = SHA1.digest content |> SHA1.rep
    in
      if hash = expected then () else
      error (kind ^ " file " ^ file ^ " has changed.\nPlease run \"python3 instantiate_laws.py\" to recreated autogenerated files.\nExpected SHA1 hash " ^ expected ^ ", got " ^ hash)
    end
in
{hash_checks_concat}
end
\<close>

end
"""

    write_to_file("Check_Autogenerated_Files.thy", content)


rewrite_all()
create_check_theory()
if had_errors:
    sys.exit(1)
