Theory SumArr

(*
 * Copyright 2016, Data61, CSIRO
 *
 * This software may be distributed and modified according to the terms of
 * the BSD 2-Clause license. Note that NO WARRANTY is provided.
 * See "LICENSE_BSD2.txt" for details.
 *
 * @TAG(DATA61_BSD)
 *)
section ‹Case-study›

theory SumArr
imports
  "../OG_Syntax"
  Word_Lib.Word_32
begin

unbundle bit_operations_syntax

type_synonym routine = nat
type_synonym word32 = "32 word"
type_synonym funcs = "string × nat"
datatype faults = Overflow | InvalidMem
type_synonym 'a array = "'a list"
 
text ‹Sumarr computes the combined sum of all the elements of
multiple arrays. It does this by running a number of threads in
parallel, each computing the sum of elements of one of the arrays,
and then adding the result to a global variable gsum shared by all threads.
›
record sumarr_state =
 ― ‹local variables of threads›
  tarr :: "routine  word32 array"
  tid :: "routine  word32"
  ti :: "routine  word32"
  tsum :: "routine  word32"
 ― ‹global variables›
  glock :: nat
  gsum :: word32
  gdone :: word32
  garr :: "(word32 array) array"
 ― ‹ghost variables›
  ghost_lock :: "routine  bool"

definition
 NSUM :: word32
where
 "NSUM = 10"

definition
 MAXSUM :: word32
where
 "MAXSUM = 1500"

definition
 array_length :: "'a array  word32"
where
 "array_length arr  of_nat (length arr)"

definition
 array_nth :: "'a array  word32 'a"
where
 "array_nth arr n  arr ! unat n"

definition
 array_in_bound :: "'a array  word32  bool"
where
 "array_in_bound arr idx  unat idx < (length arr)"

definition
  array_nat_sum :: "('a :: len) word array  nat"
where
  "array_nat_sum arr  sum_list (map unat arr)"

definition
  "local_sum arr  of_nat (min (unat MAXSUM) (array_nat_sum arr))"

definition
  "global_sum arr  sum_list (map local_sum arr)"

definition
  "tarr_inv s i 
    length (tarr s i) = unat NSUM  tarr s i = garr s ! i"

abbreviation
  "sumarr_inv_till_lock s i  ¬ bit (gdone s) i  ((¬ (ghost_lock s) (1 - i))  ((gdone s = 0  gsum s = 0) 
    (bit (gdone s) (1 - i)  gsum s = local_sum (garr s !(1 - i)))))"

abbreviation
  "lock_inv s 
    (glock s = fromEnum (ghost_lock s 0) + fromEnum (ghost_lock s 1)) 
    (¬(ghost_lock s) 0  ¬(ghost_lock s) 1)"

abbreviation
  "garr_inv s i  (a b. garr s = [a, b]) 
    length (garr s ! (1-i)) = unat NSUM"

abbreviation
  "sumarr_inv s i  lock_inv s  tarr_inv s i  garr_inv s i 
    tid s i = (of_nat i + 1)"

definition
  lock :: "routine  (sumarr_state, funcs, faults) ann_com"
where
  "lock i 
     ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´sumarr_inv_till_lock i
    AWAIT ´glock = 0
    THEN ´glock:=1,, ´ghost_lock:=´ghost_lock (i:= True)
    END"

definition
 "sumarr_in_lock1 s i  ¬bit (gdone s) i  ((gdone s = 0  gsum s = local_sum (tarr s i)) 
   (bit (gdone s) (1 - i)  ¬ bit (gdone s) i  gsum s = global_sum (garr s)))"

definition
 "sumarr_in_lock2 s i  (bit (gdone s) i  ¬ bit (gdone s) (1 - i)  gsum s = local_sum (tarr s i)) 
   (bit (gdone s) i  bit (gdone s) (1 - i)  gsum s = global_sum (garr s))"

definition
  unlock :: "routine  (sumarr_state, funcs, faults) ann_com"
where
  "unlock i 
      ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´glock = 1 
    ´ghost_lock i  bit ´gdone (unat (´tid i - 1))  ´sumarr_in_lock2 i 
    ´glock := 0,, ´ghost_lock:=´ghost_lock (i:= False)"

definition
 "local_postcond s i  (¬ (ghost_lock s) (1 - i)  gsum s = (if bit (gdone s) 0  bit (gdone s) 1
              then global_sum (garr s)
              else local_sum (garr s ! i)))  bit (gdone s) i  ¬ghost_lock s i"

definition
  sumarr :: "routine  (sumarr_state, funcs, faults) ann_com"
where
  "sumarr i  
  ´sumarr_inv i  ´sumarr_inv_till_lock i
  ´tsum:=´tsum(i:=0) ;;
   ´tsum i = 0  ´sumarr_inv i  ´sumarr_inv_till_lock i
  ´ti:=´ti(i:=0) ;;
  TRY
     ´tsum i = 0  ´sumarr_inv i  ´ti i = 0  ´sumarr_inv_till_lock i
    WHILE ´ti i < NSUM
    INV  ´sumarr_inv i  ´ti i  NSUM  ´tsum i  MAXSUM 
          ´tsum i = local_sum (take (unat (´ti i)) (´tarr i))  ´sumarr_inv_till_lock i
    DO
      ´sumarr_inv i  ´ti i < NSUM  ´tsum i  MAXSUM 
       ´tsum i = local_sum (take (unat (´ti i)) (´tarr i))  ´sumarr_inv_till_lock i
     (InvalidMem,  array_in_bound (´tarr i)  (´ti i) ) 
        ´sumarr_inv i  ´ti i < NSUM  ´tsum i  MAXSUM 
         ´tsum i = local_sum (take (unat (´ti i)) (´tarr i))  ´sumarr_inv_till_lock i
       ´tsum:=´tsum(i:=´tsum i + array_nth (´tarr i) (´ti i));;
      ´sumarr_inv i  ´ti i < NSUM 
         local_sum (take (unat (´ti i)) (´tarr i))  MAXSUM 
         (´tsum i < MAXSUM  array_nth (´tarr i) (´ti i) < MAXSUM 
       ´tsum i = local_sum (take (Suc (unat (´ti i))) (´tarr i))) 
         (array_nth (´tarr i) (´ti i)  MAXSUM  ´tsum i  MAXSUM
           local_sum (´tarr i) = MAXSUM)  
       ´sumarr_inv_till_lock i 
     (InvalidMem,  array_in_bound (´tarr i)  (´ti i) ) 
        ´sumarr_inv i  ´ti i < NSUM 
         (´tsum i < MAXSUM  array_nth (´tarr i) (´ti i) < MAXSUM 
           ´tsum i = local_sum (take (Suc (unat (´ti i))) (´tarr i))) 
         (array_nth (´tarr i) (´ti i)  MAXSUM  ´tsum i  MAXSUM 
           local_sum (´tarr i) = MAXSUM) 
         ´sumarr_inv_till_lock i
       IF array_nth (´tarr i) (´ti i)  MAXSUM  ´tsum i  MAXSUM 
       THEN
          ´sumarr_inv i  ´ti i < NSUM  local_sum (´tarr i) = MAXSUM  ´sumarr_inv_till_lock i
         ´tsum:=´tsum(i:=MAXSUM);;
          ´sumarr_inv i  ´ti i < NSUM  ´tsum i  MAXSUM 
           ´tsum i = local_sum (´tarr i)  ´sumarr_inv_till_lock i 
         THROW
       ELSE
          ´sumarr_inv i  ´ti i < NSUM  ´tsum i  MAXSUM 
           ´tsum i = local_sum (take (Suc (unat (´ti i))) (´tarr i))  ´sumarr_inv_till_lock i
         SKIP
       FI;;
      ´sumarr_inv i  ´ti i < NSUM  ´tsum i  MAXSUM 
       ´tsum i = local_sum (take (Suc (unat (´ti i))) (´tarr i))  ´sumarr_inv_till_lock i 
     ´ti:=´ti(i:=´ti i + 1)
    OD
  CATCH
     ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´sumarr_inv_till_lock i SKIP
  END;;
   ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´sumarr_inv_till_lock i
  SCALL (''lock'', i) 0;;
   ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´glock = 1 
    ´ghost_lock i  ´sumarr_inv_till_lock i 
  ´gsum:=´gsum + ´tsum i ;;
   ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´glock = 1 
    ´ghost_lock i  ´sumarr_in_lock1 i 
  ´gdone:=(´gdone OR ´tid i) ;;
   ´sumarr_inv i  ´tsum i = local_sum (´tarr i)  ´glock = 1 
    ´ghost_lock i  bit ´gdone (unat (´tid i - 1))  ´sumarr_in_lock2 i 
  SCALL (''unlock'', i) 0"

definition
 precond
where
 "precond s  (glock s) = 0  (gsum s) = 0 (gdone s) = 0 
               (a b. garr s = [a, b]) 
               (xsset (garr s). length xs = unat NSUM) 
               (ghost_lock s) 0 = False  (ghost_lock s) 1 = False"

definition
 postcond
where
 "postcond s  (gsum s) = global_sum (garr s) 
               (i < 2. bit (gdone s) i)"

definition
  "call_sumarr i 
    length (´garr ! i) = unat NSUM  ´lock_inv  ´garr_inv i 
     ´sumarr_inv_till_lock i
    CALLX (λs. starr:=(tarr s)(i:=garr s ! i),
                 tid:=(tid s)(i:=of_nat i+1),
                 ti:=(ti s)(i:=undefined),
                 tsum:=(tsum s)(i:=undefined))
          ´sumarr_inv i  ´sumarr_inv_till_lock i
          (''sumarr'', i) 0
          (λs t. ttarr:= (tarr t)(i:=(tarr s) i),
                   tid:=(tid t)(i:=(tid s i)),
                   ti:=(ti t)(i:=(ti s i)),
                   tsum:=(tsum t)(i:=(tsum s i)))
          (λ_ _. Skip)
          ´local_postcond i ´local_postcond i
          False False"

definition
  "Γ  map_of (map (λi. ((''sumarr'', i), com (sumarr i))) [0..<2]) ++
  map_of (map (λi. ((''lock'', i), com (lock i))) [0..<2]) ++
  map_of (map (λi. ((''unlock'', i), com (unlock i))) [0..<2])"

definition
  "Θ  map_of (map (λi. ((''sumarr'', i), [ann (sumarr i)])) [0..<2]) ++
  map_of (map (λi. ((''lock'', i), [ann (lock i)])) [0..<2]) ++
  map_of (map (λi. ((''unlock'', i), [ann (unlock i)])) [0..<2])"

declare [[goals_limit = 10]]

lemma [simp]:
  "local_sum [] = 0"
  by (simp add: local_sum_def array_nat_sum_def)

lemma MAXSUM_le_plus:
  "x < MAXSUM  MAXSUM  MAXSUM + x"
  unfolding MAXSUM_def
  apply (rule word_le_plus[rotated], assumption)
  apply clarsimp
 done

lemma local_sum_Suc:
  "n < length arr; local_sum (take n arr) + arr ! n < MAXSUM;
    arr ! n < MAXSUM 
    local_sum (take n arr) + arr ! n =
      local_sum (take (Suc n) arr)"
  apply (subst take_Suc_conv_app_nth)
   apply clarsimp
  apply (clarsimp simp: local_sum_def array_nat_sum_def )
   apply (subst (asm) min_def, clarsimp split: if_splits)
   apply (clarsimp simp: MAXSUM_le_plus word_not_le[symmetric])
  apply (subst min_absorb2)
   apply (subst of_nat_mono_maybe_le[where 'a=32])
     apply (clarsimp simp: MAXSUM_def)
    apply (clarsimp simp: MAXSUM_def)
    apply unat_arith
   apply (clarsimp simp: MAXSUM_def)
   apply unat_arith
  apply clarsimp
 done

lemma local_sum_MAXSUM:
  "k < length arr  MAXSUM  arr ! k  local_sum arr = MAXSUM"
  apply (clarsimp simp: local_sum_def array_nat_sum_def)
  apply (rule word_unat.Rep_inverse')
  apply (rule min_absorb1[symmetric])
  apply (subst (asm) word_le_nat_alt)
  apply (rule le_trans[rotated])
   apply (rule elem_le_sum_list)
   apply simp
  apply clarsimp
 done

lemma local_sum_MAXSUM':
  local_sum arr = MAXSUM
  if k < length arr
    MAXSUM  local_sum (take k arr) + arr ! k
    local_sum (take k arr)  MAXSUM
    arr ! k  MAXSUM
proof -
  define vs u ws where vs = take k arr u = arr ! k ws = drop (Suc k) arr
  with k < length arr have *: arr = vs @ u # ws
    and **: take k arr = vs arr ! k = u
    by (simp_all add: id_take_nth_drop)
  from that show ?thesis
    apply (simp add: **)
    apply (simp add: *)
    apply (simp add: local_sum_def array_nat_sum_def ac_simps)
    apply (rule word_unat.Rep_inverse')
    apply (rule min_absorb1[symmetric])
    apply (subst (asm) word_le_nat_alt)
    apply (subst (asm) unat_plus_simple[THEN iffD1])
     apply (rule word_add_le_mono2[where i=0, simplified])
     apply (clarsimp simp: MAXSUM_def)
     apply unat_arith
    apply (rule le_trans, assumption)
    apply (rule add_mono)
     apply simp_all
    apply (subst le_unat_uoi)
    apply (rule min.cobounded1)
     apply simp
    done
qed

lemma word_min_0[simp]:
 "min (x::'a::len word) 0 = 0"
 "min 0 (x::'a::len word) = 0"
 by (simp add:min_def)+

ML fun TRY' tac i = TRY (tac i)


lemma imp_disjL_context':
  "((P  R)  (Q  R)) = ((P  R)  (¬P  Q  R))"
by auto

lemma map_of_prod_1[simp]:
  "i < n 
    map_of (map (λi. ((p, i), g i)) [0..<n])
       (p, i) = Some (g i)"
  apply (rule map_of_is_SomeI)
   apply (clarsimp simp: distinct_map o_def)
   apply (meson inj_onI prod.inject)
  apply clarsimp
  done

lemma map_of_prod_2[simp]:
  "i < n  p  q 
    (m ++
    map_of (map (λi. ((p, i), g i)) [0..<n]))
       (q, i) = m (q, i)"
  apply (rule map_add_dom_app_simps)
  apply (subst dom_map_of_conv_image_fst)
  apply clarsimp
  done

lemma sumarr_proc_simp[unfolded oghoare_simps]:
 "n < 2  Γ (''sumarr'',n) = Some (com (sumarr n))"
 "n < 2  Θ (''sumarr'',n) = Some ([ann (sumarr n)])"
 "n < 2  Γ (''lock'',n) = Some (com (lock n))"
 "n < 2  Θ (''lock'',n) = Some ([ann (lock n)])"
 "n < 2  Γ (''unlock'',n) = Some (com (unlock n))"
 "n < 2  Θ (''unlock'',n) = Some ([ann (unlock n)])"
 "[ann (sumarr n)]!0 = ann (sumarr n)"
 "[ann (lock n)]!0 = ann (lock n)"
 "[ann (unlock n)]!0 = ann (unlock n)"
 by (simp add: Γ_def Θ_def)+

lemmas sumarr_proc_simp_unfolded = sumarr_proc_simp[unfolded sumarr_def unlock_def lock_def oghoare_simps]

lemma oghoare_sumarr:
  Γ, Θ |⊢⇘/Fsumarr i ´local_postcond i, False if i < 2
proof -
  from that have i = 0  i = 1 by auto
  note sumarr_proc_simp_unfolded[proc_simp add]
  show ?thesis
  using that apply -
  unfolding sumarr_def unlock_def lock_def
    ann_call_def call_def block_def
  apply simp
  apply oghoare (*24*)
  unfolding tarr_inv_def array_length_def array_nth_def array_in_bound_def
    sumarr_in_lock1_def sumarr_in_lock2_def
  apply (tactic "PARALLEL_ALLGOALS ((TRY' o SOLVED')
          (clarsimp_tac (@{context} addsimps
                @{thms local_postcond_def global_sum_def ex_in_conv[symmetric]}) 
          THEN_ALL_NEW fast_force_tac
             (@{context} addSDs @{thms less_2_cases}
                         addIs @{thms local_sum_Suc unat_mono}
                        )
          ))") (*4*)
  using i = 0  i = 1 apply rule
      apply (clarsimp simp add: bit_simps even_or_iff)
      apply (clarsimp simp add: bit_simps even_or_iff)
   apply clarsimp
   apply (rule conjI)
    apply (fastforce intro!: local_sum_Suc unat_mono)
   apply (subst imp_disjL_context')
   apply (rule conjI)
    apply clarsimp
    apply (erule local_sum_MAXSUM[rotated])
    apply unat_arith
   apply (clarsimp simp: not_le)
   apply (erule (1) local_sum_MAXSUM'[rotated] ; unat_arith)
  apply clarsimp
    apply unat_arith
  apply (fact that)
 done
qed

lemma less_than_two_2[simp]:
  "i < 2  Suc 0 - i < 2"
  by arith

lemma oghoare_call_sumarr:
notes sumarr_proc_simp[proc_simp add]
shows
  "i < 2 
  Γ, Θ |⊢⇘/Fcall_sumarr i ´local_postcond i, False"
  unfolding call_sumarr_def ann_call_def call_def block_def
    tarr_inv_def
  apply oghoare (*10*)

  apply (clarsimp; fail | ((simp only: pre.simps)?, rule  oghoare_sumarr))+
  apply (clarsimp simp: sumarr_def tarr_inv_def)
  apply (clarsimp simp: local_postcond_def; fail)+
  done

lemma less_than_two_inv[simp]:
  "i < 2  j < 2  i  j  Suc 0 - i = j"
  by simp
 
lemma inter_aux_call_sumarr [simplified]:
  notes sumarr_proc_simp_unfolded [proc_simp add]
     com.simps [oghoare_simps add]
     bit_simps [simp]
  shows
  "i < 2  j < 2  i  j  interfree_aux Γ Θ
     F (com (call_sumarr i), (ann (call_sumarr i), ´local_postcond i, False),
        com (call_sumarr j), ann (call_sumarr j))"
  unfolding call_sumarr_def ann_call_def call_def block_def
    tarr_inv_def sumarr_def lock_def unlock_def
  apply oghoare_interfree_aux (*650*)
  unfolding 
    tarr_inv_def local_postcond_def sumarr_in_lock1_def
    sumarr_in_lock2_def
  by (tactic "PARALLEL_ALLGOALS (
                  TRY' (remove_single_Bound_mem @{context}) THEN'
                  (TRY' o SOLVED')
                  (clarsimp_tac @{context} THEN_ALL_NEW
                    fast_force_tac (@{context} addSDs @{thms less_2_cases})
                  ))") (* 2 minutes *)

lemma pre_call_sumarr:
  "i < 2  precond x  x  pre (ann (call_sumarr i))"
  unfolding precond_def call_sumarr_def ann_call_def
  by (fastforce dest: less_2_cases simp: array_length_def)

lemma post_call_sumarr:
  "local_postcond x 0  local_postcond x 1  postcond x"
  unfolding postcond_def local_postcond_def
  by (fastforce dest: less_2_cases split: if_splits)

lemma sumarr_correct: 
  "Γ, Θ |⊩⇘/F´precond
    COBEGIN
      SCHEME [0  m < 2]
      call_sumarr m
      ´local_postcond m,False
    COEND
   ´postcond, False"
  apply oghoare (* 5 subgoals *)
      apply (fastforce simp: pre_call_sumarr)
     apply (rule oghoare_call_sumarr, simp)
    apply (clarsimp simp: post_call_sumarr)
   apply (simp add: inter_aux_call_sumarr)
  apply clarsimp
 done

end