Theory Invariant_Preservation

section Invariant_Preservation› – Preservation of invariants under queries›

theory Invariant_Preservation
  imports Function_At Misc_Compressed_Oracle
begin

hide_const (open) Order.top
no_notation Order.bottom ("ı")
unbundle no m_inv_syntax
unbundle lattice_syntax

subsection ‹Invariants›


definition preserves U I J ε  ε  0  (ψspace_as_set I. norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ)
  for U :: 'a::chilbert_space CL 'b::chilbert_space

lemma preserves_def_closure:
  assumes space_as_set I = closure I'
  shows preserves U I J ε  ε  0  (ψI'. norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ)
proof (rule iffI; (elim conjE)?)
  show preserves U I J ε  0  ε  (ψI'. norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ)
    by (metis assms closure_subset in_mono preserves_def)
  show preserves U I J ε
    if 0  ε and bound: (ψI'. norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ)
  proof (unfold preserves_def, intro conjI ballI)
    from that show ε  0 by simp
    fix ψ assume ψ  space_as_set I
    with assms have ψ  closure I'
      by simp
    then obtain φ where φ  ψ and φ n  I' for n
      using closure_sequential by blast
    define f where f ξ = ε * norm ξ - norm (U *V ξ - Proj J *V U *V ξ) for ξ
    with φ _  I' bound have bound': f (φ n)  0 for n
      by simp
    have continuous_on UNIV f
      unfolding f_def
      by (intro continuous_intros)
    then have (λn. f (φ n))  f ψ
      using φ  ψ  apply (rule continuous_on_tendsto_compose[where s=UNIV and f=f])
      by auto
    with bound' have f ψ  0
      by (simp add: Lim_bounded2)
    then show norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ
      by (simp add: f_def)
  qed
qed

lemma preservesI_closure:
  assumes ε  0
  assumes closure: space_as_set I  closure I'
  assumes csubspace I'
  assumes bound: ψ. ψ  I'  norm ψ = 1  norm (U *V ψ - Proj J *V U *V ψ)  ε
  shows preserves U I J ε
proof -
  have *: space_as_set (ccspan I') = closure I'
    by (metis assms(3) ccspan.rep_eq complex_vector.span_eq_iff)
  have preserves U (ccspan I') J ε
  proof (unfold preserves_def_closure[OF *], intro conjI ballI)
    from assms show ε  0 by simp

    fix ψ assume ψI: ψ  I'
    show norm (U *V ψ - Proj J *V U *V ψ)  ε * norm ψ
    proof (cases ψ = 0)
      case True
      then show ?thesis by auto
    next
      case False
      then have norm ψ > 0
        by simp
      define φ where φ = ψ /C norm ψ
      from ψI have φ  I'
        by (simp add: φ_def csubspace I' complex_vector.subspace_scale)
      moreover from False have norm φ = 1
        by (simp add: φ_def norm_inverse)
      ultimately have norm (U *V φ - Proj J *V U *V φ)  ε
        by (rule bound)
      then have norm (U *V ψ - Proj J *V U *V ψ) / norm ψ  ε
        unfolding φ_def
        by (auto simp flip: scaleC_diff_right
            simp add: norm_inverse divide_inverse_commute cblinfun.scaleC_right)
      with norm ψ > 0 show ?thesis
        by (simp add: divide_le_eq)
    qed
  qed
  then show preserves U I J ε
    by (smt (verit) "*" closure in_mono preserves_def)
qed


lemma preservesI:
  assumes ε  0
  assumes ψ. ψ  space_as_set I  norm ψ = 1  norm (U *V ψ - Proj J *V U *V ψ)  ε
  shows preserves U I J ε
  apply (rule preservesI_closure[where I'=space_as_set I])
  using assms by auto

lemma preservesI':
  assumes ε  0
  assumes ψ. ψ  space_as_set I  norm ψ = 1  norm (Proj (-J) *V U *V ψ)  ε
  shows preserves U I J ε
  using ε0 apply (rule preservesI)
  apply (frule assms(2))
  by (simp_all add: Proj_ortho_compl cblinfun.diff_left)

lemma preserves_onorm: preserves U I J ε  norm ((id_cblinfun - Proj J) oCL U oCL Proj I)  ε
proof (rule iffI) 
  assume pres: preserves U I J ε
  show norm ((id_cblinfun - Proj J) oCL U oCL Proj I)  ε
  proof (rule norm_cblinfun_bound)
    from pres show ε  0
      by (simp add: preserves_def)
    fix ψ
    define φ where φ = Proj I *V ψ
    have normφ: norm φ  norm ψ
      unfolding φ_def apply (rule is_Proj_reduces_norm) by simp

    have norm (((id_cblinfun - Proj J) oCL U oCL Proj I) *V ψ) = norm (U *V φ - Proj J *V U *V φ)
      unfolding φ_def by (simp add: cblinfun.diff_left)
    also from pres have   ε * norm φ
      by (metis Proj_range φ_def cblinfun_apply_in_image preserves_def)
    also have   ε * norm ψ
      by (simp add: 0  ε mult_left_mono normφ)
    finally show norm (((id_cblinfun - Proj J) oCL U oCL Proj I) *V ψ)  ε * norm ψ
      by -
  qed
next
  assume norm: norm ((id_cblinfun - Proj J) oCL U oCL Proj I)  ε
  show preserves U I J ε
  proof (rule preservesI)
    show ε  0
      using norm norm_ge_zero order_trans by blast
    fix ψ assume [simp]: ψ  space_as_set I and [simp]: norm ψ = 1
    have norm (U *V ψ - Proj J *V U *V ψ) = norm ((id_cblinfun - Proj J) *V U *V ψ)
      by (simp add: cblinfun.diff_left)
    also have  = norm ((id_cblinfun - Proj J) *V U *V Proj I *V ψ)
      by (simp add: Proj_fixes_image)
    also have  = norm (((id_cblinfun - Proj J) oCL U oCL Proj I) *V ψ)
      by simp
    also have   norm ((id_cblinfun - Proj J) oCL U oCL Proj I) * norm ψ
      using norm_cblinfun by blast
    also have   ε
      by (simp add: norm)
    finally show norm (U *V ψ - Proj J *V U *V ψ)  ε
      by -
  qed
qed

lemma preserves_cong: 
  assumes ψ. ψ  space_as_set I  U *V ψ = U' *V ψ
  shows preserves U I J ε  preserves U' I J ε
  by (simp add: assms preserves_def)

lemma preserves_mono:
  assumes preserves U I J ε
  assumes I  I'
  assumes J  J'
  assumes ε  ε'
  shows preserves U I' J' ε'
proof (rule preservesI)
  show ε'  0
    by (smt (verit) assms(1) assms(4) preserves_def)
  fix ψ assume ψ  space_as_set I'
  then have ψ  space_as_set I
    using I  I' less_eq_ccsubspace.rep_eq by blast
  assume [simp]: norm ψ = 1

  have norm (U *V ψ - Proj J' *V U *V ψ) = norm ((id_cblinfun - Proj J') *V U *V ψ)
    by (simp add: cblinfun.diff_left)
  also have   norm ((id_cblinfun - Proj J) *V U *V ψ)
  proof -
    from J  J'
    have id_cblinfun - Proj J  id_cblinfun - Proj J'
      by (simp add: Proj_mono)
    then show ?thesis
      by (metis (no_types, lifting) Proj_fixes_image Proj_ortho_compl Proj_range adj_Proj cblinfun_apply_in_image cdot_square_norm cinner_adj_right cnorm_ge_square less_eq_cblinfun_def)
  qed
  also have  = norm (U *V ψ - Proj J *V U *V ψ)
    by (simp add: cblinfun.diff_left)
  also from ψ  space_as_set I preserves U I J ε
  have   ε
    by (auto simp: preserves_def)
  also have   ε'
    using ε  ε'
    by (simp add: mult_right_mono)
  finally show norm (U *V ψ - Proj J' *V U *V ψ)  ε'
    by simp
qed

text ‹The next lemma allows us to decompose the preservation of an invariant into the preservation
  of simpler invariants. The main requirement is that the simpler invariants are all orthogonal.

  This is in particular useful when one wants to show the preservation of an invariant that
  refers to the oracle input register and other unrelated registers.
  One can then decompose the invariant into many invariants that fix the input and unrelated registers
  to specific computational basis states. (I.e., wlog the input register is in a state of the form termket x.

  Unfortunately, we have a proof only in the case of finitely many simpler invariants.
  This excludes, e.g., infinite oracle input registers etc. (e.g., quantum ints, quantum lists).›

lemma invariant_splitting:
  fixes X :: 'i set
  fixes I S :: 'i  'a::chilbert_space ccsubspace
  fixes J :: 'i  'b::chilbert_space ccsubspace
  assumes ortho_S: x y. xX  yX  x  y  orthogonal_spaces (S x) (S y)
  assumes ortho_S': x y. xX  yX  x  y  orthogonal_spaces (S' x) (S' y)
  assumes IS: x. xX  I x  S x
  assumes JS': x. xX  J x  S' x
  assumes USS': x. xX  U *S S x  S' x
  assumes II: II  (xX. I x)
  assumes JJ: JJ  (xX. J x)
  assumes ε0: ε  0
  assumes [iff]: finite X
  assumes pres: x. xX  preserves U (I x) (J x) ε
  shows preserves U II JJ ε
proof -
  have preserves U (xX. I x) (xX. J x) ε
  proof (rule preservesI_closure[where I'=(xX. space_as_set (I x))])
    from ε0 show ε  0 by -

    show csubspace (xX. space_as_set (I x))
      by (simp add: csubspace_set_sum)
    show space_as_set (sum I X)  closure (xX. space_as_set (I x))
      apply (rule eq_refl)
      apply (use finite X in induction)
      by (auto simp: sup_ccsubspace.rep_eq simp flip: closed_sum_def)

    fix ψ assume ψ  (xX. space_as_set (I x))
    then obtain ψ' where ψ'I: ψ' x  space_as_set (I x) and ψ'sum: ψ = (xX. ψ' x) for x
    proof (atomize_elim, use finite X in induction arbitrary: ψ)
      case empty
      then show ?case 
        by (auto intro!: exI[where x=λ_. 0])
    next
      case (insert x X)
      have aux: ψ  space_as_set (I x) + (xX. space_as_set (I x)) 
        ψ0 ψ1. ψ = ψ0 + ψ1  ψ0  (xX. space_as_set (I x))  ψ1  space_as_set (I x)
        by (metis add.commute set_plus_elim)
      from insert.prems
      obtain ψ0 ψ1 where ψ_decomp: ψ = ψ0 + ψ1 and ψ0: ψ0  (xX. space_as_set (I x)) and ψ1: ψ1  space_as_set (I x)
        apply atomize_elim by (auto intro!: aux simp: insert.hyps)
      from insert.IH[OF ψ0]
      obtain ψ0' where ψ0'I: ψ0' x  space_as_set (I x) and ψ0'sum: ψ0 = sum ψ0' X for x
        by auto
      define ψ' where ψ' = ψ0'(x := ψ1)
      have ψ' x  space_as_set (I x) for x
        by (simp add: ψ'_def ψ0'I ψ1)
      moreover have ψ = sum ψ' (insert x X)
        by (metis ψ'_def ψ0'sum ψ_decomp add.commute fun_upd_apply insert.hyps(1) insert.hyps(2) sum.cong sum.insert)
      ultimately show ?case
        by auto
    qed

    assume [simp]: norm ψ = 1

    define η' η where η' x = U *V (ψ' x) - Proj (J x) *V U *V (ψ' x) and η = (xX. η' x) for x
    with pres have η'bound: norm (η' x)  ε * norm (ψ' x) if xX for x
      using that by (simp add: ψ'I preserves_def)
    define US where US x = U *S S x for x
 
    have ψ' x  space_as_set (S x) if xX for x
      using that ψ'I IS less_eq_ccsubspace.rep_eq by blast
    then have Uψ'S': U *V ψ' x  space_as_set (S' x) if xX for x
      using USS'[OF that] that
      by (metis cblinfun_image.rep_eq closure_subset imageI in_mono less_eq_ccsubspace.rep_eq)

    have η'S': η' x  space_as_set (S' x) if xX for x
    proof -
      have Proj (J x) *V U *V (ψ' x)  space_as_set (J x)
        by (metis Proj_range cblinfun_apply_in_image)
      also have   space_as_set (S' x)
        unfolding US_def less_eq_ccsubspace.rep_eq[symmetric] using JS' that by auto
      finally have *: Proj (J x) *V U *V (ψ' x)  space_as_set (S' x)
        by -
      with Uψ'S'[OF that]
      show η' x  space_as_set (S' x)
        unfolding η'_def
        by (metis Proj_fixes_image Proj_range cblinfun.diff_right cblinfun_apply_in_image)
    qed
    from ortho_S' USS'
    have ortho_US: orthogonal_spaces (US x) (US y) if x  y and xX and yX for x y
      by (metis US_def in_mono less_eq_ccsubspace.rep_eq orthogonal_spaces_def
          that(1,2,3))
    have ortho_I: orthogonal_spaces (I x) (I y) if x  y and xX and yX for x y
      by (meson IS less_eq_ccsubspace.rep_eq ortho_S orthogonal_spaces_def subsetD that)
    have ortho_J: orthogonal_spaces (J x) (J y) if x  y and xX and yX for x y
      using JS' ortho_S' that
      by (meson less_eq_ccsubspace.rep_eq orthogonal_spaces_def subsetD)

    from ortho_S' η'S'
    have η'ortho: is_orthogonal (η' x) (η' y) if x  y and xX and yX for x y
      by (meson orthogonal_spaces_def that)
    have ψ'ortho: is_orthogonal (ψ' x) (ψ' y) if x  y and xX and yX for x y
      using ψ'I ortho_I orthogonal_spaces_def that by blast

    have η'2: η' x = U *V ψ' x - Proj (xX. (J x)) *V U *V ψ' x if x  X for x
    proof -
      have Proj (J y) *V U *V ψ' x = 0 if x  y and y  X for y
      proof -
        have U *V ψ' x  space_as_set (S' x)
          using x  X by (rule Uψ'S')
        moreover have orthogonal_spaces (S' x) (J y)
          using JS'[OF yX] ortho_S'[OF xX yX xy]
          by (meson less_eq_ccsubspace.rep_eq orthogonal_spaces_def subset_eq)
        ultimately show ?thesis
          by (metis (no_types, opaque_lifting) Proj_fixes_image Proj_ortho_compl Proj_range Set.basic_monos(7) cancel_comm_monoid_add_class.diff_cancel cblinfun.diff_left cblinfun.diff_right cblinfun_apply_in_image id_cblinfun.rep_eq less_eq_ccsubspace.rep_eq orthogonal_spaces_leq_compl)
      qed
      then have η' x = U *V ψ' x - Proj (J x) *V U *V ψ' x - (yX-{x}. Proj (J y) *V U *V ψ' x)
        unfolding η'_def
        by (metis (no_types, lifting) DiffE Diff_insert_absorb diff_0_right mk_disjoint_insert sum.not_neutral_contains_not_neutral)
      also have  = U *V ψ' x - (yX. Proj (J y) *V U *V ψ' x)
        apply (subst (2) asm_rl[of X = {x}  (X-{x})])
         apply (simp add: insert_absorb x  X)
        apply (subst sum.union_disjoint)
        by auto
      also have  = U *V ψ' x - (yX. Proj (J y)) *V U *V ψ' x
        by (simp add: cblinfun.sum_left)
      also have  = U *V ψ' x - Proj (yX. J y) *V U *V ψ' x
        apply (subst Proj_sum_spaces)
        using ortho_J by auto
      finally show ?thesis
        by -
    qed

    have norm (U *V ψ - Proj (sum J X) *V U *V ψ) = norm (xX. U *V ψ' x - Proj (sum J X) *V U *V ψ' x)
      by (simp add: ψ'sum sum_subtractf cblinfun.sum_right)
    also from η'2 have  = norm (xX. η' x)
      by simp
    also have  = norm η
      using η_def by blast
    also have (norm η)2 = (xX. (norm (η' x))2)
      unfolding η_def
      apply (rule pythagorean_theorem_sum)
      using η'ortho by auto
    also have   (xX. (ε * norm (ψ' x))2)
      apply (rule sum_mono)
      by (simp add: η'bound power_mono)
    also have   ε2 * (xX. (norm (ψ' x))2)
      by (simp add: sum_distrib_left power_mult_distrib)
    also have  = ε2 * (norm ψ)2
    proof -
      have aux: a  X  a'  X  a  a'  ψ = sum ψ' X  is_orthogonal (ψ' a) (ψ' a') for a a'
        by (meson ψ'I IS less_eq_ccsubspace.rep_eq ortho_S orthogonal_spaces_def subset_iff)
      show ?thesis
      apply (subst pythagorean_theorem_sum[symmetric])
        using ψ'sum aux by auto
    qed
    finally show norm (U *V ψ - Proj (sum J X) *V U *V ψ)  ε
      using ε0 norm ψ = 1 by (auto simp flip: power_mult_distrib)
  qed
  then show ?thesis
    apply (rule preserves_mono)
    using assms by auto
qed

text ‹An invariant that is consists of all states that are the superposition of computational basis states.

Useful for representing a classically formulated condition (e.g., termx  0) as an invariant (termket_invariant {x. x0}).›

definition ket_invariant M = ccspan (ket ` M)

lemma ket_invariant_UNIV[simp]: ket_invariant UNIV = 
  unfolding ket_invariant_def by simp

lemma ket_invariant_empty[simp]: ket_invariant {} = 
  unfolding ket_invariant_def by simp

lemma ket_invariant_Rep_ell2: ψ  space_as_set (ket_invariant I)  (i-I. Rep_ell2 ψ i = 0)
  by (simp add: ket_invariant_def space_ccspan_ket) 

lemma ket_invariant_compl: ket_invariant (-M) = - ket_invariant M
proof -
  have ket_invariant (-M)  - ket_invariant M for M :: 'a set
    unfolding ket_invariant_def
    apply (rule ccspan_leq_ortho_ccspan) 
    by auto
  moreover have - ket_invariant M  ket_invariant (-M)
  proof (rule ccsubspace_leI_unit)
    fix ψ
    assume ψ  space_as_set (- ket_invariant M)
    then have is_orthogonal ψ φ if φ  space_as_set (ket_invariant M) for φ
      using that
      by (auto simp: uminus_ccsubspace.rep_eq orthogonal_complement_def)
    then have is_orthogonal (ket m) ψ if m  M for m
      by (simp add: ccspan_superset' is_orthogonal_sym ket_invariant_def that)
    then have Rep_ell2 ψ m = 0 if m  M for m
      by (simp add: cinner_ket_left that)
    then show ψ  space_as_set (ket_invariant (- M))
      unfolding ket_invariant_Rep_ell2
      by simp
  qed
  ultimately show ?thesis
    by (rule order.antisym)
qed

lemma ket_invariant_tensor: ket_invariant I S ket_invariant J = ket_invariant (I × J)
proof -
  have ket_invariant I S ket_invariant J = ccspan {x s y |x y. x  ket ` I  y  ket ` J}
    by (simp add: tensor_ccsubspace_ccspan ket_invariant_def)
  also have  = ccspan {ket (x, y)| x y. x  I  y  J}
    by (auto intro!: arg_cong[where f=ccspan] simp flip: tensor_ell2_ket)
  also have  = ccspan (ket ` (I × J))
    by (auto intro!: arg_cong[where f=ccspan])
  also have  = ket_invariant (I × J)
    by (simp add: ket_invariant_def)
  finally show ?thesis
    by -
qed


abbreviation preserves_ket U I J ε  preserves U (ket_invariant I) (ket_invariant J) ε

lemma orthogonal_spaces_ket[simp]: orthogonal_spaces (ket_invariant M) (ket_invariant N)  M  N = {} for M N
  apply rule
  apply (simp add: ket_invariant_def orthogonal_spaces_def)
  apply (metis Int_emptyI ccspan_superset imageI inf_commute ket_invariant_def orthogonal_ket subset_iff)
  apply (simp add: orthogonal_spaces_leq_compl ket_invariant_def)
  by (smt (verit, best) ccspan_leq_ortho_ccspan disjoint_iff_not_equal imageE orthogonal_ket)

lemma ket_invariant_le[simp]: ket_invariant M  ket_invariant N  M  N for M N
proof -
  have x  N 
    if x  M and *: ψ. (y. y  M  Rep_ell2 ψ y = 0)  (y. y  N  Rep_ell2 ψ y = 0) for x
    using *[of ket x]
    using x  M by (auto simp: ket.rep_eq)
  then show ?thesis
    by (auto simp add: less_eq_ccsubspace.rep_eq subset_eq Ball_def ket_invariant_Rep_ell2) 
qed

lemma ket_invariant_mono:
  assumes I  J
  shows ket_invariant I  ket_invariant J
  using [[simp_trace]]
  by (simp add: assms)

lemma ket_invariant_Inf: ket_invariant (Inf M) = Inf (ket_invariant ` M)
proof (rule order.antisym)
  show ket_invariant ( M)  Inf (ket_invariant ` M)
    by (simp add: Inf_lower le_Inf_iff)
  show Inf (ket_invariant ` M)  ket_invariant ( M)
  proof (rule ccsubspace_leI_unit)
    fix ψ
    assume ψ  space_as_set (Inf (ket_invariant ` M))
    then have ψ  space_as_set (ket_invariant N) if N  M for N
      by (metis Inf_lower imageI in_mono less_eq_ccsubspace.rep_eq that)
    then have Rep_ell2 ψ n = 0 if n  N and N  M for n N
      using that by (auto simp: ket_invariant_Rep_ell2)
    then have Rep_ell2 ψ n = 0 if n  Inf M for n
      using that by blast
    then show ψ  space_as_set (ket_invariant ( M))
      by (meson ComplD ket_invariant_Rep_ell2)
  qed
qed
    

lemma ket_invariant_INF: ket_invariant (INF xM. f x) = (INF xM. ket_invariant (f x))
  by (simp add: image_image ket_invariant_Inf)


lemma ket_invariant_Sup: ket_invariant (Sup M) = Sup (ket_invariant ` M)
proof -
  have ket_invariant (Sup M) = ket_invariant (- (Inf (uminus ` M)))
    by (subst uminus_Inf, simp)
  also have  = - ket_invariant (Inf (uminus ` M))
    using ket_invariant_compl by blast
  also have  = - Inf (ket_invariant ` uminus ` M)
    using ket_invariant_Inf by auto
  also have  = - Inf (uminus ` ket_invariant ` M)
    by (metis (no_types, lifting) INF_cong image_image ket_invariant_compl)
  also have  = Sup (ket_invariant ` M)
    apply (subst uminus_Inf)
    by (metis (no_types, lifting) SUP_cong image_comp image_image o_apply ortho_involution)
  finally show ?thesis
    by -
qed

lemma ket_invariant_SUP: ket_invariant (SUP xM. f x) = (SUP xM. ket_invariant (f x))
  by (simp add: image_image ket_invariant_Sup)

lemma ket_invariant_inter: ket_invariant M  ket_invariant N = ket_invariant (M  N) for M N
  using ket_invariant_INF[where M=UNIV and f=λx. if x then M else N]
  by (smt (verit) INF_UNIV_bool_expand)

lemma ket_invariant_union: ket_invariant M  ket_invariant N = ket_invariant (M  N) for M N
  using ket_invariant_SUP[where M=UNIV and f=λx. if x then M else N]
  by (smt (verit) SUP_UNIV_bool_expand)

lemma sum_ket_invariant[simp]:
  assumes finite X
  shows (xX. ket_invariant (M x)) = ket_invariant (xX. M x)
  using assms apply induction
  apply auto using ket_invariant_union by blast

lemma ket_invariant_inj[simp]:
  ket_invariant M = ket_invariant N  M = N for M N
  by (metis dual_order.eq_iff ket_invariant_le)

text ‹Given an invariant on the content of a register, this gives the corresponding invariant 
  on the whole state. Useful for plugging together several invariants on different subsystems.›

definition lift_invariant F I = F (Proj I) *S 

lemma lift_invariant_comp: 
  assumes [simp]:  register G
  shows lift_invariant (F o G) = lift_invariant F o lift_invariant G
  by (auto intro!: ext simp: lift_invariant_def Proj_on_own_range register_projector)

lemma lift_invariant_top[simp]: register F  lift_invariant F  = 
  by (metis Proj_on_own_range' cblinfun_compose_id_right id_cblinfun_adjoint lift_invariant_def register_unitary unitary_id unitary_range)

lemma Proj_lift_invariant: register F  Proj (lift_invariant F I) = F (Proj I)
  using [[simproc del: Laws_Quantum.compatibility_warn]]
  unfolding lift_invariant_def
  by (simp add: Proj_on_own_range register_projector) 

lemma ket_invariant_image_assoc: 
  ket_invariant ((λ((a, b), c). (a, b, c)) ` X) = lift_invariant assoc (ket_invariant X)
proof -
  have ket_invariant ((λ((a, b), c). (a, b, c)) ` X) = assoc_ell2 *S ket_invariant X
    by (auto intro!: arg_cong[where f=ccspan] image_eqI simp add: ket_invariant_def image_image cblinfun_image_ccspan)
  also have  = lift_invariant assoc (ket_invariant X)
    by (simp add: lift_invariant_def assoc_ell2_sandwich Proj_sandwich)
  finally show ?thesis
    by -
qed

lemma lift_invariant_inj[simp]: lift_invariant F I = lift_invariant F J  I = J if [register]: register F
proof (rule iffI[rotated], simp)
  assume asm: lift_invariant F I = lift_invariant F J
  then have F (Proj I) *S  = F (Proj J) *S 
    by (simp add: lift_invariant_def)
  then have F (Proj I) = F (Proj J)
    by (metis Proj_lift_invariant asm that)
  then have Proj I = Proj J
    by (simp add: register_inj')
  then show I = J
    using Proj_inj by blast
qed

lemma lift_invariant_decomp:
  fixes U :: _ CL _::chilbert_space
  assumes θ. F θ = sandwich U *V (θ o id_cblinfun)
  assumes unitary U
  shows lift_invariant F I = U *S (I S )
  by (simp add: lift_invariant_def assms Proj_tensor_Proj Proj_sandwich flip: Proj_top)

text ‹Invariants are compatible if their projectors commute, i.e., if you can simultaneously measure them.
  This can happen if they refer to different parts of the system. (E.g., one talks about register X,
  the other about register Y.) But also for example for any ket-invariants.

  See lemma preserves_intersect› below for a useful consequence.›

definition compatible_invariants A B  Proj A oCL Proj B = Proj B oCL Proj A

lemma compatible_invariants_inter: Proj A oCL Proj B = Proj (A  B) if compatible_invariants A B
proof -
  have is_Proj (Proj A oCL Proj B)
    apply (rule is_Proj_I)
    apply (metis Proj_idempotent cblinfun_assoc_left(1) compatible_invariants_def that)
    by (metis adj_Proj adj_cblinfun_compose compatible_invariants_def that)

  have (Proj A oCL Proj B) *S   A
    by (simp add: Proj_image_leq cblinfun_compose_image)
  moreover have (Proj A oCL Proj B) *S   B
    using that by (simp add: Proj_image_leq cblinfun_compose_image compatible_invariants_def)
  ultimately have leq1: (Proj A oCL Proj B) *S   A  B
    by auto

  have leq2: A  B  (Proj A oCL Proj B) *S 
  proof (rule ccsubspace_leI, rule subsetI)
    fix ψ assume ψ  space_as_set (A  B)
    then have Proj A *V ψ = ψ Proj B *V ψ = ψ
      by (simp_all add: Proj_fixes_image)
    then have ψ = (Proj A oCL Proj B) *V ψ
      by simp
    also have (Proj A oCL Proj B) *V ψ  space_as_set ((Proj A oCL Proj B) *S )
      using cblinfun_apply_in_image by blast
    finally show ψ  space_as_set ((Proj A oCL Proj B) *S )
      by -
  qed

  from leq1 leq2 have (Proj A oCL Proj B) *S  = A  B
    using order_class.order_eq_iff by blast

  with is_Proj (Proj A oCL Proj B) show Proj A oCL Proj B = Proj (A  B)
    using Proj_on_own_range by force
qed

lemma compatible_invariants_ket[iff]: compatible_invariants (ket_invariant I) (ket_invariant J)
proof -
  have I: Proj (ket_invariant I) = Proj (ket_invariant (I-J)) + Proj (ket_invariant (IJ))
    apply (subst Proj_sup[symmetric])
    by (auto simp add: Un_Diff_Int ket_invariant_union)
  have J: Proj (ket_invariant J) = Proj (ket_invariant (J-I)) + Proj (ket_invariant (IJ))
    apply (subst Proj_sup[symmetric])
    by (auto intro!: arg_cong[where f=Proj] simp add: Un_Diff_Int ket_invariant_union)
  have Proj (ket_invariant I) oCL Proj (ket_invariant J) = Proj (ket_invariant J) oCL Proj (ket_invariant I)
    apply (simp add: I J)
    by (smt (verit) Diff_disjoint I Int_Diff_disjoint Proj_bot adj_Proj adj_cblinfun_compose cblinfun_compose_add_left cblinfun_compose_add_right orthogonal_projectors_orthogonal_spaces orthogonal_spaces_ket)
  then show ?thesis
    by (simp add: compatible_invariants_def)
qed

lemma preserves_intersect:
  assumes compatible_invariants J1 J2
  assumes pres1: preserves U I J1 ε1
  assumes pres2: preserves U I J2 ε2
  shows preserves U I (J1  J2) (ε1 + ε2)
(* TODO: can be improved to "sqrt (e1^2 + e2^2)" *)
proof (rule preservesI)
  show 0  ε1 + ε2
    by (meson add_nonneg_nonneg pres1 pres2 preserves_def)

  fix ψ assume ψ  space_as_set I and norm ψ = 1
  define φ J where φ = U *V ψ and J = J1  J2

  note norm_diff_triangle_le[trans]

  from pres1
  have norm (φ - Proj J1 *V φ)  ε1
    by (metis ψ  space_as_set I norm ψ = 1 φ_def mult_cancel_left1 preserves_def)
  also 
  have norm (φ - Proj J2 *V φ)  ε2
    using ψ  space_as_set I norm ψ = 1 φ_def pres2 preserves_def by force
  then have norm (Proj J1 *V (φ - Proj J2 *V φ))  ε2
    using Proj_is_Proj is_Proj_reduces_norm order_trans by blast
  then have norm (Proj J1 *V φ - Proj J1 *V Proj J2 *V φ)  ε2
    by (simp add: cblinfun.diff_right)
  also have Proj J1 *V Proj J2 *V φ = Proj J *V φ
    by (metis J_def assms(1) cblinfun_apply_cblinfun_compose compatible_invariants_inter)
  finally show norm (φ - Proj J *V φ)  ε1 + ε2
    by -
qed

lemma preserves_intersect_ket:
  assumes preserves_ket U I J1 ε1
  assumes preserves_ket U I J2 ε2
  shows preserves_ket U I (J1  J2) (ε1 + ε2)
  apply (simp flip: ket_invariant_inter)
  using _ assms apply (rule preserves_intersect)
  by (rule compatible_invariants_ket)

text ‹An invariant is compatible with a register intuitively if the invariant only talks about
  parts of the quantum state outside the register.›

definition compatible_register_invariant F I  (A. Proj I oCL F A = F A oCL Proj I)
  for F :: 'a update  'b update

lemma compatible_register_invariant_top[simp]:
  compatible_register_invariant F 
  by (simp add: compatible_register_invariant_def)

lemma compatible_register_invariant_bot[simp]:
  compatible_register_invariant F 
  by (simp add: compatible_register_invariant_def)


lemma compatible_register_invariant_id:
  assumes y. I = UNIV  I = {}
  shows compatible_register_invariant id (ket_invariant I)
  using assms
  by (metis compatible_register_invariant_bot compatible_register_invariant_top ket_invariant_UNIV ket_invariant_empty)

lemma compatible_register_invariant_compatible_register:
  assumes compatible F G
  shows compatible_register_invariant F (lift_invariant G I)
  unfolding compatible_register_invariant_def lift_invariant_def
  by (metis Proj_is_Proj Proj_on_own_range assms compatible_def register_projector)

lemma compatible_register_invariant_chain[simp]: 
  compatible_register_invariant (F o G) (lift_invariant F I)   compatible_register_invariant G I if [simp]: register F
  by (simp add: compatible_register_invariant_def Proj_lift_invariant register_mult register_inj[THEN inj_eq])

text ‹Allows to decompose the preservation of an invariant into a part that is preserved
  inside a register, and a part outside of it.›

lemma preserves_register:
  fixes F :: 'a update  'b update
  assumes pres: preserves U' I' J' ε
  assumes reg[register]: register F
  assumes compat: compatible_register_invariant F K
  assumes FU': ψspace_as_set I. F U' *V ψ = U *V ψ
  assumes FI'_I: lift_invariant F I'  I
  assumes KI: K  I
  assumes FJ'K_I: lift_invariant F J'  K  J
  shows preserves U I J ε
proof -
  define PI' PJ' where PI' = Proj I' and PJ' = Proj J'
  have 1: preserves (F U') (lift_invariant F I') (lift_invariant F J') ε
  proof (unfold preserves_onorm)
    have norm ((id_cblinfun - Proj (lift_invariant F J')) oCL F U' oCL Proj (lift_invariant F I'))
        = norm ((id_cblinfun - PJ') oCL U' oCL PI') (is ?lhs = _)
      by (smt (verit, best) PI'_def PJ'_def Proj_lift_invariant reg register_minus register_mult register_norm register_of_id)
    also from pres have   ε
      by (simp add: preserves_onorm PJ'_def PI'_def)
    finally show ?lhs  ε
      by -
  qed

  from compat
  have 2: preserves (F U') K K 0
    by (simp add: preserves_onorm cblinfun_compose_assoc cblinfun_compose_minus_left compatible_register_invariant_def)

  with 1 compat 
  have preserves (F U') (lift_invariant F I'  K) (lift_invariant F J'  K) ε
    apply (subst asm_rl[of ε = ε + 0], simp)
    apply (rule preserves_intersect)
    by (auto simp add: compatible_invariants_def compatible_register_invariant_def preserves_mono Proj_lift_invariant)

  then have preserves (F U') I J ε
    apply (rule preserves_mono)
    using FI'_I FJ'K_I KI by auto
  then show ?thesis
    apply (rule preserves_cong[THEN iffD1, rotated])
    using FU' by auto
qed

lemma preserves_top[simp]: ε  0  preserves U I  ε
  unfolding preserves_onorm by simp

lemma preserves_bot[simp]: ε  0  preserves U  J ε
  unfolding preserves_onorm by simp

lemma preserves_0[simp]: ε  0  preserves 0 I J ε
  unfolding preserves_onorm by simp


text ‹Tensor product of two invariants: The invariant that requires the first part of the system
to satisfy invariant termI and the second to satisfy termJ.›

definition tensor_invariant I J = ccspan {x s y | x y. x  space_as_set I  y  space_as_set J}

lemma tensor_invariant_via_Proj: tensor_invariant I J = (Proj I o Proj J) *S 
proof (rule Proj_inj, rule tensor_ell2_extensionality, rename_tac ψ φ)
  fix ψ φ
  define ψ1 ψ2 where ψ1 = Proj I ψ and ψ2 = Proj (-I) ψ
  have ψ = ψ1 + ψ2
    by (simp add: ψ1_def ψ2_def Proj_ortho_compl minus_cblinfun.rep_eq)
  have ψ1I: ψ1  space_as_set I
    by (metis Proj_idempotent ψ1_def cblinfun_apply_cblinfun_compose norm_Proj_apply) 

  define φ1 φ2 where φ1 = Proj J φ and φ2 = Proj (-J) φ
  have φ = φ1 + φ2
    by (simp add: φ1_def φ2_def Proj_ortho_compl minus_cblinfun.rep_eq)
  have φ1J: φ1  space_as_set J
    by (metis Proj_idempotent φ1_def cblinfun_apply_cblinfun_compose norm_Proj_apply) 

  have aux: xa  space_as_set I  y  space_as_set J  φ C y  0  is_orthogonal ψ2 xa for xa y
    by (metis Proj_fixes_image ψ = ψ1 + ψ2 ψ1I ψ1_def add_left_imp_eq cblinfun.real.add_right kernel_Proj kernel_memberI orthogonal_complement_orthoI pth_d uminus_ccsubspace.rep_eq)
  have ψ2 s φ  space_as_set (- tensor_invariant I J)
    by (auto intro!: aux orthogonal_complementI simp add: uminus_ccsubspace.rep_eq tensor_invariant_def ccspan.rep_eq
        simp flip: orthogonal_complement_of_closure orthogonal_complement_of_cspan)
  then have ψ2φ: Proj (tensor_invariant I J) *V (ψ2 s φ) = 0
    by (simp add: kernel_memberD)

  have aux: xa  space_as_set I  y  space_as_set J  φ2 C y  0  is_orthogonal ψ1 xa for xa y
    by (metis Proj_fixes_image φ = φ1 + φ2 φ1J φ1_def add_left_imp_eq cblinfun.real.add_right kernel_Proj kernel_memberI orthogonal_complement_orthoI pth_d uminus_ccsubspace.rep_eq)
  have ψ1 s φ2  space_as_set (- tensor_invariant I J)
    by (auto intro!: aux orthogonal_complementI simp add: uminus_ccsubspace.rep_eq tensor_invariant_def ccspan.rep_eq
        simp flip: orthogonal_complement_of_closure orthogonal_complement_of_cspan)
  then have ψ1φ2: Proj (tensor_invariant I J) *V (ψ1 s φ2) = 0
    by (simp add: kernel_memberD)

  have ψ1φ1: Proj (tensor_invariant I J) *V (ψ1 s φ1) = ψ1 s φ1
    by (auto intro!: Proj_fixes_image space_as_set_ccspan_memberI exI[of _ ψ1] exI[of _ φ1]
        simp: tensor_invariant_def ψ1I φ1J)

  have ProjProj: Proj ((Proj I o Proj J) *S ) = Proj I o Proj J
    by (simp add: Proj_on_own_range' adj_Proj comp_tensor_op tensor_op_adjoint)

  show Proj (tensor_invariant I J) *V (ψ s φ) = Proj ((Proj I o Proj J) *S ) *V (ψ s φ)
    apply (simp add: ProjProj tensor_op_ell2 flip: ψ1_def φ1_def)
    apply (simp add: ψ = ψ1 + ψ2 tensor_ell2_add1 cblinfun.add_right ψ2φ)
    by (simp add: ψ1φ1 ψ1φ2 φ = φ1 + φ2 tensor_ell2_add2 cblinfun.add_right)
qed

lemma tensor_invariant_mono_left: I  I'  tensor_invariant I J  tensor_invariant I' J
  by (auto intro!: space_as_set_mono ccspan_mono simp add: tensor_invariant_def less_eq_ccsubspace.rep_eq)

lemma swap_tensor_invariant[simp]: swap_ell2 *S tensor_invariant I J = tensor_invariant J I
  by (force intro!: arg_cong[where f=ccspan] simp: cblinfun_image_ccspan tensor_invariant_def)

lemma tensor_invariant_SUP_left: tensor_invariant (SUP xX. I x) J = (SUP xX. tensor_invariant (I x) J)
proof (rule order.antisym)
  show (SUP xX. tensor_invariant (I x) J)  tensor_invariant (SUP xX. I x) J
    by (auto intro!: SUP_least tensor_invariant_mono_left SUP_upper)

  have tensor_left_apply: CBlinfun (λx. x s y) *V x = x s y for x :: 'a ell2 and y :: 'b ell2
    by (simp add: bounded_clinear_tensor_ell22 bounded_clinear_CBlinfun_apply clinear_tensor_ell22)

  show tensor_invariant (SUP xX. I x) J  (SUP xX. tensor_invariant (I x) J)
  proof -
    have tensor_invariant (SUP xX. I x) J = ccspan {x s y |x y. x  space_as_set (SUP xX. I x)  y  space_as_set J}
      by (auto simp: tensor_invariant_def)
    also have  = ccspan (yspace_as_set J. {x s y |x. x  space_as_set (SUP xX. I x)})
      by (auto intro!: arg_cong[where f=ccspan])
    also have  = (yspace_as_set J. ccspan {x s y |x. x  space_as_set (SUP xX. I x)})
      by (smt (verit) Sup.SUP_cong ccspan_Sup image_image)
    also have  = (yspace_as_set J. ccspan (cblinfun_apply (CBlinfun (λx. x s y)) ` {x. x  space_as_set (SUP xX. I x)}))
      apply (rule SUP_cong, simp)
      apply (rule arg_cong[where f=ccspan])
      by (auto simp add: image_def tensor_left_apply)
    also have  = (yspace_as_set J. CBlinfun (λx. x s y) *S (SUP xX. I x))
      apply (subst cblinfun_image_ccspan[symmetric])
      by auto
    also have  = (yspace_as_set J. (SUP xX. CBlinfun (λx. x s y) *S I x))
      apply (subst cblinfun_image_SUP)
      by simp
    also have   (xX. tensor_invariant (I x) J)
    proof (rule SUP_least)
      fix y 
      assume y  space_as_set J
      have (CBlinfun (λx. x s y) *S I x)  (tensor_invariant (I x) J) for x
        apply (rule ccsubspace_leI)
        apply (simp add: tensor_invariant_def cblinfun_image.rep_eq ccspan.rep_eq image_def
            tensor_left_apply)
        apply (rule closure_mono)
        by (auto intro!: complex_vector.span_base y  space_as_set J)
      then show (SUP xX. CBlinfun (λx. x s y) *S I x)  (SUP xX. tensor_invariant (I x) J)
        by (auto intro!: SUP_mono)
    qed
    finally show tensor_invariant ( (I ` X)) J  (xX. tensor_invariant (I x) J)
      by -
  qed
qed

lemma tensor_invariant_SUP_right: tensor_invariant I (SUP xX. J x) = (SUP xX. tensor_invariant I (J x))
proof -
  have tensor_invariant I (SUP xX. J x) = swap_ell2 *S tensor_invariant (SUP xX. J x) I
    by simp
  also have  = swap_ell2 *S (SUP xX. tensor_invariant (J x) I)
    by (simp add: tensor_invariant_SUP_left)
  also have  = (SUP xX. swap_ell2 *S tensor_invariant (J x) I)
    using cblinfun_image_SUP by blast
  also have  = (SUP xX. tensor_invariant I (J x))
    by simp
  finally show ?thesis
    by -
qed

lemma tensor_invariant_bot_left[simp]: tensor_invariant  J = 
  using tensor_invariant_SUP_left[where I=id and X={} and J=J]
  by simp

lemma tensor_invariant_bot_right[simp]: tensor_invariant I  = 
  using tensor_invariant_SUP_right[where J=id and X={} and I=I]
  by simp

lemma tensor_invariant_Sup_left: tensor_invariant (Sup II) J = (SUP III. tensor_invariant I J)
  using tensor_invariant_SUP_left[where X=II and I=id and J=J]
  by simp

lemma tensor_invariant_Sup_right: tensor_invariant I (Sup JJ) = (SUP JJJ. tensor_invariant I J)
  using tensor_invariant_SUP_right[where X=JJ and I=I and J=id]
  by simp

lemma tensor_invariant_sup_left: tensor_invariant (I1  I2) J = tensor_invariant I1 J  tensor_invariant I2 J
  using tensor_invariant_Sup_left[where II={I1,I2}]
  by auto

lemma tensor_invariant_sup_right: tensor_invariant I (J1  J2) = tensor_invariant I J1  tensor_invariant I J2
  using tensor_invariant_Sup_right[where JJ={J1,J2}]
  by auto

lemma compatible_register_invariant_compl: compatible_register_invariant F I  compatible_register_invariant F (-I)
  by (simp add: compatible_register_invariant_def Proj_ortho_compl cblinfun_compose_minus_left cblinfun_compose_minus_right)

lemma compatible_register_invariant_SUP:
  assumes [simp]: register F
  assumes compat: x. x  X  compatible_register_invariant F (I x)
  shows compatible_register_invariant F (SUP xX. I x)
proof -
  from register_decomposition[OF register F]
  have let 'd::type = register_decomposition_basis F in ?thesis
  proof with_type_mp
    case with_type_mp
    then obtain U :: ('a × 'd) ell2 CL 'b ell2 
      where [iff]: unitary U and FU: F θ = sandwich U *V (θ o id_cblinfun) for θ
      by auto
    have *: Proj (I x) oCL U oCL (A o id_cblinfun) oCL U* = U oCL (A o id_cblinfun) oCL U* oCL Proj (I x) if x  X for x A
      using compat[OF that]
      by (simp add: compatible_register_invariant_def FU sandwich_apply cblinfun_compose_assoc)
    have (U* oCL Proj (I x) oCL U) oCL (A o id_cblinfun) = (A o id_cblinfun) oCL (U* oCL Proj (I x) oCL U) if x  X for x A
      using *[where A=A, OF that, THEN arg_cong, where f=λx. U* oCL x, THEN arg_cong, where f=λx. x oCL U]
      apply (simp add: cblinfun_compose_assoc)
      by (simp flip: cblinfun_compose_assoc)
    then have Proj (U* *S I x) oCL (A o id_cblinfun) = (A o id_cblinfun) oCL Proj (U* *S I x) if x  X for x A
      using that
      by (simp flip: Proj_sandwich add: sandwich_apply)
    then have Proj (U* *S I x)  commutant (range (λA. A o id_cblinfun)) if x  X for x
      unfolding commutant_def using that by auto
    then have Proj (U* *S I x)  range (λB. id_cblinfun o B) if x  X for x
      by (simp add: commutant_tensor1 that)
    then obtain π where *: Proj (U* *S I x) = id_cblinfun o π x if x  X for x
      apply atomize_elim
      apply (rule choice)
      by (simp add: image_iff)
    have π_proj: is_Proj (π x) if x  X for x
    proof -
      have Proj (U* *S I x)* = Proj (U* *S I x)
        by (simp add: adj_Proj)
      then have (id_cblinfun :: 'a ell2 CL _) o π x = id_cblinfun o π x*
        by (simp add: *[OF that] tensor_op_adjoint)
      then have 1: π x = π x*
        using inj_tensor_right[OF id_cblinfun_not_0] injD by fastforce
      have Proj (U* *S I x) oCL Proj (U* *S I x) = Proj (U* *S I x)
        by simp
      then have (id_cblinfun :: 'a ell2 CL _) o (π x oCL π x) = id_cblinfun o π x
        by (simp add: *[OF that] comp_tensor_op)
      then have 2: π x oCL π x = π x
        using inj_tensor_right[OF id_cblinfun_not_0] injD by fastforce
      from 1 2 show is_Proj (π x)
        by (simp add: is_Proj_I)
    qed
    define σ where σ x = π x *S  for x
    have **: U* *S I x = tensor_invariant  (σ x) if x  X for x
      using *[OF that, THEN arg_cong, where f=λt. t *S ]
      by (simp add: tensor_invariant_via_Proj σ_def Proj_on_own_range π_proj that)
    have sandwich (U*) (Proj (SUP xX. I x)) = Proj (U* *S (SUP xX. I x))
      by (smt (verit) sandwich_apply Proj_lift_invariant Proj_range unitary U cblinfun_compose_image unitary_adj unitary_range unitary_sandwich_register)
    also have  = Proj (SUP xX. U* *S I x)
      by (simp add: cblinfun_image_SUP)
    also have  = Proj (SUP xX. tensor_invariant  (σ x))
      using "**" by auto
    also have  = Proj (tensor_invariant  (SUP xX. σ x))
      by (simp add: tensor_invariant_SUP_right)
    also have  = id_cblinfun o Proj (SUP xX. σ x)
      by (simp add: Proj_on_own_range' adj_Proj comp_tensor_op tensor_invariant_via_Proj tensor_op_adjoint)
    also have   commutant (range (λA. A o id_cblinfun))
      by (simp add: commutant_tensor1)
    finally have (U* oCL Proj (SUP xX. I x) oCL U) oCL (A o id_cblinfun) = (A o id_cblinfun) oCL (U* oCL Proj (SUP xX. I x) oCL U) for A
      by (simp add: sandwich_apply commutant_def)
    from this[THEN arg_cong, where f=λx. U oCL x, THEN arg_cong, where f=λx. x oCL U*]
    have Proj (SUP xX. I x) oCL U oCL (A o id_cblinfun) oCL U* = U oCL (A o id_cblinfun) oCL U* oCL Proj (SUP xX. I x) for A
      apply (simp add: cblinfun_compose_assoc)
      by (simp flip: cblinfun_compose_assoc)
    then have Proj (SUP xX. I x) oCL F A = F A oCL Proj (SUP xX. I x) for A
      by (simp add: FU sandwich_apply cblinfun_compose_assoc)
    then show compatible_register_invariant F (SUP xX. I x)
      by (simp add: compatible_register_invariant_def)
  qed
  from this[cancel_with_type]
  show ?thesis
    by -
qed

lemma compatible_register_invariant_INF:
  assumes [simp]: register F
  assumes compat: x. x  X  compatible_register_invariant F (I x)
  shows compatible_register_invariant F (INF xX. I x)
proof -
  from compat have compatible_register_invariant F (- I x) if x  X for x
    by (simp add: compatible_register_invariant_compl that)
  then have compatible_register_invariant F (SUP xX. - I x)
    by (simp add: compatible_register_invariant_SUP)
  then have compatible_register_invariant F (- (SUP xX. - I x))
    by (simp add: compatible_register_invariant_compl)
  then show compatible_register_invariant F (INF xX. I x)
    by (metis Extra_General.uminus_INF ortho_involution)
qed

lemma compatible_register_invariant_Sup:
  assumes register F
  assumes I. III  compatible_register_invariant F I
  shows compatible_register_invariant F (Sup II)
  using compatible_register_invariant_SUP[where X=II and I=id and F=F] assms by simp

lemma compatible_register_invariant_Inf:
  assumes register F
  assumes I. III  compatible_register_invariant F I
  shows compatible_register_invariant F (Inf II)
  using compatible_register_invariant_INF[where X=II and I=id and F=F] assms by simp

lemma compatible_register_invariant_inter:
  assumes register F
  assumes compatible_register_invariant F I
  assumes compatible_register_invariant F J
  shows compatible_register_invariant F (I  J)
  using compatible_register_invariant_Inf[where II={I,J}]
  using assms by auto

lemma compatible_register_invariant_pair:
  assumes compatible_register_invariant F I
  assumes compatible_register_invariant G I
  shows compatible_register_invariant (F;G) I
proof (cases compatible F G)
  case True
  note this[simp]

  have *: Proj I oCL (F;G) (a o b) = (F;G) (a o b) oCL Proj I for a b
    using assms    
    apply (simp add: register_pair_apply compatible_register_invariant_def)
    by (metis cblinfun_compose_assoc)
  have Proj I oCL (F;G) A = (F;G) A oCL Proj I for A
    apply (rule tensor_extensionality[THEN fun_cong[where x=A]])
    by (auto intro!: comp_preregister[unfolded comp_def, OF _ preregister_mult_left]
        comp_preregister[unfolded comp_def, OF _ preregister_mult_right] * )
  then show ?thesis
    using assms by (auto simp: compatible_register_invariant_def)
next
  case False
  then show ?thesis 
    using [[simproc del: Laws_Quantum.compatibility_warn]]
    by (auto simp: compatible_register_invariant_def register_pair_def compatible_def)
qed

lemma compatible_register_invariant_tensor: 
  assumes [register]: register F register G
  assumes compatible_register_invariant F I
  assumes compatible_register_invariant G J
  shows compatible_register_invariant (F r G) (I S J)
proof -
  have [iff]: preregister (λab. Proj (I S J) oCL (F r G) ab)
    by (auto intro!: comp_preregister[unfolded comp_def, OF _ preregister_mult_left])
  have [iff]: preregister (λab. (F r G) ab oCL Proj (I S J))
    by (auto intro!: comp_preregister[unfolded comp_def, OF _ preregister_mult_right])
  have IF: Proj I oCL F a = F a oCL Proj I for a
    using assms(3) compatible_register_invariant_def by blast
  have JG: Proj J oCL G b = G b oCL Proj J for b
    using assms(4) compatible_register_invariant_def by blast
  have Proj (I S J) oCL (F r G) (a o b) = (F r G) (a o b) oCL Proj (I S J) for a b
    by (simp add: tensor_ccsubspace_via_Proj Proj_on_own_range is_Proj_tensor_op comp_tensor_op IF JG)
  then have (λab. Proj (I S J) oCL (F r G) ab) = (λab. (F r G) ab oCL Proj (I S J))
    apply (rule_tac tensor_extensionality)
    by auto
  then show ?thesis
    unfolding compatible_register_invariant_def
    by meson
qed


lemma compatible_register_invariant_image_shrinks:
  assumes compatible_register_invariant F I
  shows F U *S I  I
proof -
  have F U *S I = (F U oCL Proj I) *S 
    by (simp add: cblinfun_compose_image)
  also have  = (Proj I oCL F U) *S 
    by (metis assms compatible_register_invariant_def)
  also have   Proj I *S 
    by (simp add: Proj_image_leq cblinfun_compose_image)
  also have  = I
    by simp
  finally show ?thesis
    by -
qed

lemma sum_eq_SUP_ccsubspace:
  fixes I :: 'a  'b::complex_normed_vector ccsubspace
  assumes finite X
  shows (xX. I x) = (SUP xX. I x)
  using assms apply induction
  by simp_all

text ‹Variant of @{thm [source] invariant_splitting} (see there) that allows the 
  operation that is applied to depend on the state of some other register.›

lemma inv_split_reg:
  fixes X :: 'x update  'm update ― ‹register containing the index for the unitary›
    and Y :: 'z  'y update  'm update ― ‹register on which the unitary operates›
    and K :: 'z  'm ell2 ccsubspace ― ‹additional invariants›
    and M :: 'z set
(* TODO rename U1 I1 J1 etc to U' etc. for easier instantiation *)
  assumes U1_U: z ψ. zM  ψ  space_as_set (K z)  (Y z (U1 z)) *V ψ = U *V ψ
  assumes pres_I1: z. zM  preserves (U1 z) (I1 z) (J1 z) ε
  assumes I_leq: I  (SUP zM. K z  lift_invariant (Y z) (I1 z))
  assumes J_geq: z. zM  J  K z  lift_invariant (Y z) (J1 z)
  assumes YK: z. zM  compatible_register_invariant (Y z) (K z)
  assumes regY: z. zM  register (Y z)
  assumes orthoK: z z'. zM  z'M  z  z'  orthogonal_spaces (K z) (K z')
  assumes ε  0
  assumes [iff]: finite M
  shows preserves U I J ε
proof -
  show ?thesis
  proof (rule invariant_splitting[where S=K and S'=K and I=λz. K z  lift_invariant (Y z) (I1 z)
          and J=λz. K z  lift_invariant (Y z) (J1 z) and X=M])
    from orthoK
    show orthogonal_spaces (K z) (K z') if zM z'M z  z' for z z'
      using that by simp
    then show orthogonal_spaces (K z) (K z') if zM z'M z  z' for z z'
      using that by -
    show K z  lift_invariant (Y z) (I1 z)  K z for z
      by auto
    show K z  lift_invariant (Y z) (J1 z)  K z for z
      by auto
    show U *S K z  K z if zM for z
    proof -
      from U1_U[OF that]
      have U *S K z = (Y z) (U1 z) *S K z
        apply (rule_tac space_as_set_inject[THEN iffD1])
        by (simp add: cblinfun_image.rep_eq)
      also from YK[OF that] have   K z
        by (simp add: compatible_register_invariant_image_shrinks)
      finally show ?thesis
        by -
    qed
    from I_leq
    show I  (zM. K z  lift_invariant (Y z) (I1 z))
      apply (subst sum_eq_SUP_ccsubspace)
      by auto
    from J_geq
    show (zM. K z  lift_invariant (Y z) (J1 z))  J
      apply (subst sum_eq_SUP_ccsubspace)
      by (auto simp: SUP_le_iff)
    from assms show 0  ε
      by -
    show preserves U (K z  lift_invariant (Y z) (I1 z))
          (K z  lift_invariant (Y z) (J1 z)) ε if zM for z
    proof -
      show ?thesis
      proof (rule preserves_register[where U'=U1 z and I'=I1 z and J'=J1 z and F=Y z and K=K z])
        show preserves (U1 z) (I1 z) (J1 z) ε
          by (simp add: pres_I1[OF that])
        show register (Y z)
          using regY[OF that] by -
        from YK[OF that] show compatible_register_invariant (Y z) (K z)
          by -
        from U1_U[OF that]
        show ψspace_as_set (K z  lift_invariant (Y z) (I1 z)). (Y z) (U1 z) *V ψ = U *V ψ
          by auto
        show K z  lift_invariant (Y z) (I1 z)  lift_invariant (Y z) (I1 z)
          by auto
        show K z  lift_invariant (Y z) (I1 z)  K z
          by simp
        show lift_invariant (Y z) (J1 z)  K z  K z  lift_invariant (Y z) (J1 z)
          using [[simp_trace]]
          by simp
      qed
    qed
    show finite M
      by simp
  qed
qed


lemma Proj_ket_invariant_ket: Proj (ket_invariant X) *V ket i = (if iX then ket i else 0)
proof (cases iX)
  case True
  then have ket i  space_as_set (ket_invariant X)
    by (simp add: ccspan_superset' ket_invariant_def)
  then have Proj (ket_invariant X) *V ket i = ket i
    by (rule Proj_fixes_image)
  also have ket i = (if iX then ket i else 0)
    using True by simp
  finally show ?thesis
    by -
next
  case False
  then have *: ket i  space_as_set (ket_invariant (-X))
    by (simp add: ccspan_superset' ket_invariant_def)
  have Proj (ket_invariant X) *V ket i = (id_cblinfun - Proj (ket_invariant (-X))) *V ket i
    by (simp add: Proj_ortho_compl ket_invariant_compl)
  also have  = ket i - Proj (ket_invariant (-X)) *V ket i
    by (simp add: minus_cblinfun.rep_eq)
  also from * have  = ket i - ket i
    by (simp add: Proj_fixes_image)
  also have  = (if iX then ket i else 0)
    using False by simp
  finally show ?thesis 
    by -
qed

lemma lift_invariant_function_at_ket_inv: lift_invariant (function_at x) (ket_invariant I) = ket_invariant {f. f x  I}
proof -
  have Proj (lift_invariant (function_at x) (ket_invariant I)) = Proj (ket_invariant {f. f x  I})
  proof (rule equal_ket)
    fix f :: 'a  'b
    have Proj (lift_invariant (function_at x) (ket_invariant I)) (ket f) = function_at x (Proj (ket_invariant I)) (ket f)
      by (simp add: Proj_on_own_range lift_invariant_def register_projector)
    also have  = function_at_U x *V Fst (Proj (ket_invariant I)) *V (function_at_U x)* *V ket f
      by (simp add: function_at_def sandwich_apply comp_def)
    also have  = function_at_U x *V Fst (Proj (ket_invariant I)) *V ket (f x, snd (puncture_function x f))
      by (simp flip: puncture_function_split)
    also have  = (if f x  I then function_at_U x *V (ket (f x) s ket (snd (puncture_function x f))) else 0)
      by (auto simp: Fst_def tensor_op_ell2 Proj_ket_invariant_ket simp flip: tensor_ell2_ket)
    also have  = (if f x  I then ket (fix_punctured_function x (f x, snd (puncture_function x f))) else 0)
      by (simp add: tensor_ell2_ket)
    also have  = (if f x  I then ket f else 0)
      by (simp flip: puncture_function_split)
    also have  = Proj (ket_invariant {f. f x  I}) *V ket f
      by (simp add: Proj_ket_invariant_ket)
    finally show Proj (lift_invariant (function_at x) (ket_invariant I)) *V ket f = Proj (ket_invariant {f. f x  I}) *V ket f
      by -
  qed
  then show ?thesis
    by (rule Proj_inj)
qed

lemma ket_invariant_prod: Proj (ket_invariant (A × B)) = Proj (ket_invariant A) o Proj (ket_invariant B)
  apply (rule equal_ket)
  by (auto simp: Proj_ket_invariant_ket tensor_op_ell2 simp flip: tensor_ell2_ket
      split: if_split_asm)

lemma lift_Fst_inv: lift_invariant Fst I = I S 
  apply (rule Proj_inj)
  by (simp add: lift_invariant_def Proj_on_own_range register_projector Fst_def tensor_ccsubspace_via_Proj)
lemma lift_Snd_inv: lift_invariant Snd I =  S I
  apply (rule Proj_inj)
  by (simp add: lift_invariant_def Proj_on_own_range register_projector Snd_def tensor_ccsubspace_via_Proj)

lemma lift_Snd_ket_inv: lift_invariant Snd (ket_invariant I) = ket_invariant (UNIV × I)
  apply (rule Proj_inj)
  apply (simp add: lift_invariant_def Proj_on_own_range register_projector ket_invariant_prod)
  by (simp add: Snd_def)
lemma lift_Fst_ket_inv: lift_invariant Fst (ket_invariant I) = ket_invariant (I × UNIV)
  apply (rule Proj_inj)
  apply (simp add: lift_invariant_def Proj_on_own_range register_projector ket_invariant_prod)
  by (simp add: Fst_def)

lemma lift_inv_prod: 
  assumes [simp]: compatible F G
  shows lift_invariant (F;G) (ket_invariant (I × J)) = 
      lift_invariant F (ket_invariant I)  lift_invariant G (ket_invariant J)
  by (simp add: compatible_proj_intersect lift_invariant_def register_pair_apply ket_invariant_prod)

lemma lift_inv_tensor: 
  assumes [register]: register F register G
  shows lift_invariant (F r G) (ket_invariant (I × J)) = 
      lift_invariant F (ket_invariant I) S lift_invariant G (ket_invariant J)
  by (simp add: lift_invariant_def ket_invariant_prod tensor_ccsubspace_image)


lemma lift_invariant_sup:
  fixes F :: ('a ell2 CL 'a ell2)  ('b ell2 CL 'b ell2)
  assumes [simp]: register F
  shows lift_invariant F (I  J) = lift_invariant F I  lift_invariant F J
proof -
  from register_decomposition[OF register F]
  have let 'c::type = register_decomposition_basis F in ?thesis
  proof with_type_mp
    case with_type_mp
    then obtain U :: ('a × 'c) ell2 CL 'b ell2 
    where unitary U and FU: F θ = sandwich U *V (θ o id_cblinfun) for θ
      by auto
    have lift_F: lift_invariant F K = U *S (Proj (tensor_invariant K )) *S  for K
      using unitary U
      by (simp add: lift_invariant_def FU sandwich_apply cblinfun_compose_image tensor_invariant_via_Proj)
    show lift_invariant F (I  J) = lift_invariant F I  lift_invariant F J
      by (auto simp: lift_F tensor_invariant_sup_left)
  qed
  from this[cancel_with_type]
  show ?thesis
    by -
qed

lemma lift_invariant_SUP:
  fixes F :: ('a ell2 CL 'a ell2)  ('b ell2 CL 'b ell2)
  assumes register F
  shows lift_invariant F (SUP xX. I x) = (SUP xX. lift_invariant F (I x))
proof -
  from register_decomposition[OF register F]
  have let 'd::type = register_decomposition_basis F in ?thesis
  proof with_type_mp
    case with_type_mp
    then obtain U :: ('a × 'd) ell2 CL 'b ell2 
      where unitary U and FU: F θ = sandwich U *V (θ o id_cblinfun) for θ
      by auto
    have lift_F: lift_invariant F K = U *S (Proj (tensor_invariant K )) *S  for K
      using unitary U
      by (simp add: lift_invariant_def FU sandwich_apply cblinfun_compose_image tensor_invariant_via_Proj)
    show lift_invariant F (SUP xX. I x) = (SUP xX. lift_invariant F (I x))
      by (auto simp: lift_F tensor_invariant_SUP_left cblinfun_image_SUP)
  qed
  from this[cancel_with_type]
  show ?thesis
    by -
qed

lemma lift_invariant_compl: lift_invariant R (- U) = - lift_invariant R U if register R
  apply (simp add: lift_invariant_def Proj_ortho_compl)
  by (metis (no_types, lifting) Proj_is_Proj Proj_on_own_range Proj_ortho_compl Proj_range register_minus register_of_id
      register_projector that)

lemma lift_invariant_INF:
  assumes register F
  shows lift_invariant F (xA. I x) = (xA. lift_invariant F (I x))
  using lift_invariant_SUP[OF assms, where I=λx. - I x and X=A]
  by (simp add: lift_invariant_compl assms flip: uminus_INF)

(* TODO move *)
lemma lift_invariant_inf: 
  assumes register F
  shows lift_invariant F (I  J) = lift_invariant F I  lift_invariant F J
  using lift_invariant_INF[where A={False,True} and I=λb. if b then J else I] assms
  by simp


(* TODO move *)
lemma lift_invariant_mono:
  assumes register F
  assumes I  J
  shows lift_invariant F I  lift_invariant F J
  by (metis assms(1,2) inf.absorb_iff2 lift_invariant_inf)


lemma lift_inv_prod':
  fixes F :: ('a ell2 CL 'a ell2)  ('c ell2 CL 'c ell2)
  fixes G :: ('b ell2 CL 'b ell2)  ('c ell2 CL 'c ell2)
  assumes [simp]: compatible F G
  shows lift_invariant (F;G) (ket_invariant I) = 
      (SUP (x,y)I. lift_invariant F (ket_invariant {x})  lift_invariant G (ket_invariant {y}))
  by (simp flip: lift_inv_prod lift_invariant_SUP ket_invariant_SUP)

lemma lift_inv_tensor':
  assumes [register]: register F register G
  shows lift_invariant (F r G) (ket_invariant I) = 
      (SUP (x,y)I. lift_invariant F (ket_invariant {x}) S lift_invariant G (ket_invariant {y}))
  by (simp add: register_tensor_is_register flip: lift_inv_tensor lift_invariant_SUP ket_invariant_SUP)

lemma classical_operator_ket_invariant:
  assumes inj_map f
  shows classical_operator f *S ket_invariant I = ket_invariant (Some -` f ` I)
proof -
  have ccspan ((λx. case f x of None  0 | Some x  ket x) ` I) = (xI. ccspan ((λx. case f x of None  0 | Some x  ket x) ` {x}))
    by (auto intro: arg_cong[where f=ccspan] simp add: SUP_ccspan)
  also have  = (xI. ccspan (ket ` Some -` f ` {x}))
  proof (rule SUP_cong[OF refl])
    fix x
    have [simp]: Some -` {None} = {}
      by fastforce
    have [simp]: Some -` {Some a} = {a} for a
      by fastforce
    show ccspan ((λx. case f x of None  0 | Some x  ket x) ` {x}) = ccspan (ket ` Some -` f ` {x})
      apply (cases f x)
      by auto
  qed
  also have  = ccspan (ket ` Some -` f ` I)
    by (auto intro: arg_cong[where f=ccspan] simp add: SUP_ccspan)
  finally show ?thesis
    by (simp add: ket_invariant_def cblinfun_image_ccspan image_image classical_operator_ket assms
        classical_operator_exists_inj)
qed


lemma Proj_ket_invariant_singleton: Proj (ket_invariant {x}) = selfbutter (ket x)
  by (simp add: ket_invariant_def butterfly_eq_proj)


lemma lift_inv_classical:
  fixes F :: 'a ell2 CL 'a ell2  'b ell2 CL 'b ell2 and f :: 'a × 'c  'b
  assumes [register]: register F
  assumes inj f
  assumes x::'a. x  I  F (selfbutter (ket x)) = sandwich (classical_operator (Some o f)) (selfbutter (ket x) o id_cblinfun)
  shows lift_invariant F (ket_invariant I) = ket_invariant (f ` (I × UNIV))
proof -
  have [iff]: isometry (classical_operator (Some  f))
    by (auto intro!: isometry_classical_operator assms)
  have lift_invariant F (ket_invariant I) = (SUP xI. lift_invariant F (ket_invariant {x}))
    by (simp add: flip: lift_invariant_SUP ket_invariant_SUP)
  also have  = (SUP xI. F (selfbutter (ket x)) *S )
    by (simp add: lift_invariant_def Proj_ket_invariant_singleton)
  also have  = (SUP xI. sandwich (classical_operator (Some o f)) (selfbutter (ket x) o id_cblinfun) *S )
    using assms by force
  also have  = (SUP xI. sandwich (classical_operator (Some o f)) (Proj (ket_invariant ({x} × UNIV))) *S )
    apply (simp add: flip: ket_invariant_tensor)
    by (metis (no_types, lifting) Proj_ket_invariant_singleton Proj_top ket_invariant_UNIV ket_invariant_prod ket_invariant_tensor)
  also have  = (SUP xI. Proj (classical_operator (Some o f) *S ket_invariant ({x} × UNIV)) *S )
    using Proj_sandwich by fastforce
  also have  = (SUP xI. classical_operator (Some o f) *S ket_invariant ({x} × UNIV))
    by auto
  also have  = (SUP xI. ket_invariant (f ` ({x} × UNIV)))
    apply (subst classical_operator_ket_invariant)
     apply (simp add: assms(2)) 
    by (simp add: inj_vimage_image_eq flip: image_image)
  also have  = ket_invariant (xI. f ` ({x} × UNIV))
    by (simp add: ket_invariant_SUP)
  also have  = ket_invariant (f ` (I × UNIV))
    by auto
  finally show ?thesis
    by -
qed

lemma register_image_lift_invariant: 
  assumes register F
  assumes isometry U
  shows F U *S lift_invariant F I = lift_invariant F (U *S I)
proof -
  have F U *S lift_invariant F I = F U *S F (Proj I) *S 
    by (simp add: lift_invariant_def)
  also have  = F U *S F (Proj I) *S (F U)* *S 
    by (simp add: assms(1,2) range_adjoint_isometry register_isometry)
  also have  = F (sandwich U (Proj I)) *S 
    by (smt (verit, best) Proj_lift_invariant Proj_range Proj_sandwich assms(1,2) range_adjoint_isometry
        register_isometry register_sandwich)
  also have  = F (Proj (U *S I)) *S 
    by (simp add: Proj_sandwich assms(2))
  also have  = lift_invariant F (U *S I)
    by (simp add: lift_invariant_def)
  finally show ?thesis
    by -
qed



lemma ell2_sum_ket_ket_invariant:
  fixes ψ :: 'a ell2
  assumes ψ  space_as_set (ket_invariant X)
  shows ψ = (iX. Rep_ell2 ψ i *C ket i)
proof -
  from assms have ψ = Proj (ket_invariant X) *V ψ
    by (simp add: Proj_fixes_image)
  also have  = Proj (ket_invariant X) *V (i. Rep_ell2 ψ i *C ket i)
    by (simp flip: ell2_decompose_infsum)
  also have  = (i. Rep_ell2 ψ i *C Proj (ket_invariant X) *V ket i)
    by (simp flip: infsum_cblinfun_apply add: ell2_decompose_summable cblinfun.scaleC_right)
  also have  = (i. Rep_ell2 ψ i *C (if iX then ket i else 0))
    by (simp add: Proj_ket_invariant_ket)
  also have  = (iX. Rep_ell2 ψ i *C ket i)
    apply (rule infsum_cong_neutral)
    by auto
  finally show ?thesis
    by simp
qed

lemma compatible_register_invariant_Fst_comp:
  fixes I :: ('a × 'b) set
  assumes [simp]: register F
  assumes y. compatible_register_invariant F (ket_invariant ((λx. (x,y)) -` I))
  shows compatible_register_invariant (Fst o F) (ket_invariant I)
  apply (subst asm_rl[of I = (y. ((λx. (x,y)) -` I) × {y})])
   apply fastforce
  apply (simp add: ket_invariant_SUP)
  apply (rule compatible_register_invariant_SUP, simp)
  apply (simp add: compatible_register_invariant_def ket_invariant_prod Fst_def comp_tensor_op)
  by (metis assms compatible_register_invariant_def)

lemma compatible_register_invariant_Fst:
  assumes y. ((λx. (x,y)) -` I) = UNIV  ((λx. (x,y)) -` I) = {}
  shows compatible_register_invariant Fst (ket_invariant I)
  apply (subst asm_rl[of Fst = Fst o id], simp)
  apply (rule compatible_register_invariant_Fst_comp, simp)
  using assms by (rule compatible_register_invariant_id)

lemma compatible_register_invariant_Snd_comp:
  fixes I :: ('a × 'b) set
  assumes [simp]: register F
  assumes x. compatible_register_invariant F (ket_invariant ((λy. (x,y)) -` I))
  shows compatible_register_invariant (Snd o F) (ket_invariant I)
  apply (subst asm_rl[of I = (x. {x} × ((λy. (x,y)) -` I))])
   apply fastforce
  apply (simp add: ket_invariant_SUP)
  apply (rule compatible_register_invariant_SUP, simp)
  apply (simp add: compatible_register_invariant_def ket_invariant_prod Snd_def comp_tensor_op)
  by (metis assms compatible_register_invariant_def)

lemma compatible_register_invariant_Snd:
  assumes x. ((λy. (x,y)) -` I) = UNIV  ((λy. (x,y)) -` I) = {}
  shows compatible_register_invariant Snd (ket_invariant I)
  apply (subst asm_rl[of Snd = Snd o id], simp)
  apply (rule compatible_register_invariant_Snd_comp, simp)
  using assms by (rule compatible_register_invariant_id)

lemma compatible_register_invariant_Fst_tensor[simp]:
  shows compatible_register_invariant Fst ( S I)
  by (simp add: compatible_register_invariant_def Fst_def Proj_on_own_range comp_tensor_op is_Proj_tensor_op tensor_ccsubspace_via_Proj)

lemma compatible_register_invariant_Snd_tensor[simp]:
  shows compatible_register_invariant Snd (I S )
  by (simp add: compatible_register_invariant_def Snd_def Proj_on_own_range comp_tensor_op is_Proj_tensor_op tensor_ccsubspace_via_Proj)

lemma compatible_register_invariant_sandwich_comp:
  fixes U :: 'a ell2 CL 'b ell2
  assumes [simp]: unitary U
  assumes compatible_register_invariant F (U* *S I)
  shows compatible_register_invariant (sandwich U o F) I
  apply (subst asm_rl[of I = U *S U* *S I])
   apply (simp add: cblinfun_assoc_left(2))
  using assms 
  by (simp add: compatible_register_invariant_def unitary_sandwich_register register_mult
      flip: Proj_sandwich[of U])

lemma compatible_register_invariant_function_at_comp:
  assumes [simp]: register F
  assumes z. compatible_register_invariant F (ket_invariant {f x |f. f  I  z(x := undefined) = f(x := undefined)})
  shows compatible_register_invariant (function_at x o F) (ket_invariant I)
proof -
  have (λa. (a, snd (puncture_function x z))) -` Some -` inv_map (Some  fix_punctured_function x) ` I 
        = (λa. (a, snd (puncture_function x z))) -` puncture_function x ` I (is ?lhs = _) for z
    by (simp add: inv_map_total bij_fix_punctured_function bij_is_surj inj_vimage_image_eq
        flip: image_image)
  also have  z = {f x | f. fI  snd (puncture_function x z) = snd (puncture_function x f)} for z
    apply (transfer fixing: I x)
    by auto
  also have  z = {f x | f. fI  z(x:=undefined) = f(x:=undefined)} for z
  proof -
    have aux: f  I 
         z(x := undefined)  Transposition.transpose x undefined =
         f(x := undefined)  Transposition.transpose x undefined 
         fa. f x = fa x  fa  I  z(x := undefined) = fa(x := undefined) for f
      by (metis swap_nilpotent)
    show ?thesis
      apply (transfer fixing: z x I)
      using aux by (auto simp: fun_upd_comp_left)
  qed
  finally have compatible_register_invariant F (ket_invariant ((λa. (a, snd (puncture_function x z))) -` Some -` inv_map (Some  fix_punctured_function x) ` I)) for z
    by (simp add: assms)
  then have *: compatible_register_invariant F (ket_invariant ((λa. (a, y)) -` Some -` inv_map (Some  fix_punctured_function x) ` I)) for y
    by (metis fix_punctured_function_inverse snd_conv)
  show ?thesis
    unfolding function_at_def function_at_U_def Let_def comp_assoc
    apply (rule compatible_register_invariant_sandwich_comp)
     apply (simp add: bij_fix_punctured_function)
    apply (subst classical_operator_adjoint)
     apply (simp add: bij_fix_punctured_function bij_is_inj)
    apply (subst classical_operator_ket_invariant)
     apply (simp add: bij_fix_punctured_function bij_is_inj)
    apply (rule compatible_register_invariant_Fst_comp, simp)
    using * by simp
qed

lemma compatible_register_invariant_function_at:
  assumes f y. fI  f(x:=y)  I
  shows compatible_register_invariant (function_at x) (ket_invariant I)
  apply (subst asm_rl[of function_at x = function_at x o id], simp)
  apply (rule compatible_register_invariant_function_at_comp, simp)
  apply (rule compatible_register_invariant_id)
  using assms fun_upd_idem_iff by fastforce

text ‹The following lemma allows show that an invariant is preserved across several consecutive
  operations. Usually, termnorm V and termnorm U  1, so the lemma essentially says that
  the errors are additive.›

lemma preserves_trans[trans]:
  assumes presU: preserves U I J ε
  assumes presV: preserves V J K δ
  shows preserves (V oCL U) I K (norm V * ε + norm U * δ)
proof -
  have norm ((id_cblinfun - Proj K) oCL (V oCL U) oCL Proj I)
    = norm ((id_cblinfun - Proj K) oCL V oCL (Proj J + (id_cblinfun - Proj J)) oCL U oCL Proj I)
    by (auto simp add: cblinfun_assoc_left(1))
  also have   norm ((id_cblinfun - Proj K) oCL V oCL Proj J oCL U oCL Proj I)
                + norm ((id_cblinfun - Proj K) oCL V oCL (id_cblinfun - Proj J) oCL U oCL Proj I)
    by (smt (verit) cblinfun_compose_add_left cblinfun_compose_add_right norm_triangle_ineq)
  also have   norm ((id_cblinfun - Proj K) oCL V oCL Proj J oCL U oCL Proj I) + norm V * ε
  proof -
    have norm ((id_cblinfun - Proj K) oCL V oCL (id_cblinfun - Proj J) oCL U oCL Proj I)
        norm (id_cblinfun - Proj K) * norm (V oCL (id_cblinfun - Proj J) oCL U oCL Proj I)
      by (metis cblinfun_assoc_left(1) norm_cblinfun_compose)
    also have   norm (V oCL (id_cblinfun - Proj J) oCL U oCL Proj I)
      by (metis Groups.mult_ac(2) Proj_ortho_compl mult.right_neutral mult_left_mono norm_Proj_leq1 norm_ge_zero)
    also have   norm V * norm ((id_cblinfun - Proj J) oCL U oCL Proj I)
      by (metis cblinfun_assoc_left(1) norm_cblinfun_compose)
    also have   norm V * ε
      by (meson norm_ge_zero ordered_comm_semiring_class.comm_mult_left_mono presU preserves_onorm)
    finally show ?thesis
      by (rule add_left_mono)
  qed
  also have   norm ((id_cblinfun - Proj K) oCL V oCL Proj J oCL U) * norm (Proj I) + norm V * ε
    by (simp add: norm_cblinfun_compose)
  also have   norm ((id_cblinfun - Proj K) oCL V oCL Proj J oCL U) + norm V * ε
    by (simp add: norm_is_Proj mult.commute mult_left_le_one_le)
  also have   norm ((id_cblinfun - Proj K) oCL V oCL Proj J) * norm U + norm V * ε
    by (simp add: norm_cblinfun_compose)
  also have   norm U * δ + norm V * ε
    by (metis add.commute add_le_cancel_left mult.commute mult_left_mono norm_ge_zero presV preserves_onorm)
  finally show ?thesis
    by (simp add: preserves_onorm)
qed

text ‹An operation that operates on a register that is outside the invariant preserves the invariant perfectly.›

lemma preserves_compatible: 
  assumes compat: compatible_register_invariant F I
  assumes ε  0
  shows preserves (F U) I I ε
proof (rule preservesI')
  from assms show ε  0 by -
  fix ψ assume ψ  space_as_set I
  then have ψI: ψ = Proj I *V ψ
    using Proj_fixes_image by force
  from compat have FI: F U *V Proj I *V ψ = Proj I *V F U *V ψ
    by (metis cblinfun_apply_cblinfun_compose compatible_register_invariant_def)
  have Proj (- I) *V F U *V ψ = 0
    apply (subst ψI) apply (subst FI)
    by (metis FI Proj_ortho_compl ψI cancel_comm_monoid_add_class.diff_cancel cblinfun.diff_left id_cblinfun_apply)
  with ε  0 show norm (Proj (- I) *V F U *V ψ)  ε
    by simp
qed

lemma Proj_ket_invariant_butterfly: Proj (ket_invariant {x}) = selfbutter (ket x)
  by (simp add: butterfly_eq_proj ket_invariant_def)

lemma ket_in_ket_invariantI: ket x  space_as_set (ket_invariant I) if x  I
  by (metis Proj_ket_invariant_ket Proj_range cblinfun_apply_in_image that)

lemma cblinfun_image_ket_invariant_leqI:
  assumes x. x  I  U *V ket x  space_as_set J
  shows U *S ket_invariant I  J
  by (simp add: assms cblinfun_image_ccspan ccspan_leqI image_subset_iff ket_invariant_def)

lemma preserves0I: preserves U I J 0  U *S I  J
proof
  have (id_cblinfun - Proj J) oCL U oCL Proj I = 0  U *S I  J
    by (metis (no_types, lifting) Proj_range add_diff_cancel_left' cblinfun_assoc_left(2) cblinfun_compose_minus_left cblinfun_compose_id_left cblinfun_image_mono diff_add_cancel diff_zero top_greatest)
  then show preserves U I J 0  U *S I  J
    by (auto simp: preserves_onorm)
next
  assume leq: U *S I  J
  show preserves U I J 0
  proof (rule preservesI)
    show 0  (0::real) by simp
    fix ψ
    assume ψ  space_as_set I
    with leq have U *V ψ  space_as_set J
      by (metis (no_types, lifting) Proj_fixes_image Proj_range cblinfun_apply_cblinfun_compose cblinfun_apply_in_image cblinfun_compose_image less_eq_ccsubspace.rep_eq subset_iff)
    then have Proj J *V U *V ψ = U *V ψ
      by (simp add: Proj_fixes_image)
    then show norm (U *V ψ - Proj J *V U *V ψ)  0
      by simp
  qed
qed

lemma lift_invariant_id[simp]: lift_invariant id I = I
  by (simp add: lift_invariant_def)

lemma lift_invariant_pair_tensor:
  assumes compatible X Y
  shows lift_invariant (X;Y) (I S J) = lift_invariant X I  lift_invariant Y J
proof -
  have lift_invariant (X;Y) (I S J) = (X;Y) (Proj (I S J)) *S 
    by (simp add: lift_invariant_def)
  also have  = (X;Y) (Proj I o Proj J) *S 
    by (simp add: Proj_on_own_range is_Proj_tensor_op tensor_ccsubspace_via_Proj)
  also have  = (X (Proj I) oCL Y (Proj J)) *S 
    by (simp add: Laws_Quantum.register_pair_apply assms)
  also have  = lift_invariant X I  lift_invariant Y J
    by (simp add: assms compatible_proj_intersect lift_invariant_def)
  finally show ?thesis
    by -
qed

lemma lift_invariant_tensor_tensor:
  assumes [register]: register X register Y
  shows lift_invariant (X r Y) (I S J) = lift_invariant X I S lift_invariant Y J
proof -
  have lift_invariant (X r Y) (I S J) = (X r Y) (Proj (I S J)) *S 
    by (simp add: lift_invariant_def)
  also have  = (X r Y) (Proj I o Proj J) *S 
    by (simp add: Proj_on_own_range is_Proj_tensor_op tensor_ccsubspace_via_Proj)
  also have  = (X (Proj I) o Y (Proj J)) *S 
    by (simp add: Laws_Quantum.register_pair_apply assms register_tensor_apply)
  also have  = lift_invariant X I S lift_invariant Y J
    by (simp add: lift_invariant_def tensor_ccsubspace_image)
  finally show ?thesis
    by -
qed

lemma orthogonal_spaces_lift_invariant[simp]: 
  assumes register Q
  shows orthogonal_spaces (lift_invariant Q S) (lift_invariant Q T)  orthogonal_spaces S T
proof -
  have orthogonal_spaces (lift_invariant Q S) (lift_invariant Q T)  Q (Proj S) oCL Q (Proj T) = 0
    by (simp add: orthogonal_projectors_orthogonal_spaces lift_invariant_def Proj_on_own_range assms register_projector)
  also have   Proj S oCL Proj T = 0
    by (metis (no_types, lifting) assms norm_eq_zero register_mult register_norm)
  also have   orthogonal_spaces S T
    by (simp add: orthogonal_projectors_orthogonal_spaces)
  finally show ?thesis
    by -
qed



subsection ‹Distance from invariants›

definition dist_inv where dist_inv R I ψ = norm (R (Proj (-I)) *V ψ)
  for R :: ('a ell2 CL 'a ell2)  ('b ell2 CL 'b ell2)
definition dist_inv_avg where dist_inv_avg R I ψ = sqrt ((xUNIV. (dist_inv R (I x) (ψ x))2) / CARD('x)) for ψ :: 'x::finite  _

lemma dist_inv_pos[iff]: dist_inv R I ψ  0
  by (simp add: dist_inv_def)
lemma dist_inv_avg_pos[iff]: dist_inv_avg R I ψ  0
  by (simp add: dist_inv_avg_def sum_nonneg)

lemma dist_inv_0_iff:
  assumes register R
  shows dist_inv R I ψ = 0  ψ  space_as_set (lift_invariant R I)
proof -
  have dist_inv R I ψ = 0  R (Proj (- I)) *V ψ = 0
    by (simp add: dist_inv_def)
  also have   Proj (R (Proj (- I)) *S ) ψ = 0
  by (simp add: Proj_on_own_range assms register_projector)
  also have   ψ  space_as_set (- (R (Proj (- I)) *S ))
    using Proj_0_compl kernel_memberI by fastforce
  also have   ψ  space_as_set (- lift_invariant R (-I))
  by (simp add: lift_invariant_def)
  also have   ψ  space_as_set (lift_invariant R I)
    by (metis (no_types, lifting) Proj_lift_invariant Proj_ortho_compl Proj_range assms
        ortho_involution register_minus register_of_id)
  finally show ?thesis
    by -
qed

lemma dist_inv_avg_0_iff:
  assumes register R
  shows dist_inv_avg R I ψ = 0  (h. ψ h  space_as_set (lift_invariant R (I h)))
proof -
  have dist_inv_avg R I ψ = 0  (h. (dist_inv R (I h) (ψ h))2 = 0)
    by (simp add: dist_inv_avg_def sum_nonneg_eq_0_iff)
  also have   (h. ψ h  space_as_set (lift_invariant R (I h)))
    by (simp add: assms dist_inv_0_iff)
  finally show ?thesis
    by -
qed

lemma dist_inv_mono:
  assumes I  J
  assumes [register]: register Q
  shows dist_inv Q J ψ  dist_inv Q I ψ
proof -
  from assms
  have ProjJI: Proj (-J) = Proj (-J) oCL Proj (-I)
    by (simp add: Proj_o_Proj_subspace_left)
  have norm (Q (Proj (- J) oCL Proj (- I)) *V ψ)  norm (Q (Proj (- I)) *V ψ)
  by (metis Proj_is_Proj assms(2) is_Proj_reduces_norm register_mult'
      register_projector)
  then show ?thesis
    by (simp add: dist_inv_def flip: ProjJI)
qed


lemma dist_inv_avg_mono:
  assumes h. I h  J h
  assumes [register]: register Q
  shows dist_inv_avg Q J ψ  dist_inv_avg Q I ψ
  by (auto intro!: sum_mono divide_right_mono dist_inv_mono assms
      simp: dist_inv_avg_def)

lemma dist_inv_Fst_tensor:
  assumes norm φ = 1
  shows dist_inv (Fst o R) I (ψ s φ) = dist_inv R I ψ
proof -
  have (norm (Fst (R (Proj (- I))) *V ψ s φ))2 = (norm (R (Proj (- I)) *V ψ))2
    by (simp add: Fst_def tensor_op_ell2 norm_tensor_ell2 assms)
  then show ?thesis
    by (simp add: dist_inv_def)
qed

lemma dist_inv_avg_Fst_tensor:
  assumes h. norm (φ h) = 1
  shows dist_inv_avg (Fst o R) I (λh. ψ h s φ h) = dist_inv_avg R I ψ
  by (simp add: assms dist_inv_avg_def dist_inv_Fst_tensor)

lemma dist_inv_register_rewrite:
  assumes register Q and register R
  assumes lift_invariant Q I = lift_invariant R J
  shows dist_inv Q I ψ = dist_inv R J ψ
proof -
  from assms
  have  lift_invariant Q (-I) = lift_invariant R (-J)
    by (simp add: lift_invariant_compl)
  then have Proj (Q (Proj (-I)) *S ) = Proj (R (Proj (-J)) *S )
    by (simp add: lift_invariant_def)
  then have R (Proj (- J)) = Q (Proj (- I))
    by (metis Proj_lift_invariant assms lift_invariant_def)
  with assms
  show ?thesis
    by (simp add: dist_inv_def)
qed



lemma dist_inv_avg_register_rewrite:
  assumes register Q and register R
  assumes h. lift_invariant Q (I h) = lift_invariant R (J h)
  shows dist_inv_avg Q I ψ = dist_inv_avg R J ψ
  using assms by (auto intro!: dist_inv_register_rewrite sum.cong simp add: dist_inv_avg_def)

lemma distance_from_inv_avg0I: 
  dist_inv_avg Q I ψ = 0  (h. dist_inv Q (I h) (ψ h) = 0) for h :: 'h::finite and ψ :: 'h  _
  by (simp add: dist_inv_avg_def sum_nonneg_eq_0_iff)

lemma dist_inv_apply:
  assumes [register]: register Q register S
  assumes [iff]: unitary U
  assumes QSR: Q o S = R
  shows dist_inv Q I (R U *V ψ) = dist_inv Q (S U* *S I) ψ
proof -
  have norm (Q (Proj (- I)) *V R U *V ψ) = norm (Q (Proj (- (S U* *S I))) *V ψ)
  proof -
    have norm (Q (Proj (- I)) *V R U *V ψ) = norm (Q (S U)* *V Q (Proj (- I)) *V Q (S U) *V ψ)
      by (metis assms(1,2,3,4) isometry_preserves_norm o_def register_unitary unitary_twosided_isometry)
    also have  = norm (sandwich (Q (S U)*) (Q (Proj (-I))) *V ψ)
      by (simp add: sandwich_apply)
    also have  = norm (Q (sandwich (S U*) *V Proj (- I)) *V ψ)
      by (simp add: flip: register_sandwich register_adj)
    also have  = norm (Q (Proj (S U* *S - I)) *V ψ)
      by (simp add: Proj_sandwich register_coisometry)
    also have  = norm (Q (Proj (- (S U* *S I))) *V ψ)
      by (simp add: unitary_image_ortho_compl register_unitary)
    finally show ?thesis
      by -
  qed
  then show ?thesis
    by (simp add: dist_inv_def)
qed



(* TODO remove → dist_inv_apply *)
lemma dist_inv_apply_iff:
  assumes [register]: register Q
  assumes [iff]: unitary U
  shows dist_inv Q I (Q U *V ψ) = dist_inv Q (U* *S I) ψ
  apply (subst dist_inv_apply[where S=id])
  by auto


lemma dist_inv_avg_apply:
  assumes [register]: register Q register S
  assumes [iff]: h. unitary (U h)
  assumes Q o S = R
  shows dist_inv_avg Q I (λh. R (U h) *V ψ h) = dist_inv_avg Q (λh. S (U h)* *S I h) ψ
  using assms by (auto intro!: sum.cong simp: dist_inv_avg_def dist_inv_apply[where S=S])

(* TODO remove → dist_inv_avg_apply *)
lemma dist_inv_avg_apply_iff:
  assumes [register]: register Q
  assumes [iff]: h. unitary (U h)
  shows dist_inv_avg Q I (λh. Q (U h) *V ψ h) = dist_inv_avg Q (λh. U h* *S I h) ψ
  by (auto intro!: sum.cong dist_inv_apply_iff simp: dist_inv_avg_def)


lemma dist_inv_intersect_onesided:
  assumes compatible_invariants I J
  assumes register Q
  assumes dist_inv Q I ψ = 0
  shows dist_inv Q (J  I) ψ = dist_inv Q J ψ
proof -
  have inside: ψ  space_as_set (lift_invariant Q I)
    using assms(2,3) dist_inv_0_iff by blast
  have norm (Q (Proj (- (J  I))) *V ψ) = norm (ψ - Q (Proj (J  I)) *V ψ)
    by (metis (no_types, lifting) Proj_ortho_compl assms(2) cblinfun.diff_left id_cblinfun.rep_eq register_minus
        register_of_id)
  also have  = norm (ψ - Q (Proj (J) oCL Proj (I)) *V ψ)
    by (metis assms compatible_invariants_def compatible_invariants_inter)
  also have  = norm (ψ - Q (Proj (J)) *V Q (Proj (I)) *V ψ)
    by (simp add: assms register_mult')
  also have  = norm (ψ - Q (Proj (J)) *V ψ)
    by (metis Proj_fixes_image Proj_lift_invariant assms inside)
  also have  = norm (Q (Proj (- J)) *V ψ)
    by (simp add: Proj_ortho_compl assms cblinfun.diff_left register_minus)
  finally have norm (Q (Proj (- (J  I))) *V ψ) = norm (Q (Proj (- J)) *V ψ)
    by -
  then show ?thesis
    by (simp add: dist_inv_def)
qed



lemma dist_inv_avg_intersect:
  assumes h. compatible_invariants (I h) (J h)
  assumes register Q
  assumes dist_inv_avg Q I ψ = 0
  shows dist_inv_avg Q (λh. J h  I h) ψ = dist_inv_avg Q J ψ
proof -
  have dist_inv Q (I h) (ψ h) = 0 for h
    using assms(3) distance_from_inv_avg0I by blast
  then show ?thesis
    by (auto intro!: sum.cong dist_inv_intersect_onesided assms simp: dist_inv_avg_def)
qed

lemma dist_inv_avg_const: dist_inv_avg Q (λ_. I) (λ_. ψ) = dist_inv Q I ψ
  by (simp add: dist_inv_avg_def dist_inv_def)

(* TODO move *)
lemma register_plus:
  assumes register Q
  shows Q (a + b) = Q a + Q b
  by (simp add: assms clinear_register complex_vector.linear_add)

(* TODO move *)
lemma compatible_invariants_uminus_left[simp]: compatible_invariants (-I) J  compatible_invariants I J
  by (simp add: Proj_ortho_compl cblinfun_compose_minus_left cblinfun_compose_minus_right compatible_invariants_def)

(* TODO move *)
lemma compatible_invariants_uminus_right[simp]: compatible_invariants I (-J)  compatible_invariants I J
  by (simp add: Proj_ortho_compl cblinfun_compose_minus_left cblinfun_compose_minus_right compatible_invariants_def)

lemma compatible_invariants_sup: Proj (A  B) = Proj A + Proj B - Proj A oCL Proj B if compatible_invariants A B
  apply (rewrite at A  B to - (-A  -B) DEADID.rel_mono_strong)
   apply simp
  apply (subst Proj_ortho_compl)
  by (simp add: that Proj_ortho_compl cblinfun_compose_minus_left cblinfun_compose_minus_right flip: compatible_invariants_inter )

lemma compatible_invariants_sym: compatible_invariants S T  compatible_invariants T S
  by (metis compatible_invariants_def)

lemma compatible_invariants_refl[iff]: compatible_invariants S S
  by (metis compatible_invariants_def)

lemma compatible_invariants_infI:
  assumes [iff]: compatible_invariants S U
  assumes [iff]: compatible_invariants S T
  assumes [iff]: compatible_invariants T U
  shows compatible_invariants S (T  U)
  by (smt (verit, del_insts) assms(1,2,3) cblinfun_compose_assoc compatible_invariants_def compatible_invariants_inter)


lemma compatible_invariants_supI:
  assumes [iff]: compatible_invariants S U
  assumes [iff]: compatible_invariants S T
  assumes [iff]: compatible_invariants T U
  shows compatible_invariants S (T  U)
  apply (rewrite at T  U to - (-T  -U) DEADID.rel_mono_strong)
   apply simp
  by (auto intro!: compatible_invariants_infI simp del: compl_inf)

lemma compatible_invariants_inf_sup_distrib1:
  fixes S T U :: 'a::chilbert_space ccsubspace
  assumes compatible_invariants S U
  assumes compatible_invariants S T
  assumes compatible_invariants T U
  shows S  (T  U) = (S  T)  (S  U)
proof -
  have [iff]: compatible_invariants (S  T) (S  U)
    using assms by (auto intro!: compatible_invariants_infI simp: compatible_invariants_sym)
  have Proj (S  (T  U)) = Proj ((S  T)  (S  U))
    apply (simp add: assms compatible_invariants_sup compatible_invariants_supI flip: compatible_invariants_inter)
    by (metis (no_types, lifting) Proj_idempotent assms(2) cblinfun_compose_add_right cblinfun_compose_assoc cblinfun_compose_minus_right
        compatible_invariants_def)
  then show ?thesis
    using Proj_inj by blast
qed

lemma compatible_invariants_inf_sup_distrib2:
  fixes S T U :: 'a::chilbert_space ccsubspace
  assumes [iff]: compatible_invariants S U
  assumes [iff]: compatible_invariants S T
  assumes [iff]: compatible_invariants T U
  shows (T  U)  S = (T  S)  (U  S)
  by (simp add: compatible_invariants_inf_sup_distrib1 inf_commute)

lemma compatible_invariants_sup_inf_distrib1:
  fixes S T U :: 'a::chilbert_space ccsubspace
  assumes compatible_invariants S U
  assumes compatible_invariants S T
  assumes compatible_invariants T U
  shows S  (T  U) = (S  T)  (S  U)
  by (smt (verit, ccfv_SIG) Groups.add_ac(1) assms(1,2,3) compatible_invariants_def compatible_invariants_inf_sup_distrib1
      compatible_invariants_supI inf_commute inf_sup_absorb plus_ccsubspace_def)

lemma compatible_invariants_sup_inf_distrib2:
  fixes S T U :: 'a::chilbert_space ccsubspace
  assumes compatible_invariants S U
  assumes compatible_invariants S T
  assumes compatible_invariants T U
  shows (T  U)  S = (T  S)  (U  S)
  by (metis Groups.add_ac(2) assms(1,2,3) compatible_invariants_sup_inf_distrib1 plus_ccsubspace_def)

(* TODO move *)
lemma is_orthogonal_Proj_orthogonal_spaces:
  assumes orthogonal_spaces S T
  shows is_orthogonal (Proj S *V ψ) (Proj T *V ψ)
  by (metis Proj_range assms cblinfun_apply_in_image orthogonal_spaces_def)

lemma dist_inv_intersect:
  assumes [register]: register Q
  assumes [iff]: compatible_invariants I J
  shows dist_inv Q (I  J) ψ  sqrt ((dist_inv Q I ψ)2 + (dist_inv Q J ψ)2)
proof -
  define PInJ PJnI PnInJ PnI PnJ PnIJ where PInJ = Q (Proj (I  - J))
    and PJnI = Q (Proj (-I  J)) and PnInJ = Q (Proj (-I  -J))
    and PnI = Q (Proj (-I)) and PnJ = Q (Proj (-J))
    and PnIJ = Q (Proj (- (I  J)))

  have compat1: compatible_invariants (I  - J) J
    by (metis Proj_o_Proj_subspace_left Proj_o_Proj_subspace_right compatible_invariants_def compatible_invariants_uminus_right inf_le2)
  have compat2: compatible_invariants (I  - J) I
    by (simp add: Proj_o_Proj_subspace_left Proj_o_Proj_subspace_right compatible_invariants_def)

  have ortho1: orthogonal_spaces (I  - J) (- I  J)
    by (simp add: le_infI2 orthogonal_spaces_leq_compl)
  have ortho2: orthogonal_spaces (I  - J  - I  J) (- I  - J)
    by (metis inf_le1 inf_le2 ortho_involution orthocomplemented_lattice_class.compl_sup orthogonal_spaces_leq_compl sup.mono)
  have ortho3: orthogonal_spaces (- I  J) (- I  - J)
    by (simp add: le_infI2 orthogonal_spaces_leq_compl) 
  have ortho4: orthogonal_spaces (I  - J) (- I  - J)
    by (metis inf_sup_absorb le_infI2 ortho2 orthogonal_spaces_leq_compl)
  have ortho5: is_orthogonal (PInJ ψ) (PnInJ ψ)
    using ortho4 by (auto intro!: is_orthogonal_Proj_orthogonal_spaces simp: PInJ_def PnInJ_def simp flip: Proj_lift_invariant)
  have ortho6: is_orthogonal (PJnI ψ) (PnInJ ψ)
    using ortho3 by (auto intro!: is_orthogonal_Proj_orthogonal_spaces simp: PJnI_def PnInJ_def simp flip: Proj_lift_invariant)
  have ortho7: is_orthogonal (PInJ ψ) (PJnI ψ)
    using ortho1 by (auto intro!: is_orthogonal_Proj_orthogonal_spaces simp: PJnI_def PInJ_def simp flip: Proj_lift_invariant)

  have nI: -I  J  -I  -J = -I
    by (simp flip: compatible_invariants_inf_sup_distrib1)
  then have PnI_decomp: PnI = PJnI + PnInJ
    by (simp add: PnI_def PJnI_def PnInJ_def register_inj' ortho3
        flip: register_plus Proj_sup)

  have nJ: I  -J  -I  -J = -J
    by (metis (no_types, lifting) assms(2) compatible_invariants_inf_sup_distrib1 compatible_invariants_refl compatible_invariants_sym
        compatible_invariants_uminus_left complemented_lattice_class.sup_compl_top inf_aci(1) inf_top.comm_neutral)
  then have PnJ_decomp: PnJ = PInJ + PnInJ
    by (simp add: PnJ_def PInJ_def PnInJ_def register_inj' ortho4
        flip: register_plus Proj_sup)

  have I  - J  - I  J  - I  - J = - I  - J
    by (metis (no_types, lifting) Groups.add_ac(1) boolean_algebra_cancel.sup2 nI nJ plus_ccsubspace_def sup_inf_absorb)
  then have PnIJ_decomp: PnIJ = PInJ + PJnI + PnInJ
    by (simp add: PnIJ_def PInJ_def PJnI_def PnInJ_def register_inj' ortho1 ortho2
        flip: register_plus Proj_sup)

  have (dist_inv Q (I  J) ψ)2 = (norm (PnIJ ψ))2
    by (simp add: PnIJ_def dist_inv_def)
  also have  = (norm (PInJ *V ψ))2 + (norm (PJnI *V ψ))2 + (norm (PnInJ *V ψ))2
    by (simp add: PnIJ_decomp cblinfun.add_left pythagorean_theorem cinner_add_left ortho5 ortho6 ortho7)
  also have   ((norm (PJnI *V ψ))2 + (norm (PnInJ *V ψ))2) + ((norm (PInJ *V ψ))2 + (norm (PnInJ *V ψ))2)
    by simp
  also have  = (norm (PnI ψ))2 + (norm (PnJ ψ))2
    by (simp add: ortho5 ortho6 PnI_decomp PnJ_decomp cblinfun.add_left pythagorean_theorem)
  also have   (dist_inv Q I ψ)2 + (dist_inv Q J ψ)2
    apply (rule add_mono)
    using assms
    by (simp_all add: PnI_def PnJ_def dist_inv_def)

  finally show ?thesis
    using real_le_rsqrt by presburger
qed


subsection ‹Preservation of invariants›


(* TODO move stuff from above *)

lemma preserves_lift_invariant:
  assumes [register]: register Q
  shows preserves (Q U) (lift_invariant Q I) (lift_invariant Q J) ε  preserves U I J ε
  using register_minus[OF assms, of id_cblinfun, symmetric]
  by (simp add: preserves_onorm Proj_lift_invariant register_mult register_norm)


lemma dist_inv_leq_if_preserves:
  assumes pres: preserves U (lift_invariant S J) (lift_invariant R I) γ
  assumes [register]: register S register R
  shows dist_inv R I (U *V ψ)  norm U * dist_inv S J ψ + γ * norm ψ
proof -
  note [[simproc del: Laws_Quantum.compatibility_warn]]
  define ψgood ψbad where ψgood = S (Proj J) *V ψ and ψbad = S (Proj (- J)) *V ψ
  define ψ' ψ'good ψ'bad where ψ' = U ψgood and ψ'good = R (Proj I) ψ' and ψ'bad = R (Proj (-I)) ψ'
  from pres have γ  0
    using preserves_def by blast
  have ψ_decomp: ψ = ψgood + ψbad
    by (simp add: ψgood_def ψbad_def Proj_ortho_compl register_minus flip: cblinfun.add_left)
  have ψ'_decomp: ψ' = ψ'good + ψ'bad
    by (simp add: ψ'good_def ψ'bad_def Proj_ortho_compl register_minus flip: cblinfun.add_left)
  define δ where δ = dist_inv S J ψ
  then have ψbad_bound: norm ψbad  δ
    unfolding dist_inv_def ψbad_def by blast
  have ψgood  space_as_set (lift_invariant S J)
    by (simp add: ψgood_def lift_invariant_def)
  with pres have norm (ψ' - Proj (lift_invariant R I) *V ψ')  γ * norm ψgood
    by (simp add: preserves_def ψ'_def)
  then have norm ψ'bad  γ * norm ψgood
    by (simp add: ψ'bad_def Proj_ortho_compl register_minus cblinfun.diff_left Proj_lift_invariant)
  also have γ * norm ψgood  γ * norm ψ
    by (auto intro!: mult_left_mono is_Proj_reduces_norm γ  0 intro: register_projector
        simp add: ψgood_def)
  finally have ψ'bad_bound: norm ψ'bad  γ * norm ψ
    by meson
  have Uψ_decomp: U ψ = ψ'good + ψ'bad + U ψbad
    by (simp add: ψ_decomp ψ'_decomp cblinfun.add_right flip: ψ'_def)
  have mIψ'good0: R (Proj (- I)) ψ'good = 0
    by (metis Proj_fixes_image Proj_lift_invariant ψ'_decomp ψ'_def ψ'bad_def add_diff_cancel_right' assms
        cancel_comm_monoid_add_class.diff_cancel cblinfun.diff_right cblinfun_apply_in_image lift_invariant_def)
  have mIψ'bad: norm (R (Proj (- I)) ψ'bad)  γ * norm ψ
    by (metis ψ'bad_bound ψ'_decomp ψ'bad_def add_diff_cancel_left' cblinfun.diff_right diff_zero
        mIψ'good0)
  from ψbad_bound 
  have norm (U ψbad)  norm U * δ
    apply (rule_tac order_trans[OF norm_cblinfun[of U ψbad]])
    by (simp add: mult_left_mono)
  then have norm (R (Proj (- I)) *V U ψbad)  norm U * δ
    apply (rule_tac order_trans[OF norm_cblinfun])
    apply (subgoal_tac norm (R (Proj (- I)))  1)
     apply (smt (verit, best) mult_left_le_one_le norm_ge_zero) 
    by (simp add: norm_Proj_leq1 register_norm)
  with mIψ'bad have dist_inv R I (U *V ψ)  norm U * δ + γ * norm ψ
    apply (simp add: dist_inv_def Uψ_decomp cblinfun.add_right mIψ'good0)
    by (smt (verit, del_insts) norm_triangle_ineq)
  then show ?thesis
    by (simp add: δ_def)
qed

lemma dist_inv_preservesI:
  assumes dist_inv S J ψ  ε
  assumes pres: preserves U (lift_invariant S J) (lift_invariant R I) γ
  assumes norm U  1
  assumes norm ψ  1
  assumes γ + ε  δ
  assumes [register]: register S register R
  shows dist_inv R I (U *V ψ)  δ
proof -
  have γ  0
    using pres preserves_def by blast
  with assms have norm U * dist_inv S J ψ + γ * norm ψ  δ
    by (smt (verit, ccfv_SIG) dist_inv_def mult_left_le mult_left_le_one_le norm_ge_zero)
  then show ?thesis
    apply (rule order_trans[rotated])
    by (rule dist_inv_leq_if_preserves[OF pres register S register R])
qed


lemma dist_inv_apply_compatible:
  assumes compatible Q R
  shows dist_inv Q I (R U *V ψ)  norm U * dist_inv Q I ψ
proof -
  have [register]: register Q
    using assms compatible_register1 by blast
  have [register]: register R
    using assms compatible_register2 by blast
  have preserves (R U) (lift_invariant Q I) (lift_invariant Q I) 0
    apply (rule preserves_compatible[of R])
    by (simp_all add: assms compatible_register_invariant_compatible_register compatible_sym)
  then have dist_inv Q I (R U *V ψ)  norm (R U) * dist_inv Q I ψ + 0 * norm ψ
    apply (rule dist_inv_leq_if_preserves)
    by simp_all
  also have   norm U * dist_inv Q I ψ
    by (simp add: register_norm)
  finally show ?thesis
    by -
qed



lemma dist_inv_avg_apply_compatible:
  assumes h. compatible Q (R h)
  shows dist_inv_avg Q I (λh. R h (U h) *V ψ h)  (MAX h. norm (U h)) * dist_inv_avg Q I ψ
proof -
  have [iff]: (MAX hUNIV. norm (U h))  0
    by (simp add: Max_ge_iff)
  have dist_inv_avg Q I (λh. R h (U h) *V ψ h)
   = sqrt ((hUNIV. (dist_inv Q (I h) (R h (U h) *V ψ h))2) / real CARD('a))
    by (simp add: dist_inv_avg_def)
  also have   sqrt ((hUNIV. (norm (U h) * dist_inv Q (I h) (ψ h))2) / real CARD('a))
    by (auto intro!: divide_right_mono sum_mono dist_inv_apply_compatible assms)
  also have   sqrt ((hUNIV. ((MAX h. norm (U h)) * dist_inv Q (I h) (ψ h))2) / real CARD('a))
    by (auto intro!: divide_right_mono power_mono sum_mono mult_right_mono)
  also have  = (MAX h. norm (U h)) * sqrt ((hUNIV. (dist_inv Q (I h) (ψ h))2) / real CARD('a))
    by (simp add: power_mult_distrib real_sqrt_mult real_sqrt_abs abs_of_nonneg flip: sum_distrib_left times_divide_eq_right)
  also have  = (MAX h. norm (U h)) * dist_inv_avg Q I ψ
    by (simp add: dist_inv_avg_def)
  finally show ?thesis
    by -
qed



end