fix(frontends/lean/calc): allow calc_subst to be defined for multiple operators, allow calc cmds to be organized into namespaces, fixes #65

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-08-23 16:44:06 -07:00
parent 2f699fa53a
commit df60ab4ada
4 changed files with 99 additions and 138 deletions

View file

@ -57,8 +57,6 @@ num_rec zero
-- Successor and predecessor
-- -------------------------
-- TODO: this looks like a calc bug -- calc is using subst for iff, instead of =
calc_subst subst
theorem succ_ne_zero (n : ) : succ n ≠ 0 :=
assume H : succ n = 0,
have H2 : true = false, from

View file

@ -8,10 +8,6 @@ import tools.fake_simplifier
using prod eq_ops
using fake_simplifier
-- TODO: calc bug -- remove
calc_subst subst
namespace quotient
-- auxliary facts about products

View file

@ -147,8 +147,4 @@ calc
... ↔ (a c) b : iff_symm (or_assoc _ _ _)
-- TODO: add or_left_comm, and_right_comm, and_left_comm
-- TODO: this is only temporary, until the calc bug is fixed
calc_subst subst
end relation

View file

@ -18,31 +18,17 @@ Author: Leonardo de Moura
#include "library/choice.h"
#include "library/placeholder.h"
#include "library/explicit.h"
#include "library/scoped_ext.h"
#include "frontends/lean/parser.h"
#include "frontends/lean/util.h"
namespace lean {
struct calc_ext : public environment_extension {
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> refl_table;
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
optional<name> m_subst;
unsigned m_subst_num_args;
trans_table m_trans_table;
refl_table m_refl_table;
calc_ext():m_subst_num_args(0) {}
};
struct calc_ext_reg {
unsigned m_ext_id;
calc_ext_reg() { m_ext_id = environment::register_extension(std::make_shared<calc_ext>()); }
};
static calc_ext_reg g_ext;
static calc_ext const & get_extension(environment const & env) {
return static_cast<calc_ext const &>(env.get_extension(g_ext.m_ext_id));
}
static environment update(environment const & env, calc_ext const & ext) {
return env.update(g_ext.m_ext_id, std::make_shared<calc_ext>(ext));
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
static name const & get_fn_const(expr const & e, char const * msg) {
expr const & fn = get_app_fn(e);
if (!is_constant(fn))
throw exception(msg);
return const_name(fn);
}
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
@ -54,118 +40,104 @@ static expr extract_arg_types(environment const & env, name const & f, buffer<ex
return f_type;
}
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
static name const & get_fn_const(expr const & e, char const * msg) {
expr const & fn = get_app_fn(e);
if (!is_constant(fn))
throw exception(msg);
return const_name(fn);
}
enum class calc_cmd { Subst, Trans, Refl };
static std::string g_calc_subst_key("calcs");
static std::string g_calc_refl_key("calcr");
static std::string g_calc_trans_key("calct");
struct calc_entry {
calc_cmd m_cmd;
name m_name;
calc_entry() {}
calc_entry(calc_cmd c, name const & n):m_cmd(c), m_name(n) {}
};
environment add_calc_subst(environment const & env, name const & subst) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, subst, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 2)
throw exception("invalid calc substitution theorem, it must have at least 2 arguments");
calc_ext ext = get_extension(env);
ext.m_subst = subst;
ext.m_subst_num_args = nargs;
environment new_env = module::add(env, g_calc_subst_key, [=](serializer & s) {
s << subst << nargs;
});
return update(new_env, ext);
}
struct calc_state {
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> refl_table;
typedef rb_map<name, pair<name, unsigned>, name_quick_cmp> subst_table;
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
trans_table m_trans_table;
refl_table m_refl_table;
subst_table m_subst_table;
calc_state() {}
environment add_calc_refl(environment const & env, name const & refl) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, refl, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 1)
throw exception("invalid calc reflexivity rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid calc reflexivity rule, result type must be an operator application");
calc_ext ext = get_extension(env);
ext.m_refl_table.insert(rop, mk_pair(refl, nargs));
environment new_env = module::add(env, g_calc_refl_key, [=](serializer & s) {
s << rop << refl << nargs;
});
return update(new_env, ext);
}
environment add_calc_trans(environment const & env, name const & trans) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, trans, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 5)
throw exception("invalid calc transitivity rule, it must have at least 5 arguments");
name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application");
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application");
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
calc_ext ext = get_extension(env);
ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
environment new_env = module::add(env, g_calc_trans_key, [=](serializer & s) {
s << op1 << op2 << trans << rop << nargs;
});
return update(new_env, ext);
}
void add_calc_subst(environment const & env, name const & subst) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, subst, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 2)
throw exception("invalid calc substitution theorem, it must have at least 2 arguments");
name const & rop = get_fn_const(arg_types[nargs-2], "invalid calc substitution theorem, argument penultimate argument must be an operator application");
m_subst_table.insert(rop, mk_pair(subst, nargs));
}
static void calc_subst_reader(deserializer & d, module_idx, shared_environment &,
std::function<void(asynch_update_fn const &)> &,
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
name subst; unsigned nargs;
d >> subst >> nargs;
add_delayed_update([=](environment const & env, io_state const &) -> environment {
calc_ext ext = get_extension(env);
ext.m_subst = subst;
ext.m_subst_num_args = nargs;
return update(env, ext);
});
}
register_module_object_reader_fn g_calc_subst_reader(g_calc_subst_key, calc_subst_reader);
void add_calc_refl(environment const & env, name const & refl) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, refl, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 1)
throw exception("invalid calc reflexivity rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid calc reflexivity rule, result type must be an operator application");
m_refl_table.insert(rop, mk_pair(refl, nargs));
}
static void calc_refl_reader(deserializer & d, module_idx, shared_environment &,
std::function<void(asynch_update_fn const &)> &,
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
name rop, refl; unsigned nargs;
d >> rop >> refl >> nargs;
add_delayed_update([=](environment const & env, io_state const &) -> environment {
calc_ext ext = get_extension(env);
ext.m_refl_table.insert(rop, mk_pair(refl, nargs));
return update(env, ext);
});
}
register_module_object_reader_fn g_calc_refl_reader(g_calc_refl_key, calc_refl_reader);
void add_calc_trans(environment const & env, name const & trans) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, trans, arg_types);
unsigned nargs = arg_types.size();
if (nargs < 5)
throw exception("invalid calc transitivity rule, it must have at least 5 arguments");
name const & rop = get_fn_const(r_type, "invalid calc transitivity rule, result type must be an operator application");
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid calc transitivity rule, penultimate argument must be an operator application");
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid calc transitivity rule, last argument must be an operator application");
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
}
};
static void calc_trans_reader(deserializer & d, module_idx, shared_environment &,
std::function<void(asynch_update_fn const &)> &,
std::function<void(delayed_update_fn const &)> & add_delayed_update) {
name op1, op2, trans, rop; unsigned nargs;
d >> op1 >> op2 >> trans >> rop >> nargs;
add_delayed_update([=](environment const & env, io_state const &) -> environment {
calc_ext ext = get_extension(env);
ext.m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
return update(env, ext);
});
}
register_module_object_reader_fn g_calc_trans_reader(g_calc_trans_key, calc_trans_reader);
struct calc_config {
typedef calc_state state;
typedef calc_entry entry;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
switch (e.m_cmd) {
case calc_cmd::Refl: s.add_calc_refl(env, e.m_name); break;
case calc_cmd::Subst: s.add_calc_subst(env, e.m_name); break;
case calc_cmd::Trans: s.add_calc_trans(env, e.m_name); break;
}
}
static name const & get_class_name() {
static name g_calc_name("calc");
return g_calc_name;
}
static std::string const & get_serialization_key() {
static std::string g_key("calc");
return g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << static_cast<char>(e.m_cmd) << e.m_name;
}
static entry read_entry(deserializer & d) {
entry e;
char cmd;
d >> cmd >> e.m_name;
e.m_cmd = static_cast<calc_cmd>(cmd);
return e;
}
};
template class scoped_ext<calc_config>;
typedef scoped_ext<calc_config> calc_ext;
environment calc_subst_cmd(parser & p) {
name id = p.check_constant_next("invalid 'calc_subst' command, constant expected");
return add_calc_subst(p.env(), id);
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Subst, id));
}
environment calc_refl_cmd(parser & p) {
name id = p.check_constant_next("invalid 'calc_refl' command, constant expected");
return add_calc_refl(p.env(), id);
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Refl, id));
}
environment calc_trans_cmd(parser & p) {
name id = p.check_constant_next("invalid 'calc_trans' command, constant expected");
return add_calc_trans(p.env(), id);
return calc_ext::add_entry(p.env(), get_dummy_ios(), calc_entry(calc_cmd::Trans, id));
}
void register_calc_cmds(cmd_table & r) {
@ -229,21 +201,20 @@ static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::v
p.next();
expr pr = p.parse_expr();
p.check_token_next(g_rcurly, "invalid 'calc' expression, '}' expected");
calc_ext const & ext = get_extension(p.env());
if (!ext.m_subst)
throw parser_error("invalid 'calc' expression, substitution rule was not defined with calc_subst command", pos);
calc_state const & state = calc_ext::get_state(p.env());
for (auto const & pred : preds) {
auto refl_it = ext.m_refl_table.find(pred_op(pred));
if (refl_it) {
expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos);
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
expr subst = mk_op_fn(p, *ext.m_subst, ext.m_subst_num_args-2, pos);
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
steps.emplace_back(pred, subst_pr);
if (auto refl_it = state.m_refl_table.find(pred_op(pred))) {
if (auto subst_it = state.m_subst_table.find(pred_op(pred))) {
expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos);
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
expr subst = mk_op_fn(p, subst_it->first, subst_it->second-2, pos);
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
steps.emplace_back(pred, subst_pr);
}
}
}
if (steps.empty())
throw parser_error("invalid 'calc' expression, reflexivity rule is not defined for operator", pos);
throw parser_error("invalid 'calc' expression, reflexivity and/or substitution rule is not defined for operator", pos);
} else {
expr pr = p.parse_expr();
for (auto const & pred : preds)
@ -266,7 +237,7 @@ static void collect_rhss(std::vector<calc_step> const & steps, buffer<expr> & rh
static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<calc_step> const & steps2, std::vector<calc_step> & res_steps,
pos_info const & pos) {
res_steps.clear();
calc_ext const & ext = get_extension(p.env());
calc_state const & state = calc_ext::get_state(p.env());
for (calc_step const & s1 : steps1) {
check_interrupted();
calc_pred const & pred1 = step_pred(s1);
@ -276,7 +247,7 @@ static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<
expr const & pr2 = step_proof(s2);
if (!is_eqp(pred_rhs(pred1), pred_lhs(pred2)))
continue;
auto trans_it = ext.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2)));
auto trans_it = state.m_trans_table.find(name_pair(pred_op(pred1), pred_op(pred2)));
if (!trans_it)
continue;
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos);