Theory LogisticFunction

(*  Title:       LogisticFunction.thy
    Author:      Filip Smola, 2019-2021
*)

theory LogisticFunction
  imports HyperdualFunctionExtension
begin

subsection‹Logistic Function›

text‹Define the standard logistic function and its hyperdual variant:›
definition logistic :: "real  real"
  where "logistic x = inverse (1 + exp (-x))"
definition hyp_logistic :: "real hyperdual  real hyperdual"
  where "hyp_logistic x = inverse (1 + (*h* exp) (-x))"

text‹Hyperdual extension of the logistic function is its hyperdual variant:›
lemma hypext_logistic:
  "(*h* logistic) x = hyp_logistic x"
proof -
  have "(*h* (λx. exp (- x) + 1)) x = (*h* exp) (- x) + of_comp 1"
    by (simp add: hypext_compose hypext_uminus hypext_fun_cadd twice_field_differentiable_at_compose)
  then have "(*h* (λx. 1 + exp (- x))) x = 1 + (*h* exp) (- x)"
    by (simp add: one_hyperdual_def add.commute)
  moreover have "1 + exp (- Base x)  0"
    by (metis exp_ge_zero add_eq_0_iff neg_0_le_iff_le not_one_le_zero)
  moreover have "(λx. 1 + exp (- x)) twice_field_differentiable_at Base x"
  proof -
    have "(λx. exp (- x)) twice_field_differentiable_at Base x"
      by (simp add: twice_field_differentiable_at_compose)
    then have "(λx. exp (- x) + 1) twice_field_differentiable_at Base x"
      using twice_field_differentiable_at_compose[of "λx. exp (- x)" "Base x" "λx. x + 1"]
      by simp
    then show ?thesis
      by (simp add: add.commute)
  qed
  ultimately have "(*h* (λx. inverse (1 + exp (- x)))) x = inverse (1 + (*h* exp) (- x))"
    by (simp add: hypext_fun_inverse)
  then show ?thesis
    unfolding logistic_def hyp_logistic_def .
qed

text‹From properties of autodiff we know it gives us the derivative:›
lemma "Eps1 (hyp_logistic (β x)) = deriv logistic x"
  by (metis Eps1_hypext hypext_logistic)
text‹which is equal to the known derivative of the standard logistic function:›
lemma "First (autodiff logistic x) = exp (- x) / (1 + exp (- x)) ^ 2"
  (* Move to hyperdual variant: *)
  apply (simp only: autodiff.simps hyperdual_to_derivs.simps derivs.sel hypext_logistic)
  (* Unfold extensions of functions that have a hyperdual variant (all except exp): *)
  apply (simp only: hyp_logistic_def inverse_hyperdual.code hyperdual.sel)
  (* Finish by expanding the extension of exp and hyperdual computations: *)
  apply (simp add: hyperdualx_def hypext_exp_Hyperdual hyperdual_bases)
  done

text‹Similarly we can get the second derivative:›
lemma "Second (autodiff logistic x) = deriv (deriv logistic) x"
  by (rule autodiff_extract_second)
text‹and derive its value:›
lemma "Second (autodiff logistic x) = ((exp (- x) - 1) * exp (- x)) / ((1 + exp (- x)) ^ 3)"
  (* Move to hyperdual variant: *)
  apply (simp only: autodiff.simps hyperdual_to_derivs.simps derivs.sel hypext_logistic)
  (* Unfold extensions of functions that have a hyperdual variant (all except exp): *)
  apply (simp only: hyp_logistic_def inverse_hyperdual.code hyperdual.sel)
  (* Finish by expanding the extension of exp and hyperdual computations: *)
  apply (simp add: hyperdualx_def hypext_exp_Hyperdual hyperdual_bases)
  (* Simplify the resulting expression: *)
proof -
  have
    "2 * (exp (- x) * exp (- x)) / (1 + exp (- x)) ^ 3 - exp (- x) / (1 + exp (- x)) ^ 2 =
     (2 * exp (- x) / (1 + exp (- x)) ^ 3 - 1 / (1 + exp (- x)) ^ 2) * exp (- x)"
    by (simp add: field_simps)
  also have "... = (2 * exp (- x) / (1 + exp (- x)) ^ 3 - (1 + exp (- x)) / (1 + exp (- x)) ^ 3) * exp (- x)"
  proof -
    have "inverse ((1 + exp (- x)) ^ 2) = inverse (1 + exp (- x)) ^ 2"
      by (simp add: power_inverse)
    also have "... = (1 + exp (- x)) * inverse (1 + exp (- x)) * inverse (1 + exp (- x)) ^ 2"
      by (simp add: inverse_eq_divide)
    also have "... = (1 + exp (- x)) * inverse (1 + exp (- x)) ^ 3"
      by (simp add: power2_eq_square power3_eq_cube)
    finally have "inverse ((1 + exp (- x)) ^ 2) = (1 + exp (- x)) * inverse ((1 + exp (- x)) ^ 3)"
      by (simp add: power_inverse)
    then show ?thesis
      by (simp add: inverse_eq_divide)
  qed
  also have "... = (2 * exp (- x) - (1 + exp (- x))) / (1 + exp (- x)) ^ 3 * exp (- x)"
    by (metis diff_divide_distrib)
  finally show
    "2 * (exp (- x) * exp (- x)) / (1 + exp (- x)) ^ 3 - exp (- x) / (1 + exp (- x))2 =
     (exp (- x) - 1) * exp (- x) / (1 + exp (- x)) ^ 3"
    by (simp add: field_simps)
qed

end