Theory FWInitFinLift

(*  Title:      JinjaThreads/MM/FWInitFinLift.thy
    Author:     Andreas Lochbihler
*)

section ‹Synthetic first and last actions for each thread›

theory FWInitFinLift
imports
  FWLTS
  FWLiftingSem
begin

datatype status = 
  PreStart
| Running
| Finished

abbreviation convert_TA_initial :: "('l,'t,'x,'m,'w,'o) thread_action  ('l,'t,status × 'x,'m,'w,'o) thread_action"
where "convert_TA_initial == convert_extTA (Pair PreStart)"

lemma convert_obs_initial_convert_TA_initial: 
  "convert_obs_initial (convert_TA_initial ta) = convert_TA_initial (convert_obs_initial ta)"
by(simp add: convert_obs_initial_def)

lemma convert_TA_initial_inject [simp]:
  "convert_TA_initial ta = convert_TA_initial ta'  ta = ta'"
by(cases ta)(cases ta', auto)

context final_thread begin

primrec init_fin_final :: "status × 'x  bool"
where "init_fin_final (status, x)  status = Finished  final x"

end

context multithreaded_base begin

inductive init_fin :: "('l,'t,status × 'x,'m,'w,'o action) semantics" ("_  _ -_→i _" [50,0,0,51] 51)
where
  NormalAction:
  "t  x, m -ta x', m' 
   t  ((Running, x), m) -convert_TA_initial (convert_obs_initial ta)→i ((Running, x'), m')"

| InitialThreadAction:
  "t  ((PreStart, x), m) -InitialThreadAction→i ((Running, x), m)"

| ThreadFinishAction:
  "final x  t  ((Running, x), m) -ThreadFinishAction→i ((Finished, x), m)"

end

declare split_paired_Ex [simp del]

inductive_simps (in multithreaded_base) init_fin_simps [simp]:
  "t  ((Finished, x), m) -ta→i xm'"
  "t  ((PreStart, x), m) -ta→i xm'"
  "t  ((Running, x), m) -ta→i xm'"
  "t  xm -ta→i ((Finished, x'), m')"
  "t  xm -ta→i ((Running, x'), m')"
  "t  xm -ta→i ((PreStart, x'), m')"

declare split_paired_Ex [simp]

context multithreaded begin

lemma multithreaded_init_fin: "multithreaded init_fin_final init_fin"
by(unfold_locales)(fastforce simp add: init_fin.simps convert_obs_initial_def ta_upd_simps dest: new_thread_memory)+

end

locale if_multithreaded_base = multithreaded_base +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"

sublocale if_multithreaded_base < "if": multithreaded_base
  "init_fin_final"
  "init_fin"
  "map NormalAction  convert_RA"
.

locale if_multithreaded = if_multithreaded_base + multithreaded +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"

sublocale if_multithreaded < "if": multithreaded
  "init_fin_final"
  "init_fin"
  "map NormalAction  convert_RA"
by(rule multithreaded_init_fin)

context τmultithreaded begin

inductive init_fin_τmove :: "('l,'t,status × 'x,'m,'w,'o action) τmoves"
where
  "τmove (x, m) ta (x', m') 
   init_fin_τmove ((Running, x), m) (convert_TA_initial (convert_obs_initial ta)) ((Running, x'), m')"

lemma init_fin_τmove_simps [simp]:
  "init_fin_τmove ((PreStart, x), m) ta x'm' = False"
  "init_fin_τmove xm ta ((PreStart, x'), m') = False"
  "init_fin_τmove ((Running, x), m) ta ((s, x'), m') 
   (ta'. ta = convert_TA_initial (convert_obs_initial ta')  s = Running  τmove (x, m) ta' (x', m'))"
  "init_fin_τmove ((s, x), m) ta ((Running, x'), m')  
   s = Running  (ta'. ta = convert_TA_initial (convert_obs_initial ta')  τmove (x, m) ta' (x', m'))"
  "init_fin_τmove ((Finished, x), m) ta x'm' = False"
  "init_fin_τmove xm ta ((Finished, x'), m') = False"
by(simp_all add: init_fin_τmove.simps)

lemma init_fin_silent_move_RunningI:
  assumes "silent_move t (x, m) (x', m')"
  shows "τtrsys.silent_move (init_fin t) init_fin_τmove ((Running, x), m) ((Running, x'), m')"
using assms by(cases)(auto intro: τtrsys.silent_move.intros init_fin.NormalAction)

lemma init_fin_silent_moves_RunningI:
  assumes "silent_moves t (x, m) (x', m')"
  shows "τtrsys.silent_moves (init_fin t) init_fin_τmove ((Running, x), m) ((Running, x'), m')"
using assms
by(induct rule: rtranclp_induct2)(auto elim: rtranclp.rtrancl_into_rtrancl intro: init_fin_silent_move_RunningI)

lemma init_fin_silent_moveD:
  assumes "τtrsys.silent_move (init_fin t) init_fin_τmove ((s, x), m) ((s', x'), m')"
  shows "silent_move t (x, m) (x', m')  s = s'  s' = Running"
using assms by(auto elim!: τtrsys.silent_move.cases init_fin.cases)

lemma init_fin_silent_movesD:
  assumes "τtrsys.silent_moves (init_fin t) init_fin_τmove ((s, x), m) ((s', x'), m')"
  shows "silent_moves t (x, m) (x', m')  s = s'"
using assms
by(induct "((s, x), m)" "((s', x'), m')" arbitrary: s' x' m')
  (auto 7 2 simp only: dest!: init_fin_silent_moveD intro: rtranclp.rtrancl_into_rtrancl)

lemma init_fin_τdivergeD:
  assumes "τtrsys.τdiverge (init_fin t) init_fin_τmove ((status, x), m)"
  shows "τdiverge t (x, m)  status = Running"
proof
  from assms show "status = Running"
    by(cases rule: τtrsys.τdiverge.cases[consumes 1])(auto dest: init_fin_silent_moveD)
  moreover define xm where "xm = (x, m)"
  ultimately have "x m. xm = (x, m)  τtrsys.τdiverge (init_fin t) init_fin_τmove ((Running, x), m)"
    using assms by blast
  thus "τdiverge t xm"
  proof(coinduct)
    case (τdiverge xm)
    then obtain x m 
      where diverge: "τtrsys.τdiverge (init_fin t) init_fin_τmove ((Running, x), m)" 
      and xm: "xm = (x, m)" by blast
    thus ?case
      by(cases rule:τtrsys.τdiverge.cases[consumes 1])(auto dest!: init_fin_silent_moveD)
  qed
qed

lemma init_fin_τdiverge_RunningI:
  assumes "τdiverge t (x, m)"
  shows "τtrsys.τdiverge (init_fin t) init_fin_τmove ((Running, x), m)"
proof -
  define sxm where "sxm = ((Running, x), m)"
  with assms have "x m. τdiverge t (x, m)  sxm = ((Running, x), m)" by blast
  thus "τtrsys.τdiverge (init_fin t) init_fin_τmove sxm"
  proof(coinduct rule: τtrsys.τdiverge.coinduct[consumes 1, case_names τdiverge])
    case (τdiverge sxm)
    then obtain x m where "τdiverge t (x, m)" and "sxm = ((Running, x), m)" by blast
    thus ?case by(cases)(auto intro: init_fin_silent_move_RunningI)
  qed
qed

lemma init_fin_τdiverge_conv:
  "τtrsys.τdiverge (init_fin t) init_fin_τmove ((status, x), m) 
   τdiverge t (x, m)  status = Running"
by(blast intro: init_fin_τdiverge_RunningI dest: init_fin_τdivergeD)

end

lemma init_fin_τmoves_False:
  "τmultithreaded.init_fin_τmove (λ_ _ _. False) = (λ_ _ _. False)"
by(simp add: fun_eq_iff τmultithreaded.init_fin_τmove.simps)

locale if_τmultithreaded = if_multithreaded_base + τmultithreaded +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and τmove :: "('l,'t,'x,'m,'w,'o) τmoves"

sublocale if_τmultithreaded < "if": τmultithreaded
  "init_fin_final"
  "init_fin"
  "map NormalAction  convert_RA"
  "init_fin_τmove"
.

locale if_τmultithreaded_wf = if_multithreaded_base + τmultithreaded_wf +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and τmove :: "('l,'t,'x,'m,'w,'o) τmoves"

sublocale if_τmultithreaded_wf < if_multithreaded
by unfold_locales

sublocale if_τmultithreaded_wf < if_τmultithreaded .

context τmultithreaded_wf begin

lemma τmultithreaded_wf_init_fin:
  "τmultithreaded_wf init_fin_final init_fin init_fin_τmove"
proof -
  interpret "if": multithreaded init_fin_final init_fin "map NormalAction  convert_RA"
    by(rule multithreaded_init_fin)
  show ?thesis
  proof(unfold_locales)
    fix t x m ta x' m'
    assume "init_fin_τmove (x, m) ta (x', m')" "t  (x, m) -ta→i (x', m')" 
    thus "m = m'" by(cases)(auto dest: τmove_heap)
  next
    fix s ta s'
    assume "init_fin_τmove s ta s'"
    thus "ta = ε" by(cases)(auto dest: silent_tl)
  qed
qed

end

sublocale if_τmultithreaded_wf < "if": τmultithreaded_wf
  "init_fin_final"
  "init_fin"
  "map NormalAction  convert_RA"
  "init_fin_τmove"
by(rule τmultithreaded_wf_init_fin)


primrec init_fin_lift_inv :: "('i  't  'x  'm  bool)  'i  't  status × 'x  'm  bool"
where "init_fin_lift_inv P I t (s, x) = P I t x"

context lifting_inv begin

lemma lifting_inv_init_fin_lift_inv:
  "lifting_inv init_fin_final init_fin (init_fin_lift_inv P)"
proof -
  interpret "if": multithreaded init_fin_final init_fin "map NormalAction  convert_RA"
    by(rule multithreaded_init_fin)
  show ?thesis
    by(unfold_locales)(fastforce elim!: init_fin.cases dest: invariant_red invariant_NewThread invariant_other)+
qed

end

locale if_lifting_inv =
  if_multithreaded +
  lifting_inv +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and P :: "'i  't  'x  'm  bool"

sublocale if_lifting_inv < "if": lifting_inv
  init_fin_final
  init_fin
  "map NormalAction  convert_RA"
  "init_fin_lift_inv P"
by(rule lifting_inv_init_fin_lift_inv)

primrec init_fin_lift :: "('t  'x  'm  bool)  't  status × 'x  'm  bool"
where "init_fin_lift P t (s, x) = P t x"

context lifting_wf begin

lemma lifting_wf_init_fin_lift:
  "lifting_wf init_fin_final init_fin (init_fin_lift P)"
proof -
  interpret "if": multithreaded init_fin_final init_fin "map NormalAction  convert_RA"
    by(rule multithreaded_init_fin)
  show ?thesis
    by(unfold_locales)(fastforce elim!: init_fin.cases dest: dest: preserves_red preserves_other preserves_NewThread)+
qed

end

locale if_lifting_wf =
  if_multithreaded +
  lifting_wf +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and P :: "'t  'x  'm  bool"

sublocale if_lifting_wf < "if": lifting_wf 
  init_fin_final
  init_fin
  "map NormalAction  convert_RA"
  "init_fin_lift P"
by(rule lifting_wf_init_fin_lift)

lemma (in if_lifting_wf) if_lifting_inv:
  "if_lifting_inv final r (λ_::unit. P)"
proof -
  interpret lifting_inv final r convert_RA  "λ_ :: unit. P" by(rule lifting_inv)
  show ?thesis by unfold_locales
qed

locale τlifting_inv = τmultithreaded_wf +
  lifting_inv +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and τmove :: "('l,'t,'x,'m,'w,'o) τmoves"
  and P :: "'i  't  'x  'm  bool"
begin

lemma redT_silent_move_invariant:
  " τmredT s s'; ts_inv P Is (thr s) (shr s)   ts_inv P Is (thr s') (shr s')"
by(auto dest!: redT_invariant mτmove_silentD)

lemma redT_silent_moves_invariant:
  " mthr.silent_moves s s'; ts_inv P Is (thr s) (shr s)   ts_inv P Is (thr s') (shr s')"
by(induct rule: rtranclp_induct)(auto dest: redT_silent_move_invariant)

lemma redT_τrtrancl3p_invariant:
  " mthr.τrtrancl3p s ttas s'; ts_inv P Is (thr s) (shr s) 
   ts_inv P (upd_invs Is P (concat (map (thr_a  snd) ttas))) (thr s') (shr s')"
proof(induct arbitrary: Is rule: mthr.τrtrancl3p.induct)
  case τrtrancl3p_refl thus ?case by simp
next
  case (τrtrancl3p_step s s' tls s'' tl)
  thus ?case by(cases tl)(force dest: redT_invariant)
next
  case (τrtrancl3p_τstep s s' tls s'' tl)
  thus ?case by(cases tl)(force dest: redT_invariant mτmove_silentD)
qed

end

locale τlifting_wf = τmultithreaded +
  lifting_wf +
  constrains final :: "'x  bool" 
  and r :: "('l,'t,'x,'m,'w,'o) semantics" 
  and convert_RA :: "'l released_locks  'o list"
  and τmove :: "('l,'t,'x,'m,'w,'o) τmoves"
  and P :: "'t  'x  'm  bool"
begin

lemma redT_silent_move_preserves:
  " τmredT s s'; ts_ok P (thr s) (shr s)   ts_ok P (thr s') (shr s')"
by(auto dest: redT_preserves)

lemma redT_silent_moves_preserves:
  " mthr.silent_moves s s'; ts_ok P (thr s) (shr s)   ts_ok P (thr s') (shr s')"
by(induct rule: rtranclp.induct)(auto dest: redT_silent_move_preserves)

lemma redT_τrtrancl3p_preserves:
  " mthr.τrtrancl3p s ttas s'; ts_ok P (thr s) (shr s)   ts_ok P (thr s') (shr s')"
by(induct rule: mthr.τrtrancl3p.induct)(auto dest: redT_silent_moves_preserves redT_preserves)

end

definition init_fin_lift_state :: "status  ('l,'t,'x,'m,'w) state  ('l,'t,status × 'x,'m,'w) state"
where "init_fin_lift_state s σ = (locks σ, (λt. map_option (λ(x, ln). ((s, x), ln)) (thr σ t), shr σ), wset σ, interrupts σ)"

definition init_fin_descend_thr :: "('l,'t,'status × 'x) thread_info  ('l,'t,'x) thread_info"
where "init_fin_descend_thr ts = map_option (λ((s, x), ln). (x, ln))  ts"

definition init_fin_descend_state :: "('l,'t,'status × 'x,'m,'w) state  ('l,'t,'x,'m,'w) state"
where "init_fin_descend_state σ = (locks σ, (init_fin_descend_thr (thr σ), shr σ), wset σ, interrupts σ)"

lemma ts_ok_init_fin_lift_init_fin_lift_state [simp]:
  "ts_ok (init_fin_lift P) (thr (init_fin_lift_state s σ)) (shr (init_fin_lift_state s σ))  ts_ok P (thr σ) (shr σ)"
by(auto simp add: init_fin_lift_state_def intro!: ts_okI dest: ts_okD)

lemma ts_inv_init_fin_lift_inv_init_fin_lift_state [simp]:
  "ts_inv (init_fin_lift_inv P) I (thr (init_fin_lift_state s σ)) (shr (init_fin_lift_state s σ))  
   ts_inv P I (thr σ) (shr σ)"
by(auto simp add: init_fin_lift_state_def intro!: ts_invI dest: ts_invD)

lemma init_fin_lift_state_conv_simps:
  shows shr_init_fin_lift_state: "shr (init_fin_lift_state s σ) = shr σ"
  and locks_init_fin_lift_state: "locks (init_fin_lift_state s σ) = locks σ"
  and wset_init_fin_lift_state: "wset (init_fin_lift_state s σ) = wset σ"
  and interrupts_init_fin_lift_stae: "interrupts (init_fin_lift_state s σ) = interrupts σ"
  and thr_init_fin_list_state: 
  "thr (init_fin_lift_state s σ) t = map_option (λ(x, ln). ((s, x), ln)) (thr σ t)"
by(simp_all add: init_fin_lift_state_def)

lemma thr_init_fin_list_state': 
  "thr (init_fin_lift_state s σ) = map_option (λ(x, ln). ((s, x), ln))  thr σ"
by(simp add: fun_eq_iff thr_init_fin_list_state)

lemma init_fin_descend_thr_Some_conv [simp]:
  "ln. ts t = ((status, x), ln)  init_fin_descend_thr ts t = (x, ln)"
by(simp add: init_fin_descend_thr_def)

lemma init_fin_descend_thr_None_conv [simp]:
  "ts t = None  init_fin_descend_thr ts t = None"
by(simp add: init_fin_descend_thr_def)

lemma init_fin_descend_thr_eq_None [simp]:
  "init_fin_descend_thr ts t = None  ts t = None"
by(simp add: init_fin_descend_thr_def)

lemma init_fin_descend_state_simps [simp]:
  "init_fin_descend_state (ls, (ts, m), ws, is) = (ls, (init_fin_descend_thr ts, m), ws, is)"
  "locks (init_fin_descend_state s) = locks s"
  "thr (init_fin_descend_state s) = init_fin_descend_thr (thr s)"
  "shr (init_fin_descend_state s) = shr s"
  "wset (init_fin_descend_state s) = wset s"
  "interrupts (init_fin_descend_state s) = interrupts s"
by(simp_all add: init_fin_descend_state_def)

lemma init_fin_descend_thr_update [simp]:
  "init_fin_descend_thr (ts(t := v)) = (init_fin_descend_thr ts)(t := map_option (λ((status, x), ln). (x, ln)) v)"
by(simp add: init_fin_descend_thr_def fun_eq_iff)

lemma ts_ok_init_fin_descend_state: 
  "ts_ok P (init_fin_descend_thr ts) = ts_ok (init_fin_lift P) ts"
by(rule ext)(auto 4 3 intro!: ts_okI dest: ts_okD simp add: init_fin_descend_thr_def)

lemma free_thread_id_init_fin_descend_thr [simp]: 
  "free_thread_id (init_fin_descend_thr ts) = free_thread_id ts"
by(simp add: free_thread_id.simps fun_eq_iff)

lemma redT_updT'_init_fin_descend_thr_eq_None [simp]:
  "redT_updT' (init_fin_descend_thr ts) nt t = None  redT_updT' ts nt t = None"
by(cases nt) simp_all

lemma thread_ok_init_fin_descend_thr [simp]: 
  "thread_ok (init_fin_descend_thr ts) nta = thread_ok ts nta"
by(cases nta) simp_all

lemma threads_ok_init_fin_descend_thr [simp]:
  "thread_oks (init_fin_descend_thr ts) ntas = thread_oks ts ntas"
by(induct ntas arbitrary: ts)(auto elim!: thread_oks_ts_change[THEN iffD1, rotated 1])

lemma init_fin_descend_thr_redT_updT [simp]:
  "init_fin_descend_thr (redT_updT ts (convert_new_thread_action (Pair status) nt)) =
   redT_updT (init_fin_descend_thr ts) nt"
by(cases nt) simp_all

lemma init_fin_descend_thr_redT_updTs [simp]:
  "init_fin_descend_thr (redT_updTs ts (map (convert_new_thread_action (Pair status)) nts)) =
   redT_updTs (init_fin_descend_thr ts) nts"
by(induct nts arbitrary: ts) simp_all

context final_thread begin

lemma cond_action_ok_init_fin_descend_stateI [simp]:
  "final_thread.cond_action_ok init_fin_final s t ct  cond_action_ok (init_fin_descend_state s) t ct"
by(cases ct)(auto simp add: final_thread.cond_action_ok.simps init_fin_descend_thr_def)

lemma cond_action_oks_init_fin_descend_stateI [simp]:
  "final_thread.cond_action_oks init_fin_final s t cts  cond_action_oks (init_fin_descend_state s) t cts"
by(induct cts)(simp_all add: final_thread.cond_action_oks.simps cond_action_ok_init_fin_descend_stateI)

end


definition lift_start_obs :: "'t  'o list  ('t × 'o action) list"
where "lift_start_obs t obs = (t, InitialThreadAction) # map (λob. (t, NormalAction ob)) obs"

lemma length_lift_start_obs [simp]: "length (lift_start_obs t obs) = Suc (length obs)"
by(simp add: lift_start_obs_def)

lemma set_lift_start_obs [simp]:
  "set (lift_start_obs t obs) =
   insert (t, InitialThreadAction) ((Pair t  NormalAction) ` set obs)"
by(auto simp add: lift_start_obs_def o_def)

lemma distinct_lift_start_obs [simp]: "distinct (lift_start_obs t obs) = distinct obs"
by(auto simp add: lift_start_obs_def distinct_map intro: inj_onI)

end