Theory Applicative_Probability_List

(* Author: Andreas Lochbihler, ETH Zurich *)

subsection ‹Probability mass functions implemented as lists with duplicates›

theory Applicative_Probability_List imports
  Applicative_List
  Complex_Main
begin

lemma sum_list_concat_map: "sum_list (concat (map f xs)) = sum_list (map (λx. sum_list (f x)) xs)"
by(induction xs) simp_all

context includes applicative_syntax begin

lemma set_ap_list [simp]: "set (f  x) = (λ(f, x). f x) ` (set f × set x)"
by(auto simp add: ap_list_def List.bind_def)

text ‹We call the implementation type pfp› because it is the basis for the Haskell library
  Probability by Martin Erwig and Steve Kollmansberger (Probabilistic Functional Programming).›

typedef 'a pfp = "{xs :: ('a × real) list. ((_, p)  set xs. p > 0)  sum_list (map snd xs) = 1}"
proof
  show "[(x, 1)]  ?pfp" for x by simp
qed

setup_lifting type_definition_pfp

lift_definition pure_pfp :: "'a  'a pfp" is "λx. [(x, 1)]" by simp

lift_definition ap_pfp :: "('a  'b) pfp  'a pfp  'b pfp"
is "λfs xs. [λ(f, p) (x, q). (f x, p * q)]  fs  xs"
proof safe
  fix xs :: "(('a  'b) × real) list" and ys :: "('a × real) list"
  assume xs: "(x, y)  set xs. 0 < y" "sum_list (map snd xs) = 1"
    and ys: "(x, y)  set ys. 0 < y" "sum_list (map snd ys) = 1"
  let ?ap = "[λ(f, p) (x, q). (f x, p * q)]  xs  ys"
  show "0 < b" if "(a, b)  set ?ap" for a b using that xs ys
    by(auto intro!: mult_pos_pos)
  show "sum_list (map snd ?ap) = 1" using xs ys
    by(simp add: ap_list_def List.bind_def map_concat o_def split_beta sum_list_concat_map sum_list_const_mult)
qed

adhoc_overloading Applicative.ap ap_pfp

applicative pfp
 for pure: pure_pfp
     ap: ap_pfp
proof -
  show "pure_pfp (λx. x)  x = x" for x :: "'a pfp"
    by transfer(simp add: ap_list_def List.bind_def)
  show "pure_pfp f  pure_pfp x = pure_pfp (f x)" for f :: "'a  'b" and x
    by transfer (applicative_lifting; simp)
  show "pure_pfp (λg f x. g (f x))  g  f  x = g  (f  x)"
    for g :: "('b  'c) pfp" and f :: "('a  'b) pfp" and x
    by transfer(applicative_lifting; clarsimp)
  show "f  pure_pfp x = pure_pfp (λf. f x)  f" for f :: "('a  'b) pfp" and x
    by transfer(applicative_lifting; clarsimp)
qed

end

end