ProgramDerivation: adding caches

This commit is contained in:
Adam Chlipala 2018-05-05 18:51:21 -04:00
parent 3ff400b780
commit 5f981335d9

View file

@ -501,3 +501,106 @@ Definition derived_counter : { counter' | adt_refine counter counter' }.
Defined.
Eval simpl in proj1_sig derived_counter.
(** * Another refinement strategy: introducing a cache (a.k.a. finite differencing) *)
Inductive CachingMethods {state} (name : string) (func : state -> nat)
: forall {names}, methods state names -> methods (state * nat) names -> Prop :=
| CmNil :
CachingMethods name func MethodsNil MethodsNil
| CmCached : forall names (ms1 : methods state names) (ms2 : methods _ names),
CachingMethods name func ms1 ms2
-> CachingMethods name func
(MethodsCons {| MethodName := name; MethodBody := (fun s _ => ret (s, func s)) |} ms1)
(MethodsCons {| MethodName := name; MethodBody := (fun s arg => ret (s, snd s)) |} ms2)
| CmDefault : forall name' names oldbody (ms1 : methods state names) (ms2 : methods _ names),
name' <> name
-> CachingMethods name func ms1 ms2
-> CachingMethods name func
(MethodsCons {| MethodName := name'; MethodBody := oldbody |} ms1)
(MethodsCons {| MethodName := name'; MethodBody := (fun s arg =>
p <- oldbody (fst s) arg;
new_cache <- pick c where (func (fst s) = snd s -> func (fst p) = c);
ret ((fst p, new_cache), snd p)) |} ms2).
Lemma CachingMethods_ok : forall state name (func : state -> nat)
names (ms1 : methods state names) (ms2 : methods (state * nat) names),
CachingMethods name func ms1 ms2
-> RefineMethods (fun s1 s2 => fst s2 = s1 /\ snd s2 = func s1) ms1 ms2.
Proof.
induct 1; eauto.
econstructor; eauto.
unfold ret, bind.
simplify; first_order; subst.
invert H1.
rewrite H2.
eauto.
econstructor; eauto.
unfold ret, bind.
simplify; first_order; subst.
invert H5.
unfold pick_ in H4.
cases x; simplify.
eauto.
Qed.
Hint Resolve CachingMethods_ok.
Theorem refine_cache : forall state name (func : state -> nat)
names (ms1 : methods state names) (ms2 : methods (state * nat) names)
constr,
CachingMethods name func ms1 ms2
-> adt_refine {| AdtState := state;
AdtConstructor := constr;
AdtMethods := ms1 |}
{| AdtState := state * nat;
AdtConstructor := s0 <- constr; ret (s0, func s0);
AdtMethods := ms2 |}.
Proof.
simplify.
choose_relation (fun s1 s2 => fst s2 = s1 /\ snd s2 = func s1); eauto.
unfold bind, ret in *.
first_order; subst.
simplify; eauto.
Qed.
Ltac refine_cache nam := eapply refine_trans; [ eapply refine_cache with (name := nam);
repeat (apply CmNil
|| refine (CmCached _ _ _ _ _ _)
|| (refine (CmDefault _ _ _ _ _ _ _ _ _); [ equality | ])) | ].
(** ** An example with lists of numbers *)
Definition sum := fold_right plus 0.
Definition nats := ADT {
rep = list nat
and constructor = ret []
and method "add"[[self, n]] = ret (n :: self, 0)
and method "sum"[[self, _]] = ret (self, sum self)
}.
Definition optimized_nats : { nats' | adt_refine nats nats' }.
unfold nats; eexists.
refine_cache "sum".
refine_constructor.
rewrite bind_ret.
finish.
refine_method "add".
rewrite bind_ret; simplify.
rewrite (pick_one (arg + snd s)).
rewrite bind_ret.
finish.
equality.
refine_finish.
Defined.
Eval simpl in proj1_sig optimized_nats.