Theory Concrete_Multivariate_Polynomials

(*******************************************************************************

  Project: Sumcheck Protocol

  Authors: Azucena Garvia Bosshard <zucegb@gmail.com>
           Christoph Sprenger, ETH Zurich <sprenger@inf.ethz.ch>
           Jonathan Bootle, IBM Research Europe <jbt@zurich.ibm.com>

*******************************************************************************)

section ‹Multivariate Polynomials: Instance›

theory Concrete_Multivariate_Polynomials
  imports 
    "../Generalized_Sumcheck_Protocol/Sumcheck_as_Public_Coin_Proof"
    Polynomial_Instantiation
    Roots_Bounds
begin

declare total_degree_zero [simp del]

subsection ‹Auxiliary lemmas›

lemma card_indep_bound: 
  assumes "P  card {x. Q x}  d"
  shows "card {x. P  Q x}  d" 
  using assms
  by (cases P) auto
 
lemma sum_point_neq_zero [simp]: "(x' | x' = x  f x'  0. f x') = f x"
proof - 
  have (x' | x' = x  f x'  0. f x') = (x' | x' = x  f x  0. f x') 
    by (intro sum.cong) auto
  also have  = f x 
    by (cases "f x = 0") (simp_all)
  finally show ?thesis .
qed


subsection ‹Proving the assumptions of the locale›

subsubsection ‹Variables›

― ‹The term @{term vars} is already defined in @{theory "Polynomials.More_MPoly_Type"}.›

― ‹We show the assumptions @{thm [source] "multi_variate_polynomial.vars_finite"}, 
@{thm [source] "multi_variate_polynomial.vars_add"}, 
@{thm [source] "multi_variate_polynomial.vars_zero"} and 
@{thm [source] "multi_variate_polynomial.vars_inst"} from the locale 
@{locale "multi_variate_polynomial"}.›

lemma vars_zero: vars 0 = {}
  by (simp add: vars_def zero_mpoly.rep_eq)

lemma vars_inst: vars (inst p σ)  vars p - dom σ
  by (auto simp add: vars_def inst_defs keys_def MPoly_inverse
                     finite_inst_fun_keys lookup_inst_monom_resid 
           elim!: sum.not_neutral_contains_not_neutral split: if_splits)  
   

― ‹Lemmas for to translate the roots bound to the format of the locale assumption.›

lemma vars_minus: vars p = vars (-p)
  by(simp add: vars_def uminus_mpoly.rep_eq)

lemma vars_subtr: 
  fixes p q :: 'a::comm_ring mpoly
  shows vars (p - q)  vars p  vars q
  by(simp add: vars_add[where ?p1.0 = "p" and ?p2.0 = "-q", simplified] vars_minus[where p = "q"])


subsubsection ‹Degree›

― ‹We define the degree of a multivariate polynomial as its total degree›

abbreviation deg :: ('a::zero) mpoly  nat where
  deg p  total_degree p 


― ‹We show the assumptions @{thm [source] "multi_variate_polynomial.deg_zero"}, 
@{thm [source] "multi_variate_polynomial.deg_add"} and @{thm [source] "multi_variate_polynomial.deg_inst"}.›

lemma deg_zero: deg 0 = 0 by (fact total_degree_zero)

lemma deg_add: deg (p + q)  max (deg p) (deg q) 
proof - 
  have deg (p + q) = Max (insert 0 ((λx. sum (lookup x) (keys x)) ` keys (mapping_of p + mapping_of q)))
    by (simp add: total_degree.rep_eq plus_mpoly.rep_eq)
  also have ...  Max (insert 0 ((λx. sum (lookup x) (keys x)) ` (keys (mapping_of p)  keys (mapping_of q))))
    by (intro Max_mono Set.insert_mono image_mono Poly_Mapping.keys_add) (auto)
  also have ... = Max ((insert 0 ((λx. sum (lookup x) (keys x)) ` keys (mapping_of p)))
                       (insert 0 ((λx. sum (lookup x) (keys x)) ` keys (mapping_of q))))
    by (simp add: image_Un)
  also have ... = max (Max (insert 0 ((λx. sum (lookup x) (keys x)) ` keys (mapping_of p)))) 
                        (Max (insert 0 ((λx. sum (lookup x) (keys x)) ` keys (mapping_of q))))
    by (intro Max_Un) (auto)
  also have ... = max (deg p) (deg q)
    by (simp add: total_degree.rep_eq)
  finally show ?thesis .
qed

lemma deg_inst: deg (inst p σ)  deg p 
proof (transfer)
  fix p :: (nat 0 nat) 0 'a and σ :: "(nat, 'a) subst"
  show Max (insert 0 ((λm. sum (lookup m) (keys m)) ` keys (inst_aux p σ))) 
       Max (insert 0 ((λm. sum (lookup m) (keys m)) ` keys p)) 
    by (auto simp add: keys_def inst_defs finite_inst_fun_keys lookup_inst_monom_resid
             elim!: sum.not_neutral_contains_not_neutral)
       (fastforce simp add: Max_ge_iff intro!: disjI2  intro: sum_mono2)
qed


― ‹Lemmas for translating the roots bound to the format of the locale assumption.›

lemma deg_minus: deg p = deg (-p)
  by(auto simp add: total_degree_def uminus_mpoly.rep_eq)

lemma deg_subtr:
  fixes p q :: 'a::comm_ring mpoly
  shows deg (p - q)  max (deg p) (deg (q))
  by(auto simp add: deg_add[where p = "p" and q = "-q", simplified] deg_minus[where p = "q"])


subsubsection ‹Evaluation›

― ‹Our evaluation is defined as insertion in MPoly_Type›

abbreviation eval :: 'a mpoly  (nat, 'a) subst  ('a::comm_semiring_1) where 
  eval p σ  insertion (the o σ) p

― ‹We show the assumptions @{thm [source] "multi_variate_polynomial.eval_zero"}, 
@{thm [source] "multi_variate_polynomial.eval_add"} and @{thm [source] "multi_variate_polynomial.eval_inst"}.›

lemma eval_zero: eval 0 σ = 0 
  by (fact insertion_zero)

lemma eval_add: vars p  vars q  dom σ  eval (p + q) σ = eval p σ + eval q σ 
  by (intro insertion_add)

― ‹evaluation and instantiation›

lemma eval_inst: eval (inst p σ) ρ = eval p (ρ ++ σ)
proof (transfer, transfer)
  fix p :: (nat 0 nat)  'a and σ ρ :: (nat, 'a) subst 
  assume fin: finite {m. p m  0}  
  show insertion_fun (the  ρ) (inst_fun p σ) = insertion_fun (the  ρ ++ σ) p 
  proof -
    let ?mon = λσ m v. the (σ v) ^ lookup m v

    have x ^ lookup m v  1  v  keys m for x :: "'a" and v and m :: "nat 0 nat"
      using zero_less_iff_neq_zero by (fastforce simp add: in_keys_iff)
    then have exp_fin: finite {v. P v  f v ^ lookup m v  1} 
      for f :: "nat  'a" and m :: "nat 0 nat" and P :: "nat  bool"
      by (auto intro: finite_subset[where B="keys m"])

    note fin_simps [simp] = fin this this[where P1="λ_. True", simplified]
    note map_add_simps [simp] = map_add_dom_app_simps(1,3)

    have insertion_fun (the  ρ) (inst_fun p σ) = 
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  inst_monom_coeff m' σ  0. 
              p m' * inst_monom_coeff m' σ) * (v. ?mon ρ m v))
      by (simp add: insertion_fun_def inst_fun_def)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  inst_monom_coeff m' σ  0. 
              p m' * inst_monom_coeff m' σ * (v. ?mon ρ m v)))
      by (intro Sum_any.cong) (simp add: sum_distrib_right)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  inst_monom_coeff m' σ  0. 
              p m' * inst_monom_coeff m' σ * (v | v  dom σ  ?mon ρ m' v  1. ?mon ρ m' v)))
      by (intro Sum_any.cong sum.cong)
         (auto simp add: lookup_inst_monom_resid Prod_any.expand_set intro: arg_cong)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  
                    (v | v  dom σ  ?mon σ m' v  1. ?mon σ m' v)  0. 
              p m' * 
              ((v | v  dom σ  ?mon σ m' v  1. ?mon (ρ ++ σ) m' v) * 
               (v | v  dom σ  ?mon ρ m' v  1. ?mon (ρ ++ σ) m' v))))
      by (simp add: inst_monom_coeff_def mult.assoc)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  
                    (v | v  dom σ  ?mon σ m' v  1. ?mon σ m' v)  0 
                    (v | v  dom σ  ?mon ρ m' v  1. ?mon ρ m' v)  0. 
              p m' * 
              ((v | v  dom σ  ?mon σ m' v  1. ?mon (ρ ++ σ) m' v) * 
               (v | v  dom σ  ?mon ρ m' v  1. ?mon (ρ ++ σ) m' v))))
      by (intro Sum_any.cong sum.mono_neutral_right) (auto)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  
                    (v | v  dom σ  ?mon σ m' v  1. ?mon σ m' v)  0 
                    (v | v  dom σ  ?mon ρ m' v  1. ?mon ρ m' v)  0. 
              p m' * 
              (v | v  dom σ  ?mon σ m' v  1  v  dom σ  ?mon ρ m' v  1. ?mon (ρ ++ σ) m' v)))
      by (subst prod.union_disjoint[symmetric])
         (auto  intro!: Sum_any.cong sum.cong prod.cong intro: arg_cong)

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  
                    (v | v  dom σ  ?mon σ m' v  1. ?mon σ m' v)  0 
                    (v | v  dom σ  ?mon ρ m' v  1. ?mon ρ m' v)  0. 
                    p m' * (v. ?mon (ρ ++ σ) m' v)))
      apply (intro Sum_any.cong sum.cong arg_cong[where f="(*) x" for x], simp)
      apply (simp add: Prod_any.expand_set)
      apply (intro prod.cong, simp_all)
      by (metis (no_types, opaque_lifting) map_add_dom_app_simps(1,3))

    also have  =
      (m. (m' | inst_monom_resid m' σ = m  p m'  0  (v. ?mon (ρ ++ σ) m' v)  0. 
                    p m' * (v. ?mon (ρ ++ σ) m' v)))
      apply (intro Sum_any.cong sum.mono_neutral_right, simp_all)
       apply (safe, simp_all)
       ― ‹fixme: cannot get auto/fastforce to do instantiations below›
       subgoal for m v z    
         by (auto dest: Prod_any_not_zero[rotated, where a=v])
      subgoal for m' v
        by (auto simp add: domIff dest: Prod_any_not_zero[rotated, where a=v])
      done

    also have  = 
      (m. (sum 
              (λm'. p m' * (v. ?mon (ρ ++ σ) m' v)) 
              {m'  {m'. p m'  0  (v. ?mon (ρ ++ σ) m' v)  0}. 
                    inst_monom_resid m' σ = m}))
      by (intro Sum_any.cong sum.cong) (auto)

    also have  = 
      (m  (λm'. inst_monom_resid m' σ) ` {m'. p m'  0  (v. the ((ρ ++ σ) v) ^ lookup m' v)  0}.
            (sum 
              (λm'. p m' * (v. ?mon (ρ ++ σ) m' v)) 
              {m'  {m'. p m'  0  (v. ?mon (ρ ++ σ) m' v)  0}. 
                    inst_monom_resid m' σ = m}))
      by (intro Sum_any.expand_superset) (auto elim: sum.not_neutral_contains_not_neutral)

    also have  = (m. p m * (v. ?mon (ρ ++ σ) m v))
      by (subst Sum_any.expand_set, subst sum.group) (auto)

    also have  = insertion_fun (the  ρ ++ σ) p
      by (simp add: insertion_fun_def)
    finally show ?thesis .
  qed
qed

― ‹Lemmas for translating the roots bound to the format of the locale assumption.›

lemma eval_minus:
  fixes p :: 'a::comm_ring_1 mpoly
  shows eval (-p) σ = - eval p σ
  using sum_negf[where f = "λa . (lookup (mapping_of p) a * (aa. the (σ aa) ^ lookup a aa))"]
  by(auto simp add: uminus_mpoly.rep_eq insertion_def insertion_aux_def insertion_fun_def)
    (smt (verit) Collect_cong Sum_any.expand_set add.inverse_neutral neg_equal_iff_equal)

lemma eval_subtr:
  fixes p q :: 'a::comm_ring_1 mpoly
  assumes vars p  dom σ vars q  dom σ
  shows eval (p - q) σ = eval p σ - eval q σ
  using assms
  by(auto simp add: vars_minus[where p = "q"] eval_minus[where p = "q"]
                    eval_add[where p = "p" and q = "-q", simplified])


subsubsection ‹Roots assumption›

lemma univariate_eval_as_insertion: 
  fixes p::'a::comm_ring_1 mpoly and r
  assumes "vars p  {x}"
  shows "eval p [x  r] = insertion (λx. r) p"
  using assms 
  by (intro insertion_irrelevant_vars) auto

lemma univariate_mpoly_roots_bound_eval:   ― ‹variant using @{term eval}
  fixes p::"'a::idom mpoly"
  assumes vars p  {x} p  0 
  shows card {r. eval p [x  r] = 0}  deg p
  using assms
  by (simp add: univariate_eval_as_insertion univariate_mpoly_roots_bound)

lemma mpoly_roots:
  fixes p q :: 'a::idom mpoly and d x 
  shows card {r. deg p  d  vars p  {x}  deg q  d  vars q  {x}  
                  p  q  eval p [x  r] = eval q [x  r]}  d
proof (intro card_indep_bound)
  assume deg p  d vars p  {x} deg q  d vars q  {x} p  q
  show card {r. eval p [x  r] = eval q [x  r]}  d
  proof -
    have card {r. eval p [x  r] = eval q [x  r]} = card {r. eval (p - q) [x  r] = 0}
      using vars p  {x} vars q  {x} by (simp add: eval_subtr)
    also have    deg (p - q) 
      using vars p  {x} vars q  {x} p  q
      by (intro univariate_mpoly_roots_bound_eval subset_trans[OF vars_subtr]) (auto)
    also have    d using deg p  d deg q  d 
      by (intro le_trans[OF deg_subtr]) (simp)
    finally show ?thesis .
  qed
qed


subsection ‹Locale interpretation›

text ‹Finally, collect all relevant lemmas and instantiate the abstract polynomials locale.›

lemmas multi_variate_polynomial_lemmas = 
  vars_finite vars_zero vars_add vars_inst 
  deg_zero deg_add deg_inst
  eval_zero eval_add eval_inst
  mpoly_roots

interpretation mpoly: 
  multi_variate_polynomial vars "deg :: 'a::{finite, idom} mpoly  nat" eval inst 
  by (unfold_locales) (auto simp add: multi_variate_polynomial_lemmas)


text ‹Here are the main results, spezialized for type @{typ 'a mpoly}. 
The completeness theorem for this type is @{thm [display] "mpoly.completeness"}
and the soundness theorem reads @{thm [display] "mpoly.soundness"}.
›

end