Theory JVMDefensive

(*  Title:      JinjaThreads/JVM/JVMDefensive.thy
    Author:     Gerwin Klein, Andreas Lochbihler
*)

section ‹A Defensive JVM›

theory JVMDefensive
imports JVMExec "../Common/ExternalCallWF"
begin

text ‹
  Extend the state space by one element indicating a type error (or
  other abnormal termination)›
datatype 'a type_error = TypeError | Normal 'a

context JVM_heap_base begin

definition is_Array_ref :: "'addr val  'heap  bool" where
  "is_Array_ref v h  
  is_Ref v  
  (v  Null  typeof_addr h (the_Addr v)  None  is_Array (ty_of_htype (the (typeof_addr h (the_Addr v)))))"

declare is_Array_ref_def[simp]

primrec check_instr :: "['addr instr, 'addr jvm_prog, 'heap, 'addr val list, 'addr val list, 
                        cname, mname, pc, 'addr frame list]  bool"
where
  check_instr_Load:
  "check_instr (Load n) P h stk loc C M0 pc frs = 
  (n < length loc)"

| check_instr_Store:
  "check_instr (Store n) P h stk loc C0 M0 pc frs = 
  (0 < length stk  n < length loc)"

| check_instr_Push:
  "check_instr (Push v) P h stk loc C0 M0 pc frs = 
  (¬is_Addr v)"

| check_instr_New:
  "check_instr (New C) P h stk loc C0 M0 pc frs = 
  is_class P C"

| check_instr_NewArray:
  "check_instr (NewArray T) P h stk loc C0 M0 pc frs =
  (is_type P (T⌊⌉)  0 < length stk  is_Intg (hd stk))"

| check_instr_ALoad:
  "check_instr ALoad P h stk loc C0 M0 pc frs =
  (1 < length stk  is_Intg (hd stk)  is_Array_ref (hd (tl stk)) h)"

| check_instr_AStore:
  "check_instr AStore P h stk loc C0 M0 pc frs =
  (2 < length stk  is_Intg (hd (tl stk))  is_Array_ref (hd (tl (tl stk))) h  typeof⇘h(hd stk)  None)"

| check_instr_ALength:
  "check_instr ALength P h stk loc C0 M0 pc frs =
  (0 < length stk  is_Array_ref (hd stk) h)"

| check_instr_Getfield:
  "check_instr (Getfield F C) P h stk loc C0 M0 pc frs = 
  (0 < length stk  (C' T fm. P  C sees F:T (fm) in C')  
  (let (C', T, fm) = field P C F; ref = hd stk in 
    C' = C  is_Ref ref  (ref  Null  
      (T. typeof_addr h (the_Addr ref) = T  P  class_type_of T * C))))"

| check_instr_Putfield:
  "check_instr (Putfield F C) P h stk loc C0 M0 pc frs = 
  (1 < length stk  (C' T fm. P  C sees F:T (fm) in C') 
  (let (C', T, fm) = field P C F; v = hd stk; ref = hd (tl stk) in 
    C' = C  is_Ref ref  (ref  Null  
      (T'. typeof_addr h (the_Addr ref) = T'  P  class_type_of T' * C  P,h  v :≤ T))))"

| check_instr_CAS:
  "check_instr (CAS F C) P h stk loc C0 M0 pc frs =
  (2 < length stk  (C' T fm. P  C sees F:T (fm) in C') 
  (let (C', T, fm) = field P C F; v'' = hd stk; v' = hd (tl stk); v = hd (tl (tl stk)) in
     C' = C  is_Ref v  volatile fm  (v  Null 
     (T'. typeof_addr h (the_Addr v) = T'  P  class_type_of T' * C  P,h  v' :≤ T  P,h  v'' :≤ T))))"

| check_instr_Checkcast:
  "check_instr (Checkcast T) P h stk loc C0 M0 pc frs =
  (0 < length stk  is_type P T)"

| check_instr_Instanceof:
  "check_instr (Instanceof T) P h stk loc C0 M0 pc frs =
  (0 < length stk  is_type P T  is_Ref (hd stk))"

| check_instr_Invoke:
  "check_instr (Invoke M n) P h stk loc C0 M0 pc frs =
  (n < length stk  is_Ref (stk!n)   
  (stk!n  Null  
    (let a = the_Addr (stk!n); 
         T = the (typeof_addr h a);
         C = class_type_of T;
         (D, Ts, Tr, meth) = method P C M
    in typeof_addr h a  None  P  C has M  
       P,h  rev (take n stk) [:≤] Ts  
       (meth = None  DM(Ts) :: Tr))))"

| check_instr_Return:
  "check_instr Return P h stk loc C0 M0 pc frs =
  (0 < length stk  ((0 < length frs)  
    (P  C0 has M0)     
    (let v = hd stk; 
         T = fst (snd (snd (method P C0 M0)))
     in P,h  v :≤ T)))"

| check_instr_Pop:
  "check_instr Pop P h stk loc C0 M0 pc frs = 
  (0 < length stk)"

| check_instr_Dup:
  "check_instr Dup P h stk loc C0 M0 pc frs = 
  (0 < length stk)"

| check_instr_Swap:
  "check_instr Swap P h stk loc C0 M0 pc frs =
  (1 < length stk)"

| check_instr_BinOpInstr:
  "check_instr (BinOpInstr bop) P h stk loc C0 M0 pc frs =
  (1 < length stk  (T1 T2 T. typeof⇘h(hd stk) = T2  typeof⇘h(hd (tl stk)) = T1  P  T1«bop»T2 : T))"

| check_instr_IfFalse:
  "check_instr (IfFalse b) P h stk loc C0 M0 pc frs =
  (0 < length stk  is_Bool (hd stk)  0  int pc+b)"

| check_instr_Goto:
  "check_instr (Goto b) P h stk loc C0 M0 pc frs =
  (0  int pc+b)"

| check_instr_Throw:
  "check_instr ThrowExc P h stk loc C0 M0 pc frs =
  (0 < length stk  is_Ref (hd stk)  P  the (typeof⇘h(hd stk))  Class Throwable)"

| check_instr_MEnter:
  "check_instr MEnter P h stk loc C0 M0 pc frs =
   (0 < length stk  is_Ref (hd stk))"

| check_instr_MExit:
  "check_instr MExit P h stk loc C0 M0 pc frs =
   (0 < length stk  is_Ref (hd stk))"

definition check_xcpt :: "'addr jvm_prog  'heap  nat  pc  ex_table  'addr  bool"
where
  "check_xcpt P h n pc xt a 
  (C. typeof_addr h a = Class_type C  
  (case match_ex_table P C pc xt of None  True | Some (pc', d')  d'  n))"

definition check :: "'addr jvm_prog  ('addr, 'heap) jvm_state  bool"
where
  "check P σ  let (xcpt, h, frs) = σ in
               (case frs of []  True | (stk,loc,C,M,pc)#frs'  
                P  C has M 
                (let (C',Ts,T,meth) = method P C M; (mxs,mxl0,ins,xt) = the meth; i = ins!pc in
                 meth  None  pc < size ins  size stk  mxs 
                 (case xcpt of None  check_instr i P h stk loc C M pc frs'
                           | Some a  check_xcpt P h (length stk) pc xt a)))"


definition exec_d ::
  "'addr jvm_prog  'thread_id  ('addr, 'heap) jvm_state  ('addr, 'thread_id, 'heap) jvm_ta_state set type_error"
where
  "exec_d P t σ  if check P σ then Normal (exec P t σ) else TypeError"

inductive
  exec_1_d :: 
  "'addr jvm_prog  'thread_id  ('addr, 'heap) jvm_state type_error
   ('addr, 'thread_id, 'heap) jvm_thread_action  ('addr, 'heap) jvm_state type_error  bool" 
  ("_,_  _ -_-jvmd→ _" [61,0,61,0,61] 60)
  for P :: "'addr jvm_prog" and t :: 'thread_id
where
  exec_1_d_ErrorI: "exec_d P t σ = TypeError  P,t  Normal σ -ε-jvmd→ TypeError"
| exec_1_d_NormalI: " exec_d P t σ = Normal Σ; (tas, σ')  Σ    P,t  Normal σ -tas-jvmd→ Normal σ'"

lemma jvmd_NormalD:
  "P,t  Normal σ -ta-jvmd→ Normal σ'  check P σ  (ta, σ')  exec P t σ  (xcp h f frs. σ = (xcp, h, f # frs))"
apply(erule exec_1_d.cases, auto simp add: exec_d_def split: if_split_asm)
apply(case_tac b, auto)
done

lemma jvmd_NormalE:
  assumes "P,t  Normal σ -ta-jvmd→ Normal σ'"
  obtains xcp h f frs where "check P σ" "(ta, σ')  exec P t σ" "σ = (xcp, h, f # frs)"
using assms
by(auto dest: jvmd_NormalD)

lemma exec_d_eq_TypeError: "exec_d P t σ = TypeError  ¬ check P σ"
by(simp add: exec_d_def)

lemma exec_d_eq_Normal: "exec_d P t σ = Normal (exec P t σ)  check P σ"
by(auto simp add: exec_d_def)

end

declare split_paired_All [simp del]
declare split_paired_Ex [simp del]

lemma if_neq [dest!]:
  "(if P then A else B)  B  P"
  by (cases P, auto)

context JVM_heap_base begin

lemma exec_d_no_errorI [intro]:
  "check P σ  exec_d P t σ  TypeError"
  by (unfold exec_d_def) simp

theorem no_type_error_commutes:
  "exec_d P t σ  TypeError  exec_d P t σ = Normal (exec P t σ)"
  by (unfold exec_d_def, auto)

lemma defensive_imp_aggressive_1:
  "P,t  (Normal σ) -tas-jvmd→ (Normal σ')  P,t  σ -tas-jvm→ σ'"
by(auto elim!: exec_1_d.cases intro!: exec_1.intros simp add: exec_d_def split: if_split_asm)

end

context JVM_heap begin

lemma check_exec_hext:
  assumes exec: "(ta, xcp', h', frs')  exec P t (xcp, h, frs)"
  and check: "check P (xcp, h, frs)"
  shows "h  h'"
proof -
  from exec have "frs  []" by(auto)
  then obtain f Frs where frs [simp]: "frs = f # Frs"
    by(fastforce simp add: neq_Nil_conv)
  obtain stk loc C0 M0 pc where f [simp]: "f = (stk, loc, C0, M0, pc)"
    by(cases f, blast)
  show ?thesis
  proof(cases xcp)
    case None
    with check obtain C' Ts T mxs mxl0 ins xt
      where mthd: "P  C0 sees M0 : Ts  T = (mxs, mxl0, ins, xt) in C'"
                  "method P C0 M0 = (C', Ts, T, (mxs, mxl0, ins, xt))"
      and check_ins: "check_instr (ins ! pc) P h stk loc C0 M0 pc Frs"
      and "pc < length ins"
      and "length stk  mxs"
      by(auto simp add: check_def has_method_def)
    from None exec mthd
    have xexec: "(ta, xcp', h', frs')  exec_instr (ins ! pc) P t h stk loc C0 M0 pc Frs" by(clarsimp)
    thus ?thesis
    proof(cases "ins ! pc")
      case (New C)
      with xexec show ?thesis
        by(auto intro: hext_allocate split: if_split_asm)
    next
      case (NewArray T)
      with xexec show ?thesis
        by(auto intro: hext_allocate split: if_split_asm)
    next
      case AStore
      with xexec check_ins show ?thesis
        by(auto simp add: split_beta split: if_split_asm intro: hext_heap_write)
    next
      case Putfield
      with xexec check_ins show ?thesis
        by(auto intro: hext_heap_write simp add: split_beta split: if_split_asm)
    next
      case CAS
      with xexec check_ins show ?thesis
        by(auto intro: hext_heap_write simp add: split_beta split: if_split_asm)
    next
      case (Invoke M n)
      with xexec check_ins show ?thesis
        apply(auto simp add: min_def split_beta is_Ref_def extRet2JVM_def has_method_def
                split: if_split_asm intro: red_external_aggr_hext)
        apply(case_tac va)
        apply(auto 4 3 intro: red_external_aggr_hext is_native.intros)
        done
    next
      case (BinOpInstr bop)
      with xexec check_ins show ?thesis by(auto split: sum.split_asm)
    qed(auto simp add: split_beta split: if_split_asm)
  next
    case (Some a)
    with exec have "h' = h" by auto
    thus ?thesis by auto
  qed
qed

lemma exec_1_d_hext:
  " P,t  Normal (xcp, h, frs) -ta-jvmd→ Normal (xcp', h', frs')   h  h'"
by(auto elim!: exec_1_d.cases simp add: exec_d_def split: if_split_asm intro: check_exec_hext)

end

end