feat(frontends/lean): nested dependent pattern matching

This commit is contained in:
Leonardo de Moura 2015-03-06 19:04:09 -08:00
parent 14ca2d407d
commit f966634910
5 changed files with 215 additions and 95 deletions

View file

@ -177,41 +177,51 @@ static expr parse_begin_end_core(parser & p, pos_info const & pos, name const &
} else if (p.curr_is_token(get_have_tk())) {
auto pos = p.pos();
p.next();
auto id_pos = p.pos();
name id = p.check_id_next("invalid 'have' tactic, identifier expected");
p.check_token_next(get_colon_tk(), "invalid 'have' tactic, ':' expected");
expr A = p.parse_expr();
p.check_token_next(get_comma_tk(), "invalid 'have' tactic, ',' expected");
expr assert_tac = p.save_pos(mk_assert_tactic_expr(id, A), pos);
tacs.push_back(mk_begin_end_element_annotation(assert_tac));
if (p.curr_is_token(get_from_tk())) {
// parse: 'from' expr
p.next();
auto pos = p.pos();
expr t = p.parse_expr();
if (p.curr_is_token(get_bar_tk())) {
expr local = p.save_pos(mk_local(id, A), id_pos);
expr t = parse_local_equations(p, local);
t = p.mk_app(get_exact_tac_fn(), t, pos);
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
t = p.save_pos(mk_begin_end_annotation(t), pos);
add_tac(t, pos);
} else if (p.curr_is_token(get_proof_tk())) {
auto pos = p.pos();
p.next();
expr t = p.parse_expr();
p.check_token_next(get_qed_tk(), "invalid proof-qed, 'qed' expected");
t = p.mk_app(get_exact_tac_fn(), t, pos);
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
t = p.save_pos(mk_begin_end_annotation(t), pos);
add_tac(t, pos);
} else if (p.curr_is_token(get_begin_tk())) {
auto pos = p.pos();
tacs.push_back(parse_begin_end_core(p, pos, get_end_tk(), true));
} else if (p.curr_is_token(get_by_tk())) {
// parse: 'by' tactic
auto pos = p.pos();
p.next();
expr t = p.parse_tactic();
add_tac(t, pos);
} else {
throw parser_error("invalid 'have' tactic, 'by', 'begin', 'proof', or 'from' expected", p.pos());
p.check_token_next(get_comma_tk(), "invalid 'have' tactic, ',' expected");
if (p.curr_is_token(get_from_tk())) {
// parse: 'from' expr
p.next();
auto pos = p.pos();
expr t = p.parse_expr();
t = p.mk_app(get_exact_tac_fn(), t, pos);
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
t = p.save_pos(mk_begin_end_annotation(t), pos);
add_tac(t, pos);
} else if (p.curr_is_token(get_proof_tk())) {
auto pos = p.pos();
p.next();
expr t = p.parse_expr();
p.check_token_next(get_qed_tk(), "invalid proof-qed, 'qed' expected");
t = p.mk_app(get_exact_tac_fn(), t, pos);
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
t = p.save_pos(mk_begin_end_annotation(t), pos);
add_tac(t, pos);
} else if (p.curr_is_token(get_begin_tk())) {
auto pos = p.pos();
tacs.push_back(parse_begin_end_core(p, pos, get_end_tk(), true));
} else if (p.curr_is_token(get_by_tk())) {
// parse: 'by' tactic
auto pos = p.pos();
p.next();
expr t = p.parse_tactic();
add_tac(t, pos);
} else {
throw parser_error("invalid 'have' tactic, 'by', 'begin', 'proof', or 'from' expected", p.pos());
}
}
} else if (p.curr_is_token(get_show_tk())) {
auto pos = p.pos();
@ -349,17 +359,22 @@ static expr parse_have_core(parser & p, pos_info const & pos, optional<expr> con
id = p.mk_fresh_name();
prop = p.parse_expr();
}
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
expr proof;
if (prev_local) {
parser::local_scope scope(p);
p.add_local(*prev_local);
auto proof_pos = p.pos();
proof = parse_proof(p, prop);
proof = p.save_pos(Fun(*prev_local, proof), proof_pos);
proof = p.save_pos(mk_app(proof, *prev_local), proof_pos);
if (p.curr_is_token(get_bar_tk()) && !prev_local) {
expr fn = p.save_pos(mk_local(id, prop), id_pos);
proof = parse_local_equations(p, fn);
} else {
proof = parse_proof(p, prop);
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
if (prev_local) {
parser::local_scope scope(p);
p.add_local(*prev_local);
auto proof_pos = p.pos();
proof = parse_proof(p, prop);
proof = p.save_pos(Fun(*prev_local, proof), proof_pos);
proof = p.save_pos(mk_app(proof, *prev_local), proof_pos);
} else {
proof = parse_proof(p, prop);
}
}
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
parser::local_scope scope(p);
@ -398,11 +413,16 @@ static expr parse_assert(parser & p, unsigned, expr const *, pos_info const & po
static name * H_show = nullptr;
static expr parse_show(parser & p, unsigned, expr const *, pos_info const & pos) {
expr prop = p.parse_expr();
p.check_token_next(get_comma_tk(), "invalid 'show' declaration, ',' expected");
expr proof = parse_proof(p, prop);
expr b = p.save_pos(mk_lambda(*H_show, prop, Var(0)), pos);
expr r = p.mk_app(b, proof, pos);
return p.save_pos(mk_show_annotation(r), pos);
if (p.curr_is_token(get_bar_tk())) {
expr fn = p.save_pos(mk_local(*H_show, prop), pos);
return parse_local_equations(p, fn);
} else {
p.check_token_next(get_comma_tk(), "invalid 'show' declaration, ',' expected");
expr proof = parse_proof(p, prop);
expr b = p.save_pos(mk_lambda(*H_show, prop, Var(0)), pos);
expr r = p.mk_app(b, proof, pos);
return p.save_pos(mk_show_annotation(r), pos);
}
}
static expr parse_obtain(parser & p, unsigned, expr const *, pos_info const & pos) {

View file

@ -555,8 +555,8 @@ static expr merge_equation_lhs_vars(expr const & lhs, buffer<expr> & locals) {
<< n << "' in the left-hand-side does not correspond to function(s) being defined", p);
}
static bool is_eqn_prefix(parser & p) {
return p.curr_is_token(get_bar_tk()) || p.curr_is_token(get_comma_tk());
static bool is_eqn_prefix(parser & p, bool bar_only = false) {
return p.curr_is_token(get_bar_tk()) || (!bar_only && p.curr_is_token(get_comma_tk()));
}
static void check_eqn_prefix(parser & p) {
@ -580,11 +580,67 @@ static expr get_equation_fn(buffer<expr> const & fns, name const & fn_name, pos_
throw_invalid_equation_lhs(fn_name, lhs_pos);
}
static void parse_equations_core(parser & p, buffer<expr> const & fns, buffer<expr> & eqns, bool bar_only = false) {
for (expr const & fn : fns)
p.add_local(fn);
while (true) {
expr lhs;
unsigned prev_num_undef_ids = p.get_num_undef_ids();
buffer<expr> locals;
{
parser::undef_id_to_local_scope scope2(p);
buffer<expr> lhs_args;
auto lhs_pos = p.pos();
if (p.curr_is_token(get_explicit_tk())) {
p.next();
name fn_name = p.check_id_next("invalid recursive equation, identifier expected");
lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos));
} else {
expr first = p.parse_expr(get_max_prec());
expr fn = first;
if (is_explicit(fn))
fn = get_explicit_arg(fn);
if (is_local(fn) && is_equation_fn(fns, local_pp_name(fn))) {
lhs_args.push_back(first);
} else if (fns.size() == 1) {
lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos));
lhs_args.push_back(first);
} else {
throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant",
lhs_pos);
}
}
while (!p.curr_is_token(get_assign_tk()))
lhs_args.push_back(p.parse_expr(get_max_prec()));
lhs = p.save_pos(mk_app(lhs_args.size(), lhs_args.data()), lhs_pos);
unsigned num_undef_ids = p.get_num_undef_ids();
for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) {
locals.push_back(p.get_undef_id(i));
}
}
validate_equation_lhs(p, lhs, locals);
lhs = merge_equation_lhs_vars(lhs, locals);
auto assign_pos = p.pos();
p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected");
{
parser::local_scope scope2(p);
for (expr const & local : locals)
p.add_local(local);
expr rhs = p.parse_expr();
eqns.push_back(Fun(fns, Fun(locals, p.save_pos(mk_equation(lhs, rhs), assign_pos), p)));
}
if (!is_eqn_prefix(p, bar_only))
break;
p.next();
}
}
expr parse_equations(parser & p, name const & n, expr const & type, buffer<name> & auxs,
optional<local_environment> const & lenv, buffer<expr> const & ps,
pos_info const & def_pos) {
buffer<expr> eqns;
buffer<expr> fns;
buffer<expr> eqns;
{
parser::local_scope scope1(p, lenv);
for (expr const & param : ps)
@ -609,59 +665,7 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<name>
p.next();
eqns.push_back(Fun(fns, mk_no_equation(), p));
} else {
for (expr const & fn : fns)
p.add_local(fn);
while (true) {
expr lhs;
unsigned prev_num_undef_ids = p.get_num_undef_ids();
buffer<expr> locals;
{
parser::undef_id_to_local_scope scope2(p);
buffer<expr> lhs_args;
auto lhs_pos = p.pos();
if (p.curr_is_token(get_explicit_tk())) {
p.next();
name fn_name = p.check_id_next("invalid recursive equation, identifier expected");
lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos));
} else {
expr first = p.parse_expr(get_max_prec());
expr fn = first;
if (is_explicit(fn))
fn = get_explicit_arg(fn);
if (is_local(fn) && is_equation_fn(fns, local_pp_name(fn))) {
lhs_args.push_back(first);
} else if (fns.size() == 1) {
lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos));
lhs_args.push_back(first);
} else {
throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant",
lhs_pos);
}
}
while (!p.curr_is_token(get_assign_tk()))
lhs_args.push_back(p.parse_expr(get_max_prec()));
lhs = p.save_pos(mk_app(lhs_args.size(), lhs_args.data()), lhs_pos);
unsigned num_undef_ids = p.get_num_undef_ids();
for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) {
locals.push_back(p.get_undef_id(i));
}
}
validate_equation_lhs(p, lhs, locals);
lhs = merge_equation_lhs_vars(lhs, locals);
auto assign_pos = p.pos();
p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected");
{
parser::local_scope scope2(p);
for (expr const & local : locals)
p.add_local(local);
expr rhs = p.parse_expr();
eqns.push_back(Fun(fns, Fun(locals, p.save_pos(mk_equation(lhs, rhs), assign_pos), p)));
}
if (!is_eqn_prefix(p))
break;
p.next();
}
parse_equations_core(p, fns, eqns);
}
}
if (p.curr_is_token(get_wf_tk())) {
@ -675,6 +679,19 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<name>
}
}
/** \brief Parse a sequence of equations of the form <tt>| lhs := rhs</tt> */
expr parse_local_equations(parser & p, expr const & fn) {
lean_assert(p.curr_is_token(get_bar_tk()));
auto pos = p.pos();
p.next();
buffer<expr> fns;
buffer<expr> eqns;
fns.push_back(fn);
bool bar_only = true;
parse_equations_core(p, fns, eqns, bar_only);
return p.save_pos(mk_equations(fns.size(), eqns.size(), eqns.data()), pos);
}
/** \brief Use equations compiler infrastructure to implement match-with */
expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos) {
expr t = p.parse_expr();

View file

@ -17,6 +17,8 @@ class parser;
*/
bool parse_univ_params(parser & p, buffer<name> & ps);
expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos);
expr parse_local_equations(parser & p, expr const & fn);
/** \brief Add universe levels from \c found_ls to \c ls_buffer (only the levels that do not already occur in \c ls_buffer are added).
Then sort \c ls_buffer (using the order in which the universe levels were declared).
*/

View file

@ -0,0 +1,67 @@
import data.nat logic
open bool nat
check
show nat → bool
| 0 := tt
| (n+1) := ff
definition mult : nat → nat → nat :=
have plus : nat → nat → nat
| 0 b := b
| (succ a) b := succ (plus a b),
have mult : nat → nat → nat
| 0 b := 0
| (succ a) b := plus (mult a b) b,
mult
print definition mult
example : mult 3 7 = 21 := rfl
example : mult 8 7 = 56 := rfl
theorem add_eq_addl : ∀ x y, x + y = x ⊕ y
| 0 0 := rfl
| (succ x) 0 :=
begin
have addl_z : ∀ a : nat, a ⊕ 0 = a
| 0 := rfl
| (succ a) := calc
(succ a) ⊕ 0 = succ (a ⊕ 0) : rfl
... = succ a : addl_z,
rewrite addl_z
end
| 0 (succ y) :=
begin
have z_add : ∀ a : nat, 0 + a = a
| 0 := rfl
| (succ a) :=
begin
rewrite ▸ succ(0 + a) = _,
rewrite z_add
end,
rewrite z_add
end
| (succ x) (succ y) :=
begin
change (succ x + succ y = succ (x ⊕ succ y)),
have s_add : ∀ a b : nat, succ a + b = succ (a + b)
| 0 0 := rfl
| (succ a) 0 := rfl
| 0 (succ b) :=
begin
change (succ (succ 0 + b) = succ (succ (0 + b))),
rewrite -(s_add 0 b)
end
| (succ a) (succ b) :=
begin
change (succ (succ (succ a) + b) = succ (succ (succ a + b))),
apply (congr_arg succ),
rewrite (s_add (succ a) b),
end,
rewrite [s_add, add_eq_addl]
end
print definition add_eq_addl

View file

@ -0,0 +1,14 @@
import data.fin
open fin nat
definition nz_cases_on {C : Π n, fin (succ n) → Type}
(H₁ : Π n, C n (fz n))
(H₂ : Π n (f : fin n), C n (fs f))
{n : nat}
(f : fin (succ n)) : C n f :=
begin
reverts (n, f),
show ∀ (n : nat) (f : fin (succ n)), C n f
| m (fz m) := by apply H₁
| m (fs f') := by apply H₂
end