theory week10A_demo imports AutoCorres begin install_C_file "union.c" find_theorems name: "body_def" autocorres "union.c" find_theorems name: do_union lemma condition_eq_gets_if: "condition C a b = do flag \ gets C; if flag then a else b od" by (rule ext, simp add: condition_def bind_def simpler_gets_def) context union begin function fun_find where "fun_find x arr = (let y = Arrays.index arr x in if x = y then x else fun_find y arr)" by auto function fun_optimise where "fun_optimise x y arr = (if x = y then arr else let z = Arrays.index arr x; arr' = Arrays.update arr x y in fun_optimise z y arr)" by auto term get thm gets_def term put thm modify_def term gets definition "fetch_arr x = do assert (x < unat max_elements); gets (\s. unat (Arrays.index (parent_array_' s) x)) od" definition "upd_arr x v = do assert (x < unat max_elements); modify (parent_array_'_update (\arr. Arrays.update arr x (of_nat v))) od" function monad_find1 where "monad_find1 x = do y \ fetch_arr x; if y = x then return x else monad_find1 y od" by auto term "do x \ f; g x x od" term "f >>= (\x. g x x)" term "op >>=" term "All (\x. P x)" term "gets f" lemma "do x \ gets f; y \ gets g; return (h x y) od = do y \ gets g; x \ gets f; return (h x y) od" apply (rule ext) apply (simp add: exec_gets) done lemma bind_assoc_print_better_please: "do y \ (do x \ f; g od); h y od = do x \ f; y \ g; h y od" by (rule bind_assoc) function monad_optimise1 where "monad_optimise1 x y = (if x = y then return () else do z \ fetch_arr x; upd_arr x y; monad_optimise1 z y od)" by auto definition "monad_find2 x = do y \ fetch_arr x; (y, _) \ whileLoop (\(x, y) _. x \ y) (\(x, y). do z \ fetch_arr y; return (y, z) od) (x, y); return y od" lemma monad_find2_unfold: "monad_find2 x = do y \ fetch_arr x; if y = x then return x else monad_find2 y od" apply (simp add: monad_find2_def cong: if_cong) apply (rule trans) apply (subst whileLoop_unroll) apply (rule refl) apply (simp add: bind_assoc condition_eq_gets_if eq_commute if_swap[symmetric]) apply (simp add: if_to_top_of_bind bind_assoc cong: if_cong) done definition "monad_optimise2 x y = do whileLoop (\i _. i \ y) (\i. do j \ fetch_arr i; upd_arr i y; return j od) x; return () od" lemma monad_optimise2_unfold: "monad_optimise2 x y = (if x = y then return () else do z \ fetch_arr x; upd_arr x y; monad_optimise2 z y od)" apply (unfold monad_optimise2_def) apply (rule trans, rule bind_cong, rule whileLoop_unroll, rule refl) apply (clarsimp simp: condition_eq_gets_if bind_assoc) done definition "array_step s = {(x, y). x < unat max_elements \ y < unat max_elements \ unat (Arrays.index (parent_array_' s) x) = y}" definition "union_find_inv \ \s. (\x < unat max_elements. Arrays.index (parent_array_' s) x < max_elements) \ (\x < unat max_elements. \y. (x, y) \ trancl (array_step s) \ (y, y) \ array_step s)" lemma index_update_If: "z < CARD ('b) \ index (Arrays.update arr x y :: 'a['b :: finite]) z = (if x = z then y else index arr z)" by simp lemma array_step_update: "x < unat max_elements \ y < unat max_elements \ array_step (new_globals.parent_array_'_update (\arr. Arrays.update arr x (of_nat y)) s) = {(a, b). if a = x then b = y \ y < unat max_elements else (a, b) \ array_step s}" apply (rule set_eqI) apply (clarsimp simp: array_step_def max_elements_def cong: conj_cong) apply (auto simp: unat_of_nat index_update_If) done lemma index_update_is_same: "z < CARD('b) \ index arr x = y \ index (Arrays.update arr x y :: 'a['b :: finite]) z = index arr z" by (clarsimp simp add: index_update_If) lemma array_step_update_noop: "y < unat max_elements \ (y, y) \ array_step s \ array_step (new_globals.parent_array_'_update (\arr. Arrays.update arr y (of_nat y)) s) = array_step s" apply (simp add: array_step_update) apply (auto simp: array_step_def) done lemma optimise_step: "(x, y) \ rtrancl (array_step s) \ (y, y) \ array_step s \ union_find_inv s \ union_find_inv (parent_array_'_update (\arr. Arrays.update arr x (of_nat y)) s)" apply (subgoal_tac "x < unat max_elements \ y < unat max_elements") apply (cases "x = y") apply clarsimp apply (clarsimp simp add: union_find_inv_def array_step_update_noop) apply (clarsimp simp: max_elements_def index_update_If word_less_nat_alt unat_of_nat) apply (clarsimp simp: union_find_inv_def) apply (rule conjI) apply (clarsimp simp: index_update_If max_elements_def word_less_nat_alt unat_of_nat) apply clarsimp apply (frule spec, drule(1) mp, erule exE) apply (rule_tac x=ya in exI) apply (clarsimp simp: array_step_update split del: split_if) apply (subst if_not_P) back apply clarsimp oops lemma monad_optimise_inv: "\\s. x < unat max_elements \ (x, y) \ trancl (array_step s) \ (y, y) \ array_step s \ union_find_inv s\ monad_optimise2 x y \\_. union_find_inv\" apply (simp add: monad_optimise2_def) apply wp apply (rule whileLoop_wp, simp_all) apply (simp add: upd_arr_def fetch_arr_def, wp) apply (clarsimp) apply (intro conjI) apply (simp add: union_find_inv_def word_less_nat_alt) oops lemma find_explode: "find' x = do y \ monad_find2 (unat x); monad_optimise2 (unat x) y; return (of_nat y) od" apply (simp add: find'_def monad_find2_def monad_optimise2_def bind_assoc) apply (simp add: bind_assoc[symmetric]) apply (rule_tac subst=of_nat in bind_eq_split_subst) apply (simp add: liftM_def bind_assoc) defer thm whileLoop_cong apply (rule_tac subst=of_nat in bind_eq_split_subst) apply (rule bind_cong) thm whileLoop_cong oops end end