Theory Activation_Functions

(**********************************************************************************
 * Copyright (c) University of Exeter, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 ***********************************************************************************)

chapter‹Activation Functions›

theory Activation_Functions
imports
  "HOL-Analysis.Derivative"
  TensorFlow_Import
  NN_Common
  "Interval_Analysis.Affine_Functions"
  Jordan_Normal_Form.Matrix (* only needed for Matrix.vec type constructor in ML package *)
begin

text‹
  In this theory, we provide definitions for the most common activation functions. 
  Moreover, we also provide an ML-API for working with HOL-terms of activation functions.
›

section‹Defining Activation Functions and Their Derivatives›

text‹
  Many common activation functions use the function @{term "f(x) = e^x"} (written 
  @{term "f(x) = exp (x)"}. For those activation functions, we also define approximations
  using the Taylor series of the exponential function:
›

definition
  exptaylor n x = (i = 0..n . x^i / fact i)

lemma exptaylor2: "exptaylor 2 (x::real) = (1::real) + x + x^2/2"
  by(code_simp, simp)

subsection‹Activation Functions›

definition 
  identity = (λv. v)

lemma identity_linear[simp]: affine_fun identity
unfolding identity_def   by simp 
  
definition binary_step :: 'a::{zero, ord, one, zero}  'a where
 binary_step = (λ v. if v  0 then 0 else 1)

hide_const sign
definition
  sign = sgn

definition
  softsign = (λv. v / (¦v¦ + 1))
definition 
  logistic L k v0 = (λv. L / (1 + exp(-k * (v -v0))))
definition 
  logistictaylor n L k v0 = (λv. L / (1 + (exptaylor n (-k * (v -v0)))))


definition sigmoid :: "real  real" where
 sigmoid = (λv. 1 / (1 + exp(-v)))

definition
  sigmoidtaylor n = (λv. 1 / (1 + (exptaylor n (-v))))

lemma sigmoid = (logistic (1.0::real) 1.0 0)
  unfolding sigmoid_def logistic_def by auto 
lemma sigmoidtaylor n = (logistictaylor n (1.0::real) 1.0 0)
  unfolding sigmoidtaylor_def logistictaylor_def by auto 


definition 
  swish           = (λv. v * (sigmoid v))
definition 
  swishtaylor n    = (λv. v * (sigmoidtaylor n v))
definition 
  relu            = (λv. max 0 v)

definition 
  generalized_relu α m t = (λv. case m of Some m'  min (if v  t then α * v else v) m' 
                                         | None  if v  t then α * v else v)

  
lemma relu = (generalized_relu (0.0::real) None (0.0))
  unfolding relu_def generalized_relu_def by auto 

definition 
  softplus = (λv. ln (1 + exp v))

definition 
  elu α = (λv. if v  0 then α * ((exp v)-1) else v) 
definition 
  elutaylor n α = (λv. if v  0 then α * ((exptaylor n v)-1) else v) 
definition 
  selu = (λv. let α = 1.67326324; 
                   scale = 1.05070098 
               in if v  0 then scale * α * ((exp v)-1) else scale * v) 
definition 
  selutaylor n = (λv. let α = 1.67326324; 
                        scale = 1.05070098 
                      in if v  0 then scale * α * ((exptaylor n v)-1) else scale * v) 

definition
  prelu α  = (λv. if v < 0 then α * v else v)
definition 
  silu = (λv. v / (1 + (exp (-v))))
definition 
  silutaylor n = (λv. v / (1 + (exptaylor n (-v))))
definition
  gaussian = (λv. exp (- v2))
definition 
  gaussiantaylor n = (λv. exptaylor n (- v2))

definition 
  hard_sigmoid = (λv. if v < -2.5 then 0 else if v > 2.5 then 1 else 0.2 * v + 0.5)
definition
  gelu_approx = (λv. 0.5 * v * (1 + tanh(sqrt(2 / pi) * (v + 0.044715 * v * v * v)))) 

text‹
  Note, the error function @{term "erf"} is available in the AFP entry~cite"eberl:erf:2018", which can 
  be used for defining a non-approximated @{term "gelu"} activation function.
›

definition softmax :: "('a::{banach,real_normed_algebra_1,inverse}) list  'a list" where 
  softmax vs = map (λ v. exp v / (v'vs. exp v')) vs 

definition msoftmax :: "('a::{banach,real_normed_algebra_1,inverse}) vec  'a vec" where 
  msoftmax vs = map_vec (λ v. exp v / (v' (list_of_vec vs). exp v')) vs 


definition softmaxtaylor :: "nat  ('a::{banach,real_normed_algebra_1,inverse}) list  'a list" where
  softmaxtaylor n vs = map (λ v. (exptaylor n v) / (v'vs. (exptaylor n v'))) vs 

definition msoftmaxtaylor :: "nat  ('a::{banach,real_normed_algebra_1,inverse}) vec  'a vec" where
  msoftmaxtaylor n vs = map_vec (λ v. (exptaylor n v) / (v' (list_of_vec vs). (exptaylor n v'))) vs


lemma softmaxtaylor2: 
      "softmaxtaylor 2 vs = map (λ (v::real). (1 + v + v2/2) / (foldl (+) 0 (map (λ v'. 1 + v' + v'2/2) vs))) vs"
  unfolding softmaxtaylor_def exptaylor2  
  by (simp add: Groups.add_ac(2) fold_plus_sum_list_rev foldl_conv_fold) 

lemma softmaxtaylor2': "softmaxtaylor 2 vs = map (λ (v::real). (1 + v + v2/2) / (foldl (λa x. a + (1 + x + x2 / 2)) 0 vs)) vs"
  apply(simp add: softmaxtaylor2)
  by(code_simp, simp)

definition
  argmax vs = map (λ v. if v = Max (set vs) then 1 else 0) vs 

text‹
\autoref{tab:tensorflow-activation} provides a mapping from our names of the activation functions 
to the names used by TensorFlow (see 🌐‹https://www.tensorflow.org/api_docs/python/tf/keras/activations/›).
  
\begin{landscape}
\begin{table}
\caption{Mapping of the activation functions supported by TensorFlow.}
\label{tab:tensorflow-activation}
\begin{center}
\small\renewcommand{\arraystretch}{1.2}
\begin{tabular}{@ {}llp{14cm}@ {}}
   \toprule
                                 & TensorFlow 2.8.0 & Definition\\
    \cmidrule(r){2-3}
    @{constidentity}          & linear                         & \vspace{-.7cm}@{thm [display, margin=120]  "identity_def"}\vspace{-0.9cm}\\ 
    @{constsoftsign}          & softsign                       & \vspace{-.7cm}@{thm [display, margin=120]  "softsign_def"}\vspace{-0.9cm}\\ 
    @{constsigmoid}           & sigmoid                        & \vspace{-.7cm}@{thm [display, margin=120]  "sigmoid_def"}\vspace{-0.9cm}\\  
    @{constsigmoidtaylor}      & \multicolumn{1}{c}{--}         & \vspace{-.7cm}@{thm [display, margin=120]  "sigmoidtaylor_def"}\vspace{-0.9cm}\\  
    @{constswish}             & swish                          & \vspace{-.7cm}@{thm [display, margin=120]  "swish_def"}\vspace{-0.9cm}\\  
    @{constswishtaylor}        & \multicolumn{1}{c}{--}         & \vspace{-.7cm}@{thm [display, margin=120]  "swishtaylor_def"}\vspace{-0.9cm}\\  
    @{consttanh}              & thanh                          & \vspace{-.7cm}@{thm [display, margin=120]  "tanh_def"}\vspace{-0.9cm}\\ 
    @{constgeneralized_relu}  & relu                           & \vspace{-.7cm}@{thm [display, margin=120]  "generalized_relu_def"}\vspace{-0.9cm}\\
    @{constrelu}              & relu (with default parameters) & \vspace{-.7cm}@{thm [display, margin=120]  "relu_def"}\vspace{-0.9cm}\\ 
    @{constgelu_approx}       & gelu (approx=True)             & \vspace{-.7cm}@{thm [display, margin=120]  "gelu_approx_def"}\vspace{-0.9cm}\\
    \multicolumn{1}{c}{--}       & gelu (approx=False)            & \multicolumn{1}{c}{--}\\
    @{constsoftplus}          & softplus                       & \vspace{-.7cm}@{thm [display, margin=120]  "softplus_def"}\vspace{-0.9cm}\\     
    @{constelu}               & elu                            & \vspace{-.7cm}@{thm [display, margin=120]  "elu_def"}\vspace{-0.9cm}\\
    @{constelutaylor}          & \multicolumn{1}{c}{--}          & \vspace{-.7cm}@{thm [display, margin=120]  "elutaylor_def"}\vspace{-0.9cm}\\
    @{constselu}              & selu                           &\vspace{-.7cm}@{thm [display, margin=120]  "selu_def"}\vspace{-0.9cm}\\
    @{constselutaylor}         & \multicolumn{1}{c}{--}         & \vspace{-.7cm}@{thm [display, margin=120]  "selutaylor_def"}\vspace{-0.9cm}\\
    @{constexp}               & exponential                    & \vspace{-.7cm}@{thm [display, margin=120]  "exp_def"}\vspace{-0.9cm}\\
    @{constexptaylor}          & \multicolumn{1}{c}{--}          & \vspace{-.7cm}@{thm [display, margin=120]  "exptaylor_def"}\vspace{-0.9cm}\\
    @{consthard_sigmoid}      & hard\_sigmoid                  & \vspace{-.7cm}@{thm [display, margin=120]  "hard_sigmoid_def"}\vspace{-0.9cm}\\ 
    @{constsoftmax}           & softmax                        & \vspace{-.7cm}@{thm [display, margin=120]  "softmax_def"}\vspace{-0.9cm}\\ 
    @{constsoftmaxtaylor}      & \multicolumn{1}{c}{--}          & \vspace{-.7cm}@{thm [display, margin=120]  "softmaxtaylor_def"}\vspace{-0.9cm}\\ 
    \bottomrule
\end{tabular}\end{center}\end{table}
\end{landscape}
›

subsection‹Derivatives of Activation Functions›

lemma has_real_derivative_transform:
  x  s  (x. x  s  g x = f x)  (f has_real_derivative f') (at x within s) 
                                         (g has_real_derivative f') (at x within s)
  by (simp add: has_derivative_transform has_field_derivative_def)

lemma one_plus_exp_eq: "(1 + exp v) = (exp v) * (1 + exp (- v))  "
  by (simp add: distrib_left exp_minus_inverse)

definition identity'        = (λ v. 1.0)
lemma identity'[simp]: (identity has_real_derivative (identity' v))  (at v)
  by(simp add:identity_def identity'_def) 

definition logistic' L k v0 = (λ v. (exp( (-k)*(v-v0)) * k * L) / (1 + exp((-k)*(v-v0)))2)
lemma logistic'[simp]: ((logistic L k v0) has_real_derivative ((logistic' L k v0) v))  (at v)
  apply (simp add:logistic_def logistic'_def, intro derivative_eq_intros, simp_all)
    subgoal by (metis add_eq_0_iff exp_ge_zero le_minus_one_simps(3))
    subgoal by (simp add: power2_eq_square) 
    done 

definition tanh'            = (λ v. 1 - ((tanh v)2))
lemma tanh'[simp]: (tanh has_real_derivative (tanh' v))  (at v)
  by (auto intro: derivative_eq_intros simp add:tanh'_def) 

definition softplus'        = (λ v. 1 / (1 + exp(-v)))
lemma softplus'[simp]: (softplus has_real_derivative (softplus' v))  (at v)
    apply (simp add: softplus_def softplus'_def, intro derivative_eq_intros, simp_all add: add_pos_pos) 
    by (metis one_plus_exp_eq add.left_neutral exp_not_eq_zero mult.right_neutral 
              nonzero_divide_mult_cancel_left)

definition prelu1'          = (λ v. 1)
lemma prelu1'[simp]: ((prelu 1) has_real_derivative (prelu1' v))  (at v)
proof (-)
  have *:  (λ v. if v < (0::real) then (1::real) * v else v) = (λ v. v) by auto
  show ?thesis
  by (simp add: prelu_def prelu1'_def if_split *)
qed
  
definition silu'            = (λ v. (1 + exp (-v) + v * (exp (-v)) ) / ((1 + exp(-v))2))
lemma silu'[simp]: (silu has_real_derivative (silu' v))  (at v)
  apply(simp add:silu_def silu'_def) 
  apply (intro derivative_eq_intros, simp_all) 
  subgoal by (metis add_eq_0_iff exp_ge_zero le_minus_one_simps(3))
  subgoal by (simp add: power2_eq_square) 
  done

definition gaussian'        = (λ v. -2 * v * exp (- v2))
lemma gaussian'[simp]: (gaussian has_real_derivative (gaussian' v))  (at v)
  by (simp add:gaussian_def gaussian'_def, intro derivative_eq_intros, simp_all, simp) 


subsection‹Single Class Folding Activation Functions›

datatype activationsingle = Identity |  Sign | BinaryStep | Logistic real real real | Logistictaylor nat real real real 
                         | Tanh | Sigmoid | Sigmoidtaylor nat | ReLU | GReLU real real option real 
                         | Softplus | SoftSign | Swish | Swishtaylor nat | GeLUapprox | ELU real 
                         | ELUtaylor nat real | SELU | SELUtaylor nat | PReLU real | SiLU | SiLUtaylor nat 
                         | Gaussian | Gaussiantaylor nat | Exp | Exptaylor nat | HardSigmoid 

fun φsingle:: activationsingle  (real  real) option where
   φsingle Identity          = Some identity                 
 | φsingle Sign              = Some sign
 | φsingle BinaryStep        = Some binary_step
 | φsingle SoftSign          = Some softsign                 
 | φsingle (Logistic L k v0) = Some (logistic L k v0)           
 | φsingle (Logistictaylor n L k v0) = Some (logistictaylor n L k v0)           
 | φsingle Sigmoid           = Some sigmoid                  
 | φsingle (Sigmoidtaylor n)   = Some (sigmoidtaylor n)                  
 | φsingle Swish             = Some swish                    
 | φsingle (Swishtaylor n)    = Some (swishtaylor n)                    
 | φsingle Tanh              = Some tanh                     
 | φsingle ReLU              = Some relu                     
 | φsingle GeLUapprox        = Some gelu_approx              
 | φsingle (GReLU α m t)     = Some (generalized_relu α m t) 
 | φsingle Softplus          = Some softplus                     
 | φsingle (ELU α)           = Some (elu α)                  
 | φsingle (ELUtaylor n α)    = Some (elutaylor n α)                  
 | φsingle SELU              = Some selu                     
 | φsingle (SELUtaylor n)      = Some (selutaylor n)                     
 | φsingle Exp               = Some exp                      
 | φsingle (Exptaylor n)       = Some (exptaylor n)                      
 | φsingle HardSigmoid       = Some hard_sigmoid             
 | φsingle (PReLU α)         = Some (prelu α)
 | φsingle SiLU              = Some silu
 | φsingle (SiLUtaylor n)      = Some (silutaylor n)
 | φsingle Gaussian          = Some gaussian
 | φsingle (Gaussiantaylor n)  = Some (gaussiantaylor n)
text‹ 
  The datatype @{typeactivationsingle} enumerates a list of standard activation functions that are 
  commonly used as part of computing the weighted sum (fold) of all inputs of a neuron. The 
  function @{constφsingle} provides easy access to the activation function itself.
›

fun φsingle':: activationsingle  (real  real option) where
   φsingle' Identity          = (λv. Some (identity' v))
 | φsingle' Sign              = (λv. None)
 | φsingle' BinaryStep        = (λv. None)
 | φsingle' (Logistic L k v0) = (λv. Some (logistic' L k v0 v))
 | φsingle' (Logistictaylor n L k v0) = (λv. None)
 | φsingle' Tanh              = (λv. Some (tanh' v))
 | φsingle' ReLU              = (λv. None)
 | φsingle' Softplus          = (λv. Some (softplus' v))
 | φsingle' (ELU α)           =  (λv. None)
 | φsingle' (ELUtaylor n α)    =  (λv. None)
 | φsingle' (PReLU α)         = (λ v. if α = 1 then Some (prelu1' v) else None)
 | φsingle' SiLU              = (λv. Some (silu' v))
 | φsingle' (SiLUtaylor n)     = (λv. None)
 | φsingle' Gaussian          = (λv. Some (gaussian' v))
 | φsingle' (Gaussiantaylor n) = (λv. None)
 | φsingle' (GReLU v va vb)   = (λv. None)
 | φsingle' GeLUapprox        = (λv. None)
 | φsingle' Sigmoid           = (λv. None)
 | φsingle' (Sigmoidtaylor n)  = (λv. None)
 | φsingle' SoftSign          = (λv. None)
 | φsingle' Swish             = (λv. None)
 | φsingle' (Swishtaylor n)    = (λv. None)
 | φsingle' SELU              = (λv. None)
 | φsingle' (SELUtaylor n)     = (λv. None)
 | φsingle' Exp               = (λ v. Some (exp v))
 | φsingle' (Exptaylor n)      = (λ v. None)
 | φsingle' HardSigmoid       = (λv. None) 

text‹
  The function @{constφsingle'} defines, for derivable activation functions, their derivative. 
  Note that we require derivability in the mathematical sense. For example, while some machine 
  learning text books consider the binary step function derivable except at the point 0, we consider 
  it non derivable, as the binary step function is non continuous at the point 0. In the following, 
  we also provide the ``approximated  derivatives'' of non-continuous activation functions:
›
                 
lemma 
  assumes v  (dom (φsingle' a))
  shows   ((λ v. the (φsingle a) v) has_real_derivative (the (φsingle' a v))) (at v within (dom (φsingle' a)))
  using assms by (cases a, auto)

subsection‹Multiclass Folding Activation Functions›

datatype activationmulti = mIdentity |  mSign | mBinaryStep |  mLogistic real real real | mLogistictaylor nat real real real 
                        | mTanh  | mSigmoid | mSigmoidtaylor nat | mReLU | mGReLU real real option real
                        | mSoftplus | mSoftSign | mSwish | mSwishtaylor nat | mGeLUapprox | mELU real 
                        | mELUtaylor nat real | mSELU | mSELUtaylor nat | mPReLU real | mSiLU | mSiLUtaylor nat 
                        | mGaussian | mGaussiantaylor nat | mExp | mExptaylor nat | mHardSigmoid | mSoftmax 
                        | mSoftmaxtaylor nat | mArgmax

fun φmulti :: activationmulti   (real list  real list) option where
   φmulti  mIdentity          = Some (map identity)                 
 | φmulti  mSign              = Some (map sign)
 | φmulti  mBinaryStep        = Some (map binary_step)
 | φmulti  mSoftSign          = Some (map softsign)                 
 | φmulti  (mLogistictaylor n L k v0) = Some (map (logistictaylor n L k v0))           
 | φmulti  (mLogistic L k v0) = Some (map (logistic L k v0))           
 | φmulti  mSigmoid           = Some (map sigmoid)                   
 | φmulti  (mSigmoidtaylor n)  = Some (map (sigmoidtaylor n))                   
 | φmulti  mSwish             = Some (map swish)                      
 | φmulti  (mSwishtaylor n)    = Some (map (swishtaylor n))                      
 | φmulti  mTanh              = Some (map tanh)                     
 | φmulti  mReLU              = Some (map relu)                      
 | φmulti  mGeLUapprox        = Some (map gelu_approx)               
 | φmulti  (mGReLU α m t)     = Some (map (generalized_relu α m t))  
 | φmulti  mSoftplus          = Some (map softplus)                      
 | φmulti  (mELU α)           = Some (map (elu α))                  
 | φmulti  (mELUtaylor n α)    = Some (map (elutaylor n α))                  
 | φmulti  mSELU              = Some (map selu)                     
 | φmulti  (mSELUtaylor n)     = Some (map (selutaylor n))                     
 | φmulti  mExp               = Some (map exp)                      
 | φmulti  (mExptaylor n)      = Some (map (exptaylor n))                      
 | φmulti  mHardSigmoid       = Some (map hard_sigmoid)              
 | φmulti  (mPReLU α)         = Some (map (prelu α))
 | φmulti  mSiLU              = Some (map silu)
 | φmulti  (mSiLUtaylor n)     = Some (map (silutaylor n))
 | φmulti  mGaussian          = Some (map gaussian)
 | φmulti  (mGaussiantaylor n) = Some (map (gaussiantaylor n))
 | φmulti  mSoftmax           = Some softmax
 | φmulti  (mSoftmaxtaylor n)  = Some (softmaxtaylor n)
 | φmulti  mArgmax            = Some argmax

text‹ 
  The datatype @{typeactivationmulti} enumerates a list of standard activation functions that are 
  commonly used as part of computing the weighted sum (fold) of all inputs of a neuron. The 
  function @{constφsingle} provides easy access to the activation function itself.
›

section‹Encoding of Activion Functions›

MLsignature ACTIVATION_TERM = sig
    datatype mode = MultiList | MultiMatrix | Single
    datatype activationT = Elu | Exponential | GRelu | Gelu | Hard_sigmoid | Linear | Relu | Selu 
                         | Sigmoid | Softmax | Softmax_taylor | Softplus | Softsign | Sign | BinaryStep | Swish | Tanh
                         | Sigmoid_taylor
    val add_function: binding -> term list -> local_theory -> Proof.context
    val def_phi_tab: mode -> string -> activationT list -> local_theory -> Proof.context
    val term_of_activation_eqn_multi_list: activationT -> term
    val term_of_activation_eqn_multi_matrix: activationT -> term
    val term_of_activation_eqn_single: activationT -> term
    val term_of_activation_multi: activationT -> term
    val term_of_activation_single: activationT -> term
  end

ML_file ‹Tools/Activation_Functions.ML›
                        
text‹
  The ML structure @{ML_structure  Activation_Term:ACTIVATION_TERM } provides the core infrastructure
  to construct HOL terms for the activation on the ML-level.
›

end