File ‹mat_alg.ML›
fun string_of_terms ctxt ts =
    ts |> map (Syntax.pretty_term ctxt)
       |> Pretty.commas |> Pretty.block |> Pretty.string_of
fun trace_t ctxt s t =
    tracing (s ^ " " ^ (Syntax.string_of_term ctxt t))
fun trace_fullthm ctxt s th =
    tracing (s ^ " [" ^ (Thm.hyps_of th |> string_of_terms ctxt) ^
             "] ==> " ^ (Thm.prop_of th |> Syntax.string_of_term ctxt))
val natT = HOLogic.natT
fun is_times t =
  case t of
    Const (@{const_name times}, _) $ _ $ _ => true
  | _ => false
fun is_plus t =
  case t of
    Const (@{const_name plus}, _) $ _ $ _ => true
  | _ => false
fun is_minus t =
  case t of
    Const (@{const_name minus}, _) $ _ $ _ => true
  | _ => false
fun is_uminus t =
  case t of
    Const (@{const_name uminus}, _) $ _ => true
  | _ => false
fun dest_binop t =
  case t of
    _ $ a $ b => (a, b)
  | _ => raise Fail "dest_binop"
fun dest_arg t =
  case t of
    _ $ x => x
  | _ => raise Fail "dest_arg"
fun dest_arg1 t =
  case t of
    _ $ arg1 $ _ => arg1
  | _ => raise Fail "dest_arg1"
fun is_mat_type t =
  is_Type (fastype_of t) andalso
  (fastype_of t |> dest_Type |> fst) = "Matrix.mat"
fun is_smult_mat t =
  case t of
    Const (@{const_name smult_mat}, _) $ _ $ _ => true
  | _ => false
fun is_adjoint t =
  case t of
    Const (@{const_name mat_adjoint}, _) $ _ => true
  | _ => false
fun is_id_mat t =
  case t of
    Const (@{const_name one_mat}, _) $ _ => true
  | _ => false
fun is_zero_mat t =
  case t of
    Const (@{const_name zero_mat}, _) $ _ $ _ => true
  | _ => false
fun strip_times t =
  if is_times t then
    strip_times (dest_arg1 t) @ [dest_arg t]
  else
    [t]
fun carrier_mat n t =
  let
    val T = fastype_of t  
    val Tset = HOLogic.mk_setT T  
  in
    Const (@{const_name carrier_mat}, natT --> natT --> Tset) $ n $ n
  end
fun mk_mem_carrier n t =
  HOLogic.mk_mem (t, carrier_mat n t)
fun assume_carrier ctxt n t =
  Thm.assume (Thm.cterm_of ctxt (HOLogic.mk_Trueprop (mk_mem_carrier n t)))
fun prod_in_carrier ctxt n t =
  if is_times t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm mult_carrier_mat}
    end
  else if is_plus t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm add_carrier_mat'}
    end
  else if is_uminus t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm uminus_carrier_mat}
    end
  else if is_minus t then
    let
      val (a, b) = dest_binop t
      val th1 = prod_in_carrier ctxt n a
      val th2 = prod_in_carrier ctxt n b
    in
      [th1, th2] MRS @{thm minus_carrier_mat'}
    end
  else if is_adjoint t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm adjoint_dim}
    end
  else if is_smult_mat t then
    let
      val a = dest_arg t
      val th = prod_in_carrier ctxt n a
    in
      th RS @{thm smult_carrier_mat}
    end
  else
    assume_carrier ctxt n t
fun obj_sym th =
  th RS @{thm HOL.sym}
fun to_meta_eq th =
  th RS @{thm HOL.eq_reflection}
fun to_obj_eq th =
  th RS @{thm HOL.meta_eq_to_obj_eq}
fun rewr_cv ctxt n th ct =
  let
    val th = to_meta_eq th
    val pat = th |> Thm.concl_of |> dest_arg1 |> Thm.cterm_of ctxt
    val inst = Thm.match (pat, ct)
    val th = Thm.instantiate inst th
    val prems = map (fn prem => prod_in_carrier ctxt n (prem |> dest_arg |> dest_arg1))
                    (Thm.prems_of th)
  in
    prems MRS th
  end
  handle THM _ => let val _ = trace_fullthm ctxt "here" th in raise Fail "THM" end
     | Pattern.MATCH => let val _ = trace_fullthm ctxt "here" th in raise Fail "MATCH" end
  
fun assoc_times_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if is_smult_mat a then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_smult_assoc_mat},
        Conv.arg_conv (assoc_times_norm ctxt n)] ct
    else if is_smult_mat b then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_smult_distrib},
        Conv.arg_conv (assoc_times_norm ctxt n)] ct
    else if is_times b then
      Conv.every_conv [
        rewr_cv ctxt n (obj_sym @{thm assoc_mult_mat}),
        Conv.arg1_conv (assoc_times_norm ctxt n)] ct
    else if is_id_mat a then
      rewr_cv ctxt n @{thm left_mult_one_mat} ct
    else if is_id_mat b then
      rewr_cv ctxt n @{thm right_mult_one_mat} ct
    else
      Conv.all_conv ct
  end
fun assoc_plus_one_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if not (is_mat_type t) then
      Conv.all_conv ct
    else if is_plus a then
      if Term_Ord.term_ord (dest_arg a, b) = GREATER then
        Conv.every_conv [
          rewr_cv ctxt n @{thm swap_plus_mat},
          Conv.arg1_conv (assoc_plus_one_norm ctxt n)] ct
      else
        Conv.all_conv ct
    else
      if Term_Ord.term_ord (a, b) = GREATER then
        rewr_cv ctxt n @{thm comm_add_mat} ct
      else
        Conv.all_conv ct
  end
fun assoc_plus_norm ctxt n ct =
  let
    val t = Thm.term_of ct
    val (a, b) = dest_binop t
  in
    if not (is_mat_type t) then
      Conv.all_conv ct
    else if is_plus b then
      Conv.every_conv [
        rewr_cv ctxt n (obj_sym @{thm assoc_add_mat}),
        Conv.arg1_conv (assoc_plus_norm ctxt n),
        assoc_plus_one_norm ctxt n] ct
    else if is_zero_mat a then
      rewr_cv ctxt n @{thm left_add_zero_mat} ct
    else if is_zero_mat b then
      rewr_cv ctxt n @{thm right_add_zero_mat} ct
    else
      assoc_plus_one_norm ctxt n ct
  end
fun smult_plus_norm ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm add_smult_distrib_left_mat},
        Conv.arg1_conv (smult_plus_norm ctxt n)] ct
    else
      Conv.all_conv ct
  end
fun norm_mult_poly_monomial ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg1 t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm add_mult_distrib_mat},
        Conv.arg1_conv (norm_mult_poly_monomial ctxt n),
        Conv.arg_conv (assoc_times_norm ctxt n),
        assoc_plus_norm ctxt n] ct
    else
      assoc_times_norm ctxt n ct
  end
fun norm_mult_polynomials ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_plus (dest_arg t) then
      Conv.every_conv [
        rewr_cv ctxt n @{thm mult_add_distrib_mat},
        Conv.arg1_conv (norm_mult_polynomials ctxt n),
        Conv.arg_conv (norm_mult_poly_monomial ctxt n),
        assoc_plus_norm ctxt n] ct
    else
      norm_mult_poly_monomial ctxt n ct
  end   
fun is_trace t =
  case t of
    Const (@{const_name trace}, _) $ _ => true
  | _ => false
fun norm_trace_times ctxt n ct =
  let
    val tt = Thm.term_of ct
    val t = dest_arg tt
    val ts = strip_times t
    val (rest, last) = split_last ts
  in
    if exists (fn t' => Term_Ord.term_ord (last, t') = LESS) rest then
      Conv.every_conv [
        rewr_cv ctxt n @{thm trace_comm},
        Conv.arg_conv (assoc_times_norm ctxt n),
        norm_trace_times ctxt n] ct
    else
      Conv.all_conv ct
  end
fun norm_trace_plus ctxt n ct =
  let
    val tt = Thm.term_of ct
    val t = dest_arg tt
  in
    if is_plus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm trace_add_linear},
        Conv.arg1_conv (norm_trace_plus ctxt n),
        Conv.arg_conv (norm_trace_times ctxt n)] ct
    else
      norm_trace_times ctxt n ct
  end
fun assoc_norm ctxt n ct =
  let
    val t = Thm.term_of ct
  in
    if is_times t then
      Conv.every_conv [
        Conv.binop_conv (assoc_norm ctxt n),
        norm_mult_polynomials ctxt n] ct
    else if is_plus t then
      Conv.every_conv [
        Conv.binop_conv (assoc_norm ctxt n),
        assoc_plus_norm ctxt n] ct
    else if is_smult_mat t then
      Conv.every_conv [
        Conv.arg_conv (assoc_norm ctxt n),
        smult_plus_norm ctxt n] ct
    else if is_minus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm minus_add_uminus_mat},
        assoc_norm ctxt n] ct
    else if is_uminus t then
      Conv.every_conv [
        rewr_cv ctxt n @{thm uminus_mat},
        assoc_norm ctxt n] ct
    else if is_adjoint t then
      if is_times (dest_arg t) then
        Conv.every_conv [
          rewr_cv ctxt n @{thm adjoint_mult},
          assoc_norm ctxt n] ct
      else if is_adjoint (dest_arg t) then
        Conv.every_conv [
          Conv.rewr_conv (to_meta_eq @{thm adjoint_adjoint}),
          assoc_norm ctxt n] ct
      else
        Conv.all_conv ct
    else if is_trace t then
      Conv.every_conv [
        Conv.arg_conv (assoc_norm ctxt n),
        norm_trace_plus ctxt n] ct
    else
      Conv.all_conv ct
  end
fun prove_by_assoc_norm ctxt n t =
  let
    val _ = trace_t ctxt "To show equation:" t
    val (a, b) = dest_binop t
    val norm1 = assoc_norm ctxt n (Thm.cterm_of ctxt a)
    val norm2 = assoc_norm ctxt n (Thm.cterm_of ctxt b)
  in
    if Thm.rhs_of norm1 aconvc Thm.rhs_of norm2 then
      let
        val res = Thm.transitive norm1 (Thm.symmetric norm2)
      in
        res |> to_obj_eq
      end
    else
      let
        val _ = trace_t ctxt "Left side is:" (Thm.term_of (Thm.rhs_of norm1))
        val _ = trace_t ctxt "Right side is:" (Thm.term_of (Thm.rhs_of norm2))
      in
        raise Fail "Normalization are not equal."
      end
  end
fun prove_by_assoc_norm_tac n ctxt state =
  let
    val n = Syntax.read_term ctxt n
    val subgoals = Thm.prems_of state
  in
    if null subgoals then Seq.empty else
      let
        
        val subgoal = hd (Thm.take_cprems_of 1 state)
        val (cprems, cconcl) = (Drule.strip_imp_prems subgoal, Drule.strip_imp_concl subgoal)
        val concl = HOLogic.dest_Trueprop (Thm.term_of cconcl)
        
        val subgoal_th = fold Thm.implies_intr (rev cprems) (prove_by_assoc_norm ctxt n concl)
        val chyps = Thm.chyps_of subgoal_th
        val res = Thm.implies_elim state subgoal_th
      in
        Seq.single (fold Thm.implies_intr chyps res)
      end
  end
val mat_assoc_method : (Proof.context -> Method.method) context_parser =
  Scan.lift Parse.term >> (fn n => fn ctxt => (SIMPLE_METHOD (prove_by_assoc_norm_tac n ctxt)))