(* File: Fisher_Yates.thy Author: Manuel Eberl, TU München Definition and correctness proofs for two variants of the Fisher-Yates shuffle, an algorithm to shuffle an array in-place in linear time, i.e. produce a permutation uniformly at random. *) section ‹Fisher--Yates shuffle› theory Fisher_Yates imports "HOL-Probability.Probability" begin (* TODO Move *) lemma integral_pmf_of_multiset: "A ≠ {#} ⟹ (∫x. (f x :: real) ∂measure_pmf (pmf_of_multiset A)) = (∑x∈set_mset A. of_nat (count A x) * f x) / of_nat (size A)" by (subst integral_measure_pmf[where A = "set_mset A"]) (simp_all add: sum_divide_distrib mult_ac) lemma pmf_bind_pmf_of_multiset: "A ≠ {#} ⟹ pmf (pmf_of_multiset A ⤜ f) y = (∑x∈set_mset A. real (count A x) * pmf (f x) y) / real (size A)" by (simp add: pmf_bind integral_pmf_of_multiset) lemma pmf_map_inj_inv: assumes "inj_on f (set_pmf p)" assumes "⋀x. f' (f x) = x" shows "pmf (map_pmf f p) x = (if x ∈ range f then pmf p (f' x) else 0)" proof (cases "x ∈ f ` set_pmf p") case True from this obtain y where y: "y ∈ set_pmf p" "x = f y" by blast with assms(1) have "pmf (map_pmf f p) x = pmf p y" by (simp add: pmf_map_inj) also from y assms(2)[of y] have "y = f' x" by simp finally show ?thesis using y by simp next case False hence "x ∉ set_pmf (map_pmf f p)" by simp hence "pmf (map_pmf f p) x = 0" by (simp add: set_pmf_eq) also from False have "0 = (if x ∈ range f then pmf p (f' x) else 0)" by (auto simp: assms(2) set_pmf_eq) finally show ?thesis . qed (* END MOVE *) subsection ‹Swapping elements in a list› definition swap where "swap xs i j = xs[i := xs!j, j := xs ! i]" lemma length_swap [simp]: "length (swap xs i j) = length xs" by (simp add: swap_def) lemma swap_eq_Nil_iff [simp]: "swap xs i j = [] ⟷ xs = []" by (simp add: swap_def) lemma nth_swap: "i < length xs ⟹ j < length xs ⟹ swap xs i j ! k = (if k = i then xs ! j else if k = j then xs ! i else xs ! k)" by (auto simp: swap_def nth_list_update) lemma map_swap: "i < length xs ⟹ j < length xs ⟹ map f (swap xs i j) = swap (map f xs) i j" by (simp add: swap_def map_update map_nth) lemma swap_swap: "i < length xs ⟹ j < length xs ⟹ swap (swap xs i j) j i = xs" by (intro nth_equalityI) (auto simp: nth_swap nth_list_update) lemma mset_swap: "i < length xs ⟹ j < length xs ⟹ mset (swap xs i j) = mset xs" by (simp add: mset_update swap_def nth_list_update) lemma hd_swap_0: "i < length xs ⟹ hd (swap xs 0 i) = xs ! i" unfolding swap_def by (subst hd_conv_nth) (subst nth_list_update | force)+ subsection ‹Random Permutations› text ‹ First, we prove the intuitively obvious fact that choosing a random permutation of a multiset can be done by first randomly choosing the first element and then randomly choosing the rest of the list. › lemma pmf_of_set_permutations_of_multiset_nonempty: assumes "(A :: 'a multiset) ≠ {#}" shows "pmf_of_set (permutations_of_multiset A) = do {x ← pmf_of_multiset A; xs ← pmf_of_set (permutations_of_multiset (A - {#x#})); return_pmf (x#xs) }" (is "?lhs = ?rhs") proof (rule pmf_eqI) fix xs :: "'a list" show "pmf ?lhs xs = pmf ?rhs xs" proof (cases "xs ∈ permutations_of_multiset A") case False with assms have "xs ∉ set_pmf ?lhs" by simp moreover from assms False have "xs ∉ set_pmf ?rhs" by (auto simp: permutations_of_multiset_Cons_iff) ultimately show ?thesis by (simp add: set_pmf_eq) next case True with assms have nonempty: "xs ≠ []" by (auto dest: permutations_of_multisetD) hence range_Cons: "xs ∈ range ((#) x) ⟷ hd xs = x" for x by (cases xs) auto from True nonempty have hd_tl: "hd xs ∈# A ∧ tl xs ∈ permutations_of_multiset (A - {#hd xs#})" by (cases xs) (auto simp: permutations_of_multiset_Cons_iff) from assms have "pmf ?rhs xs = (∑x∈set_mset A. real (count A x) * pmf (map_pmf ((#) x) (pmf_of_set (permutations_of_multiset (A - {#x#})))) xs) / real (size A)" (is "_ = ?S / _") unfolding map_pmf_def [symmetric] by (simp add: pmf_bind_pmf_of_multiset) also have "?S = (∑x∈set_mset A. if x = hd xs then real (count A (hd xs)) / real (card (permutations_of_multiset (A - {#hd xs#}))) else 0)" using range_Cons hd_tl by (intro sum.cong refl, subst pmf_map_inj_inv[where f' = tl]) auto also have "… = real (count A (hd xs)) / real (card (permutations_of_multiset (A - {#hd xs#})))" using hd_tl by (simp add: sum.delta) also from hd_tl have "… = real (size A) / real (card (permutations_of_multiset A))" by (simp add: divide_simps real_card_permutations_of_multiset_remove[of "hd xs"]) also have " … / real (size A) = pmf (pmf_of_set (permutations_of_multiset A)) xs" using assms True by simp finally show ?thesis .. qed qed subsection ‹Shuffling Lists› text ‹ We define shuffling of a list as choosing from the set of all lists that correspond to the same multiset uniformly at random. › definition shuffle :: "'a list ⇒ 'a list pmf" where "shuffle xs = pmf_of_set (permutations_of_multiset (mset xs))" lemma shuffle_empty [simp]: "shuffle [] = return_pmf []" by (simp add: shuffle_def pmf_of_set_singleton) lemma shuffle_singleton [simp]: "shuffle [x] = return_pmf [x]" by (simp add: shuffle_def pmf_of_set_singleton) text ‹ The crucial ingredient of the Fisher--Yates shuffle is the following lemma, which decomposes a shuffle into swapping the first element of the list with a random element of the remaining list and shuffling the new remaining list. With a random-access implementation of a list -- such as an array -- all of the required operations are cheap and the resulting algorithm runs in linear time. › lemma shuffle_fisher_yates_step: assumes xs_nonempty [simp]: "xs ≠ []" shows "shuffle xs = do {i ← pmf_of_set {..<length xs}; let ys = swap xs 0 i; zs ← shuffle (tl ys); return_pmf (hd ys # zs) }" proof - have "shuffle xs = do {x ← pmf_of_multiset (mset xs); xs ← pmf_of_set (permutations_of_multiset (mset xs - {#x#})); return_pmf (x#xs) }" unfolding shuffle_def by (simp add: pmf_of_set_permutations_of_multiset_nonempty) also have "pmf_of_multiset (mset xs) = pmf_of_multiset (image_mset ((!) xs) (mset (upt 0 (length xs))))" by (subst mset_map [symmetric]) (simp add: map_nth) also have "… = map_pmf ((!) xs) (pmf_of_set {..<length xs})" by (subst map_pmf_of_set) (auto simp add: map_pmf_of_set atLeast0LessThan lessThan_empty_iff) also have "do {x ← map_pmf ((!) xs) (pmf_of_set {..<length xs}); ys ← pmf_of_set (permutations_of_multiset (mset xs - {#x#})); return_pmf (x # ys) } = do {i ← pmf_of_set {..<length xs}; ys ← pmf_of_set (permutations_of_multiset (mset xs - {#xs ! i#})); return_pmf (xs ! i # ys) }" by (simp add: map_pmf_def bind_assoc_pmf bind_return_pmf) also have "… = do {i ← pmf_of_set {..<length xs}; let ys = swap xs 0 i; zs ← shuffle (tl (swap xs 0 i)); return_pmf (hd ys # zs) }" unfolding Let_def shuffle_def by (intro bind_pmf_cong refl, subst (asm) set_pmf_of_set) (auto simp: lessThan_empty_iff mset_tl mset_swap hd_swap_0) finally show ?thesis by (simp add: Let_def) qed subsection ‹Forward Fisher-Yates Shuffle› text ‹ The actual Fisher--Yates shuffle is now merely a kind of tail-recursive version of decomposition described above. Note that unlike the traditional Fisher--Yates shuffle, we shuffle the list from front to back, which is the more natural way to do it when working with linked lists. › function fisher_yates_aux where "fisher_yates_aux i xs = (if i + 1 ≥ length xs then return_pmf xs else do {j ← pmf_of_set {i..<length xs}; fisher_yates_aux (i + 1) (swap xs i j)})" by auto termination by (relation "Wellfounded.measure (λ(i,xs). length xs - i)") simp_all declare fisher_yates_aux.simps [simp del] lemma fisher_yates_aux_correct: "fisher_yates_aux i xs = map_pmf (λys. take i xs @ ys) (shuffle (drop i xs))" proof (induction i xs rule: fisher_yates_aux.induct) case (1 i xs) show ?case proof (cases "i + 1 ≥ length xs") case True show ?thesis proof (cases "i ≥ length xs") case False with True have "length xs = Suc i" and i: "i = length xs - 1" by simp_all hence "xs ≠ []" by auto hence "xs = butlast xs @ [last xs]" by (rule append_butlast_last_id [symmetric]) also have "butlast xs = take i xs" by (simp add: butlast_conv_take i) finally have eq: "take i xs @ [last xs] = xs" .. moreover have "xs = take i xs @ drop i xs" by simp ultimately have "take i xs @ [last xs] = take i xs @ drop i xs" by (rule trans) hence "drop i xs = [last xs]" by (subst (asm) same_append_eq) simp_all with True show ?thesis by (simp add: eq fisher_yates_aux.simps) qed (simp_all add: fisher_yates_aux.simps) next case False from False have xs_nonempty [simp]: "xs ≠ []" by auto have "fisher_yates_aux i xs = pmf_of_set {i..<length xs} ⤜ (λj. fisher_yates_aux (i+1) (swap xs i j))" using False by (subst fisher_yates_aux.simps) simp also have "{i..<length xs} = ((λj. j + i) ` {..<length xs - i})" using False by (simp add: lessThan_atLeast0) also from False have "pmf_of_set … = map_pmf (λj. j + i) (pmf_of_set {..<length xs - i})" by (subst map_pmf_of_set_inj) (simp_all add: lessThan_empty_iff) also from False have "length xs - i = length (drop i xs)" by simp also have "map_pmf (λj. j + i) (pmf_of_set {..<length (drop i xs)}) ⤜ (λj. fisher_yates_aux (i + 1) (swap xs i j)) = pmf_of_set {..<length (drop i xs)} ⤜ (λj. fisher_yates_aux (i + 1) (swap xs i (j+i)))" by (simp add: map_pmf_def bind_return_pmf bind_assoc_pmf) also have "… = do {j ← pmf_of_set {..<length (drop i xs)}; let ys = swap (drop i xs) 0 j; zs ← shuffle (tl ys); return_pmf (take i xs @ hd ys # zs)}" (is "_ = bind_pmf _ ?T") proof (intro bind_pmf_cong refl) fix j assume "j ∈ set_pmf (pmf_of_set {..<length (drop i xs)})" with False have j: "j < length (drop i xs)" by (simp_all add: lessThan_empty_iff) define ys where "ys = swap xs i (j + i)" have "fisher_yates_aux (i + 1) ys = map_pmf ((@) (take (i+1) ys)) (shuffle (drop (i+1) ys))" using False j unfolding ys_def by (intro "1.IH") simp_all also from False have "take (i+1) ys = take i ys @ [hd (drop i ys)]" by (simp add: ys_def take_hd_drop) also have "drop (i+1) ys = tl (drop i ys)" by (simp add: ys_def tl_drop drop_Suc) also from False j have "drop i ys = swap (drop i xs) 0 j" by (simp add: ys_def swap_def drop_update_swap add_ac) also from False j have "take i ys = take i xs" by (simp add: ys_def swap_def) finally show "fisher_yates_aux (i + 1) ys = ?T j" by (simp add: ys_def map_pmf_def Let_def bind_assoc_pmf bind_return_pmf) qed also from False have "… = map_pmf (λzs. take i xs @ zs) (shuffle (drop i xs))" by (subst shuffle_fisher_yates_step[of "drop i xs"]) (simp_all add: map_pmf_def Let_def bind_return_pmf bind_assoc_pmf) finally show ?thesis . qed qed definition fisher_yates where "fisher_yates = fisher_yates_aux 0" lemma fisher_yates_correct: "fisher_yates xs = shuffle xs" unfolding fisher_yates_def by (subst fisher_yates_aux_correct) (simp_all add: map_pmf_def bind_return_pmf') subsection ‹Backwards Fisher-Yates Shuffle› text ‹ We can now easily derive the classical Fisher--Yates shuffle, which goes through the list from back to front and show its equivalence to the forward Fisher--Yates shuffle. › fun fisher_yates_alt_aux where "fisher_yates_alt_aux i xs = (if i = 0 then return_pmf xs else do {j ← pmf_of_set {..i}; fisher_yates_alt_aux (i - 1) (swap xs i j)})" declare fisher_yates_alt_aux.simps [simp del] lemma fisher_yates_alt_aux_altdef: "i < length xs ⟹ fisher_yates_alt_aux i xs = map_pmf rev (fisher_yates_aux (length xs - i - 1) (rev xs))" proof (induction i xs rule: fisher_yates_alt_aux.induct) case (1 i xs) show ?case proof (cases "i = 0") case False with "1.prems" have "map_pmf rev (fisher_yates_aux (length xs - i - 1) (rev xs)) = pmf_of_set {length xs - Suc i..<length xs} ⤜ (λx. fisher_yates_aux (Suc (length xs - Suc i)) (swap (rev xs) (length xs - Suc i) x) ⤜ (λx. return_pmf (rev x)))" by (subst fisher_yates_aux.simps) (auto simp: map_pmf_def bind_return_pmf bind_assoc_pmf) also from "1.prems" False have bij: "bij_betw (λj. length xs - Suc j) {..i} {length xs - Suc i..<length xs}" by (intro bij_betwI[where g = "λj. length xs - Suc j"]) auto from bij have "{length xs - Suc i..<length xs} = (λj. length xs - Suc j) ` {..i}" by (simp add: bij_betw_def) also from bij have "pmf_of_set … = map_pmf (λj. length xs - Suc j) (pmf_of_set {..i})" by (subst map_pmf_of_set_inj) (auto simp: bij_betw_def) also have "map_pmf (λj. length xs - Suc j) (pmf_of_set {..i}) ⤜ (λx. fisher_yates_aux (Suc (length xs - Suc i)) (swap (rev xs) (length xs - Suc i) x) ⤜ (λx. return_pmf (rev x))) = pmf_of_set {..i} ⤜ (λx. map_pmf rev ( fisher_yates_aux (length xs - i) (rev (swap xs i x))))" using "1.prems" False by (auto simp add: map_pmf_def bind_assoc_pmf bind_return_pmf Suc_diff_Suc swap_def rev_update rev_nth intro!: bind_pmf_cong) also have "… = pmf_of_set {..i} ⤜ (λj. fisher_yates_alt_aux (i - 1) (swap xs i j))" using "1.prems" False "1.IH" [symmetric] by (auto intro!: bind_pmf_cong) also from "1.prems" False have "… = fisher_yates_alt_aux i xs" by (subst fisher_yates_alt_aux.simps[of i]) simp_all finally show ?thesis .. qed (insert "1.prems", simp_all add: fisher_yates_aux.simps fisher_yates_alt_aux.simps) qed definition fisher_yates_alt where "fisher_yates_alt xs = fisher_yates_alt_aux (length xs - 1) xs" lemma fisher_yates_alt_aux_correct: "fisher_yates_alt xs = shuffle xs" proof (cases "xs = []") case True thus ?thesis by (simp add: fisher_yates_alt_def fisher_yates_alt_aux.simps) next case False thus ?thesis unfolding fisher_yates_alt_def by (subst fisher_yates_alt_aux_altdef) (simp_all add: fisher_yates_aux_correct shuffle_def map_pmf_of_set_inj) qed subsection ‹Code generation test› text ‹ Isabelle's code generator allows us to produce executable code both for @{const shuffle} and for @{const fisher_yates} and @{const fisher_yates_alt}. However, this code does not produce a random sample (i.e. a single randomly permuted list) -- which is, in fact, the only purpose of the Fisher--Yates algorithm -- but the entire probability distribution consisting of $n!$ lists, each with probability $1/n!$. In the future, it would be nice if Isabelle also had some code generation facility that supports generating sampling code. › value [code] "shuffle ''abcd''" value [code] "fisher_yates ''abcd''" value [code] "fisher_yates_alt ''abcd''" end