FirstClassFunctions: flattenS_ok

This commit is contained in:
Adam Chlipala 2018-02-19 14:02:17 -05:00
parent 0047d49139
commit 63836dad24

View file

@ -49,19 +49,19 @@ Definition languages := [pascal; c; gallina; haskell; ocaml].
(** * Classic list functions *) (** * Classic list functions *)
Fixpoint map {A B : Set} (f : A -> B) (ls : list A) : list B := Fixpoint map {A B} (f : A -> B) (ls : list A) : list B :=
match ls with match ls with
| nil => nil | nil => nil
| x :: ls' => f x :: map f ls' | x :: ls' => f x :: map f ls'
end. end.
Fixpoint filter {A : Set} (f : A -> bool) (ls : list A) : list A := Fixpoint filter {A} (f : A -> bool) (ls : list A) : list A :=
match ls with match ls with
| nil => nil | nil => nil
| x :: ls' => if f x then x :: filter f ls' else filter f ls' | x :: ls' => if f x then x :: filter f ls' else filter f ls'
end. end.
Fixpoint fold_left {A B : Set} (f : B -> A -> B) (ls : list A) (acc : B) : B := Fixpoint fold_left {A B} (f : B -> A -> B) (ls : list A) (acc : B) : B :=
match ls with match ls with
| nil => acc | nil => acc
| x :: ls' => fold_left f ls' (f acc x) | x :: ls' => fold_left f ls' (f acc x)
@ -79,7 +79,7 @@ Reset map.
(** * Motivating continuations with search problems *) (** * Motivating continuations with search problems *)
Fixpoint allSublists {A : Set} (ls : list A) : list (list A) := Fixpoint allSublists {A} (ls : list A) : list (list A) :=
match ls with match ls with
| [] => [[]] | [] => [[]]
| x :: ls' => | x :: ls' =>
@ -103,7 +103,7 @@ Fixpoint countingDown (from : nat) :=
Time Compute sublistSummingTo (countingDown 20) 1. Time Compute sublistSummingTo (countingDown 20) 1.
Fixpoint allSublistsK {A B : Set} (ls : list A) Fixpoint allSublistsK {A B} (ls : list A)
(failed : unit -> B) (failed : unit -> B)
(found : list A -> (unit -> B) -> B) : B := (found : list A -> (unit -> B) -> B) : B :=
match ls with match ls with
@ -123,7 +123,7 @@ Definition sublistSummingToK (ns : list nat) (target : nat) : option (list nat)
Time Compute sublistSummingToK (countingDown 20) 1. Time Compute sublistSummingToK (countingDown 20) 1.
Theorem allSublistsK_ok : forall {A B : Set} (ls : list A) (failed : unit -> B) found, Theorem allSublistsK_ok : forall {A B} (ls : list A) (failed : unit -> B) found,
(forall sol, (exists ans, (forall failed', found sol failed' = ans) (forall sol, (exists ans, (forall failed', found sol failed' = ans)
/\ ans <> failed tt) /\ ans <> failed tt)
\/ (forall failed', found sol failed' = failed' tt)) \/ (forall failed', found sol failed' = failed' tt))
@ -238,31 +238,31 @@ Qed.
(** * The classics in continuation-passing style *) (** * The classics in continuation-passing style *)
Fixpoint mapK {A B R : Set} (f : A -> (B -> R) -> R) (ls : list A) (k : list B -> R) : R := Fixpoint mapK {A B R} (f : A -> (B -> R) -> R) (ls : list A) (k : list B -> R) : R :=
match ls with match ls with
| nil => k nil | nil => k nil
| x :: ls' => f x (fun x' => mapK f ls' (fun ls'' => k (x' :: ls''))) | x :: ls' => f x (fun x' => mapK f ls' (fun ls'' => k (x' :: ls'')))
end. end.
Fixpoint filterK {A R : Set} (f : A -> (bool -> R) -> R) (ls : list A) (k : list A -> R) : R := Fixpoint filterK {A R} (f : A -> (bool -> R) -> R) (ls : list A) (k : list A -> R) : R :=
match ls with match ls with
| nil => k nil | nil => k nil
| x :: ls' => f x (fun b => filterK f ls' (fun ls'' => k (if b then x :: ls'' else ls''))) | x :: ls' => f x (fun b => filterK f ls' (fun ls'' => k (if b then x :: ls'' else ls'')))
end. end.
Fixpoint fold_leftK {A B R : Set} (f : B -> A -> (B -> R) -> R) (ls : list A) (acc : B) (k : B -> R) : R := Fixpoint fold_leftK {A B R} (f : B -> A -> (B -> R) -> R) (ls : list A) (acc : B) (k : B -> R) : R :=
match ls with match ls with
| nil => k acc | nil => k acc
| x :: ls' => f acc x (fun x' => fold_leftK f ls' x' k) | x :: ls' => f acc x (fun x' => fold_leftK f ls' x' k)
end. end.
Definition NameK {R : Set} (l : programming_language) (k : string -> R) : R := Definition NameK {R} (l : programming_language) (k : string -> R) : R :=
k (Name l). k (Name l).
Definition PurelyFunctionalK {R : Set} (l : programming_language) (k : bool -> R) : R := Definition PurelyFunctionalK {R} (l : programming_language) (k : bool -> R) : R :=
k (PurelyFunctional l). k (PurelyFunctional l).
Definition AppearedInYearK {R : Set} (l : programming_language) (k : nat -> R) : R := Definition AppearedInYearK {R} (l : programming_language) (k : nat -> R) : R :=
k (AppearedInYear l). k (AppearedInYear l).
Definition maxK {R : Set} (n1 n2 : nat) (k : nat -> R) : R := Definition maxK {R} (n1 n2 : nat) (k : nat -> R) : R :=
k (max n1 n2). k (max n1 n2).
Compute mapK NameK languages (fun ls => ls). Compute mapK NameK languages (fun ls => ls).
@ -272,7 +272,7 @@ Compute filterK PurelyFunctionalK languages
(fun ls1 => mapK AppearedInYearK ls1 (fun ls1 => mapK AppearedInYearK ls1
(fun ls2 => fold_leftK maxK ls2 0 (fun x => x))). (fun ls2 => fold_leftK maxK ls2 0 (fun x => x))).
Theorem mapK_ok : forall {A B R : Set} (f : A -> (B -> R) -> R) (f_base : A -> B), Theorem mapK_ok : forall {A B R} (f : A -> (B -> R) -> R) (f_base : A -> B),
(forall x k, f x k = k (f_base x)) (forall x k, f x k = k (f_base x))
-> forall (ls : list A) (k : list B -> R), -> forall (ls : list A) (k : list B -> R),
mapK f ls k = k (map f_base ls). mapK f ls k = k (map f_base ls).
@ -292,7 +292,7 @@ Proof.
trivial. trivial.
Qed. Qed.
Theorem filterK_ok : forall {A R : Set} (f : A -> (bool -> R) -> R) (f_base : A -> bool), Theorem filterK_ok : forall {A R} (f : A -> (bool -> R) -> R) (f_base : A -> bool),
(forall x k, f x k = k (f_base x)) (forall x k, f x k = k (f_base x))
-> forall (ls : list A) (k : list A -> R), -> forall (ls : list A) (k : list A -> R),
filterK f ls k = k (filter f_base ls). filterK f ls k = k (filter f_base ls).
@ -312,7 +312,7 @@ Proof.
apply mapK_ok with (f_base := Name); trivial. apply mapK_ok with (f_base := Name); trivial.
Qed. Qed.
Theorem fold_leftK_ok : forall {A B R : Set} (f : B -> A -> (B -> R) -> R) (f_base : B -> A -> B), Theorem fold_leftK_ok : forall {A B R} (f : B -> A -> (B -> R) -> R) (f_base : B -> A -> B),
(forall x acc k, f x acc k = k (f_base x acc)) (forall x acc k, f x acc k = k (f_base x acc))
-> forall (ls : list A) (acc : B) (k : B -> R), -> forall (ls : list A) (acc : B) (k : B -> R),
fold_leftK f ls acc k = k (fold_left f_base ls acc). fold_leftK f ls acc k = k (fold_left f_base ls acc).
@ -352,10 +352,10 @@ Inductive tree {A} :=
| Node (l : tree) (d : A) (r : tree). | Node (l : tree) (d : A) (r : tree).
Arguments tree : clear implicits. Arguments tree : clear implicits.
Fixpoint depth {A} (t : tree A) : nat := Fixpoint size {A} (t : tree A) : nat :=
match t with match t with
| Leaf => 0 | Leaf => 0
| Node l _ r => 2 + depth l + depth r | Node l _ r => 2 + size l + size r
end. end.
Fixpoint flatten {A} (t : tree A) : list A := Fixpoint flatten {A} (t : tree A) : list A :=
@ -421,10 +421,10 @@ Fixpoint flattenKD {A} (fuel : nat) (t : tree A) (acc : list A)
end end
end. end.
Fixpoint continuation_depth {A} (k : flatten_continuation A) : nat := Fixpoint continuation_size {A} (k : flatten_continuation A) : nat :=
match k with match k with
| KDone => 0 | KDone => 0
| KMore l d k' => 1 + depth l + continuation_depth k' | KMore l d k' => 1 + size l + continuation_size k'
end. end.
Fixpoint flatten_cont {A} (k : flatten_continuation A) : list A := Fixpoint flatten_cont {A} (k : flatten_continuation A) : list A :=
@ -434,7 +434,7 @@ Fixpoint flatten_cont {A} (k : flatten_continuation A) : list A :=
end. end.
Lemma flattenKD_ok' : forall {A} fuel fuel' (t : tree A) acc k, Lemma flattenKD_ok' : forall {A} fuel fuel' (t : tree A) acc k,
depth t + continuation_depth k < fuel' < fuel size t + continuation_size k < fuel' < fuel
-> flattenKD fuel' t acc k -> flattenKD fuel' t acc k
= flatten_cont k ++ flatten t ++ acc. = flatten_cont k ++ flatten t ++ acc.
Proof. Proof.
@ -458,12 +458,61 @@ Proof.
Qed. Qed.
Theorem flattenKD_ok : forall {A} (t : tree A), Theorem flattenKD_ok : forall {A} (t : tree A),
flattenKD (depth t + 1) t [] KDone = flatten t. flattenKD (size t + 1) t [] KDone = flatten t.
Proof. Proof.
simplify. simplify.
rewrite flattenKD_ok' with (fuel := depth t + 2). rewrite flattenKD_ok' with (fuel := size t + 2).
simplify. simplify.
apply app_nil_r. apply app_nil_r.
simplify. simplify.
linear_arithmetic. linear_arithmetic.
Qed. Qed.
Definition call_stack A := list (tree A * A).
Definition pop_call_stack {A} (acc : list A) (st : call_stack A)
(flattenS : tree A -> list A -> call_stack A -> list A)
: list A :=
match st with
| [] => acc
| (l, d) :: st' => flattenS l (d :: acc) st'
end.
Fixpoint flattenS {A} (fuel : nat) (t : tree A) (acc : list A)
(st : call_stack A) : list A :=
match fuel with
| O => []
| S fuel' =>
match t with
| Leaf => pop_call_stack acc st (flattenS fuel')
| Node l d r => flattenS fuel' r acc ((l, d) :: st)
end
end.
Fixpoint call_stack_to_continuation {A} (st : call_stack A) : flatten_continuation A :=
match st with
| [] => KDone
| (l, d) :: st' => KMore l d (call_stack_to_continuation st')
end.
Lemma flattenS_flattenKD : forall {A} fuel (t : tree A) acc st,
flattenS fuel t acc st = flattenKD fuel t acc (call_stack_to_continuation st).
Proof.
induct fuel; simplify; trivial.
cases t.
cases st; simplify; trivial.
cases p; simplify.
apply IHfuel.
apply IHfuel.
Qed.
Theorem flattenS_ok : forall {A} (t : tree A),
flattenS (size t + 1) t [] [] = flatten t.
Proof.
simplify.
rewrite flattenS_flattenKD.
apply flattenKD_ok.
Qed.