Theory Discrete_Log

theory Discrete_Log imports 
  CryptHOL.CryptHOL
  Cyclic_Group_Ext
begin 

locale dis_log = 
  fixes 𝒢 :: "'grp cyclic_group" (structure)
  assumes order_gt_0 [simp]: "order 𝒢 > 0"
begin

type_synonym 'grp' dislog_adv = "'grp'  nat spmf"

type_synonym 'grp' dislog_adv' = "'grp'  (nat × nat) spmf"

type_synonym 'grp' dislog_adv2 = "'grp' × 'grp'  nat spmf"

definition dis_log :: "'grp dislog_adv  bool spmf"
where "dis_log 𝒜 = TRY do {
  x  sample_uniform (order 𝒢);
  let h = g [^] x; 
  x' 𝒜 h;
  return_spmf ([x = x'] (mod order 𝒢))} ELSE return_spmf False"

definition advantage :: "'grp dislog_adv  real"
where "advantage 𝒜  spmf (dis_log 𝒜) True" 

lemma lossless_dis_log: "0 < order 𝒢;  h. lossless_spmf (𝒜 h)  lossless_spmf (dis_log 𝒜)"
by(auto simp add:  dis_log_def)

end 

locale dis_log_alt = 
  fixes 𝒢 :: "'grp cyclic_group" (structure)
    and x :: nat
  assumes order_gt_0 [simp]: "order 𝒢 > 0"
begin

sublocale dis_log: dis_log 𝒢
  unfolding dis_log_def by simp

definition "g' = g [^] x"

definition dis_log2 :: "'grp dis_log.dislog_adv'  bool spmf"
where "dis_log2 𝒜 = TRY do {
    w  sample_uniform (order 𝒢);
    let h = g [^] w;
    (w1',w2')  𝒜 h;
    return_spmf ([w = (w1' + x * w2')]  (mod (order 𝒢)))} ELSE return_spmf False"

definition advantage2 :: "'grp dis_log.dislog_adv'  real"
where "advantage2 𝒜  spmf (dis_log2 𝒜) True" 

definition adversary2 :: "('grp  (nat × nat) spmf)  'grp  nat spmf"
  where "adversary2 𝒜 h = do {
    (w1,w2)  𝒜 h;
    return_spmf (w1 + x * w2)}"

definition dis_log3 :: "'grp dis_log.dislog_adv2  bool spmf"
where "dis_log3 𝒜 = TRY do {
    w  sample_uniform (order 𝒢);
    let (h,w) = ((g [^] w, g' [^] w), w);
    w'  𝒜 h;
    return_spmf ([w = w'] (mod (order 𝒢)))} ELSE return_spmf False"

definition advantage3 :: "'grp dis_log.dislog_adv2  real"
  where "advantage3 𝒜  spmf (dis_log3 𝒜) True" 

definition adversary3:: "'grp dis_log.dislog_adv2  'grp  nat spmf"
  where "adversary3 𝒜 g = do {
    𝒜 (g, g [^] x)}"

end 

locale dis_log_alt_reductions = dis_log_alt + cyclic_group 𝒢 
begin

lemma dis_log_adv3:
  shows "advantage3 𝒜 = dis_log.advantage (adversary3 𝒜)"
  unfolding dis_log_alt.advantage3_def
  by(simp add: advantage3_def dis_log.advantage_def adversary3_def dis_log.dis_log_def dis_log3_def Let_def g'_def power_swap)

lemma dis_log_adv2:
  shows  "advantage2 𝒜 = dis_log.advantage (adversary2 𝒜)"
  unfolding dis_log_alt.advantage2_def
  by(simp add: advantage2_def dis_log2_def dis_log.advantage_def dis_log.dis_log_def adversary2_def split_def)

end 

end