feat(frontends/lean/calc): add parse_calc function

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-06-17 17:15:38 -07:00
parent 037cfcf622
commit 4cbc429192
7 changed files with 216 additions and 9 deletions

View file

@ -8,6 +8,7 @@ Author: Leonardo de Moura
#include "library/placeholder.h" #include "library/placeholder.h"
#include "frontends/lean/builtin_exprs.h" #include "frontends/lean/builtin_exprs.h"
#include "frontends/lean/token_table.h" #include "frontends/lean/token_table.h"
#include "frontends/lean/calc.h"
#include "frontends/lean/parser.h" #include "frontends/lean/parser.h"
namespace lean { namespace lean {
@ -139,6 +140,10 @@ static expr parse_show(parser & p, unsigned, expr const *, pos_info const & pos)
return p.save_pos(mk_let(H_show, prop, proof, Var(0)), pos); return p.save_pos(mk_let(H_show, prop, proof, Var(0)), pos);
} }
static expr parse_calc_expr(parser & p, unsigned, expr const *, pos_info const &) {
return parse_calc(p);
}
parse_table init_nud_table() { parse_table init_nud_table() {
action Expr(mk_expr_action()); action Expr(mk_expr_action());
action Skip(mk_skip_action()); action Skip(mk_skip_action());
@ -154,6 +159,7 @@ parse_table init_nud_table() {
r = r.add({transition("Pi", Binders), transition(",", mk_scoped_expr_action(x0, 0, false))}, x0); r = r.add({transition("Pi", Binders), transition(",", mk_scoped_expr_action(x0, 0, false))}, x0);
r = r.add({transition("Type", mk_ext_action(parse_Type))}, x0); r = r.add({transition("Type", mk_ext_action(parse_Type))}, x0);
r = r.add({transition("let", mk_ext_action(parse_let_expr))}, x0); r = r.add({transition("let", mk_ext_action(parse_let_expr))}, x0);
r = r.add({transition("calc", mk_ext_action(parse_calc_expr))}, x0);
return r; return r;
} }

View file

@ -6,12 +6,17 @@ Author: Leonardo de Moura
*/ */
#include <string> #include <string>
#include <utility> #include <utility>
#include <algorithm>
#include <vector>
#include "util/optional.h" #include "util/optional.h"
#include "util/name.h" #include "util/name.h"
#include "util/rb_map.h" #include "util/rb_map.h"
#include "util/buffer.h" #include "util/buffer.h"
#include "util/interrupt.h"
#include "kernel/environment.h" #include "kernel/environment.h"
#include "library/module.h" #include "library/module.h"
#include "library/choice.h"
#include "library/placeholder.h"
#include "frontends/lean/parser.h" #include "frontends/lean/parser.h"
namespace lean { namespace lean {
@ -96,8 +101,8 @@ environment add_calc_trans(environment const & env, name const & trans) {
if (nargs < 5) if (nargs < 5)
throw exception("invalid calc transitivity rule, it must have at least 5 arguments"); throw exception("invalid calc transitivity rule, it must have at least 5 arguments");
name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application"); name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application");
name const & op1 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application"); name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application");
name const & op2 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application"); name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
calc_ext ext = get_extension(env); calc_ext ext = get_extension(env);
ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
environment new_env = module::add(env, g_calc_trans_key, [=](serializer & s) { environment new_env = module::add(env, g_calc_trans_key, [=](serializer & s) {
@ -166,4 +171,150 @@ void register_calc_cmds(cmd_table & r) {
add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd)); add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd));
add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd)); add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd));
} }
typedef std::tuple<name, expr, expr> calc_pred;
typedef std::pair<calc_pred, expr> calc_step;
inline name const & pred_op(calc_pred const & p) { return std::get<0>(p); }
inline expr const & pred_lhs(calc_pred const & p) { return std::get<1>(p); }
inline expr const & pred_rhs(calc_pred const & p) { return std::get<2>(p); }
inline calc_pred const & step_pred(calc_step const & s) { return s.first; }
inline expr const & step_proof(calc_step const & s) { return s.second; }
static name g_lcurly("{");
static name g_rcurly("}");
static name g_ellipsis("...");
static name g_colon(":");
static void decode_expr_core(expr const & e, buffer<calc_pred> & preds) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (!is_constant(fn))
return;
unsigned nargs = args.size();
if (nargs < 2)
return;
preds.emplace_back(const_name(fn), args[nargs-2], args[nargs-1]);
}
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
static void decode_expr(expr const & e, buffer<calc_pred> & preds, pos_info const & pos) {
preds.clear();
if (is_choice(e)) {
for (unsigned i = 0; i < get_num_choices(e); i++)
decode_expr_core(get_choice(e, i), preds);
} else {
decode_expr_core(e, preds);
}
if (preds.empty())
throw parser_error("invalid 'calc' expression, expression must be a function application 'f a_1 ... a_k' "
"where f is a constant, and k >= 2", pos);
}
// Create (op _ _ ... _)
static expr mk_op_fn(parser & p, name const & op, unsigned num_placeholders, pos_info const & pos) {
expr r = p.save_pos(mk_constant(op), pos);
while (num_placeholders > 0) {
num_placeholders--;
r = p.mk_app(r, p.save_pos(mk_expr_placeholder(), pos), pos);
}
return r;
}
static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::vector<calc_step> & steps) {
steps.clear();
auto pos = p.pos();
p.check_token_next(g_colon, "invalid 'calc' expression, ':' expected");
if (p.curr_is_token(g_lcurly)) {
p.next();
expr pr = p.parse_expr();
p.check_token_next(g_rcurly, "invalid 'calc' expression, '}' expected");
calc_ext const & ext = get_extension(p.env());
if (!ext.m_subst)
throw parser_error("invalid 'calc' expression, substitution rule was not defined with calc_subst command", pos);
for (auto const & pred : preds) {
auto refl_it = ext.m_refl_table.find(pred_op(pred));
if (refl_it) {
expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos);
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
expr subst = mk_op_fn(p, *ext.m_subst, ext.m_subst_num_args-2, pos);
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
steps.emplace_back(pred, subst_pr);
}
}
if (steps.empty())
throw parser_error("invalid 'calc' expression, reflexivity rule is not defined for operator", pos);
} else {
expr pr = p.parse_expr();
for (auto const & pred : preds)
steps.emplace_back(pred, pr);
}
}
/** \brief Collect distinct rhs's */
static void collect_rhss(std::vector<calc_step> const & steps, buffer<expr> & rhss) {
rhss.clear();
for (auto const & step : steps) {
calc_pred const & pred = step_pred(step);
expr const & rhs = pred_rhs(pred);
if (std::find(rhss.begin(), rhss.end(), rhs) == rhss.end())
rhss.push_back(rhs);
}
lean_assert(!rhss.empty());
}
static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<calc_step> const & steps2, std::vector<calc_step> & res_steps,
pos_info const & pos) {
res_steps.clear();
calc_ext const & ext = get_extension(p.env());
for (calc_step const & s1 : steps1) {
check_interrupted();
calc_pred const & pred1 = step_pred(s1);
expr const & pr1 = step_proof(s1);
for (calc_step const & s2 : steps2) {
calc_pred const & pred2 = step_pred(s2);
expr const & pr2 = step_proof(s2);
if (!is_eqp(pred_rhs(pred1), pred_lhs(pred2)))
continue;
auto trans_it = ext.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2)));
if (!trans_it)
continue;
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos);
expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos);
res_steps.emplace_back(calc_pred(std::get<1>(*trans_it), pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
}
}
}
expr parse_calc(parser & p) {
buffer<calc_pred> preds, new_preds;
buffer<expr> rhss;
std::vector<calc_step> steps, new_steps, next_steps;
auto pos = p.pos();
decode_expr(p.parse_expr(), preds, pos);
parse_calc_proof(p, preds, steps);
expr dummy = mk_expr_placeholder();
while (p.curr_is_token(g_ellipsis)) {
pos = p.pos();
p.next();
decode_expr(p.parse_led(dummy), preds, pos);
collect_rhss(steps, rhss);
new_steps.clear();
for (auto const & pred : preds) {
if (is_eqp(pred_lhs(pred), dummy)) {
for (expr const & rhs : rhss)
new_preds.emplace_back(pred_op(pred), rhs, pred_rhs(pred));
}
}
if (new_preds.empty())
throw parser_error("invalid 'calc' expression, invalid expression", pos);
parse_calc_proof(p, new_preds, new_steps);
join(p, steps, new_steps, next_steps, pos);
if (next_steps.empty())
throw parser_error("invalid 'calc' expression, transitivity rule is not defined for current step", pos);
steps.swap(next_steps);
}
buffer<expr> choices;
for (auto const & s : steps)
choices.push_back(step_proof(s));
return mk_choice(choices.size(), choices.data());
}
} }

View file

@ -7,5 +7,7 @@ Author: Leonardo de Moura
#pragma once #pragma once
#include "frontends/lean/cmd_table.h" #include "frontends/lean/cmd_table.h"
namespace lean { namespace lean {
class parser;
void register_calc_cmds(cmd_table & r); void register_calc_cmds(cmd_table & r);
expr parse_calc(parser & p);
} }

View file

@ -281,6 +281,17 @@ expr parser::mk_app(expr fn, expr arg, pos_info const & p) {
return save_pos(::lean::mk_app(fn, arg), p); return save_pos(::lean::mk_app(fn, arg), p);
} }
expr parser::mk_app(std::initializer_list<expr> const & args, pos_info const & p) {
unsigned nargs = args.size();
lean_assert(nargs >= 2);
auto it = args.begin();
expr r = *it;
it++;
for (; it != args.end(); it++)
r = mk_app(r, *it, p);
return r;
}
void parser::push_local_scope() { void parser::push_local_scope() {
if (m_type_use_placeholder) if (m_type_use_placeholder)
m_local_level_decls.push(); m_local_level_decls.push();

View file

@ -124,7 +124,6 @@ class parser {
parameter parse_binder_core(binder_info const & bi); parameter parse_binder_core(binder_info const & bi);
void parse_binder_block(buffer<parameter> & r, binder_info const & bi); void parse_binder_block(buffer<parameter> & r, binder_info const & bi);
void parse_binders_core(buffer<parameter> & r); void parse_binders_core(buffer<parameter> & r);
expr mk_app(expr fn, expr arg, pos_info const & p);
friend environment section_cmd(parser & p); friend environment section_cmd(parser & p);
friend environment end_scoped_cmd(parser & p); friend environment end_scoped_cmd(parser & p);
@ -159,6 +158,9 @@ public:
pos_info cmd_pos() const { return m_last_cmd_pos; } pos_info cmd_pos() const { return m_last_cmd_pos; }
void set_line(unsigned p) { return m_scanner.set_line(p); } void set_line(unsigned p) { return m_scanner.set_line(p); }
expr mk_app(expr fn, expr arg, pos_info const & p);
expr mk_app(std::initializer_list<expr> const & args, pos_info const & p);
/** \brief Read the next token. */ /** \brief Read the next token. */
void scan() { m_curr = m_scanner.scan(m_env); } void scan() { m_curr = m_scanner.scan(m_env); }
/** \brief Return the current token */ /** \brief Return the current token */

View file

@ -2,17 +2,48 @@ variable A : Type.{1}
definition [inline] bool : Type.{1} := Type.{0} definition [inline] bool : Type.{1} := Type.{0}
variable eq : A → A → bool variable eq : A → A → bool
infixl `=` 50 := eq infixl `=` 50 := eq
variable subst (P : A → bool) (a b : A) (H1 : a = b) (H2 : P a) : P b axiom subst (P : A → bool) (a b : A) (H1 : a = b) (H2 : P a) : P b
variable eq_trans (a b c : A) (H1 : a = b) (H2 : b = c) : a = c axiom eq_trans (a b c : A) (H1 : a = b) (H2 : b = c) : a = c
variable eq_refl (a : A) : a = a axiom eq_refl (a : A) : a = a
variable le : A → A → bool variable le : A → A → bool
infixl `≤` 50 := le infixl `≤` 50 := le
variable le_trans (a b c : A) (H1 : a ≤ b) (H2 : b ≤ c) : a ≤ c axiom le_trans (a b c : A) (H1 : a ≤ b) (H2 : b ≤ c) : a ≤ c
variable le_refl (a : A) : a ≤ a axiom le_refl (a : A) : a ≤ a
variable eq_le_trans (a b c : A) (H1 : a = b) (H2 : b ≤ c) : a ≤ c axiom eq_le_trans (a b c : A) (H1 : a = b) (H2 : b ≤ c) : a ≤ c
axiom le_eq_trans (a b c : A) (H1 : a ≤ b) (H2 : b = c) : a ≤ c
calc_subst subst calc_subst subst
calc_refl eq_refl calc_refl eq_refl
calc_refl le_refl calc_refl le_refl
calc_trans eq_trans calc_trans eq_trans
calc_trans le_trans calc_trans le_trans
calc_trans eq_le_trans calc_trans eq_le_trans
calc_trans le_eq_trans
variables a b c d e f : A
axiom H1 : a = b
axiom H2 : b ≤ c
axiom H3 : c ≤ d
axiom H4 : d = e
check calc a = b : H1
... ≤ c : H2
... ≤ d : H3
... = e : H4
variable lt : A → A → bool
infixl `<` 50 := lt
axiom lt_trans (a b c : A) (H1 : a < b) (H2 : b < c) : a < c
axiom le_lt_trans (a b c : A) (H1 : a ≤ b) (H2 : b < c) : a < c
axiom lt_le_trans (a b c : A) (H1 : a < b) (H2 : b ≤ c) : a < c
axiom H5 : c < d
check calc b ≤ c : H2
... < d : H5 -- Error le_lt_trans was not registered yet
calc_trans le_lt_trans
check calc b ≤ c : H2
... < d : H5
variable le2 : A → A → bool
infixl `≤` 50 := le2
variable le2_trans (a b c : A) (H1 : le2 a b) (H2 : le2 b c) : le2 a c
calc_trans le2_trans
print raw calc b ≤ c : H2
... ≤ d : H3
... ≤ e : H4

View file

@ -0,0 +1,4 @@
le_eq_trans a d e (le_trans a c d (eq_le_trans a b c H1 H2) H3) H4 : le a e
calc1.lean:38:10: error: invalid 'calc' expression, transitivity rule is not defined for current step
le_lt_trans b c d H2 H5 : lt b d
choice (le2_trans b d e (le2_trans b c d H2 H3) H4) (le_trans b d e (le_trans b c d H2 H3) H4)