fix(frontends/lean/calc): allow calc_subst to be defined for multiple operators, allow calc cmds to be organized into namespaces, fixes #65
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
2f699fa53a
commit
df60ab4ada
4 changed files with 99 additions and 138 deletions
|
@ -57,8 +57,6 @@ num_rec zero
|
|||
-- Successor and predecessor
|
||||
-- -------------------------
|
||||
|
||||
-- TODO: this looks like a calc bug -- calc is using subst for iff, instead of =
|
||||
calc_subst subst
|
||||
theorem succ_ne_zero (n : ℕ) : succ n ≠ 0 :=
|
||||
assume H : succ n = 0,
|
||||
have H2 : true = false, from
|
||||
|
|
|
@ -8,10 +8,6 @@ import tools.fake_simplifier
|
|||
using prod eq_ops
|
||||
using fake_simplifier
|
||||
|
||||
-- TODO: calc bug -- remove
|
||||
calc_subst subst
|
||||
|
||||
|
||||
namespace quotient
|
||||
|
||||
-- auxliary facts about products
|
||||
|
|
|
@ -147,8 +147,4 @@ calc
|
|||
... ↔ (a ∨ c) ∨ b : iff_symm (or_assoc _ _ _)
|
||||
|
||||
-- TODO: add or_left_comm, and_right_comm, and_left_comm
|
||||
|
||||
-- TODO: this is only temporary, until the calc bug is fixed
|
||||
calc_subst subst
|
||||
|
||||
end relation
|
||||
|
|
|
@ -18,31 +18,17 @@ Author: Leonardo de Moura
|
|||
#include "library/choice.h"
|
||||
#include "library/placeholder.h"
|
||||
#include "library/explicit.h"
|
||||
#include "library/scoped_ext.h"
|
||||
#include "frontends/lean/parser.h"
|
||||
#include "frontends/lean/util.h"
|
||||
|
||||
namespace lean {
|
||||
struct calc_ext : public environment_extension {
|
||||
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> refl_table;
|
||||
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
|
||||
optional<name> m_subst;
|
||||
unsigned m_subst_num_args;
|
||||
trans_table m_trans_table;
|
||||
refl_table m_refl_table;
|
||||
calc_ext():m_subst_num_args(0) {}
|
||||
};
|
||||
|
||||
struct calc_ext_reg {
|
||||
unsigned m_ext_id;
|
||||
calc_ext_reg() { m_ext_id = environment::register_extension(std::make_shared<calc_ext>()); }
|
||||
};
|
||||
|
||||
static calc_ext_reg g_ext;
|
||||
static calc_ext const & get_extension(environment const & env) {
|
||||
return static_cast<calc_ext const &>(env.get_extension(g_ext.m_ext_id));
|
||||
}
|
||||
static environment update(environment const & env, calc_ext const & ext) {
|
||||
return env.update(g_ext.m_ext_id, std::make_shared<calc_ext>(ext));
|
||||
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
|
||||
static name const & get_fn_const(expr const & e, char const * msg) {
|
||||
expr const & fn = get_app_fn(e);
|
||||
if (!is_constant(fn))
|
||||
throw exception(msg);
|
||||
return const_name(fn);
|
||||
}
|
||||
|
||||
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
|
||||
|
@ -54,118 +40,104 @@ static expr extract_arg_types(environment const & env, name const & f, buffer<ex
|
|||
return f_type;
|
||||
}
|
||||
|
||||
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
|
||||
static name const & get_fn_const(expr const & e, char const * msg) {
|
||||
expr const & fn = get_app_fn(e);
|
||||
if (!is_constant(fn))
|
||||
throw exception(msg);
|
||||
return const_name(fn);
|
||||
}
|
||||
enum class calc_cmd { Subst, Trans, Refl };
|
||||
|
||||
static std::string g_calc_subst_key("calcs");
|
||||
static std::string g_calc_refl_key("calcr");
|
||||
static std::string g_calc_trans_key("calct");
|
||||
struct calc_entry {
|
||||
calc_cmd m_cmd;
|
||||
name m_name;
|
||||
calc_entry() {}
|
||||
calc_entry(calc_cmd c, name const & n):m_cmd(c), m_name(n) {}
|
||||
};
|
||||
|
||||
environment add_calc_subst(environment const & env, name const & subst) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, subst, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 2)
|
||||
throw exception("invalid calc substitution theorem, it must have at least 2 arguments");
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_subst = subst;
|
||||
ext.m_subst_num_args = nargs;
|
||||
environment new_env = module::add(env, g_calc_subst_key, [=](serializer & s) {
|
||||
s << subst << nargs;
|
||||
});
|
||||
return update(new_env, ext);
|
||||
}
|
||||
struct calc_state {
|
||||
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> refl_table;
|
||||
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> subst_table;
|
||||
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
|
||||
trans_table m_trans_table;
|
||||
refl_table m_refl_table;
|
||||
subst_table m_subst_table;
|
||||
calc_state() {}
|
||||
|
||||
environment add_calc_refl(environment const & env, name const & refl) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, refl, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 1)
|
||||
throw exception("invalid calc reflexivity rule, it must have at least 1 argument");
|
||||
name const & rop = get_fn_const(r_type, "invalid calc reflexivity rule, result type must be an operator application");
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_refl_table.insert(rop, mk_pair(refl, nargs));
|
||||
environment new_env = module::add(env, g_calc_refl_key, [=](serializer & s) {
|
||||
s << rop << refl << nargs;
|
||||
});
|
||||
return update(new_env, ext);
|
||||
}
|
||||
|
||||
environment add_calc_trans(environment const & env, name const & trans) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, trans, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 5)
|
||||
throw exception("invalid calc transitivity rule, it must have at least 5 arguments");
|
||||
name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application");
|
||||
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application");
|
||||
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
|
||||
environment new_env = module::add(env, g_calc_trans_key, [=](serializer & s) {
|
||||
s << op1 << op2 << trans << rop << nargs;
|
||||
});
|
||||
return update(new_env, ext);
|
||||
}
|
||||
void add_calc_subst(environment const & env, name const & subst) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, subst, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 2)
|
||||
throw exception("invalid calc substitution theorem, it must have at least 2 arguments");
|
||||
name const & rop = get_fn_const(arg_types[nargs-2], "invalid calc substitution theorem, argument penultimate argument must be an operator application");
|
||||
m_subst_table.insert(rop, mk_pair(subst, nargs));
|
||||
}
|
||||
|
||||
static void calc_subst_reader(deserializer & d, module_idx, shared_environment &,
|
||||
std::function<void(asynch_update_fn const &)> &,
|
||||
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
|
||||
name subst; unsigned nargs;
|
||||
d >> subst >> nargs;
|
||||
add_delayed_update([=](environment const & env, io_state const &) -> environment {
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_subst = subst;
|
||||
ext.m_subst_num_args = nargs;
|
||||
return update(env, ext);
|
||||
});
|
||||
}
|
||||
register_module_object_reader_fn g_calc_subst_reader(g_calc_subst_key, calc_subst_reader);
|
||||
void add_calc_refl(environment const & env, name const & refl) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, refl, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 1)
|
||||
throw exception("invalid calc reflexivity rule, it must have at least 1 argument");
|
||||
name const & rop = get_fn_const(r_type, "invalid calc reflexivity rule, result type must be an operator application");
|
||||
m_refl_table.insert(rop, mk_pair(refl, nargs));
|
||||
}
|
||||
|
||||
static void calc_refl_reader(deserializer & d, module_idx, shared_environment &,
|
||||
std::function<void(asynch_update_fn const &)> &,
|
||||
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
|
||||
name rop, refl; unsigned nargs;
|
||||
d >> rop >> refl >> nargs;
|
||||
add_delayed_update([=](environment const & env, io_state const &) -> environment {
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_refl_table.insert(rop, mk_pair(refl, nargs));
|
||||
return update(env, ext);
|
||||
});
|
||||
}
|
||||
register_module_object_reader_fn g_calc_refl_reader(g_calc_refl_key, calc_refl_reader);
|
||||
void add_calc_trans(environment const & env, name const & trans) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, trans, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 5)
|
||||
throw exception("invalid calc transitivity rule, it must have at least 5 arguments");
|
||||
name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application");
|
||||
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application");
|
||||
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
|
||||
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
|
||||
}
|
||||
};
|
||||
|
||||
static void calc_trans_reader(deserializer & d, module_idx, shared_environment &,
|
||||
std::function<void(asynch_update_fn const &)> &,
|
||||
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
|
||||
name op1, op2, trans, rop; unsigned nargs;
|
||||
d >> op1 >> op2 >> trans >> rop >> nargs;
|
||||
add_delayed_update([=](environment const & env, io_state const &) -> environment {
|
||||
calc_ext ext = get_extension(env);
|
||||
ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
|
||||
return update(env, ext);
|
||||
});
|
||||
}
|
||||
register_module_object_reader_fn g_calc_trans_reader(g_calc_trans_key, calc_trans_reader);
|
||||
struct calc_config {
|
||||
typedef calc_state state;
|
||||
typedef calc_entry entry;
|
||||
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
|
||||
switch (e.m_cmd) {
|
||||
case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break;
|
||||
case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break;
|
||||
case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break;
|
||||
}
|
||||
}
|
||||
static name const & get_class_name() {
|
||||
static name g_calc_name("calc");
|
||||
return g_calc_name;
|
||||
}
|
||||
static std::string const & get_serialization_key() {
|
||||
static std::string g_key("calc");
|
||||
return g_key;
|
||||
}
|
||||
static void write_entry(serializer & s, entry const & e) {
|
||||
s << static_cast<char>(e.m_cmd) << e.m_name;
|
||||
}
|
||||
static entry read_entry(deserializer & d) {
|
||||
entry e;
|
||||
char cmd;
|
||||
d >> cmd >> e.m_name;
|
||||
e.m_cmd = static_cast<calc_cmd>(cmd);
|
||||
return e;
|
||||
}
|
||||
};
|
||||
|
||||
template class scoped_ext<calc_config>;
|
||||
typedef scoped_ext<calc_config> calc_ext;
|
||||
|
||||
environment calc_subst_cmd(parser & p) {
|
||||
name id = p.check_constant_next("invalid 'calc_subst' command, constant expected");
|
||||
return add_calc_subst(p.env(), id);
|
||||
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Subst, id));
|
||||
}
|
||||
|
||||
environment calc_refl_cmd(parser & p) {
|
||||
name id = p.check_constant_next("invalid 'calc_refl' command, constant expected");
|
||||
return add_calc_refl(p.env(), id);
|
||||
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Refl, id));
|
||||
}
|
||||
|
||||
environment calc_trans_cmd(parser & p) {
|
||||
name id = p.check_constant_next("invalid 'calc_trans' command, constant expected");
|
||||
return add_calc_trans(p.env(), id);
|
||||
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id));
|
||||
}
|
||||
|
||||
void register_calc_cmds(cmd_table & r) {
|
||||
|
@ -229,21 +201,20 @@ static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::v
|
|||
p.next();
|
||||
expr pr = p.parse_expr();
|
||||
p.check_token_next(g_rcurly, "invalid 'calc' expression, '}' expected");
|
||||
calc_ext const & ext = get_extension(p.env());
|
||||
if (!ext.m_subst)
|
||||
throw parser_error("invalid 'calc' expression, substitution rule was not defined with calc_subst command", pos);
|
||||
calc_state const & state = calc_ext::get_state(p.env());
|
||||
for (auto const & pred : preds) {
|
||||
auto refl_it = ext.m_refl_table.find(pred_op(pred));
|
||||
if (refl_it) {
|
||||
expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos);
|
||||
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
|
||||
expr subst = mk_op_fn(p, *ext.m_subst, ext.m_subst_num_args-2, pos);
|
||||
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
|
||||
steps.emplace_back(pred, subst_pr);
|
||||
if (auto refl_it = state.m_refl_table.find(pred_op(pred))) {
|
||||
if (auto subst_it = state.m_subst_table.find(pred_op(pred))) {
|
||||
expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos);
|
||||
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
|
||||
expr subst = mk_op_fn(p, subst_it->first, subst_it->second-2, pos);
|
||||
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
|
||||
steps.emplace_back(pred, subst_pr);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (steps.empty())
|
||||
throw parser_error("invalid 'calc' expression, reflexivity rule is not defined for operator", pos);
|
||||
throw parser_error("invalid 'calc' expression, reflexivity and/or substitution rule is not defined for operator", pos);
|
||||
} else {
|
||||
expr pr = p.parse_expr();
|
||||
for (auto const & pred : preds)
|
||||
|
@ -266,7 +237,7 @@ static void collect_rhss(std::vector<calc_step> const & steps, buffer<expr> & rh
|
|||
static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<calc_step> const & steps2, std::vector<calc_step> & res_steps,
|
||||
pos_info const & pos) {
|
||||
res_steps.clear();
|
||||
calc_ext const & ext = get_extension(p.env());
|
||||
calc_state const & state = calc_ext::get_state(p.env());
|
||||
for (calc_step const & s1 : steps1) {
|
||||
check_interrupted();
|
||||
calc_pred const & pred1 = step_pred(s1);
|
||||
|
@ -276,7 +247,7 @@ static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<
|
|||
expr const & pr2 = step_proof(s2);
|
||||
if (!is_eqp(pred_rhs(pred1), pred_lhs(pred2)))
|
||||
continue;
|
||||
auto trans_it = ext.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2)));
|
||||
auto trans_it = state.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2)));
|
||||
if (!trans_it)
|
||||
continue;
|
||||
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos);
|
||||
|
|
Loading…
Reference in a new issue