feat(frontends/lean): nested dependent pattern matching
This commit is contained in:
parent
14ca2d407d
commit
f966634910
5 changed files with 215 additions and 95 deletions
|
@ -177,41 +177,51 @@ static expr parse_begin_end_core(parser & p, pos_info const & pos, name const &
|
|||
} else if (p.curr_is_token(get_have_tk())) {
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
auto id_pos = p.pos();
|
||||
name id = p.check_id_next("invalid 'have' tactic, identifier expected");
|
||||
p.check_token_next(get_colon_tk(), "invalid 'have' tactic, ':' expected");
|
||||
expr A = p.parse_expr();
|
||||
p.check_token_next(get_comma_tk(), "invalid 'have' tactic, ',' expected");
|
||||
expr assert_tac = p.save_pos(mk_assert_tactic_expr(id, A), pos);
|
||||
tacs.push_back(mk_begin_end_element_annotation(assert_tac));
|
||||
if (p.curr_is_token(get_from_tk())) {
|
||||
// parse: 'from' expr
|
||||
p.next();
|
||||
auto pos = p.pos();
|
||||
expr t = p.parse_expr();
|
||||
if (p.curr_is_token(get_bar_tk())) {
|
||||
expr local = p.save_pos(mk_local(id, A), id_pos);
|
||||
expr t = parse_local_equations(p, local);
|
||||
t = p.mk_app(get_exact_tac_fn(), t, pos);
|
||||
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
|
||||
t = p.save_pos(mk_begin_end_annotation(t), pos);
|
||||
add_tac(t, pos);
|
||||
} else if (p.curr_is_token(get_proof_tk())) {
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
expr t = p.parse_expr();
|
||||
p.check_token_next(get_qed_tk(), "invalid proof-qed, 'qed' expected");
|
||||
t = p.mk_app(get_exact_tac_fn(), t, pos);
|
||||
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
|
||||
t = p.save_pos(mk_begin_end_annotation(t), pos);
|
||||
add_tac(t, pos);
|
||||
} else if (p.curr_is_token(get_begin_tk())) {
|
||||
auto pos = p.pos();
|
||||
tacs.push_back(parse_begin_end_core(p, pos, get_end_tk(), true));
|
||||
} else if (p.curr_is_token(get_by_tk())) {
|
||||
// parse: 'by' tactic
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
expr t = p.parse_tactic();
|
||||
add_tac(t, pos);
|
||||
} else {
|
||||
throw parser_error("invalid 'have' tactic, 'by', 'begin', 'proof', or 'from' expected", p.pos());
|
||||
p.check_token_next(get_comma_tk(), "invalid 'have' tactic, ',' expected");
|
||||
if (p.curr_is_token(get_from_tk())) {
|
||||
// parse: 'from' expr
|
||||
p.next();
|
||||
auto pos = p.pos();
|
||||
expr t = p.parse_expr();
|
||||
t = p.mk_app(get_exact_tac_fn(), t, pos);
|
||||
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
|
||||
t = p.save_pos(mk_begin_end_annotation(t), pos);
|
||||
add_tac(t, pos);
|
||||
} else if (p.curr_is_token(get_proof_tk())) {
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
expr t = p.parse_expr();
|
||||
p.check_token_next(get_qed_tk(), "invalid proof-qed, 'qed' expected");
|
||||
t = p.mk_app(get_exact_tac_fn(), t, pos);
|
||||
t = p.save_pos(mk_begin_end_element_annotation(t), pos);
|
||||
t = p.save_pos(mk_begin_end_annotation(t), pos);
|
||||
add_tac(t, pos);
|
||||
} else if (p.curr_is_token(get_begin_tk())) {
|
||||
auto pos = p.pos();
|
||||
tacs.push_back(parse_begin_end_core(p, pos, get_end_tk(), true));
|
||||
} else if (p.curr_is_token(get_by_tk())) {
|
||||
// parse: 'by' tactic
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
expr t = p.parse_tactic();
|
||||
add_tac(t, pos);
|
||||
} else {
|
||||
throw parser_error("invalid 'have' tactic, 'by', 'begin', 'proof', or 'from' expected", p.pos());
|
||||
}
|
||||
}
|
||||
} else if (p.curr_is_token(get_show_tk())) {
|
||||
auto pos = p.pos();
|
||||
|
@ -349,17 +359,22 @@ static expr parse_have_core(parser & p, pos_info const & pos, optional<expr> con
|
|||
id = p.mk_fresh_name();
|
||||
prop = p.parse_expr();
|
||||
}
|
||||
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
|
||||
expr proof;
|
||||
if (prev_local) {
|
||||
parser::local_scope scope(p);
|
||||
p.add_local(*prev_local);
|
||||
auto proof_pos = p.pos();
|
||||
proof = parse_proof(p, prop);
|
||||
proof = p.save_pos(Fun(*prev_local, proof), proof_pos);
|
||||
proof = p.save_pos(mk_app(proof, *prev_local), proof_pos);
|
||||
if (p.curr_is_token(get_bar_tk()) && !prev_local) {
|
||||
expr fn = p.save_pos(mk_local(id, prop), id_pos);
|
||||
proof = parse_local_equations(p, fn);
|
||||
} else {
|
||||
proof = parse_proof(p, prop);
|
||||
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
|
||||
if (prev_local) {
|
||||
parser::local_scope scope(p);
|
||||
p.add_local(*prev_local);
|
||||
auto proof_pos = p.pos();
|
||||
proof = parse_proof(p, prop);
|
||||
proof = p.save_pos(Fun(*prev_local, proof), proof_pos);
|
||||
proof = p.save_pos(mk_app(proof, *prev_local), proof_pos);
|
||||
} else {
|
||||
proof = parse_proof(p, prop);
|
||||
}
|
||||
}
|
||||
p.check_token_next(get_comma_tk(), "invalid 'have/assert' declaration, ',' expected");
|
||||
parser::local_scope scope(p);
|
||||
|
@ -398,11 +413,16 @@ static expr parse_assert(parser & p, unsigned, expr const *, pos_info const & po
|
|||
static name * H_show = nullptr;
|
||||
static expr parse_show(parser & p, unsigned, expr const *, pos_info const & pos) {
|
||||
expr prop = p.parse_expr();
|
||||
p.check_token_next(get_comma_tk(), "invalid 'show' declaration, ',' expected");
|
||||
expr proof = parse_proof(p, prop);
|
||||
expr b = p.save_pos(mk_lambda(*H_show, prop, Var(0)), pos);
|
||||
expr r = p.mk_app(b, proof, pos);
|
||||
return p.save_pos(mk_show_annotation(r), pos);
|
||||
if (p.curr_is_token(get_bar_tk())) {
|
||||
expr fn = p.save_pos(mk_local(*H_show, prop), pos);
|
||||
return parse_local_equations(p, fn);
|
||||
} else {
|
||||
p.check_token_next(get_comma_tk(), "invalid 'show' declaration, ',' expected");
|
||||
expr proof = parse_proof(p, prop);
|
||||
expr b = p.save_pos(mk_lambda(*H_show, prop, Var(0)), pos);
|
||||
expr r = p.mk_app(b, proof, pos);
|
||||
return p.save_pos(mk_show_annotation(r), pos);
|
||||
}
|
||||
}
|
||||
|
||||
static expr parse_obtain(parser & p, unsigned, expr const *, pos_info const & pos) {
|
||||
|
|
|
@ -555,8 +555,8 @@ static expr merge_equation_lhs_vars(expr const & lhs, buffer<expr> & locals) {
|
|||
<< n << "' in the left-hand-side does not correspond to function(s) being defined", p);
|
||||
}
|
||||
|
||||
static bool is_eqn_prefix(parser & p) {
|
||||
return p.curr_is_token(get_bar_tk()) || p.curr_is_token(get_comma_tk());
|
||||
static bool is_eqn_prefix(parser & p, bool bar_only = false) {
|
||||
return p.curr_is_token(get_bar_tk()) || (!bar_only && p.curr_is_token(get_comma_tk()));
|
||||
}
|
||||
|
||||
static void check_eqn_prefix(parser & p) {
|
||||
|
@ -580,11 +580,67 @@ static expr get_equation_fn(buffer<expr> const & fns, name const & fn_name, pos_
|
|||
throw_invalid_equation_lhs(fn_name, lhs_pos);
|
||||
}
|
||||
|
||||
static void parse_equations_core(parser & p, buffer<expr> const & fns, buffer<expr> & eqns, bool bar_only = false) {
|
||||
for (expr const & fn : fns)
|
||||
p.add_local(fn);
|
||||
while (true) {
|
||||
expr lhs;
|
||||
unsigned prev_num_undef_ids = p.get_num_undef_ids();
|
||||
buffer<expr> locals;
|
||||
{
|
||||
parser::undef_id_to_local_scope scope2(p);
|
||||
buffer<expr> lhs_args;
|
||||
auto lhs_pos = p.pos();
|
||||
if (p.curr_is_token(get_explicit_tk())) {
|
||||
p.next();
|
||||
name fn_name = p.check_id_next("invalid recursive equation, identifier expected");
|
||||
lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos));
|
||||
} else {
|
||||
expr first = p.parse_expr(get_max_prec());
|
||||
expr fn = first;
|
||||
if (is_explicit(fn))
|
||||
fn = get_explicit_arg(fn);
|
||||
if (is_local(fn) && is_equation_fn(fns, local_pp_name(fn))) {
|
||||
lhs_args.push_back(first);
|
||||
} else if (fns.size() == 1) {
|
||||
lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos));
|
||||
lhs_args.push_back(first);
|
||||
} else {
|
||||
throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant",
|
||||
lhs_pos);
|
||||
}
|
||||
}
|
||||
while (!p.curr_is_token(get_assign_tk()))
|
||||
lhs_args.push_back(p.parse_expr(get_max_prec()));
|
||||
lhs = p.save_pos(mk_app(lhs_args.size(), lhs_args.data()), lhs_pos);
|
||||
|
||||
unsigned num_undef_ids = p.get_num_undef_ids();
|
||||
for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) {
|
||||
locals.push_back(p.get_undef_id(i));
|
||||
}
|
||||
}
|
||||
validate_equation_lhs(p, lhs, locals);
|
||||
lhs = merge_equation_lhs_vars(lhs, locals);
|
||||
auto assign_pos = p.pos();
|
||||
p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected");
|
||||
{
|
||||
parser::local_scope scope2(p);
|
||||
for (expr const & local : locals)
|
||||
p.add_local(local);
|
||||
expr rhs = p.parse_expr();
|
||||
eqns.push_back(Fun(fns, Fun(locals, p.save_pos(mk_equation(lhs, rhs), assign_pos), p)));
|
||||
}
|
||||
if (!is_eqn_prefix(p, bar_only))
|
||||
break;
|
||||
p.next();
|
||||
}
|
||||
}
|
||||
|
||||
expr parse_equations(parser & p, name const & n, expr const & type, buffer<name> & auxs,
|
||||
optional<local_environment> const & lenv, buffer<expr> const & ps,
|
||||
pos_info const & def_pos) {
|
||||
buffer<expr> eqns;
|
||||
buffer<expr> fns;
|
||||
buffer<expr> eqns;
|
||||
{
|
||||
parser::local_scope scope1(p, lenv);
|
||||
for (expr const & param : ps)
|
||||
|
@ -609,59 +665,7 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<name>
|
|||
p.next();
|
||||
eqns.push_back(Fun(fns, mk_no_equation(), p));
|
||||
} else {
|
||||
for (expr const & fn : fns)
|
||||
p.add_local(fn);
|
||||
while (true) {
|
||||
expr lhs;
|
||||
unsigned prev_num_undef_ids = p.get_num_undef_ids();
|
||||
buffer<expr> locals;
|
||||
{
|
||||
parser::undef_id_to_local_scope scope2(p);
|
||||
buffer<expr> lhs_args;
|
||||
auto lhs_pos = p.pos();
|
||||
if (p.curr_is_token(get_explicit_tk())) {
|
||||
p.next();
|
||||
name fn_name = p.check_id_next("invalid recursive equation, identifier expected");
|
||||
lhs_args.push_back(p.save_pos(mk_explicit(get_equation_fn(fns, fn_name, lhs_pos)), lhs_pos));
|
||||
} else {
|
||||
expr first = p.parse_expr(get_max_prec());
|
||||
expr fn = first;
|
||||
if (is_explicit(fn))
|
||||
fn = get_explicit_arg(fn);
|
||||
if (is_local(fn) && is_equation_fn(fns, local_pp_name(fn))) {
|
||||
lhs_args.push_back(first);
|
||||
} else if (fns.size() == 1) {
|
||||
lhs_args.push_back(p.save_pos(mk_explicit(fns[0]), lhs_pos));
|
||||
lhs_args.push_back(first);
|
||||
} else {
|
||||
throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant",
|
||||
lhs_pos);
|
||||
}
|
||||
}
|
||||
while (!p.curr_is_token(get_assign_tk()))
|
||||
lhs_args.push_back(p.parse_expr(get_max_prec()));
|
||||
lhs = p.save_pos(mk_app(lhs_args.size(), lhs_args.data()), lhs_pos);
|
||||
|
||||
unsigned num_undef_ids = p.get_num_undef_ids();
|
||||
for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) {
|
||||
locals.push_back(p.get_undef_id(i));
|
||||
}
|
||||
}
|
||||
validate_equation_lhs(p, lhs, locals);
|
||||
lhs = merge_equation_lhs_vars(lhs, locals);
|
||||
auto assign_pos = p.pos();
|
||||
p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected");
|
||||
{
|
||||
parser::local_scope scope2(p);
|
||||
for (expr const & local : locals)
|
||||
p.add_local(local);
|
||||
expr rhs = p.parse_expr();
|
||||
eqns.push_back(Fun(fns, Fun(locals, p.save_pos(mk_equation(lhs, rhs), assign_pos), p)));
|
||||
}
|
||||
if (!is_eqn_prefix(p))
|
||||
break;
|
||||
p.next();
|
||||
}
|
||||
parse_equations_core(p, fns, eqns);
|
||||
}
|
||||
}
|
||||
if (p.curr_is_token(get_wf_tk())) {
|
||||
|
@ -675,6 +679,19 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer<name>
|
|||
}
|
||||
}
|
||||
|
||||
/** \brief Parse a sequence of equations of the form <tt>| lhs := rhs</tt> */
|
||||
expr parse_local_equations(parser & p, expr const & fn) {
|
||||
lean_assert(p.curr_is_token(get_bar_tk()));
|
||||
auto pos = p.pos();
|
||||
p.next();
|
||||
buffer<expr> fns;
|
||||
buffer<expr> eqns;
|
||||
fns.push_back(fn);
|
||||
bool bar_only = true;
|
||||
parse_equations_core(p, fns, eqns, bar_only);
|
||||
return p.save_pos(mk_equations(fns.size(), eqns.size(), eqns.data()), pos);
|
||||
}
|
||||
|
||||
/** \brief Use equations compiler infrastructure to implement match-with */
|
||||
expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos) {
|
||||
expr t = p.parse_expr();
|
||||
|
|
|
@ -17,6 +17,8 @@ class parser;
|
|||
*/
|
||||
bool parse_univ_params(parser & p, buffer<name> & ps);
|
||||
expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos);
|
||||
expr parse_local_equations(parser & p, expr const & fn);
|
||||
|
||||
/** \brief Add universe levels from \c found_ls to \c ls_buffer (only the levels that do not already occur in \c ls_buffer are added).
|
||||
Then sort \c ls_buffer (using the order in which the universe levels were declared).
|
||||
*/
|
||||
|
|
67
tests/lean/run/local_eqns.lean
Normal file
67
tests/lean/run/local_eqns.lean
Normal file
|
@ -0,0 +1,67 @@
|
|||
import data.nat logic
|
||||
|
||||
open bool nat
|
||||
|
||||
check
|
||||
show nat → bool
|
||||
| 0 := tt
|
||||
| (n+1) := ff
|
||||
|
||||
definition mult : nat → nat → nat :=
|
||||
have plus : nat → nat → nat
|
||||
| 0 b := b
|
||||
| (succ a) b := succ (plus a b),
|
||||
have mult : nat → nat → nat
|
||||
| 0 b := 0
|
||||
| (succ a) b := plus (mult a b) b,
|
||||
mult
|
||||
|
||||
print definition mult
|
||||
|
||||
example : mult 3 7 = 21 := rfl
|
||||
|
||||
example : mult 8 7 = 56 := rfl
|
||||
|
||||
theorem add_eq_addl : ∀ x y, x + y = x ⊕ y
|
||||
| 0 0 := rfl
|
||||
| (succ x) 0 :=
|
||||
begin
|
||||
have addl_z : ∀ a : nat, a ⊕ 0 = a
|
||||
| 0 := rfl
|
||||
| (succ a) := calc
|
||||
(succ a) ⊕ 0 = succ (a ⊕ 0) : rfl
|
||||
... = succ a : addl_z,
|
||||
rewrite addl_z
|
||||
end
|
||||
| 0 (succ y) :=
|
||||
begin
|
||||
have z_add : ∀ a : nat, 0 + a = a
|
||||
| 0 := rfl
|
||||
| (succ a) :=
|
||||
begin
|
||||
rewrite ▸ succ(0 + a) = _,
|
||||
rewrite z_add
|
||||
end,
|
||||
rewrite z_add
|
||||
end
|
||||
| (succ x) (succ y) :=
|
||||
begin
|
||||
change (succ x + succ y = succ (x ⊕ succ y)),
|
||||
have s_add : ∀ a b : nat, succ a + b = succ (a + b)
|
||||
| 0 0 := rfl
|
||||
| (succ a) 0 := rfl
|
||||
| 0 (succ b) :=
|
||||
begin
|
||||
change (succ (succ 0 + b) = succ (succ (0 + b))),
|
||||
rewrite -(s_add 0 b)
|
||||
end
|
||||
| (succ a) (succ b) :=
|
||||
begin
|
||||
change (succ (succ (succ a) + b) = succ (succ (succ a + b))),
|
||||
apply (congr_arg succ),
|
||||
rewrite (s_add (succ a) b),
|
||||
end,
|
||||
rewrite [s_add, add_eq_addl]
|
||||
end
|
||||
|
||||
print definition add_eq_addl
|
14
tests/lean/run/local_eqns2.lean
Normal file
14
tests/lean/run/local_eqns2.lean
Normal file
|
@ -0,0 +1,14 @@
|
|||
import data.fin
|
||||
open fin nat
|
||||
|
||||
definition nz_cases_on {C : Π n, fin (succ n) → Type}
|
||||
(H₁ : Π n, C n (fz n))
|
||||
(H₂ : Π n (f : fin n), C n (fs f))
|
||||
{n : nat}
|
||||
(f : fin (succ n)) : C n f :=
|
||||
begin
|
||||
reverts (n, f),
|
||||
show ∀ (n : nat) (f : fin (succ n)), C n f
|
||||
| m (fz m) := by apply H₁
|
||||
| m (fs f') := by apply H₂
|
||||
end
|
Loading…
Reference in a new issue