Theory NN_Digraph

(***********************************************************************************
 * Copyright (c) University of Exeter, UK
 *
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 *
 * * Redistributions of source code must retain the above copyright notice, this
 *
 * * Redistributions in binary form must reproduce the above copyright notice,
 *   this list of conditions and the following disclaimer in the documentation
 *   and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 * SPDX-License-Identifier: BSD-2-Clause
 ***********************************************************************************)

section‹Neural Networks as Graphs›
text‹
  In this theory, we use the AFP entry ``Graph Theory''~cite"noschinski:graph:2013"
  to model neural networks. In particular, we make use of the formalization of directed 
  graphs.
›

theory NN_Digraph
imports
  Graph_Theory.Digraph
begin

definition
  pipe :: 'a  ('a  'b)  'b (infixl  70)  where
  a  f  = f a

text‹
We follow the notation used in cite"aggarwal:neural:2018", i.e., a neural network consists
our of edges and neurons (nodes). 
›

type_synonym id = nat

record ('a, 'b) Neuron  = 
  φ :: 'b             ― ‹activation function› 
  α :: 'a             ― ‹learning rate›
  β :: 'a             ― ‹bias›
  uid :: id           ― ‹unique identifier›

datatype ('a, 'b) neuron = In id | Out id | Neuron ('a, 'b) Neuron

fun uid where
  uid (In nid)   = nid
| uid (Out nid)  = nid
| uid (Neuron n) = Neuron.uid n

record ('a, 'b) edge = 
  ω  :: 'a            ― ‹weight input to head›
  tl :: ('a, 'b) neuron   ― ‹source neuron›
  hd :: ('a, 'b) neuron   ― ‹target neuron›

type_synonym ('a, 'b) nn_pregraph = (('a, 'b) neuron, ('a, 'b) edge) pre_digraph

definition upd_edge :: ('a, 'b) nn_pregraph  (('a, 'b) edge  ('a, 'b) edge)   
                         ('a, 'b) nn_pregraph where
          upd_edge G upd   =  
                                             verts = verts G , 
                                             arcs = upd ` (arcs G), 
                                             tail = tail G, 
                                             head = head G
                                          

definition updω ω' hd_nid tl_nid a = (if uid (hd a) = hd_nid  uid (tl a) = tl_nid 
                                      then ω = ω', tl = tl a, hd = hd a  
                                      else a)

definition upd_neuron :: ('a, 'b) nn_pregraph  (('a, 'b) Neuron  ('a, 'b) Neuron)  
                          ('a, 'b) nn_pregraph where
          upd_neuron G upd = (let upd_Neuron = case_neuron In Out (λ n. Neuron (upd n))
                              in           
                                              verts = upd_Neuron ` (verts G) , 
                                              arcs = (λ a.  ω = ω a,  
                                                             tl = upd_Neuron (tl a), 
                                                             hd = upd_Neuron (hd a)) ` (arcs G), 
                                              tail = tail G, 
                                              head = head G
                                           )

definition updφ φ' nid n = (if Neuron.uid n = nid 
                             then φ = φ', α = α n, β = β n, uid = Neuron.uid n  
                             else n)

definition updβ β' nid n = (if Neuron.uid n = nid 
                             then φ = φ n, α = α n, β = β', uid = Neuron.uid n  
                             else n)

definition updα α' nid n  = (if Neuron.uid n = nid 
                             then φ = φ n, α = α', β = β n, uid = Neuron.uid n  
                             else n)

text‹
  A neural network is a directed graph without loops and without multi-edges. Moreover, 
  @{term "id"} of neurons are unique.
›

definition input_verts :: (('a, 'b) neuron, ('a, 'b) edge) pre_digraph  ('a, 'b) neuron set 
  where
          input_verts G = (verts G) - (hd ` arcs G)

definition output_verts :: (('a, 'b) neuron, ('a, 'b) edge) pre_digraph  ('a, 'b) neuron set 
  where
          output_verts G = (verts G) - (tl ` arcs G)

definition internal_verts :: (('a, 'b) neuron, ('a, 'b) edge) pre_digraph  ('a, 'b) neuron set
  where
          internal_verts G = (verts G) - ((input_verts G)  (output_verts G))

locale nn_pregraph = digraph G 
  for G::(('a::{comm_monoid_add,times,linorder}, 'b) neuron, ('a, 'b) edge) pre_digraph + 
  assumes id_vert_inj: inj_on uid (verts G)
  and     tail_eq_tl:  tail G = tl
  and     head_eq_hd:  head G = hd
  and     ids_growing:  e  arcs G. uid (tl e) < uid (hd e) ―‹Not strictly necessary, but simplifies termination proofs.›
begin

lemma nn_pregraph: "nn_pregraph G" by intro_locales

end 

definition uids G = uid ` verts G

subsection‹Neurons as Vertices›
context nn_pregraph
begin 
subsubsection‹The operation @{constadd_vert} preserves neural networks›

lemma nn_pregraph_add_neuron: 
  assumes uid n  (uids G)  n  verts G 
  shows nn_pregraph (add_vert n)
  apply standard
  subgoal by (simp add: wf_digraph.tail_in_verts wf_digraph_add_vert) 
  subgoal by (simp add: wf_digraph.head_in_verts wf_digraph_add_vert)
  subgoal by (simp add: verts_add_vert) 
  subgoal by (simp add: arcs_add_vert) 
  subgoal by (simp add: arcs_add_vert head_add_vert no_loops pre_digraph.tail_add_vert)
  subgoal by (simp add: arc_to_ends_def arcs_add_vert head_add_vert no_multi_arcs tail_add_vert) 
  subgoal proof(cases n  verts G)
    case True
    then show ?thesis 
      by (simp add: id_vert_inj insert_absorb verts_add_vert)
  next
    case False
    then show ?thesis 
      using assms verts_add_vert arcs_add_vert head_add_vert no_loops apply(simp)
      using id_vert_inj uids_def image_def by blast 
  qed
  subgoal using tail_eq_tl tail_add_vert by simp 
  subgoal using head_eq_hd head_add_vert by simp 
  subgoal using arcs_add_vert ids_growing by blast 
done 


definition add_neuron::('a, 'b) neuron  ('a, 'b) nn_pregraph where
          add_neuron n = (if (uid n  (uids G)  n  verts G ) then add_vert n else G)

lemma nn_pregraph_add_nn_neuron:  nn_pregraph (add_neuron a)
  using add_neuron_def nn_pregraph_add_neuron nn_pregraph_axioms
  by simp
end

subsubsection‹The operation @{constpre_digraph.del_vert} preserves neural networks›
context nn_pregraph 
begin 

lemma nn_pregraph_del_vert: nn_pregraph (del_vert n)
  apply standard
  subgoal by (simp add: wf_digraph.tail_in_verts wf_digraph_del_vert) 
  subgoal
    apply(simp add: ends_del_vert no_multi_arcs pre_digraph.arcs_del_vert inj_on_def)
    by (simp add: head_del_vert verts_del_vert)
  subgoal by (simp add: fin_digraph.finite_verts fin_digraph_del_vert)
  subgoal by (simp add: fin_digraph.finite_arcs fin_digraph_del_vert)
  subgoal by (simp add: head_del_vert no_loops pre_digraph.arcs_del_vert tail_del_vert)
  subgoal by (simp add: ends_del_vert no_multi_arcs pre_digraph.arcs_del_vert)
  subgoal 
    apply(simp add: ends_del_vert no_multi_arcs pre_digraph.arcs_del_vert inj_on_def)
    by (metis Diff_iff id_vert_inj inj_on_def verts_del_vert)
  subgoal using tail_eq_tl tail_del_vert by simp 
  subgoal using head_eq_hd head_del_vert by simp 
  subgoal using arcs_del_vert ids_growing by blast 
  done 

end 
subsection‹Arcs (Edges)›

declare pre_digraph.add_arc_def [code]
definition add_nn_edge G a =   (if (uid (tl a)  (uids G)  (tl a)  verts G) 
                                   (uid (hd a)  (uids G)  (hd a)  verts G)
                                   uid (hd a)  uid (tl a) 
                                   ((arc_to_ends G a)  arcs_ends G  a  arcs G)
                                   uid (tl a) < uid (hd a)
                               then pre_digraph.add_arc G a
                               else G)             

context nn_pregraph
begin 

subsubsection‹The operation @{constadd_arc} preserves neural networks›
lemma nn_pregraph_add_arc: 
  assumes uid (tl a)  (uids G)  (tl a)  verts G 
  and     uid (hd a)  (uids G)  (hd a)  verts G 
  and     uid (tl a) < uid (hd a)
  and     uid (hd a)  uid (tl a) 
  and     (arc_to_ends G a)  arcs_ends G  a  arcs G
shows nn_pregraph (add_arc a)
  apply standard
  subgoal by (meson wf_digraph.tail_in_verts wf_digraph_add_arc) 
  subgoal by (meson wf_digraph_add_arc wf_digraph_def) 
  subgoal by (simp add: pre_digraph.verts_add_arc_conv) 
  subgoal by simp  
  subgoal using assms 
    by (metis head_add_arc head_eq_hd insert_iff no_loops pre_digraph.arcs_add_arc 
              pre_digraph.tail_add_arc tail_eq_tl)
  subgoal 
    by (metis arc_to_ends_def arcs_add_arc assms(5) insert_iff no_multi_arcs pre_digraph.head_add_arc 
              pre_digraph.tail_add_arc wf_digraph.dominatesI wf_digraph_axioms)
  subgoal  using assms
    by (smt (z3) Un_iff head_eq_hd id_vert_inj image_eqI inj_on_def insertE pre_digraph.verts_add_arc_conv
                   singletonD tail_eq_tl uids_def) 
  subgoal using tail_eq_tl by simp  
  subgoal using head_eq_hd by simp 
  subgoal by (simp add: assms(3) ids_growing) 
  done  

declare add_nn_edge_def[code]

lemma nn_pregraph_add_nn_edge: nn_pregraph (add_nn_edge G a)
  using add_nn_edge_def nn_pregraph_add_arc nn_pregraph_axioms 
  by metis 


subsubsection‹The operation @{constdel_arc} preserves neural networks›
lemma nn_pregraph_del_arc: nn_pregraph (del_arc a)
  apply standard
  subgoal by simp 
  subgoal by simp 
  subgoal by simp 
  subgoal by simp
  subgoal by (simp add: no_loops)
  subgoal by (simp add: arc_to_ends_def no_multi_arcs)
  subgoal by (simp add: id_vert_inj)
  subgoal by (simp add: tail_eq_tl) 
  subgoal by (simp add: head_eq_hd) 
  subgoal using tail_eq_tl head_eq_hd ids_growing by simp
  done


end 

subsection‹Updating Neurons›
context nn_pregraph begin 

lemma updφ_nid_immutable[simp]: Neuron.uid n  nid  n = (updφ φ' nid n)
 and  updφ_id_immutable[simp]: Neuron.uid n = Neuron.uid (updφ φ' nid n)
 and  updφ_α_immutable[simp]:  α n = α (updφ φ' nid n)
 and  updφ_β_immutable[simp]:  β n = β (updφ φ' nid n)
 and  updβ_nid_immutable[simp]: Neuron.uid n  nid  n = (updβ β' nid n)  
 and  updβ_id_immutable[simp]: Neuron.uid n = Neuron.uid (updβ β' nid n)
 and  updβ_φ_immutable[simp]:  φ n = φ (updβ β' nid n)
 and  updβ_α_immutable[simp]:  α n = α (updβ β' nid n)
 and  updα_nid_immutable[simp]: Neuron.uid n  nid  n = (updα α' nid n)  
 and  updα_id_immutable[simp]: Neuron.uid n = Neuron.uid (updα α' nid n)
 and  updα_φ_immutable[simp]:  φ n = φ (updα α' nid n)
 and  updα_β_immutable[simp]:  β n = β (updα α' nid n)
 by(cases "n", simp_all add:updφ_def updα_def updβ_def)+


lemma wf_digraph_update_neuron:
  assumes  n. Neuron.uid n = Neuron.uid (upd n) 
  showswf_digraph (upd_neuron G upd )
  unfolding upd_neuron_def
  apply(simp add: wf_digraph assms image_def wf_digraph_def tail_eq_tl head_eq_hd Let_def)
  using head_eq_hd head_in_verts tail_eq_tl tail_in_verts
  by fastforce 

lemma fin_digraph_update_neuron:
  assumes  n. Neuron.uid n = Neuron.uid (upd n) 
  showsfin_digraph (upd_neuron G upd )
  apply standard 
     apply (meson assms wf_digraph.tail_in_verts wf_digraph_update_neuron)
    apply (meson assms wf_digraph.head_in_verts wf_digraph_update_neuron) 
  by(simp_all add: upd_neuron_def Let_def)


lemma nomulti_digraph_update_neuron:
  assumes  n. Neuron.uid n = Neuron.uid (upd n) 
  shows nomulti_digraph (upd_neuron G upd )
  apply standard 
  subgoal by (meson assms wf_digraph_def wf_digraph_update_neuron) 
  subgoal by (meson upd_neuron_def  assms wf_digraph_def wf_digraph_update_neuron) 
  subgoal 
    apply(simp add:image_def arc_to_ends_def upd_neuron_def Let_def id_vert_inj assms 
                        wf_digraph tail_eq_tl head_eq_hd)
    using assms head_eq_hd head_in_verts id_vert_inj inj_onD no_multi_arcs tail_eq_tl tail_in_verts     
    apply simp 
    by (smt (verit) arc_to_ends_def edge.select_convs(2,3) inj_onD neuron.simps(10,11,12) uid.elims uid.simps(3)) 
  done 

lemma loopfree_digraph_update_neuron:
  assumes  n. Neuron.uid n = Neuron.uid (upd n) 
  shows loopfree_digraph (upd_neuron G upd)
  apply standard 
  subgoal by (meson assms wf_digraph.tail_in_verts wf_digraph_update_neuron)
  subgoal by (meson assms wf_digraph.head_in_verts wf_digraph_update_neuron) 
  subgoal
    apply(simp add:image_def arc_to_ends_def Let_def upd_neuron_def id_vert_inj assms 
                        wf_digraph tail_eq_tl head_eq_hd)
    using assms ids_growing order_less_irrefl neuron.distinct(5) neuron.inject(3)
    apply simp 
    by (smt (verit) edge.select_convs(2,3) less_not_refl3 neuron.simps(10,11,12) uid.elims uid.simps(3))
  done 

lemma nn_pregraph_update_neuron: 
  assumes  n. Neuron.uid n = Neuron.uid (upd n) 
  showsnn_pregraph (upd_neuron G upd)
  apply standard 
  subgoal by (meson assms wf_digraph.tail_in_verts wf_digraph_update_neuron) 
  subgoal by (meson assms wf_digraph.head_in_verts wf_digraph_update_neuron) 
  subgoal using assms fin_digraph.finite_verts fin_digraph_update_neuron by blast 
  subgoal using assms fin_digraph.finite_arcs fin_digraph_update_neuron by blast 
  subgoal by (metis assms(1) loopfree_digraph.no_loops loopfree_digraph_update_neuron) 
  subgoal by (metis assms(1) nomulti_digraph.no_multi_arcs nomulti_digraph_update_neuron)
  subgoal apply(simp add:Let_def assms upd_neuron_def id_vert_inj image_iff inj_on_def uid.elims)
    using assms id_vert_inj inj_on_def neuron.distinct neuron.simps uid.elims by (smt (verit))
  subgoal by (simp add: tail_eq_tl upd_neuron_def Let_def) 
  subgoal by (simp add: head_eq_hd upd_neuron_def Let_def) 
  subgoal apply(simp add:  assms upd_neuron_def ids_growing Let_def) using ids_growing
    by (smt (z3) assms neuron.exhaust neuron.simps(10) neuron.simps(11) neuron.simps(12) uid.simps(3)) 
  
  done
 
lemma nn_pregraph_updφ[simp]: nn_pregraph (upd_neuron G (updφ φ' nid))
 and  nn_pregraph_updβ[simp]: nn_pregraph (upd_neuron G (updβ β' nid))
 and  nn_pregraph_updα[simp]: nn_pregraph (upd_neuron G (updα α' nid))
  using nn_pregraph_update_neuron by simp_all

end

subsection‹Updating arcs (edges)›

context nn_pregraph begin 

lemma updω_tl_immutable[simp]: (tl a = tl (updω  ω' nhd ntl a))
and   updω_hd_immutable[simp]: (hd a = hd (updω  ω' nhd ntl a))
 by(auto simp: updω_def split: if_split) 

lemma updω_ends_immutable[simp]: (arc_to_ends G a = arc_to_ends G (updω  ω' nhd ntl a))
  by (auto simp add: arc_to_ends_def head_eq_hd tail_eq_tl) 

lemma upd_edge_tail_immutable: 
  tail (upd_edge G upd) = tail G 
  by (simp add: upd_edge_def) 

lemma upd_edge_head_immutable: 
  head (upd_edge G upd)  = head G 
  by (simp add: upd_edge_def) 

lemma upd_edge_vert_immutable: verts (upd_edge G upd) = verts G
  by(simp add: upd_edge_def)


lemma upd_edge_arcs: a  arcs (upd_edge G upd)    x  arcs G. a = upd x
  by (auto simp: upd_edge_def)


lemma wf_digraph_update_edge: 
  assumes  a  arcs G. (arc_to_ends G a = arc_to_ends G (upd a ))
  shows  wf_digraph (upd_edge G upd )
  apply unfold_locales
  subgoal  
    using assms upd_edge_vert_immutable upd_edge_tail_immutable 
    apply(simp add:upd_edge_def arc_to_ends_def image_def) 
     using tail_in_verts by auto
  subgoal 
    using assms upd_edge_vert_immutable upd_edge_tail_immutable 
    apply(simp add:upd_edge_def arc_to_ends_def image_def) 
     using head_in_verts by auto
   done

lemma fin_digraph_update_edge:  
  assumes  a  arcs G. (arc_to_ends G a = arc_to_ends G (upd a ))
  shows  fin_digraph (upd_edge G upd ) 
  by (metis fin_digraph_axioms fin_digraph_axioms_def fin_digraph_def finite_imageI 
            pre_digraph.select_convs(1) pre_digraph.select_convs(2) upd_edge_def 
            wf_digraph_update_edge assms)  

lemma nomulti_digraph_update_edge:  
  assumes  a  arcs G. (arc_to_ends G a = arc_to_ends G (upd a ))
  shows  nomulti_digraph (upd_edge G upd )
  apply standard 
  subgoal using assms by (meson wf_digraph_def wf_digraph_update_edge) 
  subgoal using assms by (meson wf_digraph_def wf_digraph_update_edge)
  subgoal  using assms 
    by (metis arc_to_ends_def local.upd_edge_arcs  nn_pregraph.upd_edge_head_immutable 
              nn_pregraph.upd_edge_tail_immutable nn_pregraph_axioms no_multi_arcs)
  done 

lemma loopfree_digraph_update_edge:
  assumes  a  arcs G. (arc_to_ends G a = arc_to_ends G (upd a ))
  shows  loopfree_digraph (upd_edge G upd)
  apply standard 
  subgoal using assms by (simp add: wf_digraph.tail_in_verts wf_digraph_update_edge) 
  subgoal using assms by (simp add: wf_digraph.head_in_verts wf_digraph_update_edge) 
  subgoal using assms by (metis adj_not_same arc_to_ends_def dominatesI local.upd_edge_arcs 
                                upd_edge_head_immutable upd_edge_tail_immutable) 
  done


lemma nn_pregraph_update_edge:
  assumes  a  arcs G. (arc_to_ends G a = arc_to_ends G (upd a ))
  and      a  arcs G. uid (tl (upd a )) < uid (hd (upd a )) 
shows  nn_pregraph (upd_edge G upd ) 
  apply(simp add: digraph_def fin_digraph_update_edge head_eq_hd id_vert_inj loopfree_digraph_update_edge 
            nn_pregraph_axioms_def nn_pregraph_def nomulti_digraph_update_edge 
            tail_eq_tl upd_edge_def assms upd_edge_arcs nn_pregraph_axioms.intro 
            upd_edge_head_immutable upd_edge_tail_immutable upd_edge_vert_immutable)
  by (metis assms(1) fin_digraph_update_edge head_eq_hd loopfree_digraph_update_edge 
            nomulti_digraph_update_edge tail_eq_tl upd_edge_def)

lemma nn_pregraph_updω[simp]: nn_pregraph (upd_edge G (updω ω' nhd ntl))
  using nn_pregraph_update_edge
  by (metis (no_types, lifting) ids_growing updω_ends_immutable updω_hd_immutable updω_tl_immutable) 

end


record ('a, 'b, 'c) neural_network = 
  graph :: "(('a, 'b) neuron, ('a, 'b) edge) pre_digraph"
  activation_tab :: 'b  'c option

definition upd_edge' :: ('a, 'b, 'c) neural_network  (('a, 'b) edge  ('a, 'b) edge)
                         ('a, 'b, 'c) neural_network where
          upd_edge' N upd   =  
                                 graph = upd_edge (graph N) upd,
                                 activation_tab = activation_tab N
                              

definition upd_neuron' :: ('a, 'b, 'c) neural_network  (('a, 'b) Neuron  ('a, 'b) Neuron)  
                          ('a, 'b, 'c) neural_network where
          upd_neuron' N upd =  
                                  graph = upd_neuron (graph N) upd,
                                  activation_tab = activation_tab N
                               

definition input_layer :: ('a, 'b, 'c) neural_network  ('a, 'b) neuron set where
          input_layer N = input_verts (graph N)

definition arity :: ('a, 'b, 'c) neural_network  nat where
          arity N = card (input_layer N)

definition input_layer_ids :: ('a, 'b, 'c) neural_network  id set where
          input_layer_ids N = uid ` (input_layer N)

definition output_layer :: ('a, 'b, 'c) neural_network  ('a, 'b) neuron set where
          output_layer N = output_verts (graph N)

definition output_layer_ids :: ('a, 'b, 'c) neural_network  id set where
          output_layer_ids N = uid ` (output_layer N)

definition incoming_arcs :: ('a, 'b, 'c) neural_network  id  ('a, 'b) edge set where

           incoming_arcs N nid = {a . a  arcs (graph N)  uid (hd a) = nid}

definition sorted_list_of_set'  map_fun id id (folding_on.F (insort_key (λ x. uid (tl x)))) []

definition incoming_arcs_l :: ('a, 'b, 'c) neural_network  id  ('a, 'b) edge list where
           incoming_arcs_l N nid = sorted_list_of_set' (incoming_arcs N nid)

context nn_pregraph begin 
lemma incoming_arcs_l_eq_incoming_arcs: set (incoming_arcs_l N nid)= (incoming_arcs N nid)
  unfolding incoming_arcs_l_def  sorted_list_of_set'_def
  oops

lemma incoming_arcs_l_alt_def: (incoming_arcs_l N nid)
= (sorted_key_list_of_set (λ x. uid (tl x)) (incoming_arcs N nid))
  unfolding incoming_arcs_l_def sorted_list_of_set'_def incoming_arcs_def
  by (simp add: sorted_key_list_of_set_def)

lemma insert_key_comm: "inj f   (insort_key f y  insort_key f x) = (insort_key f x  insort_key f y)"
  apply (cases "x = y") 
  by (auto intro: insort_key_left_comm simp add: inj_def inj_on_def fun_eq_iff)

lemma tl_subset_verts: tl ` (arcs G)  verts G
  using tail_eq_tl tail_in_verts by force 

lemma hd_subset_verts: hd ` (arcs G)  verts G
  using head_eq_hd head_in_verts by force 

lemma inj_on_tl: inj_on uid (tl ` (arcs G))
using id_vert_inj tl_subset_verts
  by (meson inj_on_subset)

end



definition outgoing_arcs :: ('a, 'b, 'c) neural_network  id  ('a, 'b) edge set where
           outgoing_arcs N nid = {a . a  arcs (graph N)  uid (tl a) = nid}

definition neurons :: ('a, 'b, 'c) neural_network  ('a, 'b) neuron set where
  neurons = verts o graph

definition edges :: ('a, 'b, 'c) neural_network  ('a, 'b) edge set where
  edges = arcs o graph


locale nn_graph = nn_pregraph + 
  assumes id_vert_inj: inj_on uid (verts G)
  and     inputs_In:   
            input_verts G = Set.filter (λ v. (case v of In _  True | _  False)) (verts G)
  and     outputs_Out: 
            output_verts G = Set.filter (λ v. (case v of Out _  True | _  False)) (verts G)
  and     internal_Neuron: 
            internal_verts G = Set.filter (λ v. (case v of Neuron _  True | _  False)) (verts G)
begin

lemma nn_graph: "nn_graph G" by intro_locales
end 


locale neural_network_digraph = 
  fixes N::('a::{comm_monoid_add,times,linorder,one}, 'b, 'c) neural_network 
  assumes nn_graph (graph N)
  and φ ` {n . Neuron n  (verts (graph N)) }   dom (activation_tab N) 

subsection‹The empty neural network›

definition empty:: ('a, 'b) nn_pregraph where
          empty =  verts={}, arcs = {}, tail = edge.tl, head = edge.hd 

lemma nn_pregraph_empty[simp]: nn_pregraph (empty)
  unfolding empty_def apply standard
  by(auto simp add: input_verts_def output_verts_def internal_verts_def)

lemma nn_graph_empty[simp]: nn_graph (empty)
  unfolding empty_def apply standard
  by(auto simp add: input_verts_def output_verts_def internal_verts_def)

lemma fold_inv: "P e  ( e' x. P e'  P (f x e'))  P (fold f xs e)"
  by (simp add: fold_invariant)


lemma nn_pregraph_fold: nn_pregraph G  nn_pregraph (foldr (λ a b. add_nn_edge b a) edge_list G)
proof(induction "edge_list")
  case Nil
  then show ?case by (simp add:nn_graph.axioms)  
next
  case (Cons a edge_list) note * = this
  then show ?case 
    using fold_inv[of nn_pregraph G "(λx s. add_nn_edge s x)" edge_list]
    by (simp add: fold_inv nn_pregraph.nn_pregraph_add_nn_edge) 
  qed

definition 
mk_nn_pregraph edge_list = foldr (λ a b. add_nn_edge b a) edge_list empty

lemma nn_pregraph_mk: nn_pregraph(mk_nn_pregraph edge_list)
  using mk_nn_pregraph_def nn_pregraph_fold
  by (metis nn_graph.axioms(1) nn_graph_empty)
 
lemma verts_subseteq_add_edge: "nn_pregraph G  verts G  verts (add_nn_edge G a)"
  unfolding add_nn_edge_def pre_digraph.add_arc_def 
  by auto

subsection‹Computing Predictions of Neural Networks›
datatype error = OK | ERROR

locale neural_network_digraph_single = neural_network_digraph N
  for N::('a::{comm_monoid_add,times,linorder,one}, 'b, 'a  'a) neural_network 


function (sequential, domintros) predictdigraph_single'::nat
                                  ('a::{comm_monoid_add,times,linorder,one}, 'b, 'a  'a) neural_network 
                                  (id  'a)  ('a, 'b) edge   ('a ×  error)
  where 
  predictdigraph_single' _ N inputs (ω=_,  tl= _, hd=In _)       = (0, ERROR)
| predictdigraph_single' _ N inputs (ω=_,  tl=Out _ , hd=_ )     = (0, ERROR)
| predictdigraph_single' _ N inputs (ω=ω', tl=In uidin, hd=_ )   = (case inputs uidin of 
                                                                None    (0, ERROR)  
                                                              | Some v  (v * ω',  OK))
| predictdigraph_single' n N inputs e = (if 0 < n then  
                  (let 
                      ω'    = ω e;
                      tl' = (case (tl e) of (Neuron t)  t); 
                      E'    = incoming_arcs N (Neuron.uid tl');
                      lvals = ((λ e'. (case predictdigraph_single' (n-1)  N inputs e' of  
                                            (_, ERROR)  ((0,0),  ERROR)  
                                          | (v, OK)    ((v ,uid (tl e')), OK))) ` E')
                  in  
                  (  case (activation_tab N) (φ tl') of 
                         Some f  (ω' *(  f(( v  lvals.  (fst (fst v))) + (β tl'))),  OK)
                       | None   (0, ERROR)))
 else  (0, ERROR) )
  apply pat_completeness                                                                 
  by(simp_all) 

termination 
  by(size_change)

definition
predictdigraph_single N inputs e = (case predictdigraph_single' (card (edges N)) N inputs e of 
                                        (r, OK)  Some r
                                      | (_, ERROR)  None) 

definition
  get_input_neuron_ids_l N = sorted_list_of_set (uid ` (input_verts (graph N)))
definition 
  mk_input_map N vs = map_of (rev (zip (get_input_neuron_ids_l N) vs))
definition 
  get_output_edge_ids_l N = sorted_list_of_set (uid ` (output_verts (graph N)))
definition 
  get_output_edge_l N = map the_elem (map (λ i. {e . e  edges N  i = uid (hd e)})  (get_output_edge_ids_l N))
definition 
  predictdigraph_single_list N inputs' = those (map (λ e.  predictdigraph_single N (mk_input_map N inputs') e) (get_output_edge_l N))


context neural_network_digraph_single begin

lemma ids_growing':  neural_network_digraph N  e  edges N  uid (tl e) < uid (hd e)
  by (metis comp_def edges_def neural_network_digraph_def nn_graph.axioms(1) nn_pregraph.ids_growing) 

end

context neural_network_digraph begin
fun (sequential) predictdigraph::(id  'a) list  ('a, 'b) edge list  ('a list × error)
  where 
  predictdigraph  _ _ = ([], ERROR)
end

record 'a data =
  inputs::id  'a 
  outputs::id  'a 

end