refactor(library/relation_manager): cleanup and add API for declaring a relation that may not be reflexive, symmetric nor transitive

This commit is contained in:
Leonardo de Moura 2015-07-07 15:58:24 -07:00
parent fb833a724b
commit 991ff67b45
6 changed files with 92 additions and 77 deletions

View file

@ -31,16 +31,6 @@ Author: Leonardo de Moura
#include "frontends/lean/begin_end_ext.h" #include "frontends/lean/begin_end_ext.h"
namespace lean { namespace lean {
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op) {
return get_refl_extra_info(env, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op) {
return get_subst_extra_info(env, op);
}
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op) {
return get_symm_extra_info(env, op);
}
static name * g_calc_name = nullptr; static name * g_calc_name = nullptr;
static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); } static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); }
@ -108,9 +98,9 @@ static void parse_calc_proof(parser & p, buffer<calc_pred> const & preds, std::v
for (auto const & pred : preds) { for (auto const & pred : preds) {
if (auto refl_it = get_refl_extra_info(env, pred_op(pred))) { if (auto refl_it = get_refl_extra_info(env, pred_op(pred))) {
if (auto subst_it = get_subst_extra_info(env, pred_op(pred))) { if (auto subst_it = get_subst_extra_info(env, pred_op(pred))) {
expr refl = mk_op_fn(p, std::get<0>(*refl_it), std::get<1>(*refl_it)-1, pos); expr refl = mk_op_fn(p, refl_it->m_name, refl_it->m_num_args-1, pos);
expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos); expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos);
expr subst = mk_op_fn(p, std::get<0>(*subst_it), std::get<1>(*subst_it)-2, pos); expr subst = mk_op_fn(p, subst_it->m_name, subst_it->m_num_args-2, pos);
expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos); expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos);
steps.emplace_back(pred, subst_pr); steps.emplace_back(pred, subst_pr);
} }
@ -156,9 +146,9 @@ static void join(parser & p, std::vector<calc_step> const & steps1, std::vector<
continue; continue;
auto trans_it = get_trans_extra_info(env, pred_op(pred1), pred_op(pred2)); auto trans_it = get_trans_extra_info(env, pred_op(pred1), pred_op(pred2));
if (trans_it) { if (trans_it) {
expr trans = mk_op_fn(p, std::get<0>(*trans_it), std::get<2>(*trans_it)-5, pos); expr trans = mk_op_fn(p, trans_it->m_name, trans_it->m_num_args-5, pos);
expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos); expr trans_pr = p.mk_app({trans, pred_lhs(pred1), pred_rhs(pred1), pred_rhs(pred2), pr1, pr2}, pos);
res_steps.emplace_back(calc_pred(std::get<1>(*trans_it), pred_lhs(pred1), pred_rhs(pred2)), trans_pr); res_steps.emplace_back(calc_pred(trans_it->m_res_relation, pred_lhs(pred1), pred_rhs(pred2)), trans_pr);
} else if (pred_op(pred1) == get_eq_name()) { } else if (pred_op(pred1) == get_eq_name()) {
expr trans_right = mk_op_fn(p, get_trans_rel_right_name(), 1, pos); expr trans_right = mk_op_fn(p, get_trans_rel_right_name(), 1, pos);
expr R = mk_op_fn(p, pred_op(pred2), get_arity_of(p, pred_op(pred2))-2, pos); expr R = mk_op_fn(p, pred_op(pred2), get_arity_of(p, pred_op(pred2))-2, pos);

View file

@ -10,12 +10,6 @@ namespace lean {
class parser; class parser;
expr parse_calc(parser & p); expr parse_calc(parser & p);
bool is_calc_annotation(expr const & e); bool is_calc_annotation(expr const & e);
/** \brief Given an operator name \c op, return the symmetry rule associated with, number of arguments, and universe parameters.
Return none if the operator does not have a symmetry rule associated with it.
*/
optional<std::tuple<name, unsigned, unsigned>> get_calc_symm_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_calc_refl_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_calc_subst_info(environment const & env, name const & op);
void initialize_calc(); void initialize_calc();
void finalize_calc(); void finalize_calc();
} }

View file

@ -12,6 +12,7 @@ Author: Leonardo de Moura
#include "library/reducible.h" #include "library/reducible.h"
#include "library/metavar_closure.h" #include "library/metavar_closure.h"
#include "library/local_context.h" #include "library/local_context.h"
#include "library/relation_manager.h"
#include "frontends/lean/util.h" #include "frontends/lean/util.h"
#include "frontends/lean/info_manager.h" #include "frontends/lean/info_manager.h"
#include "frontends/lean/calc.h" #include "frontends/lean/calc.h"
@ -75,10 +76,9 @@ static optional<pair<expr, expr>> apply_symmetry(environment const & env, local_
buffer<expr> args; buffer<expr> args;
expr const & op = get_app_args(e_type, args); expr const & op = get_app_args(e_type, args);
if (is_constant(op)) { if (is_constant(op)) {
if (auto t = get_calc_symm_info(env, const_name(op))) { if (auto info = get_symm_extra_info(env, const_name(op))) {
name symm; unsigned nargs; unsigned nunivs; return mk_op(env, ctx, ngen, tc, info->m_name,
std::tie(symm, nargs, nunivs) = *t; info->m_num_univs, info->m_num_args-1, {e}, cs, g);
return mk_op(env, ctx, ngen, tc, symm, nunivs, nargs-1, {e}, cs, g);
} }
} }
return optional<pair<expr, expr>>(); return optional<pair<expr, expr>>();
@ -95,15 +95,12 @@ static optional<pair<expr, expr>> apply_subst(environment const & env, local_con
buffer<expr> args; buffer<expr> args;
expr const & op = get_app_args(e_type, args); expr const & op = get_app_args(e_type, args);
if (is_constant(op) && args.size() >= 2) { if (is_constant(op) && args.size() >= 2) {
if (auto subst_it = get_calc_subst_info(env, const_name(op))) { if (auto sinfo = get_subst_extra_info(env, const_name(op))) {
name subst; unsigned subst_nargs; unsigned subst_univs; if (auto rinfo = get_refl_extra_info(env, const_name(op))) {
std::tie(subst, subst_nargs, subst_univs) = *subst_it; if (auto refl_pair = mk_op(env, ctx, ngen, tc, rinfo->m_name, rinfo->m_num_univs,
if (auto refl_it = get_calc_refl_info(env, const_name(op))) { rinfo->m_num_args-1, { pred_args[npargs-2] }, cs, g)) {
name refl; unsigned refl_nargs; unsigned refl_univs; return mk_op(env, ctx, ngen, tc, sinfo->m_name, sinfo->m_num_univs,
std::tie(refl, refl_nargs, refl_univs) = *refl_it; sinfo->m_num_args-2, {e, refl_pair->first}, cs, g);
if (auto refl_pair = mk_op(env, ctx, ngen, tc, refl, refl_univs, refl_nargs-1,
{ pred_args[npargs-2] }, cs, g)) {
return mk_op(env, ctx, ngen, tc, subst, subst_univs, subst_nargs-2, {e, refl_pair->first}, cs, g);
} }
} }
} }

View file

@ -32,11 +32,7 @@ static pair<expr, unsigned> extract_arg_types_core(environment const & env, name
return mk_pair(f_type, d.get_num_univ_params()); return mk_pair(f_type, d.get_num_univ_params());
} }
static expr extract_arg_types(environment const & env, name const & f, buffer<expr> & arg_types) { enum class op_kind { Relation, Subst, Trans, Refl, Symm };
return extract_arg_types_core(env, f, arg_types).first;
}
enum class op_kind { Subst, Trans, Refl, Symm };
struct rel_entry { struct rel_entry {
op_kind m_kind; op_kind m_kind;
@ -46,10 +42,10 @@ struct rel_entry {
}; };
struct rel_state { struct rel_state {
typedef name_map<std::tuple<name, unsigned, unsigned>> refl_table; typedef name_map<refl_info> refl_table;
typedef name_map<std::tuple<name, unsigned, unsigned>> subst_table; typedef name_map<subst_info> subst_table;
typedef name_map<std::tuple<name, unsigned, unsigned>> symm_table; typedef name_map<symm_info> symm_table;
typedef rb_map<name_pair, std::tuple<name, name, unsigned>, name_pair_quick_cmp> trans_table; typedef rb_map<name_pair, trans_info, name_pair_quick_cmp> trans_table;
typedef name_map<relation_info> rop_table; typedef name_map<relation_info> rop_table;
trans_table m_trans_table; trans_table m_trans_table;
refl_table m_refl_table; refl_table m_refl_table;
@ -77,12 +73,14 @@ struct rel_state {
expr type = d.get_type(); expr type = d.get_type();
while (is_pi(type)) { while (is_pi(type)) {
if (is_explicit(binding_info(type))) { if (is_explicit(binding_info(type))) {
if (!lhs_pos) if (!lhs_pos) {
lhs_pos = i; lhs_pos = i;
else if (!rhs_pos) } else if (!rhs_pos) {
rhs_pos = i; rhs_pos = i;
else } else {
throw_invalid_relation(rop); lhs_pos = rhs_pos;
rhs_pos = i;
}
} }
type = binding_body(type); type = binding_body(type);
i++; i++;
@ -103,7 +101,7 @@ struct rel_state {
if (nargs < 2) if (nargs < 2)
throw exception("invalid substitution theorem, it must have at least 2 arguments"); throw exception("invalid substitution theorem, it must have at least 2 arguments");
name const & rop = get_fn_const(arg_types[nargs-2], "invalid substitution theorem, penultimate argument must be an operator application"); name const & rop = get_fn_const(arg_types[nargs-2], "invalid substitution theorem, penultimate argument must be an operator application");
m_subst_table.insert(rop, std::make_tuple(subst, nargs, nunivs)); m_subst_table.insert(rop, subst_info(subst, nunivs, nargs));
} }
void add_refl(environment const & env, name const & refl) { void add_refl(environment const & env, name const & refl) {
@ -116,12 +114,14 @@ struct rel_state {
throw exception("invalid reflexivity rule, it must have at least 1 argument"); throw exception("invalid reflexivity rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid reflexivity rule, result type must be an operator application"); name const & rop = get_fn_const(r_type, "invalid reflexivity rule, result type must be an operator application");
register_rop(env, rop); register_rop(env, rop);
m_refl_table.insert(rop, std::make_tuple(refl, nargs, nunivs)); m_refl_table.insert(rop, refl_info(refl, nunivs, nargs));
} }
void add_trans(environment const & env, name const & trans) { void add_trans(environment const & env, name const & trans) {
buffer<expr> arg_types; buffer<expr> arg_types;
expr r_type = extract_arg_types(env, trans, arg_types); auto p = extract_arg_types_core(env, trans, arg_types);
expr r_type = p.first;
unsigned nunivs = p.second;
unsigned nargs = arg_types.size(); unsigned nargs = arg_types.size();
if (nargs < 5) if (nargs < 5)
throw exception("invalid transitivity rule, it must have at least 5 arguments"); throw exception("invalid transitivity rule, it must have at least 5 arguments");
@ -129,7 +129,7 @@ struct rel_state {
name const & op1 = get_fn_const(arg_types[nargs-2], "invalid transitivity rule, penultimate argument must be an operator application"); name const & op1 = get_fn_const(arg_types[nargs-2], "invalid transitivity rule, penultimate argument must be an operator application");
name const & op2 = get_fn_const(arg_types[nargs-1], "invalid transitivity rule, last argument must be an operator application"); name const & op2 = get_fn_const(arg_types[nargs-1], "invalid transitivity rule, last argument must be an operator application");
register_rop(env, rop); register_rop(env, rop);
m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); m_trans_table.insert(name_pair(op1, op2), trans_info(trans, nunivs, nargs, rop));
} }
void add_symm(environment const & env, name const & symm) { void add_symm(environment const & env, name const & symm) {
@ -142,7 +142,7 @@ struct rel_state {
throw exception("invalid symmetry rule, it must have at least 1 argument"); throw exception("invalid symmetry rule, it must have at least 1 argument");
name const & rop = get_fn_const(r_type, "invalid symmetry rule, result type must be an operator application"); name const & rop = get_fn_const(r_type, "invalid symmetry rule, result type must be an operator application");
register_rop(env, rop); register_rop(env, rop);
m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs)); m_symm_table.insert(rop, symm_info(symm, nunivs, nargs));
} }
}; };
@ -154,6 +154,7 @@ struct rel_config {
typedef rel_entry entry; typedef rel_entry entry;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) { static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
switch (e.m_kind) { switch (e.m_kind) {
case op_kind::Relation: s.register_rop(env, e.m_name); break;
case op_kind::Refl: s.add_refl(env, e.m_name); break; case op_kind::Refl: s.add_refl(env, e.m_name); break;
case op_kind::Subst: s.add_subst(env, e.m_name); break; case op_kind::Subst: s.add_subst(env, e.m_name); break;
case op_kind::Trans: s.add_trans(env, e.m_name); break; case op_kind::Trans: s.add_trans(env, e.m_name); break;
@ -184,6 +185,10 @@ struct rel_config {
template class scoped_ext<rel_config>; template class scoped_ext<rel_config>;
typedef scoped_ext<rel_config> rel_ext; typedef scoped_ext<rel_config> rel_ext;
environment add_relation(environment const & env, name const & n, bool persistent) {
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Relation, n), persistent);
}
environment add_subst(environment const & env, name const & n, bool persistent) { environment add_subst(environment const & env, name const & n, bool persistent) {
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Subst, n), persistent); return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Subst, n), persistent);
} }
@ -200,29 +205,29 @@ environment add_trans(environment const & env, name const & n, bool persistent)
return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Trans, n), persistent); return rel_ext::add_entry(env, get_dummy_ios(), rel_entry(op_kind::Trans, n), persistent);
} }
static optional<std::tuple<name, unsigned, unsigned>> get_info(name_map<std::tuple<name, unsigned, unsigned>> const & table, name const & op) { static optional<relation_lemma_info> get_info(name_map<relation_lemma_info> const & table, name const & op) {
if (auto it = table.find(op)) { if (auto it = table.find(op)) {
return optional<std::tuple<name, unsigned, unsigned>>(*it); return optional<relation_lemma_info>(*it);
} else { } else {
return optional<std::tuple<name, unsigned, unsigned>>(); return optional<relation_lemma_info>();
} }
} }
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op) { optional<refl_info> get_refl_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_refl_table, op); return get_info(rel_ext::get_state(env).m_refl_table, op);
} }
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op) { optional<subst_info> get_subst_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_subst_table, op); return get_info(rel_ext::get_state(env).m_subst_table, op);
} }
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op) { optional<symm_info> get_symm_extra_info(environment const & env, name const & op) {
return get_info(rel_ext::get_state(env).m_symm_table, op); return get_info(rel_ext::get_state(env).m_symm_table, op);
} }
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2) { optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2) {
if (auto it = rel_ext::get_state(env).m_trans_table.find(mk_pair(op1, op2))) { if (auto it = rel_ext::get_state(env).m_trans_table.find(mk_pair(op1, op2))) {
return optional<std::tuple<name, name, unsigned>>(*it); return optional<trans_info>(*it);
} else { } else {
return optional<std::tuple<name, name, unsigned>>(); return optional<trans_info>();
} }
} }
@ -232,21 +237,21 @@ bool is_subst_relation(environment const & env, name const & op) {
optional<name> get_refl_info(environment const & env, name const & op) { optional<name> get_refl_info(environment const & env, name const & op) {
if (auto it = get_refl_extra_info(env, op)) if (auto it = get_refl_extra_info(env, op))
return optional<name>(std::get<0>(*it)); return optional<name>(it->m_name);
else else
return optional<name>(); return optional<name>();
} }
optional<name> get_symm_info(environment const & env, name const & op) { optional<name> get_symm_info(environment const & env, name const & op) {
if (auto it = get_symm_extra_info(env, op)) if (auto it = get_symm_extra_info(env, op))
return optional<name>(std::get<0>(*it)); return optional<name>(it->m_name);
else else
return optional<name>(); return optional<name>();
} }
optional<name> get_trans_info(environment const & env, name const & op) { optional<name> get_trans_info(environment const & env, name const & op) {
if (auto it = get_trans_extra_info(env, op, op)) if (auto it = get_trans_extra_info(env, op, op))
return optional<name>(std::get<0>(*it)); return optional<name>(it->m_name);
else else
return optional<name>(); return optional<name>();
} }

View file

@ -27,17 +27,46 @@ public:
/** \brief Return true if \c rop is a registered equivalence relation in the given manager */ /** \brief Return true if \c rop is a registered equivalence relation in the given manager */
bool is_equivalence(environment const & env, name const & rop); bool is_equivalence(environment const & env, name const & rop);
/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information */ /** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information
Lean assumes that the arguments of the binary relation are the last two explicit arguments.
Everything else is assumed to be a parameter.
*/
relation_info const * get_relation_info(environment const & env, name const & rop); relation_info const * get_relation_info(environment const & env, name const & rop);
inline bool is_relation(environment const & env, name const & rop) { return get_relation_info(env, rop) != nullptr; }
/** \brief Declare a new binary relation named \c n */
environment add_relation(environment const & env, name const & n, bool persistent = true);
/** \brief Declare subst/refl/symm/trans lemmas for a binary relation,
* it also declares the relation if it has not been declared yet */
environment add_subst(environment const & env, name const & n, bool persistent = true); environment add_subst(environment const & env, name const & n, bool persistent = true);
environment add_refl(environment const & env, name const & n, bool persistent = true); environment add_refl(environment const & env, name const & n, bool persistent = true);
environment add_symm(environment const & env, name const & n, bool persistent = true); environment add_symm(environment const & env, name const & n, bool persistent = true);
environment add_trans(environment const & env, name const & n, bool persistent = true); environment add_trans(environment const & env, name const & n, bool persistent = true);
optional<std::tuple<name, unsigned, unsigned>> get_refl_extra_info(environment const & env, name const & op);
optional<std::tuple<name, unsigned, unsigned>> get_subst_extra_info(environment const & env, name const & op); struct relation_lemma_info {
optional<std::tuple<name, unsigned, unsigned>> get_symm_extra_info(environment const & env, name const & op); name m_name;
optional<std::tuple<name, name, unsigned>> get_trans_extra_info(environment const & env, name const & op1, name const & op2); unsigned m_num_univs;
unsigned m_num_args;
relation_lemma_info() {}
relation_lemma_info(name const & n, unsigned nunivs, unsigned nargs):m_name(n), m_num_univs(nunivs), m_num_args(nargs) {}
};
typedef relation_lemma_info refl_info;
typedef relation_lemma_info symm_info;
typedef relation_lemma_info subst_info;
struct trans_info : public relation_lemma_info {
name m_res_relation;
trans_info() {}
trans_info(name const & n, unsigned nunivs, unsigned nargs, name const & rel):
relation_lemma_info(n, nunivs, nargs), m_res_relation(rel) {}
};
optional<subst_info> get_subst_extra_info(environment const & env, name const & op);
optional<refl_info> get_refl_extra_info(environment const & env, name const & op);
optional<symm_info> get_symm_extra_info(environment const & env, name const & op);
optional<trans_info> get_trans_extra_info(environment const & env, name const & op1, name const & op2);
optional<name> get_refl_info(environment const & env, name const & op); optional<name> get_refl_info(environment const & env, name const & op);
optional<name> get_symm_info(environment const & env, name const & op); optional<name> get_symm_info(environment const & env, name const & op);
optional<name> get_trans_info(environment const & env, name const & op); optional<name> get_trans_info(environment const & env, name const & op);

View file

@ -65,10 +65,10 @@ tactic trans_tactic(elaborate_fn const & elab, expr const & e) {
if (!op) if (!op)
return proof_state_seq(); return proof_state_seq();
if (auto info = get_trans_extra_info(env, *op, *op)) { if (auto info = get_trans_extra_info(env, *op, *op)) {
expr pr = mk_explicit(mk_constant(std::get<0>(*info))); expr pr = mk_explicit(mk_constant(info->m_name));
unsigned nparams = std::get<2>(*info); unsigned nargs = info->m_num_args;
lean_assert(nparams >= 5); lean_assert(nargs >= 5);
for (unsigned i = 0; i < nparams - 4; i++) for (unsigned i = 0; i < nargs - 4; i++)
pr = mk_app(pr, mk_expr_placeholder()); pr = mk_app(pr, mk_expr_placeholder());
pr = mk_app(pr, e); pr = mk_app(pr, e);
return apply_tactic(elab, pr)(env, ios, s); return apply_tactic(elab, pr)(env, ios, s);