refactor(library/relation_manager): cleanup and add API for declaring a relation that may not be reflexive, symmetric nor transitive

This commit is contained in:
Leonardo de Moura 2015-07-07 15:58:24 -07:00
parent fb833a724b
commit 991ff67b45
6 changed files with 92 additions and 77 deletions

View file

@ -31,16 +31,6 @@ Author: Leonardo de Moura
#include "frontends/lean/begin_end_ext.h"
namespace lean {
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op) {
return get_refl_extra_info(env, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op) {
return get_subst_extra_info(env, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op) {
return get_symm_extra_info(env, op);
}
static name * g_calc_name = nullptr;
static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); }
@ -108,9 +98,9 @@ static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::v
for (auto const & pred : preds) {
if (auto refl_it = get_refl_extra_info(env, pred_op(pred))) {
if (auto subst_it = get_subst_extra_info(env, pred_op(pred))) {
expr refl = mk_op_fn(p, std::get<0>(*refl_it), std::get<1>(*refl_it)-1, pos);
expr refl = mk_op_fn(p, refl_it->m_name, refl_it->m_num_args-1, pos);
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
expr subst = mk_op_fn(p, std::get<0>(*subst_it), std::get<1>(*subst_it)-2, pos);
expr subst = mk_op_fn(p, subst_it->m_name, subst_it->m_num_args-2, pos);
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
steps.emplace_back(pred, subst_pr);
}
@ -156,9 +146,9 @@ static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<
continue;
auto trans_it = get_trans_extra_info(env, pred_op(pred1), pred_op(pred2));
if (trans_it) {
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos);
expr trans = mk_op_fn(p, trans_it->m_name, trans_it->m_num_args-5, pos);
expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos);
res_steps.emplace_back(calc_pred(std::get<1>(*trans_it), pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
res_steps.emplace_back(calc_pred(trans_it->m_res_relation, pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
} else if (pred_op(pred1) == get_eq_name()) {
expr trans_right = mk_op_fn(p, get_trans_rel_right_name(), 1, pos);
expr R = mk_op_fn(p, pred_op(pred2), get_arity_of(p, pred_op(pred2))-2, pos);

View file

@ -10,12 +10,6 @@ namespace lean {
class parser;
expr parse_calc(parser & p);
bool is_calc_annotation(expr const & e);
/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters.
Return none if the operator does not have a symmetry rule associated with it.
*/
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op);
void initialize_calc();
void finalize_calc();
}

View file

@ -12,6 +12,7 @@ Author: Leonardo de Moura
#include "library/reducible.h"
#include "library/metavar_closure.h"
#include "library/local_context.h"
#include "library/relation_manager.h"
#include "frontends/lean/util.h"
#include "frontends/lean/info_manager.h"
#include "frontends/lean/calc.h"
@ -75,10 +76,9 @@ static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_
buffer<expr> args;
expr const & op = get_app_args(e_type, args);
if (is_constant(op)) {
if (auto t = get_calc_symm_info(env, const_name(op))) {
name symm; unsigned nargs; unsigned nunivs;
std::tie(symm, nargs, nunivs) = *t;
return mk_op(env, ctx, ngen, tc, symm, nunivs, nargs-1, {e}, cs, g);
if (auto info = get_symm_extra_info(env, const_name(op))) {
return mk_op(env, ctx, ngen, tc, info->m_name,
info->m_num_univs, info->m_num_args-1, {e}, cs, g);
}
}
return optional<pair<expr, expr>>();
@ -95,15 +95,12 @@ static optional<pair<expr, expr>> apply_subst(environment const & env, local_con
buffer<expr> args;
expr const & op = get_app_args(e_type, args);
if (is_constant(op) && args.size() >= 2) {
if (auto subst_it = get_calc_subst_info(env, const_name(op))) {
name subst; unsigned subst_nargs; unsigned subst_univs;
std::tie(subst, subst_nargs, subst_univs) = *subst_it;
if (auto refl_it = get_calc_refl_info(env, const_name(op))) {
name refl; unsigned refl_nargs; unsigned refl_univs;
std::tie(refl, refl_nargs, refl_univs) = *refl_it;
if (auto refl_pair = mk_op(env, ctx, ngen, tc, refl, refl_univs, refl_nargs-1,
{ pred_args[npargs-2] }, cs, g)) {
return mk_op(env, ctx, ngen, tc, subst, subst_univs, subst_nargs-2, {e, refl_pair->first}, cs, g);
if (auto sinfo = get_subst_extra_info(env, const_name(op))) {
if (auto rinfo = get_refl_extra_info(env, const_name(op))) {
if (auto refl_pair = mk_op(env, ctx, ngen, tc, rinfo->m_name, rinfo->m_num_univs,
rinfo->m_num_args-1, { pred_args[npargs-2] }, cs, g)) {
return mk_op(env, ctx, ngen, tc, sinfo->m_name, sinfo->m_num_univs,
sinfo->m_num_args-2, {e, refl_pair->first}, cs, g);
}
}
}

View file

@ -32,11 +32,7 @@ static pair<expr, unsigned> extract_arg_types_core(environment const & env, name
return mk_pair(f_type, d.get_num_univ_params());
}
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) {
return extract_arg_types_core(env, f, arg_types).first;
}
enum class op_kind { Subst, Trans, Refl, Symm };
enum class op_kind { Relation, Subst, Trans, Refl, Symm };
struct rel_entry {
op_kind m_kind;
@ -46,10 +42,10 @@ struct rel_entry {
};
struct rel_state {
typedef name_map<std::tuple<name, unsigned, unsigned>> refl_table;
typedef name_map<std::tuple<name, unsigned, unsigned>> subst_table;
typedef name_map<std::tuple<name, unsigned, unsigned>> symm_table;
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table;
typedef name_map<refl_info> refl_table;
typedef name_map<subst_info> subst_table;
typedef name_map<symm_info> symm_table;
typedef rb_map<name_pair, trans_info, name_pair_quick_cmp> trans_table;
typedef name_map<relation_info> rop_table;
trans_table m_trans_table;
refl_table m_refl_table;
@ -77,12 +73,14 @@ struct rel_state {
expr type = d.get_type();
while (is_pi(type)) {
if (is_explicit(binding_info(type))) {
if (!lhs_pos)
if (!lhs_pos) {
lhs_pos = i;
else if (!rhs_pos)
} else if (!rhs_pos) {
rhs_pos = i;
else
throw_invalid_relation(rop);
} else {
lhs_pos = rhs_pos;
rhs_pos = i;
}
}
type = binding_body(type);
i++;
@ -103,7 +101,7 @@ struct rel_state {
if (nargs < 2)
throw exception("invalid substitution theorem, it must have at least 2 arguments");
name const & rop = get_fn_const(arg_types[nargs-2], "invalid substitution theorem, penultimate argument must be an operator application");
m_subst_table.insert(rop, std::make_tuple(subst, nargs, nunivs));
m_subst_table.insert(rop, subst_info(subst, nunivs, nargs));
}
void add_refl(environment const & env, name const & refl) {
@ -116,12 +114,14 @@ struct rel_state {
throw exception("invalid reflexivity rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid reflexivity rule, result type must be an operator application");
register_rop(env, rop);
m_refl_table.insert(rop, std::make_tuple(refl, nargs, nunivs));
m_refl_table.insert(rop, refl_info(refl, nunivs, nargs));
}
void add_trans(environment const & env, name const & trans) {
buffer<expr> arg_types;
expr r_type = extract_arg_types(env, trans, arg_types);
auto p = extract_arg_types_core(env, trans, arg_types);
expr r_type = p.first;
unsigned nunivs = p.second;
unsigned nargs = arg_types.size();
if (nargs < 5)
throw exception("invalid transitivity rule, it must have at least 5 arguments");
@ -129,7 +129,7 @@ struct rel_state {
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid transitivity rule, penultimate argument must be an operator application");
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid transitivity rule, last argument must be an operator application");
register_rop(env, rop);
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs));
m_trans_table.insert(name_pair(op1, op2), trans_info(trans, nunivs, nargs, rop));
}
void add_symm(environment const & env, name const & symm) {
@ -142,7 +142,7 @@ struct rel_state {
throw exception("invalid symmetry rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid symmetry rule, result type must be an operator application");
register_rop(env, rop);
m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs));
m_symm_table.insert(rop, symm_info(symm, nunivs, nargs));
}
};
@ -154,6 +154,7 @@ struct rel_config {
typedef rel_entry entry;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
switch (e.m_kind) {
case op_kind::Relation: s.register_rop(env, e.m_name); break;
case op_kind::Refl: s.add_refl(env, e.m_name); break;
case op_kind::Subst: s.add_subst(env, e.m_name); break;
case op_kind::Trans: s.add_trans(env, e.m_name); break;
@ -184,6 +185,10 @@ struct rel_config {
template class scoped_ext<rel_config>;
typedef scoped_ext<rel_config> rel_ext;
environment add_relation(environment const & env, name const & n, bool persistent) {
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Relation, n), persistent);
}
environment add_subst(environment const & env, name const & n, bool persistent) {
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Subst, n), persistent);
}
@ -200,29 +205,29 @@ environment add_trans(environment const & env, name const & n, bool persistent)
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Trans, n), persistent);
}
static optional<std::tuple<name, unsigned, unsigned>> get_info(name_map<std::tuple<name, unsigned, unsigned>> const & table, name const & op) {
static optional<relation_lemma_info> get_info(name_map<relation_lemma_info> const & table, name const & op) {
if (auto it = table.find(op)) {
return optional<std::tuple<name, unsigned, unsigned>>(*it);
return optional<relation_lemma_info>(*it);
} else {
return optional<std::tuple<name, unsigned, unsigned>>();
return optional<relation_lemma_info>();
}
}
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op) {
optional<refl_info> get_refl_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_refl_table, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op) {
optional<subst_info> get_subst_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_subst_table, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op) {
optional<symm_info> get_symm_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_symm_table, op);
}
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2) {
optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2) {
if (auto it = rel_ext::get_state(env).m_trans_table.find(mk_pair(op1, op2))) {
return optional<std::tuple<name, name, unsigned>>(*it);
return optional<trans_info>(*it);
} else {
return optional<std::tuple<name, name, unsigned>>();
return optional<trans_info>();
}
}
@ -232,21 +237,21 @@ bool is_subst_relation(environment const & env, name const & op) {
optional<name> get_refl_info(environment const & env, name const & op) {
if (auto it = get_refl_extra_info(env, op))
return optional<name>(std::get<0>(*it));
return optional<name>(it->m_name);
else
return optional<name>();
}
optional<name> get_symm_info(environment const & env, name const & op) {
if (auto it = get_symm_extra_info(env, op))
return optional<name>(std::get<0>(*it));
return optional<name>(it->m_name);
else
return optional<name>();
}
optional<name> get_trans_info(environment const & env, name const & op) {
if (auto it = get_trans_extra_info(env, op, op))
return optional<name>(std::get<0>(*it));
return optional<name>(it->m_name);
else
return optional<name>();
}

View file

@ -27,17 +27,46 @@ public:
/** \brief Return true if \c rop is a registered equivalence relation in the given manager */
bool is_equivalence(environment const & env, name const & rop);
/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information */
/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information
Lean assumes that the arguments of the binary relation are the last two explicit arguments.
Everything else is assumed to be a parameter.
*/
relation_info const * get_relation_info(environment const & env, name const & rop);
inline bool is_relation(environment const & env, name const & rop) { return get_relation_info(env, rop) != nullptr; }
/** \brief Declare a new binary relation named \c n */
environment add_relation(environment const & env, name const & n, bool persistent = true);
/** \brief Declare subst/refl/symm/trans lemmas for a binary relation,
* it also declares the relation if it has not been declared yet */
environment add_subst(environment const & env, name const & n, bool persistent = true);
environment add_refl(environment const & env, name const & n, bool persistent = true);
environment add_symm(environment const & env, name const & n, bool persistent = true);
environment add_trans(environment const & env, name const & n, bool persistent = true);
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op);
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2);
struct relation_lemma_info {
name m_name;
unsigned m_num_univs;
unsigned m_num_args;
relation_lemma_info() {}
relation_lemma_info(name const & n, unsigned nunivs, unsigned nargs):m_name(n), m_num_univs(nunivs), m_num_args(nargs) {}
};
typedef relation_lemma_info refl_info;
typedef relation_lemma_info symm_info;
typedef relation_lemma_info subst_info;
struct trans_info : public relation_lemma_info {
name m_res_relation;
trans_info() {}
trans_info(name const & n, unsigned nunivs, unsigned nargs, name const & rel):
relation_lemma_info(n, nunivs, nargs), m_res_relation(rel) {}
};
optional<subst_info> get_subst_extra_info(environment const & env, name const & op);
optional<refl_info> get_refl_extra_info(environment const & env, name const & op);
optional<symm_info> get_symm_extra_info(environment const & env, name const & op);
optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2);
optional<name> get_refl_info(environment const & env, name const & op);
optional<name> get_symm_info(environment const & env, name const & op);
optional<name> get_trans_info(environment const & env, name const & op);

View file

@ -65,10 +65,10 @@ tactic trans_tactic(elaborate_fn const & elab, expr const & e) {
if (!op)
return proof_state_seq();
if (auto info = get_trans_extra_info(env, *op, *op)) {
expr pr = mk_explicit(mk_constant(std::get<0>(*info)));
unsigned nparams = std::get<2>(*info);
lean_assert(nparams >= 5);
for (unsigned i = 0; i < nparams - 4; i++)
expr pr = mk_explicit(mk_constant(info->m_name));
unsigned nargs = info->m_num_args;
lean_assert(nargs >= 5);
for (unsigned i = 0; i < nargs - 4; i++)
pr = mk_app(pr, mk_expr_placeholder());
pr = mk_app(pr, e);
return apply_tactic(elab, pr)(env, ios, s);