Theory ETP_OT

subsection ‹ Oblivious transfer constructed from ETPs ›

text‹Here we construct the OT protocol based on ETPs given in cite"DBLP:books/sp/17/Lindell17" (Chapter 4) and prove
semi honest security for both parties. We show information theoretic security for Party 1 and reduce the security of 
Party 2 to the HCP assumption.›

theory ETP_OT imports
  "HOL-Number_Theory.Cong"
  ETP
  OT_Functionalities
  Semi_Honest_Def
begin

type_synonym 'range viewP1 = "((bool × bool) × 'range × 'range) spmf"
type_synonym 'range dist1 = "((bool × bool) × 'range × 'range)  bool spmf"
type_synonym 'index viewP2 = "(bool × 'index × (bool × bool)) spmf"
type_synonym 'index dist2 = "(bool × 'index × bool × bool)  bool spmf"
type_synonym ('index, 'range) advP2 = "'index  bool  bool  'index dist2  'range  bool spmf"

lemma if_False_True: "(if x then False else ¬ False)  (if x then False else True)"
  by simp

lemma if_then_True [simp]: "(if b then True else x)  (¬ b  x)"
  by simp

lemma if_else_True [simp]: "(if b then x else True)  (b  x)"
  by simp

lemma inj_on_Not [simp]: "inj_on Not A"
  by(auto simp add: inj_on_def)

locale ETP_base = etp: etp I domain range F Finv B
  for I :: "('index × 'trap) spmf" ― ‹samples index and trapdoor›
    and domain :: "'index  'range set" 
    and range :: "'index  'range set"
    and B :: "'index  'range  bool" ― ‹hard core predicate›
    and F :: "'index  'range  'range"
    and Finv :: "'index  'trap  'range  'range"
begin

text‹The probabilistic program that defines the protocol.›

definition protocol :: "(bool × bool)  bool  (unit × bool) spmf"
  where "protocol input1 σ = do {  
    let (bσ, bσ') = input1;
    (α :: 'index, τ :: 'trap)  I;
    xσ :: 'range  etp.S α;
    yσ' :: 'range  etp.S α;
    let (yσ :: 'range) = F α xσ;
    let (xσ :: 'range) = Finv α τ yσ;
    let (xσ' :: 'range) = Finv α τ yσ';
    let (βσ :: bool) = xor (B α xσ) bσ;
    let (βσ' :: bool) = xor (B α xσ') bσ';
    return_spmf ((), if σ then xor (B α xσ') βσ' else xor (B α xσ) βσ)}"

lemma correctness: "protocol (m0,m1) c = funct_OT_12 (m0,m1) c"
proof-
  have "(B α (Finv α τ yσ') = (B α (Finv α τ yσ') = m1)) = m1" 
    for α τ yσ'  by auto
  then show ?thesis 
    by(auto simp add: protocol_def funct_OT_12_def Let_def etp.B_F_inv_rewrite bind_spmf_const etp.lossless_S local.etp.lossless_I lossless_weight_spmfD split_def cong: bind_spmf_cong)
qed

text ‹ Party 1 views ›

definition R1 :: "(bool × bool)  bool  'range viewP1"
  where "R1 input1 σ = do {
    let (b0, b1) = input1;
    (α, τ)  I;
    xσ  etp.S α;
    yσ'  etp.S α;
    let yσ = F α xσ;
    return_spmf ((b0, b1), if σ then yσ' else yσ, if σ then yσ else yσ')}"

lemma lossless_R1: "lossless_spmf (R1 msgs σ)"
  by(simp add: R1_def local.etp.lossless_I split_def etp.lossless_S Let_def)

definition S1 :: "(bool × bool)  unit  'range viewP1"
  where "S1 == (λ input1 (). do {
    let (b0, b1) = input1;
    (α, τ)  I;
    y0 :: 'range  etp.S α;
    y1  etp.S α;
    return_spmf ((b0, b1), y0, y1)})" 

lemma lossless_S1: "lossless_spmf (S1 msgs ())"
  by(simp add: S1_def local.etp.lossless_I split_def etp.lossless_S)

text ‹ Party 2 views ›

definition R2 :: "(bool × bool)  bool  'index viewP2"
  where "R2 msgs σ = do {
    let (b0,b1) = msgs;
    (α, τ)  I;
    xσ  etp.S α;
    yσ'  etp.S α;
    let yσ = F α xσ;
    let xσ = Finv α τ yσ;
    let xσ' = Finv α τ yσ';
    let βσ = (B α xσ)  (if σ then b1 else b0) ;
    let βσ' = (B α xσ')  (if σ then b0 else b1);
    return_spmf (σ, α,(βσ, βσ'))}"

lemma lossless_R2: "lossless_spmf (R2 msgs σ)"
  by(simp add: R2_def split_def local.etp.lossless_I etp.lossless_S)

definition S2 :: "bool  bool  'index viewP2"
  where "S2 σ bσ = do {
    (α, τ)  I;
    xσ  etp.S α;
    yσ'  etp.S α;
    let xσ' = Finv α τ yσ';
    let βσ = (B α xσ)  bσ;
    let βσ' = B α xσ';
    return_spmf (σ, α, (βσ, βσ'))}"

lemma lossless_S2: "lossless_spmf (S2 σ bσ)"
  by(simp add: S2_def local.etp.lossless_I etp.lossless_S split_def)

text ‹ Security for Party 1 ›

text‹We have information theoretic security for Party 1.›

lemma P1_security: "R1 input1 σ = funct_OT_12 x y  (λ (s1, s2). S1 input1 s1)" 
  including monad_normalisation
proof-
   have "R1 input1 σ =  do {
    let (b0,b1) = input1;
    (α, τ)  I;
    yσ' :: 'range  etp.S α;
    yσ  map_spmf (λ xσ. F α xσ) (etp.S α);
    return_spmf ((b0,b1), if σ then yσ' else yσ, if σ then yσ else yσ')}"
     by(simp add: bind_map_spmf o_def Let_def R1_def)
   also have "... = do {
    let (b0,b1) = input1;
    (α, τ)  I;
    yσ' :: 'range  etp.S α;
    yσ  etp.S α;
    return_spmf ((b0,b1), if σ then yσ' else yσ, if σ then yσ else yσ')}"
     by(simp add: etp.uni_set_samp Let_def split_def cong: bind_spmf_cong)
   also have "... = funct_OT_12 x y  (λ (s1, s2). S1 input1 s1)"
     by(cases σ; simp add: S1_def R1_def Let_def funct_OT_12_def)
   ultimately show ?thesis by auto
qed 

text ‹ The adversary used in proof of security for party 2 ›

definition 𝒜 :: "('index, 'range) advP2"
  where "𝒜 α σ bσ D2 x = do {
    βσ'  coin_spmf;
    xσ  etp.S α;
    let βσ = (B α xσ)  bσ;
    d  D2(σ, α, βσ, βσ');
    return_spmf(if d then βσ' else ¬ βσ')}"

lemma lossless_𝒜: 
  assumes " view. lossless_spmf (D2 view)"
  shows "y  set_spmf I   lossless_spmf (𝒜 (fst y) σ bσ D2 x)"
  by(simp add: 𝒜_def etp.lossless_S assms)

lemma assm_bound_funct_OT_12: 
  assumes "etp.HCP_adv 𝒜 σ (if σ then b1 else b0) D  HCP_ad"
  shows "¦spmf (funct_OT_12 (b0,b1) σ  (λ (out1,out2). 
              etp.HCP_game 𝒜 σ out2 D)) True - 1/2¦  HCP_ad"
(is "?lhs  HCP_ad")
proof-
  have "?lhs = ¦spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D) True - 1/2¦" 
    by(simp add: funct_OT_12_def)
  thus ?thesis using assms etp.HCP_adv_def by simp
qed

lemma assm_bound_funct_OT_12_collapse: 
  assumes " bσ. etp.HCP_adv 𝒜 σ bσ D  HCP_ad"
  shows "¦spmf (funct_OT_12 m1 σ  (λ (out1,out2). etp.HCP_game 𝒜 σ out2 D)) True - 1/2¦  HCP_ad"
  using assm_bound_funct_OT_12 surj_pair assms by metis 

text ‹ To prove security for party 2 we split the proof on the cases on party 2's input ›

lemma R2_S2_False:
  assumes "((if σ then b0 else b1) = False)" 
  shows "spmf (R2 (b0,b1) σ  (D2 :: (bool × 'index × bool × bool)  bool spmf)) True 
                = spmf (funct_OT_12 (b0,b1) σ  (λ (out1,out2). S2 σ out2  D2)) True"
proof-
  have "σ  ¬ b0" using assms by simp
  moreover have "¬ σ  ¬ b1" using assms by simp
  ultimately show ?thesis
    by(auto simp add: R2_def S2_def split_def local.etp.F_f_inv assms funct_OT_12_def cong: bind_spmf_cong_simp) 
qed

lemma R2_S2_True:
  assumes "((if σ then b0 else b1) = True)" 
    and lossless_D: " a. lossless_spmf (D2 a)"
  shows "¦(spmf (bind_spmf (R2 (b0,b1) σ) D2) True) - spmf (funct_OT_12 (b0,b1) σ  (λ (out1, out2). S2 σ out2  (λ view. D2 view))) True¦
                         = ¦2*((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)¦"
proof-
  have  "(spmf (funct_OT_12 (b0,b1) σ  (λ (out1, out2). S2 σ out2  D2)) True
              - spmf (bind_spmf (R2 (b0,b1) σ) D2) True) 
                    = 2 * ((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)"
  proof-
    have  "((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)  = 
                  1/2*(spmf (bind_spmf (S2 σ (if σ then b1 else b0)) D2) True
                        - spmf (bind_spmf (R2 (b0,b1) σ) D2) True)"
      including monad_normalisation
    proof- 
      have σ_true_b0_true: "σ  b0 = True" using assms(1) by simp
      have σ_false_b1_true: "¬ σ  b1" using assms(1) by simp 
      have return_True_False: "spmf (return_spmf (¬ d)) True = spmf (return_spmf d) False"
        for d by(cases d; simp)
      define HCP_game_true where "HCP_game_true == λ σ bσ. do {
    (α, τ)  I;
    xσ  etp.S α;
    x  (etp.S α);
    let βσ = (B α xσ)  bσ;
    let βσ' = B α (Finv α τ x); 
    d  D2(σ, α, βσ, βσ');
    let b' = (if d then βσ' else ¬ βσ');
    let b = B α (Finv α τ x);
    return_spmf (b = b')}"
      define HCP_game_false where "HCP_game_false == λ σ bσ. do {
    (α, τ)  I;
    xσ  etp.S α;
    x  (etp.S α);
    let βσ = (B α xσ)  bσ;
    let βσ' = ¬ B α (Finv α τ x); 
    d  D2(σ, α, βσ, βσ');
    let b' = (if d then βσ' else ¬ βσ');
    let b = B α (Finv α τ x);
    return_spmf (b = b')}"
      define HCP_game_𝒜 where "HCP_game_𝒜 == λ σ bσ. do {
    βσ'  coin_spmf;
    (α, τ)  I;
    x  etp.S α;
    x'  etp.S α;
    d  D2 (σ, α, (B α x)  bσ, βσ');
    let b' = (if d then  βσ' else ¬ βσ');
    return_spmf (B α (Finv α τ x') = b')}"
      define S2D where "S2D == λ σ bσ . do {
      (α, τ)  I;
      xσ  etp.S α;
      yσ'  etp.S α;
      let xσ' = Finv α τ yσ';
      let βσ = (B α xσ)  bσ;
      let βσ' = B α xσ';
      d :: bool  D2(σ, α, βσ, βσ');
      return_spmf d}"
      define R2D where "R2D == λ msgs σ.  do {
      let (b0,b1) = msgs;
      (α, τ)  I;
      xσ  etp.S α;
      yσ'  etp.S α;
      let yσ = F α xσ;
      let xσ = Finv α τ yσ;
      let xσ' = Finv α τ yσ';
      let βσ = (B α xσ)  (if σ then b1 else b0) ;
      let βσ' = (B α xσ')  (if σ then b0 else b1);
      b :: bool  D2(σ, α,(βσ, βσ'));
      return_spmf b}"
      define D_true where "D_true  == λσ bσ. do {
    (α, τ)  I;
    xσ  etp.S α;
    x  (etp.S α);
    let βσ = (B α xσ)  bσ;
    let βσ' = B α (Finv α τ x);
    d :: bool  D2(σ, α, βσ, βσ');
    return_spmf d}"
      define D_false where "D_false == λ σ bσ. do {
    (α, τ)  I;
    xσ  etp.S α;
    x  etp.S α;
    let βσ = (B α xσ)  bσ;
    let βσ' = ¬ B α (Finv α τ x);
    d :: bool  D2(σ, α, βσ, βσ');
    return_spmf d}"
      have lossless_D_false: "lossless_spmf (D_false σ (if σ then b1 else b0))"
        apply(auto simp add: D_false_def lossless_D local.etp.lossless_I) 
        using local.etp.lossless_S by auto
      have "spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True =  spmf (HCP_game_𝒜 σ (if σ then b1 else b0)) True" 
        apply(simp add: etp.HCP_game_def HCP_game_𝒜_def 𝒜_def split_def etp.F_f_inv)
        by(rewrite bind_commute_spmf[where q = "coin_spmf"]; rewrite bind_commute_spmf[where q = "coin_spmf"]; rewrite bind_commute_spmf[where q = "coin_spmf"]; auto)+
      also have "... = spmf (bind_spmf (map_spmf Not coin_spmf) (λb. if b then HCP_game_true σ (if σ then b1 else b0) else HCP_game_false σ (if σ then b1 else b0))) True"
        unfolding HCP_game_𝒜_def HCP_game_true_def HCP_game_false_def 𝒜_def Let_def
        apply(simp add: split_def cong: if_cong)
        supply [[simproc del: monad_normalisation]]
        apply(subst if_distrib[where f = "bind_spmf _" for f, symmetric]; simp cong: bind_spmf_cong add: if_distribR )+
        apply(rewrite in "_ = " bind_commute_spmf)
        apply(rewrite in "bind_spmf _ "  in "_ = " bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in "bind_spmf _ " in "_ = " bind_commute_spmf)
        apply(rewrite in " = _" bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in " = _" bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in "bind_spmf _ " in " = _" bind_commute_spmf)
        apply(fold map_spmf_conv_bind_spmf)
        apply(rule conjI; rule impI; simp) 
         apply(simp only: spmf_bind)
         apply(rule Bochner_Integration.integral_cong[OF refl])+
         apply clarify
        subgoal for r rσ α τ 
          apply(simp only: UNIV_bool spmf_of_set integral_spmf_of_set) 
          apply(simp cong: if_cong split del: if_split)
          apply(cases "B r (Finv r rσ τ)") 
          by auto
        apply(rewrite in "_ = " bind_commute_spmf)
        apply(rewrite in "bind_spmf _ "  in "_ = " bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in "bind_spmf _ " in "_ = " bind_commute_spmf)
        apply(rewrite in " = _" bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in " = _" bind_commute_spmf)
        apply(rewrite in "bind_spmf _ " in "bind_spmf _ " in " = _" bind_commute_spmf)
        apply(simp only: spmf_bind)
        apply(rule Bochner_Integration.integral_cong[OF refl])+
        apply clarify
        subgoal for r rσ α τ 
          apply(simp only: UNIV_bool spmf_of_set integral_spmf_of_set) 
          apply(simp cong: if_cong split del: if_split)
          apply(cases " B r (Finv r rσ τ)") 
          by auto
        done
      also have "... = 1/2*(spmf (HCP_game_true σ (if σ then b1 else b0)) True) + 1/2*(spmf (HCP_game_false σ (if σ then b1 else b0)) True)"
        by(simp add: spmf_bind UNIV_bool spmf_of_set integral_spmf_of_set)
      also have "... = 1/2*(spmf (D_true σ (if σ then b1 else b0)) True) + 1/2*(spmf (D_false σ (if σ then b1 else b0)) False)"   
      proof-
        have "spmf (I  (λ(α, τ). etp.S α  (λxσ. etp.S α  (λx. D2 (σ, α, B α xσ = (¬ (if σ then b1 else b0)), ¬ B α (Finv α τ x))  (λd. return_spmf (¬ d)))))) True 
                = spmf (I  (λ(α, τ). etp.S α  (λxσ. etp.S α  (λx. D2 (σ, α, B α xσ = (¬ (if σ then b1 else b0)), ¬ B α (Finv α τ x)))))) False"
          (is "?lhs = ?rhs")
        proof-
          have "?lhs = spmf (I  (λ(α, τ). etp.S α  (λxσ. etp.S α  (λx. D2 (σ, α, B α xσ = (¬ (if σ then b1 else b0)), ¬ B α (Finv α τ x))  (λd. return_spmf (d)))))) False"
            by(simp only: split_def return_True_False spmf_bind) 
          then show ?thesis by simp
        qed
        then show ?thesis  by(simp add: HCP_game_true_def HCP_game_false_def Let_def D_true_def D_false_def if_distrib[where f="(=) _"] cong: if_cong)   
      qed
      also have "... =  1/2*((spmf (D_true σ (if σ then b1 else b0) ) True) + (1 - spmf (D_false σ (if σ then b1 else b0) ) True))"
        by(simp add: spmf_False_conv_True lossless_D_false)
      also have "... = 1/2 + 1/2* (spmf (D_true σ (if σ then b1 else b0)) True) - 1/2*(spmf (D_false σ (if σ then b1 else b0)) True)" 
        by(simp)     
      also have "... =  1/2 + 1/2* (spmf (S2D σ (if σ then b1 else b0) ) True) - 1/2*(spmf (R2D (b0,b1) σ ) True)"
        apply(auto  simp add: local.etp.F_f_inv S2D_def R2D_def D_true_def D_false_def  assms split_def cong: bind_spmf_cong_simp)
         apply(simp add: σ_true_b0_true)
        by(simp add: σ_false_b1_true)
      ultimately show ?thesis by(simp add: S2D_def R2D_def R2_def S2_def split_def)
    qed
    then show ?thesis by(auto simp add: funct_OT_12_def)
  qed
  thus ?thesis by simp
qed

lemma P2_adv_bound:
  assumes lossless_D: " a. lossless_spmf (D2 a)"
  shows "¦(spmf (bind_spmf (R2 (b0,b1) σ) D2) True) - spmf (funct_OT_12 (b0,b1) σ  (λ (out1, out2). S2 σ out2  (λ view. D2 view))) True¦
                          ¦2*((spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D2) True) - 1/2)¦"
  by(cases "(if σ then b0 else b1)"; auto simp add: R2_S2_False R2_S2_True assms)

sublocale OT_12: sim_det_def R1 S1 R2 S2 funct_OT_12 protocol 
  unfolding sim_det_def_def 
  by(simp add: lossless_R1 lossless_S1 lossless_R2 lossless_S2 funct_OT_12_def)

lemma correct: "OT_12.correctness m1 m2"
  unfolding OT_12.correctness_def  
  by (metis prod.collapse correctness)

lemma P1_security_inf_the: "OT_12.perfect_sec_P1 m1 m2" 
  unfolding OT_12.perfect_sec_P1_def using P1_security by simp 

lemma P2_security:
  assumes " a. lossless_spmf (D a)"
  and " bσ. etp.HCP_adv 𝒜 m2 bσ D  HCP_ad"
  shows "OT_12.adv_P2 m1 m2 D  2 * HCP_ad"
proof-
  have "spmf (etp.HCP_game 𝒜 σ (if σ then b1 else b0) D) True = spmf (funct_OT_12 (b0,b1) σ  (λ (out1, out2). etp.HCP_game 𝒜 σ out2 D)) True"
    for σ b0 b1
  by(simp add: funct_OT_12_def)
  hence "OT_12.adv_P2 m1 m2 D  ¦2*((spmf (funct_OT_12 m1 m2  (λ (out1, out2). etp.HCP_game 𝒜 m2 out2 D)) True) - 1/2)¦"
    unfolding OT_12.adv_P2_def using P2_adv_bound assms surj_pair prod.collapse by metis
  moreover have "¦2*((spmf (funct_OT_12 m1 m2  (λ (out1, out2). etp.HCP_game 𝒜 m2 out2 D)) True) - 1/2)¦  ¦2*HCP_ad¦" 
  proof -
    have "(r. ¦(1::real) / r¦  1 / ¦r¦)  2 / ¦1 / (spmf (funct_OT_12 m1 m2 
                 (λ(x, y). ((λu b. etp.HCP_game 𝒜 m2 b D)::unit  bool  bool spmf) x y)) True - 1 / 2)¦ 
                       HCP_ad / (1 / 2)"
      using assm_bound_funct_OT_12_collapse assms by auto
    then show ?thesis
      by fastforce
  qed 
  moreover have "HCP_ad  0" 
    using assms(2)  local.etp.HCP_adv_def by auto
  ultimately show ?thesis by argo
qed

end

text ‹ We also consider the asymptotic case for security proofs ›

locale ETP_sec_para = 
  fixes I :: "nat  ('index × 'trap) spmf"
    and domain ::  "'index  'range set"
    and range ::  "'index  'range set"
    and f :: "'index  ('range  'range)"
    and F :: "'index  'range  'range"
    and Finv :: "'index  'trap  'range  'range"
    and B :: "'index  'range  bool"
  assumes ETP_base: " n. ETP_base (I n) domain range F Finv"
begin

sublocale ETP_base "(I n)" domain range 
  using ETP_base  by simp

lemma correct_asym: "OT_12.correctness n m1 m2"
  by(simp add: correct)

lemma P1_sec_asym: "OT_12.perfect_sec_P1 n m1 m2"
  using P1_security_inf_the by simp                                                                

lemma P2_sec_asym: 
  assumes " a. lossless_spmf (D a)" 
    and HCP_adv_neg: "negligible (λ n. etp_advantage n)"
    and etp_adv_bound: " bσ n. etp.HCP_adv n 𝒜 m2 bσ D  etp_advantage n"
  shows "negligible (λ n. OT_12.adv_P2 n m1 m2 D)" 
proof-
  have "negligible (λ n. 2 * etp_advantage n)" using HCP_adv_neg 
    by (simp add: negligible_cmultI)
  moreover have "¦OT_12.adv_P2 n m1 m2 D¦ = OT_12.adv_P2 n m1 m2 D" for n unfolding OT_12.adv_P2_def by simp
  moreover have  "OT_12.adv_P2 n m1 m2 D  2 * etp_advantage n" for n using assms P2_security by blast
  ultimately show ?thesis 
    using assms negligible_le HCP_adv_neg P2_security by presburger 
qed

end

end