refactor(library/relation_manager): cleanup and add API for declaring a relation that may not be reflexive, symmetric nor transitive
This commit is contained in:
parent
fb833a724b
commit
991ff67b45
6 changed files with 92 additions and 77 deletions
|
@ -31,16 +31,6 @@ Author: Leonardo de Moura
|
|||
#include "frontends/lean/begin_end_ext.h"
|
||||
|
||||
namespace lean {
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op) {
|
||||
return get_refl_extra_info(env, op);
|
||||
}
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op) {
|
||||
return get_subst_extra_info(env, op);
|
||||
}
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op) {
|
||||
return get_symm_extra_info(env, op);
|
||||
}
|
||||
|
||||
static name * g_calc_name = nullptr;
|
||||
|
||||
static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); }
|
||||
|
@ -108,9 +98,9 @@ static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::v
|
|||
for (auto const & pred : preds) {
|
||||
if (auto refl_it = get_refl_extra_info(env, pred_op(pred))) {
|
||||
if (auto subst_it = get_subst_extra_info(env, pred_op(pred))) {
|
||||
expr refl = mk_op_fn(p, std::get<0>(*refl_it), std::get<1>(*refl_it)-1, pos);
|
||||
expr refl = mk_op_fn(p, refl_it->m_name, refl_it->m_num_args-1, pos);
|
||||
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
|
||||
expr subst = mk_op_fn(p, std::get<0>(*subst_it), std::get<1>(*subst_it)-2, pos);
|
||||
expr subst = mk_op_fn(p, subst_it->m_name, subst_it->m_num_args-2, pos);
|
||||
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
|
||||
steps.emplace_back(pred, subst_pr);
|
||||
}
|
||||
|
@ -156,9 +146,9 @@ static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<
|
|||
continue;
|
||||
auto trans_it = get_trans_extra_info(env, pred_op(pred1), pred_op(pred2));
|
||||
if (trans_it) {
|
||||
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos);
|
||||
expr trans = mk_op_fn(p, trans_it->m_name, trans_it->m_num_args-5, pos);
|
||||
expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos);
|
||||
res_steps.emplace_back(calc_pred(std::get<1>(*trans_it), pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
|
||||
res_steps.emplace_back(calc_pred(trans_it->m_res_relation, pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
|
||||
} else if (pred_op(pred1) == get_eq_name()) {
|
||||
expr trans_right = mk_op_fn(p, get_trans_rel_right_name(), 1, pos);
|
||||
expr R = mk_op_fn(p, pred_op(pred2), get_arity_of(p, pred_op(pred2))-2, pos);
|
||||
|
|
|
@ -10,12 +10,6 @@ namespace lean {
|
|||
class parser;
|
||||
expr parse_calc(parser & p);
|
||||
bool is_calc_annotation(expr const & e);
|
||||
/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters.
|
||||
Return none if the operator does not have a symmetry rule associated with it.
|
||||
*/
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op);
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op);
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op);
|
||||
void initialize_calc();
|
||||
void finalize_calc();
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ Author: Leonardo de Moura
|
|||
#include "library/reducible.h"
|
||||
#include "library/metavar_closure.h"
|
||||
#include "library/local_context.h"
|
||||
#include "library/relation_manager.h"
|
||||
#include "frontends/lean/util.h"
|
||||
#include "frontends/lean/info_manager.h"
|
||||
#include "frontends/lean/calc.h"
|
||||
|
@ -75,10 +76,9 @@ static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_
|
|||
buffer<expr> args;
|
||||
expr const & op = get_app_args(e_type, args);
|
||||
if (is_constant(op)) {
|
||||
if (auto t = get_calc_symm_info(env, const_name(op))) {
|
||||
name symm; unsigned nargs; unsigned nunivs;
|
||||
std::tie(symm, nargs, nunivs) = *t;
|
||||
return mk_op(env, ctx, ngen, tc, symm, nunivs, nargs-1, {e}, cs, g);
|
||||
if (auto info = get_symm_extra_info(env, const_name(op))) {
|
||||
return mk_op(env, ctx, ngen, tc, info->m_name,
|
||||
info->m_num_univs, info->m_num_args-1, {e}, cs, g);
|
||||
}
|
||||
}
|
||||
return optional<pair<expr, expr>>();
|
||||
|
@ -95,15 +95,12 @@ static optional<pair<expr, expr>> apply_subst(environment const & env, local_con
|
|||
buffer<expr> args;
|
||||
expr const & op = get_app_args(e_type, args);
|
||||
if (is_constant(op) && args.size() >= 2) {
|
||||
if (auto subst_it = get_calc_subst_info(env, const_name(op))) {
|
||||
name subst; unsigned subst_nargs; unsigned subst_univs;
|
||||
std::tie(subst, subst_nargs, subst_univs) = *subst_it;
|
||||
if (auto refl_it = get_calc_refl_info(env, const_name(op))) {
|
||||
name refl; unsigned refl_nargs; unsigned refl_univs;
|
||||
std::tie(refl, refl_nargs, refl_univs) = *refl_it;
|
||||
if (auto refl_pair = mk_op(env, ctx, ngen, tc, refl, refl_univs, refl_nargs-1,
|
||||
{ pred_args[npargs-2] }, cs, g)) {
|
||||
return mk_op(env, ctx, ngen, tc, subst, subst_univs, subst_nargs-2, {e, refl_pair->first}, cs, g);
|
||||
if (auto sinfo = get_subst_extra_info(env, const_name(op))) {
|
||||
if (auto rinfo = get_refl_extra_info(env, const_name(op))) {
|
||||
if (auto refl_pair = mk_op(env, ctx, ngen, tc, rinfo->m_name, rinfo->m_num_univs,
|
||||
rinfo->m_num_args-1, { pred_args[npargs-2] }, cs, g)) {
|
||||
return mk_op(env, ctx, ngen, tc, sinfo->m_name, sinfo->m_num_univs,
|
||||
sinfo->m_num_args-2, {e, refl_pair->first}, cs, g);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,11 +32,7 @@ static pair<expr, unsigned> extract_arg_types_core(environment const & env, name
|
|||
return mk_pair(f_type, d.get_num_univ_params());
|
||||
}
|
||||
|
||||
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
|
||||
return extract_arg_types_core(env, f, arg_types).first;
|
||||
}
|
||||
|
||||
enum class op_kind { Subst, Trans, Refl, Symm };
|
||||
enum class op_kind { Relation, Subst, Trans, Refl, Symm };
|
||||
|
||||
struct rel_entry {
|
||||
op_kind m_kind;
|
||||
|
@ -46,10 +42,10 @@ struct rel_entry {
|
|||
};
|
||||
|
||||
struct rel_state {
|
||||
typedef name_map<std::tuple<name, unsigned, unsigned>> refl_table;
|
||||
typedef name_map<std::tuple<name, unsigned, unsigned>> subst_table;
|
||||
typedef name_map<std::tuple<name, unsigned, unsigned>> symm_table;
|
||||
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
|
||||
typedef name_map<refl_info> refl_table;
|
||||
typedef name_map<subst_info> subst_table;
|
||||
typedef name_map<symm_info> symm_table;
|
||||
typedef rb_map<name_pair, trans_info, name_pair_quick_cmp> trans_table;
|
||||
typedef name_map<relation_info> rop_table;
|
||||
trans_table m_trans_table;
|
||||
refl_table m_refl_table;
|
||||
|
@ -77,12 +73,14 @@ struct rel_state {
|
|||
expr type = d.get_type();
|
||||
while (is_pi(type)) {
|
||||
if (is_explicit(binding_info(type))) {
|
||||
if (!lhs_pos)
|
||||
if (!lhs_pos) {
|
||||
lhs_pos = i;
|
||||
else if (!rhs_pos)
|
||||
} else if (!rhs_pos) {
|
||||
rhs_pos = i;
|
||||
else
|
||||
throw_invalid_relation(rop);
|
||||
} else {
|
||||
lhs_pos = rhs_pos;
|
||||
rhs_pos = i;
|
||||
}
|
||||
}
|
||||
type = binding_body(type);
|
||||
i++;
|
||||
|
@ -103,7 +101,7 @@ struct rel_state {
|
|||
if (nargs < 2)
|
||||
throw exception("invalid substitution theorem, it must have at least 2 arguments");
|
||||
name const & rop = get_fn_const(arg_types[nargs-2], "invalid substitution theorem, penultimate argument must be an operator application");
|
||||
m_subst_table.insert(rop, std::make_tuple(subst, nargs, nunivs));
|
||||
m_subst_table.insert(rop, subst_info(subst, nunivs, nargs));
|
||||
}
|
||||
|
||||
void add_refl(environment const & env, name const & refl) {
|
||||
|
@ -116,20 +114,22 @@ struct rel_state {
|
|||
throw exception("invalid reflexivity rule, it must have at least 1 argument");
|
||||
name const & rop = get_fn_const(r_type, "invalid reflexivity rule, result type must be an operator application");
|
||||
register_rop(env, rop);
|
||||
m_refl_table.insert(rop, std::make_tuple(refl, nargs, nunivs));
|
||||
m_refl_table.insert(rop, refl_info(refl, nunivs, nargs));
|
||||
}
|
||||
|
||||
void add_trans(environment const & env, name const & trans) {
|
||||
buffer<expr> arg_types;
|
||||
expr r_type = extract_arg_types(env, trans, arg_types);
|
||||
unsigned nargs = arg_types.size();
|
||||
auto p = extract_arg_types_core(env, trans, arg_types);
|
||||
expr r_type = p.first;
|
||||
unsigned nunivs = p.second;
|
||||
unsigned nargs = arg_types.size();
|
||||
if (nargs < 5)
|
||||
throw exception("invalid transitivity rule, it must have at least 5 arguments");
|
||||
name const & rop = get_fn_const(r_type, "invalid transitivity rule, result type must be an operator application");
|
||||
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid transitivity rule, penultimate argument must be an operator application");
|
||||
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid transitivity rule, last argument must be an operator application");
|
||||
register_rop(env, rop);
|
||||
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
|
||||
m_trans_table.insert(name_pair(op1, op2), trans_info(trans, nunivs, nargs, rop));
|
||||
}
|
||||
|
||||
void add_symm(environment const & env, name const & symm) {
|
||||
|
@ -142,7 +142,7 @@ struct rel_state {
|
|||
throw exception("invalid symmetry rule, it must have at least 1 argument");
|
||||
name const & rop = get_fn_const(r_type, "invalid symmetry rule, result type must be an operator application");
|
||||
register_rop(env, rop);
|
||||
m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs));
|
||||
m_symm_table.insert(rop, symm_info(symm, nunivs, nargs));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -154,10 +154,11 @@ struct rel_config {
|
|||
typedef rel_entry entry;
|
||||
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
|
||||
switch (e.m_kind) {
|
||||
case op_kind::Refl: s.add_refl(env, e.m_name); break;
|
||||
case op_kind::Subst: s.add_subst(env, e.m_name); break;
|
||||
case op_kind::Trans: s.add_trans(env, e.m_name); break;
|
||||
case op_kind::Symm: s.add_symm(env, e.m_name); break;
|
||||
case op_kind::Relation: s.register_rop(env, e.m_name); break;
|
||||
case op_kind::Refl: s.add_refl(env, e.m_name); break;
|
||||
case op_kind::Subst: s.add_subst(env, e.m_name); break;
|
||||
case op_kind::Trans: s.add_trans(env, e.m_name); break;
|
||||
case op_kind::Symm: s.add_symm(env, e.m_name); break;
|
||||
}
|
||||
}
|
||||
static name const & get_class_name() {
|
||||
|
@ -184,6 +185,10 @@ struct rel_config {
|
|||
template class scoped_ext<rel_config>;
|
||||
typedef scoped_ext<rel_config> rel_ext;
|
||||
|
||||
environment add_relation(environment const & env, name const & n, bool persistent) {
|
||||
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Relation, n), persistent);
|
||||
}
|
||||
|
||||
environment add_subst(environment const & env, name const & n, bool persistent) {
|
||||
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Subst, n), persistent);
|
||||
}
|
||||
|
@ -200,29 +205,29 @@ environment add_trans(environment const & env, name const & n, bool persistent)
|
|||
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Trans, n), persistent);
|
||||
}
|
||||
|
||||
static optional<std::tuple<name, unsigned, unsigned>> get_info(name_map<std::tuple<name, unsigned, unsigned>> const & table, name const & op) {
|
||||
static optional<relation_lemma_info> get_info(name_map<relation_lemma_info> const & table, name const & op) {
|
||||
if (auto it = table.find(op)) {
|
||||
return optional<std::tuple<name, unsigned, unsigned>>(*it);
|
||||
return optional<relation_lemma_info>(*it);
|
||||
} else {
|
||||
return optional<std::tuple<name, unsigned, unsigned>>();
|
||||
return optional<relation_lemma_info>();
|
||||
}
|
||||
}
|
||||
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op) {
|
||||
optional<refl_info> get_refl_extra_info(environment const & env, name const & op) {
|
||||
return get_info(rel_ext::get_state(env).m_refl_table, op);
|
||||
}
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op) {
|
||||
optional<subst_info> get_subst_extra_info(environment const & env, name const & op) {
|
||||
return get_info(rel_ext::get_state(env).m_subst_table, op);
|
||||
}
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op) {
|
||||
optional<symm_info> get_symm_extra_info(environment const & env, name const & op) {
|
||||
return get_info(rel_ext::get_state(env).m_symm_table, op);
|
||||
}
|
||||
|
||||
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2) {
|
||||
optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2) {
|
||||
if (auto it = rel_ext::get_state(env).m_trans_table.find(mk_pair(op1, op2))) {
|
||||
return optional<std::tuple<name, name, unsigned>>(*it);
|
||||
return optional<trans_info>(*it);
|
||||
} else {
|
||||
return optional<std::tuple<name, name, unsigned>>();
|
||||
return optional<trans_info>();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -232,21 +237,21 @@ bool is_subst_relation(environment const & env, name const & op) {
|
|||
|
||||
optional<name> get_refl_info(environment const & env, name const & op) {
|
||||
if (auto it = get_refl_extra_info(env, op))
|
||||
return optional<name>(std::get<0>(*it));
|
||||
return optional<name>(it->m_name);
|
||||
else
|
||||
return optional<name>();
|
||||
}
|
||||
|
||||
optional<name> get_symm_info(environment const & env, name const & op) {
|
||||
if (auto it = get_symm_extra_info(env, op))
|
||||
return optional<name>(std::get<0>(*it));
|
||||
return optional<name>(it->m_name);
|
||||
else
|
||||
return optional<name>();
|
||||
}
|
||||
|
||||
optional<name> get_trans_info(environment const & env, name const & op) {
|
||||
if (auto it = get_trans_extra_info(env, op, op))
|
||||
return optional<name>(std::get<0>(*it));
|
||||
return optional<name>(it->m_name);
|
||||
else
|
||||
return optional<name>();
|
||||
}
|
||||
|
|
|
@ -27,17 +27,46 @@ public:
|
|||
/** \brief Return true if \c rop is a registered equivalence relation in the given manager */
|
||||
bool is_equivalence(environment const & env, name const & rop);
|
||||
|
||||
/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information */
|
||||
/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information
|
||||
Lean assumes that the arguments of the binary relation are the last two explicit arguments.
|
||||
Everything else is assumed to be a parameter.
|
||||
*/
|
||||
relation_info const * get_relation_info(environment const & env, name const & rop);
|
||||
inline bool is_relation(environment const & env, name const & rop) { return get_relation_info(env, rop) != nullptr; }
|
||||
|
||||
/** \brief Declare a new binary relation named \c n */
|
||||
environment add_relation(environment const & env, name const & n, bool persistent = true);
|
||||
|
||||
/** \brief Declare subst/refl/symm/trans lemmas for a binary relation,
|
||||
* it also declares the relation if it has not been declared yet */
|
||||
environment add_subst(environment const & env, name const & n, bool persistent = true);
|
||||
environment add_refl(environment const & env, name const & n, bool persistent = true);
|
||||
environment add_symm(environment const & env, name const & n, bool persistent = true);
|
||||
environment add_trans(environment const & env, name const & n, bool persistent = true);
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op);
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op);
|
||||
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op);
|
||||
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2);
|
||||
|
||||
struct relation_lemma_info {
|
||||
name m_name;
|
||||
unsigned m_num_univs;
|
||||
unsigned m_num_args;
|
||||
relation_lemma_info() {}
|
||||
relation_lemma_info(name const & n, unsigned nunivs, unsigned nargs):m_name(n), m_num_univs(nunivs), m_num_args(nargs) {}
|
||||
};
|
||||
|
||||
typedef relation_lemma_info refl_info;
|
||||
typedef relation_lemma_info symm_info;
|
||||
typedef relation_lemma_info subst_info;
|
||||
|
||||
struct trans_info : public relation_lemma_info {
|
||||
name m_res_relation;
|
||||
trans_info() {}
|
||||
trans_info(name const & n, unsigned nunivs, unsigned nargs, name const & rel):
|
||||
relation_lemma_info(n, nunivs, nargs), m_res_relation(rel) {}
|
||||
};
|
||||
|
||||
optional<subst_info> get_subst_extra_info(environment const & env, name const & op);
|
||||
optional<refl_info> get_refl_extra_info(environment const & env, name const & op);
|
||||
optional<symm_info> get_symm_extra_info(environment const & env, name const & op);
|
||||
optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2);
|
||||
optional<name> get_refl_info(environment const & env, name const & op);
|
||||
optional<name> get_symm_info(environment const & env, name const & op);
|
||||
optional<name> get_trans_info(environment const & env, name const & op);
|
||||
|
|
|
@ -65,10 +65,10 @@ tactic trans_tactic(elaborate_fn const & elab, expr const & e) {
|
|||
if (!op)
|
||||
return proof_state_seq();
|
||||
if (auto info = get_trans_extra_info(env, *op, *op)) {
|
||||
expr pr = mk_explicit(mk_constant(std::get<0>(*info)));
|
||||
unsigned nparams = std::get<2>(*info);
|
||||
lean_assert(nparams >= 5);
|
||||
for (unsigned i = 0; i < nparams - 4; i++)
|
||||
expr pr = mk_explicit(mk_constant(info->m_name));
|
||||
unsigned nargs = info->m_num_args;
|
||||
lean_assert(nargs >= 5);
|
||||
for (unsigned i = 0; i < nargs - 4; i++)
|
||||
pr = mk_app(pr, mk_expr_placeholder());
|
||||
pr = mk_app(pr, e);
|
||||
return apply_tactic(elab, pr)(env, ios, s);
|
||||
|
|
Loading…
Reference in a new issue