diff --git a/src/frontends/lean/builtin_exprs.cpp b/src/frontends/lean/builtin_exprs.cpp index 42e7304a0..a43da1b5c 100644 --- a/src/frontends/lean/builtin_exprs.cpp +++ b/src/frontends/lean/builtin_exprs.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "library/placeholder.h" #include "frontends/lean/builtin_exprs.h" #include "frontends/lean/token_table.h" +#include "frontends/lean/calc.h" #include "frontends/lean/parser.h" namespace lean { @@ -139,6 +140,10 @@ static expr parse_show(parser & p, unsigned, expr const *, pos_info const & pos) return p.save_pos(mk_let(H_show, prop, proof, Var(0)), pos); } +static expr parse_calc_expr(parser & p, unsigned, expr const *, pos_info const &) { + return parse_calc(p); +} + parse_table init_nud_table() { action Expr(mk_expr_action()); action Skip(mk_skip_action()); @@ -154,6 +159,7 @@ parse_table init_nud_table() { r = r.add({transition("Pi", Binders), transition(",", mk_scoped_expr_action(x0, 0, false))}, x0); r = r.add({transition("Type", mk_ext_action(parse_Type))}, x0); r = r.add({transition("let", mk_ext_action(parse_let_expr))}, x0); + r = r.add({transition("calc", mk_ext_action(parse_calc_expr))}, x0); return r; } diff --git a/src/frontends/lean/calc.cpp b/src/frontends/lean/calc.cpp index 851f6bb9d..8bb771ae8 100644 --- a/src/frontends/lean/calc.cpp +++ b/src/frontends/lean/calc.cpp @@ -6,12 +6,17 @@ Author: Leonardo de Moura */ #include #include +#include +#include #include "util/optional.h" #include "util/name.h" #include "util/rb_map.h" #include "util/buffer.h" +#include "util/interrupt.h" #include "kernel/environment.h" #include "library/module.h" +#include "library/choice.h" +#include "library/placeholder.h" #include "frontends/lean/parser.h" namespace lean { @@ -96,8 +101,8 @@ environment add_calc_trans(environment const & env, name const & trans) { if (nargs < 5) throw exception("invalid calc transitivity rule, it must have at least 5 arguments"); name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application"); - name const & op1 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application"); - name const & op2 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application"); + name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application"); + name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application"); calc_ext ext = get_extension(env); ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); environment new_env = module::add(env, g_calc_trans_key, [=](serializer & s) { @@ -166,4 +171,150 @@ void register_calc_cmds(cmd_table & r) { add_cmd(r, cmd_info("calc_refl", "set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_refl_cmd)); add_cmd(r, cmd_info("calc_trans", "set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation", calc_trans_cmd)); } + +typedef std::tuple calc_pred; +typedef std::pair calc_step; +inline name const & pred_op(calc_pred const & p) { return std::get<0>(p); } +inline expr const & pred_lhs(calc_pred const & p) { return std::get<1>(p); } +inline expr const & pred_rhs(calc_pred const & p) { return std::get<2>(p); } +inline calc_pred const & step_pred(calc_step const & s) { return s.first; } +inline expr const & step_proof(calc_step const & s) { return s.second; } +static name g_lcurly("{"); +static name g_rcurly("}"); +static name g_ellipsis("..."); +static name g_colon(":"); + +static void decode_expr_core(expr const & e, buffer & preds) { + buffer args; + expr const & fn = get_app_args(e, args); + if (!is_constant(fn)) + return; + unsigned nargs = args.size(); + if (nargs < 2) + return; + preds.emplace_back(const_name(fn), args[nargs-2], args[nargs-1]); +} + +// Check whether e is of the form (f ...) where f is a constant. If it is return f. +static void decode_expr(expr const & e, buffer & preds, pos_info const & pos) { + preds.clear(); + if (is_choice(e)) { + for (unsigned i = 0; i < get_num_choices(e); i++) + decode_expr_core(get_choice(e, i), preds); + } else { + decode_expr_core(e, preds); + } + if (preds.empty()) + throw parser_error("invalid 'calc' expression, expression must be a function application 'f a_1 ... a_k' " + "where f is a constant, and k >= 2", pos); +} + +// Create (op _ _ ... _) +static expr mk_op_fn(parser & p, name const & op, unsigned num_placeholders, pos_info const & pos) { + expr r = p.save_pos(mk_constant(op), pos); + while (num_placeholders > 0) { + num_placeholders--; + r = p.mk_app(r, p.save_pos(mk_expr_placeholder(), pos), pos); + } + return r; +} + +static void parse_calc_proof(parser & p, buffer const & preds, std::vector & steps) { + steps.clear(); + auto pos = p.pos(); + p.check_token_next(g_colon, "invalid 'calc' expression, ':' expected"); + if (p.curr_is_token(g_lcurly)) { + p.next(); + expr pr = p.parse_expr(); + p.check_token_next(g_rcurly, "invalid 'calc' expression, '}' expected"); + calc_ext const & ext = get_extension(p.env()); + if (!ext.m_subst) + throw parser_error("invalid 'calc' expression, substitution rule was not defined with calc_subst command", pos); + for (auto const & pred : preds) { + auto refl_it = ext.m_refl_table.find(pred_op(pred)); + if (refl_it) { + expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos); + expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos); + expr subst = mk_op_fn(p, *ext.m_subst, ext.m_subst_num_args-2, pos); + expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos); + steps.emplace_back(pred, subst_pr); + } + } + if (steps.empty()) + throw parser_error("invalid 'calc' expression, reflexivity rule is not defined for operator", pos); + } else { + expr pr = p.parse_expr(); + for (auto const & pred : preds) + steps.emplace_back(pred, pr); + } +} + +/** \brief Collect distinct rhs's */ +static void collect_rhss(std::vector const & steps, buffer & rhss) { + rhss.clear(); + for (auto const & step : steps) { + calc_pred const & pred = step_pred(step); + expr const & rhs = pred_rhs(pred); + if (std::find(rhss.begin(), rhss.end(), rhs) == rhss.end()) + rhss.push_back(rhs); + } + lean_assert(!rhss.empty()); +} + +static void join(parser & p, std::vector const & steps1, std::vector const & steps2, std::vector & res_steps, + pos_info const & pos) { + res_steps.clear(); + calc_ext const & ext = get_extension(p.env()); + for (calc_step const & s1 : steps1) { + check_interrupted(); + calc_pred const & pred1 = step_pred(s1); + expr const & pr1 = step_proof(s1); + for (calc_step const & s2 : steps2) { + calc_pred const & pred2 = step_pred(s2); + expr const & pr2 = step_proof(s2); + if (!is_eqp(pred_rhs(pred1), pred_lhs(pred2))) + continue; + auto trans_it = ext.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2))); + if (!trans_it) + continue; + expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos); + expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos); + res_steps.emplace_back(calc_pred(std::get<1>(*trans_it), pred_lhs(pred1), pred_rhs(pred2)), trans_pr); + } + } +} + +expr parse_calc(parser & p) { + buffer preds, new_preds; + buffer rhss; + std::vector steps, new_steps, next_steps; + auto pos = p.pos(); + decode_expr(p.parse_expr(), preds, pos); + parse_calc_proof(p, preds, steps); + expr dummy = mk_expr_placeholder(); + while (p.curr_is_token(g_ellipsis)) { + pos = p.pos(); + p.next(); + decode_expr(p.parse_led(dummy), preds, pos); + collect_rhss(steps, rhss); + new_steps.clear(); + for (auto const & pred : preds) { + if (is_eqp(pred_lhs(pred), dummy)) { + for (expr const & rhs : rhss) + new_preds.emplace_back(pred_op(pred), rhs, pred_rhs(pred)); + } + } + if (new_preds.empty()) + throw parser_error("invalid 'calc' expression, invalid expression", pos); + parse_calc_proof(p, new_preds, new_steps); + join(p, steps, new_steps, next_steps, pos); + if (next_steps.empty()) + throw parser_error("invalid 'calc' expression, transitivity rule is not defined for current step", pos); + steps.swap(next_steps); + } + buffer choices; + for (auto const & s : steps) + choices.push_back(step_proof(s)); + return mk_choice(choices.size(), choices.data()); +} } diff --git a/src/frontends/lean/calc.h b/src/frontends/lean/calc.h index cfed35885..8b08c00b1 100644 --- a/src/frontends/lean/calc.h +++ b/src/frontends/lean/calc.h @@ -7,5 +7,7 @@ Author: Leonardo de Moura #pragma once #include "frontends/lean/cmd_table.h" namespace lean { +class parser; void register_calc_cmds(cmd_table & r); +expr parse_calc(parser & p); } diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index adfb7cdee..d6d9df29f 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -281,6 +281,17 @@ expr parser::mk_app(expr fn, expr arg, pos_info const & p) { return save_pos(::lean::mk_app(fn, arg), p); } +expr parser::mk_app(std::initializer_list const & args, pos_info const & p) { + unsigned nargs = args.size(); + lean_assert(nargs >= 2); + auto it = args.begin(); + expr r = *it; + it++; + for (; it != args.end(); it++) + r = mk_app(r, *it, p); + return r; +} + void parser::push_local_scope() { if (m_type_use_placeholder) m_local_level_decls.push(); diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index c8691b3fc..a9ce8ffd9 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -124,7 +124,6 @@ class parser { parameter parse_binder_core(binder_info const & bi); void parse_binder_block(buffer & r, binder_info const & bi); void parse_binders_core(buffer & r); - expr mk_app(expr fn, expr arg, pos_info const & p); friend environment section_cmd(parser & p); friend environment end_scoped_cmd(parser & p); @@ -159,6 +158,9 @@ public: pos_info cmd_pos() const { return m_last_cmd_pos; } void set_line(unsigned p) { return m_scanner.set_line(p); } + expr mk_app(expr fn, expr arg, pos_info const & p); + expr mk_app(std::initializer_list const & args, pos_info const & p); + /** \brief Read the next token. */ void scan() { m_curr = m_scanner.scan(m_env); } /** \brief Return the current token */ diff --git a/tests/lean/calc1.lean b/tests/lean/calc1.lean index b013ff647..ccac41090 100644 --- a/tests/lean/calc1.lean +++ b/tests/lean/calc1.lean @@ -2,17 +2,48 @@ variable A : Type.{1} definition [inline] bool : Type.{1} := Type.{0} variable eq : A → A → bool infixl `=` 50 := eq -variable subst (P : A → bool) (a b : A) (H1 : a = b) (H2 : P a) : P b -variable eq_trans (a b c : A) (H1 : a = b) (H2 : b = c) : a = c -variable eq_refl (a : A) : a = a +axiom subst (P : A → bool) (a b : A) (H1 : a = b) (H2 : P a) : P b +axiom eq_trans (a b c : A) (H1 : a = b) (H2 : b = c) : a = c +axiom eq_refl (a : A) : a = a variable le : A → A → bool infixl `≤` 50 := le -variable le_trans (a b c : A) (H1 : a ≤ b) (H2 : b ≤ c) : a ≤ c -variable le_refl (a : A) : a ≤ a -variable eq_le_trans (a b c : A) (H1 : a = b) (H2 : b ≤ c) : a ≤ c +axiom le_trans (a b c : A) (H1 : a ≤ b) (H2 : b ≤ c) : a ≤ c +axiom le_refl (a : A) : a ≤ a +axiom eq_le_trans (a b c : A) (H1 : a = b) (H2 : b ≤ c) : a ≤ c +axiom le_eq_trans (a b c : A) (H1 : a ≤ b) (H2 : b = c) : a ≤ c calc_subst subst calc_refl eq_refl calc_refl le_refl calc_trans eq_trans calc_trans le_trans calc_trans eq_le_trans +calc_trans le_eq_trans +variables a b c d e f : A +axiom H1 : a = b +axiom H2 : b ≤ c +axiom H3 : c ≤ d +axiom H4 : d = e +check calc a = b : H1 + ... ≤ c : H2 + ... ≤ d : H3 + ... = e : H4 + +variable lt : A → A → bool +infixl `<` 50 := lt +axiom lt_trans (a b c : A) (H1 : a < b) (H2 : b < c) : a < c +axiom le_lt_trans (a b c : A) (H1 : a ≤ b) (H2 : b < c) : a < c +axiom lt_le_trans (a b c : A) (H1 : a < b) (H2 : b ≤ c) : a < c +axiom H5 : c < d +check calc b ≤ c : H2 + ... < d : H5 -- Error le_lt_trans was not registered yet +calc_trans le_lt_trans +check calc b ≤ c : H2 + ... < d : H5 + +variable le2 : A → A → bool +infixl `≤` 50 := le2 +variable le2_trans (a b c : A) (H1 : le2 a b) (H2 : le2 b c) : le2 a c +calc_trans le2_trans +print raw calc b ≤ c : H2 + ... ≤ d : H3 + ... ≤ e : H4 diff --git a/tests/lean/calc1.lean.expected.out b/tests/lean/calc1.lean.expected.out index e69de29bb..eef23965f 100644 --- a/tests/lean/calc1.lean.expected.out +++ b/tests/lean/calc1.lean.expected.out @@ -0,0 +1,4 @@ +le_eq_trans a d e (le_trans a c d (eq_le_trans a b c H1 H2) H3) H4 : le a e +calc1.lean:38:10: error: invalid 'calc' expression, transitivity rule is not defined for current step +le_lt_trans b c d H2 H5 : lt b d +choice (le2_trans b d e (le2_trans b c d H2 H3) H4) (le_trans b d e (le_trans b c d H2 H3) H4)