Theory Properties

(***********************************************************************************
 * Copyright (c) University of Exeter, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 ***********************************************************************************)

section‹Desirable Properties of Neural Networks Predictions›

theory Properties
imports 
  Prediction_Utils 
  "HOL-Library.Interval"
  "HOL-Library.Interval_Float"
begin 

subsection‹Approximate Comparison of Results›
definition approx a ε b = (¦a-b¦  ε)
notation approx ("((_)/ ≈[(_)]≈ (_))" [60, 60] 60)

fun checkget_result_list where
   checkget_result_list _ None None           = (None,True) 
 | checkget_result_list ε (Some xs) (Some ys) = (Some xs, fold (∧) (map2 (λ x y. x ≈[ε]≈ y)  xs ys) True)
 | checkget_result_list _ r _ = (r,False)

definition check_result_list r ε s = snd (checkget_result_list ε r s)
notation check_result_list ("((_)/ ≈[(_)]≈l (_))" [60, 60] 60)

fun checkget_result_singleton where
   checkget_result_singleton _ None None         = (None,True) 
 | checkget_result_singleton ε (Some x) (Some y) = (Some x, x ≈[ε]≈ y)
 | checkget_result_singleton _ r _ = (r,False)

definition check_result_singleton r ε s = snd (checkget_result_singleton ε r s)
notation check_result_singleton ("((_)/ ≈[(_)]≈s (_))" [60, 60] 60)

definition 
  ensure_testdata_range_list :: real  real list list  (real list  real list)  real list list  bool 
  where
 ensure_testdata_range_list delta inputs P outputs
   =  foldl (∧) True
      (map (λ e. (P (fst e))  ≈[delta]≈l Some (snd e))
           (zip inputs outputs))     
notation ensure_testdata_range_list ("(_) l {(_)} (_) {(_)}" [61, 3, 90, 3] 60)


subsubsection‹Interval Arithmetic›
definition interval_distance :: 'a::{preorder,minus, zero, ord} interval  'a interval  'a where
          interval_distance a b = (let (la, ua) = bounds_of_interval a;
                                        (lb, ub) = bounds_of_interval b
                                    in if ua  lb then lb - ua
                                                   else if ub  la then la -ub
                                                                    else 0)

fun intervals_of_list where 
   intervals_of_list _ [] = []
 | intervals_of_list δ (x#xs) = (Interval (x- ¦δ¦, x+¦δ¦))#(intervals_of_list δ xs)

definition intervals_of_l δ = map (intervals_of_list δ)

lemma intervall_in_implies_set: " (x  {a..b})  (x  set_of (Interval (a,b))) "
  by (metis atLeastAtMost_iff dual_order.trans fst_conv lower_Interval set_of_eq snd_conv upper_Interval) 

lemma in_set_interval: "a b  (x  set_of (Interval (a,b))) = (x  {a..b})"
  by (simp add: lower_Interval set_of_eq upper_Interval) 

fun check_result_list_interval_list :: 'a::preorder list option  'a interval list option  bool where
   check_result_list_interval_list None None           = True 
 | check_result_list_interval_list (Some xs) (Some ys) = fold (∧) (map2 (λ x y. x  set_of y)  xs ys) True 
 | check_result_list_interval_list _ _ = False

notation check_result_list_interval_list ("((_)/ l (_))" [60, 60] 60)

text ‹We define @{term "check_result_list_interval"} for checking that two lists are approximatively 
equal (we need the error interval due to possible rounding errors in IEEE754 arithmetic in python 
compared to mathematical reals in Isabelle).›

definition
  ensure_testdata_interval_list :: real list list  (real list  real list)  real interval list list  bool 
  where
 ensure_testdata_interval_list inputs P outputs
   =  foldl (∧) True
      (map (λ e. let a = (P (fst e)) in let b = Some (snd e) in (a l b))
           (zip inputs outputs))     

notation ensure_testdata_interval_list ("il {(_)} (_) {(_)}" [3, 90, 3] 60)

text ‹Using @{term "check_result_list_interval"} we now define the property 
@{term "ensure_testdata_interval"} to check that the (symbolically) computed predictions of a neural 
network meet our expectations.›

subsection‹Maximum Classifiers›
definition
  ensure_testdata_max_list :: real list list  (real list  real list)  real list list  bool 
  where
 ensure_testdata_max_list inputs P outputs 
   =  foldl (∧) True
      (map (λ e. case P (fst e) of 
                   None  False 
                 | Some p  pos_of_max p = pos_of_max (snd e)) 
           (zip inputs outputs))
notation ensure_testdata_max_list ("l {(_)} (_) {(_)}" [3, 90, 3] 60)

text ‹Many classification networks use the maximum output as the result, without normalisation 
(e.g., to values between 0 and 1). In such cases, a weaker form of ensuring compliance to 
predictions might be used that only checks that checks for the maximum output of each given input, 
this can be tested using @{term "ensure_testdata_max"}

definition ensure_delta_min :: real  (real list  real list)  bool where 
          ensure_delta_min δ P = ( xs  ran P. δ  δmin xs)
notation ensure_delta_min ("(_)  (_)" [61, 90] 60)

lemma ensure_delta_min_dom: ensure_delta_min δ P = ( x  dom P. δ  δmin (the (P x)))
  by(auto simp add:ensure_delta_min_def dom_def ran_def)

text ‹
Further properties that we formalised can increase the confidence in the predictions of a neural 
network by reducing the likelihood of ambiguous classification results. This includes, e.g., the
requirement that for a given input, the classification outputs have at least a
given minimum distance (e.g., avoiding situations where all classification
outputs show nearly identical values) shown in @{term "ensure_delta_min"}.›

subsection‹Distance-based Properties›

subsubsection‹Distance and Measurements›

locale distance = 
  fixes d::'a list  'a list  ('b::{linordered_ab_group_add})
  assumes identity: length x = length y  (d x y = 0) = (x = y)
  and symmetry:     (d x y = d y x)
  and triangle_inequality: length x = length y ; length z = length y  (d x z  d x y + d y z)
begin

lemma zero: (d x y = 0) = (x = y)  (length x  length y)
  using distance_axioms distance_def by fastforce 

lemma length x = length y ; length z = length y  d x y + d y x  d x x
  using distance.triangle_inequality distance_axioms by blast

lemma  length x = length y  0  d x y
  using distance_axioms distance_def
        linordered_ab_group_add_class.zero_le_double_add_iff_zero_le_single_add
  by (metis (mono_tags, opaque_lifting))

end 


definition mapfoldr :: ('a  'a  'b)  ('b  'c  'c)  'c  'a list  'a list  'c where 
mapfoldr map_f fold_f e xs ys = foldr fold_f (map2 (λ e0 e1 . map_f e0 e1) xs ys) e

definition hamming::'a list  'a list  nat where
          hamming x y = mapfoldr (=) (λ e a. if e then a else a + 1) 0 x y



lemma hamming_identity: length x = length y  (hamming x y = 0) = (x = y)
proof(induction rule:list_induct2)
  case Nil
  then show ?case 
    by (simp add: hamming_def mapfoldr_def) 
next
  case (Cons x xs y ys) note * = this
  then show ?case 
    by (simp add: hamming_def mapfoldr_def) 
qed

lemma hamming_symmetry: hamming x y = hamming y x
  apply(simp add: hamming_def mapfoldr_def)
  using list_all2_all_nthI[where P="(=)", unfolded list.rel_eq]
  by (smt (verit, best) map2_map_map map_eq_conv zip_commute zip_map_fst_snd)

lemma hamming_unroll: "length xs = length ys 
       hamming (x#xs) (y#ys) = (if x = y then hamming xs ys else 1 + hamming xs ys)"
proof(cases "x = y")
  case True
  then show ?thesis by(simp add: hamming_def mapfoldr_def)
next
  case False
  then show ?thesis by(simp add: hamming_def mapfoldr_def)
qed

lemma hamming_triangle_inequality: 
    length xs = length ys ; length ys = length zs 
 hamming xs zs  hamming xs ys + (hamming ys zs)
proof(induction rule:list_induct3)
  case Nil
  then show ?case by simp
next
  case (Cons x xs y ys z zs)
  then show ?case
    by (simp add: hamming_unroll)
qed

global_interpretation hamming_distance: distance hamming
  apply(unfold_locales)
  subgoal by(simp add: hamming_identity)  
  subgoal by(simp add: hamming_symmetry)  
  subgoal by (metis hamming_triangle_inequality of_nat_add of_nat_le_iff)
  done

definition manhattan::real list  real list  real where
          manhattan = mapfoldr (λ x y . ¦x-y¦) (+)  0

  lemma manhattan_unroll: "length xs = length ys 
         manhattan (x#xs) (y#ys) = ¦x - y¦ + manhattan xs ys"
    by(simp add: manhattan_def mapfoldr_def)
  
  
  lemma manhattan_positive: length x = length y  0  manhattan x y
  proof(induction rule:list_induct2)
    case Nil
    then show ?case by (simp add: manhattan_def mapfoldr_def)
  next
    case (Cons x xs y ys) note * = this
    then show ?case 
      using manhattan_unroll[of xs ys x y] by simp 
  qed
  
  lemma manhattan_identity: length x = length y  (manhattan x y = 0) = (x = y)
  proof(induction rule:list_induct2)
    case Nil
    then show ?case by (simp add: manhattan_def mapfoldr_def)
  next
    case (Cons x xs y ys) note * = this
    then show ?case 
      proof(cases "x = y")
        case True
        then show ?thesis 
      using manhattan_unroll[of xs ys x y]
      by (simp add: Cons.IH Cons.hyps)
      next
        case False
        then show ?thesis 
        using manhattan_unroll[of xs ys x y] * manhattan_positive
        by (simp add: add_nonneg_eq_0_iff) 
      qed
  qed

lemma manhattan_symmetry: manhattan x y = manhattan y x
  apply (induct x y rule:list_induct2')
  subgoal by(simp add: manhattan_def mapfoldr_def) 
  subgoal by(simp add: manhattan_def mapfoldr_def) 
  subgoal by(simp add: manhattan_def mapfoldr_def) 
  subgoal by(simp add: manhattan_def mapfoldr_def) 
  done

lemma manhattan_triangle_inequality: 
    length xs = length ys ; length ys = length (zs::real list) 
 manhattan xs zs  manhattan xs ys + (manhattan ys zs)
proof(induction rule:list_induct3)
  case Nil
  then show ?case by(simp add:manhattan_def mapfoldr_def)   
next
  case (Cons x xs y ys z zs)
  then show ?case
    by (simp add: manhattan_unroll)
qed

global_interpretation manhattan_distance: distance manhattan
  apply(unfold_locales)
  subgoal by(simp add: manhattan_identity)
  subgoal by(simp add: manhattan_symmetry)  
  subgoal by(simp add: manhattan_triangle_inequality)
  done

definition avg_difference::real list  real list  real where
          avg_difference xs ys = (manhattan xs ys) / (min (length xs) (length ys))

global_interpretation avg_difference_distance: distance avg_difference
  apply(unfold_locales)
  subgoal using avg_difference_def distance_def manhattan_distance.distance_axioms
    by fastforce
  subgoal using avg_difference_def distance_def manhattan_distance.distance_axioms
    by (metis min.commute)
  subgoal unfolding avg_difference_def distance_def manhattan_distance.distance_axioms
    by (metis add_divide_distrib divide_right_mono manhattan_distance.triangle_inequality of_nat_0_le_iff)
  done


definition euclidean::real list  real list  real where
          euclidean X Y = sqrt (mapfoldr (λ x y . (x-y)2) (+)  0 X Y)

lemma euclidean_positive: length x = length y  0  euclidean x y
  proof(induction rule:list_induct2)
    case Nil
    then show ?case 
      by(simp add: euclidean_def mapfoldr_def)
  next
    case (Cons x xs y ys)
    then show ?case 
      by(simp add: euclidean_def mapfoldr_def)
  qed

lemma euclidean_identity: length x = length y  (euclidean x y = 0) = (x = y)
proof(induction rule:list_induct2)
  case Nil
  then show ?case by(simp add: euclidean_def mapfoldr_def)
next
  case (Cons x xs y ys) note * = this
  then show ?case  
  proof(cases "x = y")
    case True
    then show ?thesis using * by(simp add: euclidean_def mapfoldr_def)
  next
    case False note ** = this
    then have ***: x  y  0  (x - y)2 by simp
    then show ?thesis apply(simp add: euclidean_def euclidean_positive * ** *** mapfoldr_def)
      using ** euclidean_positive[unfolded euclidean_def mapfoldr_def, simplified]
      by (simp add: Cons.hyps add_nonneg_eq_0_iff)  
  qed
qed 


lemma euclidean_symmetry: euclidean x y = euclidean y x
  apply (induct x y rule:list_induct2')
  subgoal by(simp add: euclidean_def mapfoldr_def) 
  subgoal by(simp add: euclidean_def mapfoldr_def) 
  subgoal by(simp add: euclidean_def mapfoldr_def) 
  subgoal by(simp add: euclidean_def mapfoldr_def power2_commute) 
  done

definition 
  check :: ('a list  'a list  'b)  ('b  bool)  'a list  ('a list  'a list) 
              ('a list option   'a list option  bool)  bool where
 check d P inputref prediction P' 
        = ( x  dom prediction. P(d inputref x)  P' (prediction x) (prediction inputref))


lemma " (( l  dom prediction. P(dist i l)  P' (prediction l) (prediction i)))
      = (( l  {l  dom prediction .  P (dist i l)}.  P' (prediction l)  (prediction i)))"
  by auto 


lemma hamming_update_1: 
  "length xs = length ys  hamming xs ys  1  ( i. xs = ys[i := xs!i])"
proof(induction rule:list_induct2)
  case Nil
  then show ?case by simp
next
  case (Cons x' xs' y' ys')
  then show ?case unfolding hamming_def mapfoldr_def
  proof(cases "x' = y'") 
    case True note * = this
    then have h: "hamming xs' ys'  1" using hamming_unroll by (metis Cons.hyps Cons.prems) 
    then show ?thesis 
    proof -
      obtain i where xs'_eq: "xs' = ys'[i := xs'!i]" using Cons.IH h by blast
      show ?thesis
      proof (cases i)
        case 0
        then show ?thesis 
          apply (intro exI[of _ "Suc 0"])
          apply simp
          using xs'_eq * by simp
      next
        case (Suc j)
        then show ?thesis
          apply (intro exI[of _ "Suc (Suc j)"])
          apply (simp split: nat.splits)
          using xs'_eq * Suc by simp
      qed
    qed
  next
    case False note ** = this
    then show ?thesis proof(cases " hamming xs' ys' = 0")
      case True
      then show ?thesis using hamming_identity 
        by (metis Cons.hyps list_update_code(2) nth_Cons_0)
    next
      case False
      then have h: hamming xs' ys' = 1 using ** Cons.hyps Cons.prems hamming_unroll 
        by fastforce 
      then show ?thesis
        apply(intro exI[of _ "0"])
        apply simp
        using h Cons.hyps Cons.prems hamming_unroll **
        by (metis add_le_same_cancel1 not_one_le_zero)
    qed
  qed
qed

lemma hamming_cases1:
  assumes l: length xs = length ys
  and h: hamming xs ys  1
  and p: P xs
  and u:  i. i  < length xs  ys = xs[i := (ys!i)]  P ys
shows P ys
proof(insert assms, induct "xs")
  case Nil
  then show ?case by simp
next
  case (Cons x' xs')
  then show ?case  
    using hamming_update_1[of "x'#xs'" "ys"]
          hamming_update_1 One_nat_def linorder_not_le list_update_beyond
    by (metis hamming_symmetry) 
qed


lemma hamming_update_2: 
  "length xs = length ys  hamming xs ys  2  ( i j. xs = (ys[i := xs!i])[j:= xs!j])"
proof(induction rule:list_induct2)
  case Nil
  then show ?case by simp
next
  case (Cons x xs y ys)
  then show ?case 
    by (metis Suc_1 Suc_eq_plus1_left Suc_le_mono hamming_unroll hamming_update_1 
              list_update_code(2) list_update_code(3) nth_Cons_0 nth_Cons_Suc) 
qed

lemma hamming_cases2:
  assumes l: length xs = length ys
  and h: hamming xs ys  2
  and p: P xs
  and u:  i j. i < length xs  j < length xs  ys = xs[i := ys!i,j := ys!j]  P ys
shows P ys
proof(insert assms, induct "xs")
  case Nil
  then show ?case by simp
next
  case (Cons x' xs')
  then show ?case  
    using  hamming_update_2[of "x'#xs'" "ys"]
    by (metis  hamming_symmetry hamming_update_2 length_list_update list_update_beyond
               list_update_id verit_comp_simplify1(3))
  qed

lemma hamming_update_n:
  "length xs = length ys  hamming xs ys = Suc n  ( i. hamming xs (ys[i := xs!i]) = n)"
proof(induction rule:list_induct2)
  case Nil
  then show ?case unfolding hamming_def mapfoldr_def by simp 
next
  case (Cons x xs y ys)
  then show ?case 
  by (metis Suc_eq_plus1_left diff_Suc_1 hamming_unroll length_list_update list_update_code(2) 
            list_update_code(3) nth_Cons_0 nth_Cons_Suc)
qed

lemma hamming_update_3: 
  "length xs = length ys  hamming xs ys  3  ( i j k. xs = ys[i := xs!i,j:= xs!j,k:=xs!k])"
proof(induction rule:list_induct2)
  case Nil
  then show ?case by simp
next
  case (Cons x xs y ys)
  then show ?case proof(cases "hamming xs ys =3")
    case True
    then show ?thesis 
          using  hamming_update_n hamming_update_2
          by (metis (mono_tags, opaque_lifting) Cons.hyps Cons.prems One_nat_def Suc_1 le_Suc_eq 
                     length_Cons length_list_update list_update_id numeral_3_eq_3)
  next
    case False note * = this
    then show ?thesis 
    proof(cases "hamming xs ys < 3")
      case True
      then show ?thesis 
        by (metis Cons.hyps One_nat_def Suc_1 hamming_update_2 less_Suc_eq_le list_update_code(2) 
                  list_update_code(3) nth_Cons_0 nth_Cons_Suc numeral_3_eq_3)
    next
      case False
      then show ?thesis 
        using * Cons.prems Cons.hyps hamming_unroll[of "xs" "ys" "x" "y"] 
        by(simp split:if_splits) 
    qed
  qed
qed

lemma hamming_cases3:
  assumes l: length xs = length ys
  and h: hamming xs ys  3
  and p: P xs
  and u:  i j k. i < length xs  j < length xs  k < length xs  ys = xs[i := ys!i,j := ys!j,k := ys!k]  P ys
shows P ys
proof(insert assms, induct "xs")
  case Nil
  then show ?case by simp
next
  case (Cons x' xs')
  then show ?case proof(cases "hamming (x'#xs') ys = 3")
    case True
    then show ?thesis 
      using hamming_symmetry hamming_update_2 hamming_update_3 Cons.prems linorder_not_le 
            length_list_update length_Cons list_update_id list_update_beyond
      apply(simp)
      by (smt (verit, del_insts) Cons.prems(2) Cons.prems(4) hamming_symmetry hamming_update_3 
              length_Cons length_list_update linorder_not_le list_update_beyond list_update_id) 
  next
    case False
    then have h: "hamming (x'#xs') ys  2" using Cons.prems by simp
    then show ?thesis 
        using hamming_cases2 Cons.prems
        by (metis list_update_id)
  qed
qed

end