feat(library/tactic/rewrite_tactic): add custom matcher pluging for rewriter
This commit is contained in:
parent
8912c759dd
commit
0e05c239a5
3 changed files with 372 additions and 16 deletions
|
@ -71,6 +71,8 @@ public:
|
|||
static location mk_hypotheses_at(buffer<name> const & hs, buffer<occurrence> const & occs);
|
||||
static location mk_at(occurrence const & g_occs, buffer<name> const & hs, buffer<occurrence> const & hs_occs);
|
||||
|
||||
bool is_goal_only() const { return m_kind == GoalOnly; }
|
||||
|
||||
optional<occurrence> includes_goal() const;
|
||||
optional<occurrence> includes_hypothesis(name const & h) const;
|
||||
void get_explicit_hypotheses_names(buffer<name> & r) const;
|
||||
|
|
|
@ -5,10 +5,22 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <string>
|
||||
#include "util/interrupt.h"
|
||||
#include "util/list_fn.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/replace_fn.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "library/kernel_serializer.h"
|
||||
#include "library/reducible.h"
|
||||
#include "library/util.h"
|
||||
#include "library/match.h"
|
||||
#include "library/projection.h"
|
||||
#include "library/generic_exception.h"
|
||||
#include "library/tactic/rewrite_tactic.h"
|
||||
#include "library/tactic/expr_to_tactic.h"
|
||||
|
||||
// #define TRACE_MATCH_PLUGIN
|
||||
|
||||
namespace lean {
|
||||
class unfold_info {
|
||||
name m_name;
|
||||
|
@ -33,7 +45,9 @@ deserializer & operator>>(deserializer & d, unfold_info & e) {
|
|||
}
|
||||
|
||||
class rewrite_info {
|
||||
public:
|
||||
enum multiplicity { Once, AtMostN, ExactlyN, ZeroOrMore, OneOrMore };
|
||||
private:
|
||||
bool m_symm;
|
||||
multiplicity m_multiplicity;
|
||||
optional<unsigned> m_num;
|
||||
|
@ -60,7 +74,7 @@ public:
|
|||
}
|
||||
|
||||
static rewrite_info mk_one_or_more(bool symm, location const & loc) {
|
||||
return rewrite_info(symm, ZeroOrMore, optional<unsigned>(), loc);
|
||||
return rewrite_info(symm, OneOrMore, optional<unsigned>(), loc);
|
||||
}
|
||||
|
||||
bool symm() const {
|
||||
|
@ -80,7 +94,7 @@ public:
|
|||
return *m_num;
|
||||
}
|
||||
|
||||
location get_location() const { return m_location; }
|
||||
location const & get_location() const { return m_location; }
|
||||
|
||||
friend serializer & operator<<(serializer & s, rewrite_info const & elem);
|
||||
friend deserializer & operator>>(deserializer & d, rewrite_info & e);
|
||||
|
@ -135,6 +149,11 @@ bool is_rewrite_unfold_step(expr const & e) {
|
|||
return is_macro(e) && macro_def(e).get_name() == *g_rewrite_unfold_name;
|
||||
}
|
||||
|
||||
unfold_info const & get_rewrite_unfold_info(expr const & e) {
|
||||
lean_assert(is_rewrite_unfold_step(e));
|
||||
return static_cast<rewrite_unfold_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
class rewrite_element_macro_cell : public macro_definition_cell {
|
||||
rewrite_info m_info;
|
||||
public:
|
||||
|
@ -201,6 +220,11 @@ expr const & get_rewrite_pattern(expr const & e) {
|
|||
return macro_arg(e, 1);
|
||||
}
|
||||
|
||||
rewrite_info const & get_rewrite_info(expr const & e) {
|
||||
lean_assert(is_rewrite_step(e));
|
||||
return static_cast<rewrite_element_macro_cell const*>(macro_def(e).raw())->get_info();
|
||||
}
|
||||
|
||||
expr mk_rewrite_tactic_expr(buffer<expr> const & elems) {
|
||||
lean_assert(std::all_of(elems.begin(), elems.end(), [](expr const & e) {
|
||||
return is_rewrite_step(e) || is_rewrite_unfold_step(e);
|
||||
|
@ -208,16 +232,349 @@ expr mk_rewrite_tactic_expr(buffer<expr> const & elems) {
|
|||
return mk_app(*g_rewrite_tac, mk_expr_list(elems.size(), elems.data()));
|
||||
}
|
||||
|
||||
tactic mk_rewrite_tactic(buffer<expr> const & elems) {
|
||||
// TODO(Leo)
|
||||
std::cout << "rewrite_tactic\n";
|
||||
for (auto const & e : elems) {
|
||||
if (is_rewrite_step(e))
|
||||
std::cout << ">> " << get_rewrite_rule(e) << "\n";
|
||||
else
|
||||
std::cout << ">> unfold\n";
|
||||
class rewrite_match_plugin : public match_plugin {
|
||||
#ifdef TRACE_MATCH_PLUGIN
|
||||
io_state m_ios;
|
||||
#endif
|
||||
type_checker & m_tc;
|
||||
public:
|
||||
#ifdef TRACE_MATCH_PLUGIN
|
||||
rewrite_match_plugin(io_state const & ios, type_checker & tc):
|
||||
m_ios(ios), m_tc(tc) {}
|
||||
#else
|
||||
rewrite_match_plugin(io_state const &, type_checker & tc):
|
||||
m_tc(tc) {}
|
||||
#endif
|
||||
|
||||
bool is_projection_app(expr const & e) const {
|
||||
expr const & f = get_app_fn(e);
|
||||
return is_constant(f) && is_projection(m_tc.env(), const_name(f));
|
||||
}
|
||||
return id_tactic();
|
||||
|
||||
virtual bool on_failure(expr const & p, expr const & t, match_context & ctx) const {
|
||||
try {
|
||||
constraint_seq cs;
|
||||
expr p1 = is_projection_app(p) ? p : m_tc.whnf(p, cs);
|
||||
expr t1 = is_projection_app(t) ? t : m_tc.whnf(t, cs);
|
||||
return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1);
|
||||
} catch (exception&) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
virtual lbool pre(expr const & p, expr const & t, match_context & ctx) const {
|
||||
if (!is_app(p) || !is_app(t))
|
||||
return l_undef;
|
||||
expr const & p_fn = get_app_fn(p);
|
||||
if (!is_constant(p_fn))
|
||||
return l_false;
|
||||
expr const & t_fn = get_app_fn(t);
|
||||
if (!is_constant(t_fn))
|
||||
return l_false;
|
||||
if (!ctx.match(p_fn, t_fn))
|
||||
return l_false;
|
||||
projection_info const * info = get_projection_info(m_tc.env(), const_name(p_fn));
|
||||
if (!info || !info->m_inst_implicit)
|
||||
return l_undef; // use default matcher
|
||||
buffer<expr> p_args, t_args;
|
||||
get_app_args(p, p_args);
|
||||
get_app_args(t, t_args);
|
||||
if (p_args.size() != t_args.size())
|
||||
return l_false;
|
||||
for (unsigned i = 0; i < p_args.size(); i++) {
|
||||
if (i == info->m_nparams)
|
||||
continue; // skip structure
|
||||
if (!ctx.match(p_args[i], t_args[i]))
|
||||
return l_false;
|
||||
}
|
||||
return l_true;
|
||||
}
|
||||
};
|
||||
|
||||
class rewrite_fn {
|
||||
environment m_env;
|
||||
io_state m_ios;
|
||||
elaborate_fn m_elab;
|
||||
proof_state m_ps;
|
||||
name_generator m_ngen;
|
||||
type_checker_ptr m_tc;
|
||||
rewrite_match_plugin m_mplugin;
|
||||
goal m_g;
|
||||
substitution m_subst;
|
||||
expr m_expr_loc; // auxiliary expression used for error localization
|
||||
|
||||
buffer<optional<level>> m_lsubst; // auxiliary buffer for pattern matching
|
||||
buffer<optional<expr>> m_esubst; // auxiliary buffer for pattern matching
|
||||
|
||||
[[ noreturn ]] void throw_rewrite_exception(char const * msg) {
|
||||
throw_generic_exception(msg, m_expr_loc);
|
||||
}
|
||||
|
||||
[[ noreturn ]] void throw_rewrite_exception(sstream const & strm) {
|
||||
throw_generic_exception(strm, m_expr_loc);
|
||||
}
|
||||
|
||||
expr mk_meta(expr const & type) {
|
||||
return m_g.mk_meta(m_ngen.next(), type);
|
||||
}
|
||||
|
||||
void elaborate_elems(buffer<expr> const & elems, buffer<expr> & new_elems) {
|
||||
for (expr const & elem : elems) {
|
||||
if (is_rewrite_unfold_step(elem)) {
|
||||
// nothing to be done
|
||||
new_elems.push_back(elem);
|
||||
} else {
|
||||
expr rule = get_rewrite_rule(elem);
|
||||
auto rcs = m_elab(m_g, m_ngen.mk_child(), rule, false);
|
||||
rule = rcs.first;
|
||||
if (has_rewrite_pattern(elem)) {
|
||||
auto pcs = m_elab(m_g, m_ngen.mk_child(), get_rewrite_pattern(elem), false);
|
||||
expr pattern = pcs.first;
|
||||
// We ignore any constraints generated when elaborating patterns.
|
||||
// The pattern is just a hint to locate positions where the rule should be applied.
|
||||
expr new_args[2] = { rule, pattern };
|
||||
new_elems.push_back(mk_macro(macro_def(elem), 2, new_args));
|
||||
} else {
|
||||
new_elems.push_back(mk_macro(macro_def(elem), 1, &rule));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool process_unfold_step(expr const & elem) {
|
||||
lean_assert(is_rewrite_unfold_step(elem));
|
||||
// TODO(Leo)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Replace metavariables with special metavariables for the higher-order matcher. This is method is used when
|
||||
// converting an expression into a pattern.
|
||||
expr to_meta_idx(expr const & e) {
|
||||
m_lsubst.clear();
|
||||
m_esubst.clear();
|
||||
name_map<expr> emap;
|
||||
name_map<level> lmap;
|
||||
|
||||
auto to_meta_idx = [&](level const & l) {
|
||||
return replace(l, [&](level const & l) {
|
||||
if (!has_meta(l)) {
|
||||
return some_level(l);
|
||||
} else if (is_meta(l)) {
|
||||
if (auto it = lmap.find(meta_id(l))) {
|
||||
return some_level(*it);
|
||||
} else {
|
||||
unsigned next_idx = m_lsubst.size();
|
||||
level r = mk_idx_meta_univ(next_idx);
|
||||
m_lsubst.push_back(none_level());
|
||||
lmap.insert(meta_id(l), r);
|
||||
return some_level(r);
|
||||
}
|
||||
} else {
|
||||
return none_level();
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
return replace(e, [&](expr const & e, unsigned) {
|
||||
if (!has_metavar(e)) {
|
||||
return some_expr(e); // done
|
||||
} else if (is_binding(e)) {
|
||||
throw_rewrite_exception("invalid rewrite tactic, pattern contains binders");
|
||||
} else if (is_meta(e)) {
|
||||
expr const & fn = get_app_fn(e);
|
||||
lean_assert(is_metavar(fn));
|
||||
name const & n = mlocal_name(fn);
|
||||
if (auto it = emap.find(n)) {
|
||||
return some_expr(*it);
|
||||
} else {
|
||||
unsigned next_idx = m_esubst.size();
|
||||
expr r = mk_idx_meta(next_idx, m_tc->infer(e).first);
|
||||
m_esubst.push_back(none_expr());
|
||||
emap.insert(n, r);
|
||||
return some_expr(r);
|
||||
}
|
||||
} else if (is_constant(e)) {
|
||||
levels ls = map(const_levels(e), [&](level const & l) { return to_meta_idx(l); });
|
||||
return some_expr(update_constant(e, ls));
|
||||
} else {
|
||||
return none_expr();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Given the rewrite step \c e, return a pattern to be used to locate the term to be rewritten.
|
||||
expr get_pattern(expr const & e) {
|
||||
lean_assert(is_rewrite_step(e));
|
||||
if (has_rewrite_pattern(e)) {
|
||||
return to_meta_idx(get_rewrite_pattern(e));
|
||||
} else {
|
||||
// Remark: we discard constraints generated producing the pattern.
|
||||
// Patterns are only used to locate positions where the rule should be applied.
|
||||
expr rule = get_rewrite_rule(e);
|
||||
expr rule_type = m_tc->whnf(m_tc->infer(rule).first).first;
|
||||
while (is_pi(rule_type)) {
|
||||
expr meta = mk_meta(binding_domain(rule_type));
|
||||
rule_type = m_tc->whnf(instantiate(binding_body(rule_type), meta)).first;
|
||||
}
|
||||
if (!is_eq(rule_type))
|
||||
throw_rewrite_exception("invalid rewrite tactic, given lemma is not an equality");
|
||||
if (get_rewrite_info(e).symm()) {
|
||||
return to_meta_idx(app_arg(rule_type));
|
||||
} else {
|
||||
return to_meta_idx(app_arg(app_fn(rule_type)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set m_esubst and m_lsubst elements to none
|
||||
void reset_subst() {
|
||||
for (optional<level> & l : m_lsubst)
|
||||
l = none_level();
|
||||
for (optional<expr> & e : m_esubst)
|
||||
e = none_expr();
|
||||
}
|
||||
|
||||
optional<expr> find_target(expr const & e, expr const & pattern) {
|
||||
optional<expr> found;
|
||||
for_each(e, [&](expr const & t, unsigned) {
|
||||
if (found)
|
||||
return false; // stop search
|
||||
if (closed(t)) {
|
||||
lean_assert(std::all_of(m_esubst.begin(), m_esubst.end(), [&](optional<expr> const & e) { return !e; }));
|
||||
bool assigned = false;
|
||||
bool r = match(pattern, t, m_lsubst, m_esubst, nullptr, nullptr, &m_mplugin, &assigned);
|
||||
if (assigned)
|
||||
reset_subst();
|
||||
if (r) {
|
||||
found = t;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
|
||||
bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pattern) {
|
||||
// TODO(Leo)
|
||||
return false;
|
||||
}
|
||||
|
||||
bool process_rewrite_goal(expr const & elem, expr const & pattern) {
|
||||
expr goal_type = m_g.get_type();
|
||||
if (auto it = find_target(goal_type, pattern)) {
|
||||
regular(m_env, m_ios) << "found: " << *it << "\n";
|
||||
// TODO(Leo)
|
||||
return false;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool process_rewrite_single_step(expr const & elem, expr const & pattern) {
|
||||
check_system("rewrite tactic");
|
||||
rewrite_info const & info = get_rewrite_info(elem);
|
||||
location const & loc = info.get_location();
|
||||
if (loc.is_goal_only())
|
||||
return process_rewrite_goal(elem, pattern);
|
||||
bool progress = false;
|
||||
buffer<expr> hyps;
|
||||
m_g.get_hyps(hyps);
|
||||
for (expr const & h : hyps) {
|
||||
if (!loc.includes_hypothesis(local_pp_name(h)))
|
||||
continue;
|
||||
if (process_rewrite_hypothesis(h, elem, pattern))
|
||||
progress = true;
|
||||
}
|
||||
if (loc.includes_goal() && process_rewrite_goal(elem, pattern))
|
||||
progress = true;
|
||||
return progress;
|
||||
}
|
||||
|
||||
bool process_rewrite_step(expr const & elem) {
|
||||
lean_assert(is_rewrite_step(elem));
|
||||
expr pattern = get_pattern(elem);
|
||||
regular(m_env, m_ios) << "pattern: " << pattern << "\n";
|
||||
rewrite_info const & info = get_rewrite_info(elem);
|
||||
unsigned num;
|
||||
switch (info.get_multiplicity()) {
|
||||
case rewrite_info::Once:
|
||||
return process_rewrite_single_step(elem, pattern);
|
||||
case rewrite_info::AtMostN:
|
||||
num = info.num();
|
||||
for (unsigned i = 0; i < num; i++) {
|
||||
if (!process_rewrite_single_step(elem, pattern))
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
case rewrite_info::ExactlyN:
|
||||
num = info.num();
|
||||
for (unsigned i = 0; i < num; i++) {
|
||||
if (!process_rewrite_single_step(elem, pattern))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
case rewrite_info::ZeroOrMore:
|
||||
while (true) {
|
||||
if (!process_rewrite_single_step(elem, pattern))
|
||||
return true;
|
||||
}
|
||||
case rewrite_info::OneOrMore:
|
||||
if (!process_rewrite_single_step(elem, pattern))
|
||||
return false;
|
||||
while (true) {
|
||||
if (!process_rewrite_single_step(elem, pattern))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
// Process the given rewrite element/step. This method destructively update
|
||||
// m_g, m_subst, m_ngen. It returns true if it succeeded and false otherwise.
|
||||
bool process_step(expr const & elem) {
|
||||
if (is_rewrite_unfold_step(elem)) {
|
||||
return process_unfold_step(elem);
|
||||
} else {
|
||||
return process_rewrite_step(elem);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
rewrite_fn(environment const & env, io_state const & ios, elaborate_fn const & elab, proof_state const & ps):
|
||||
m_env(env), m_ios(ios), m_elab(elab), m_ps(ps), m_ngen(ps.get_ngen()),
|
||||
m_tc(mk_type_checker(m_env, m_ngen.mk_child(), ps.relax_main_opaque())),
|
||||
m_mplugin(m_ios, *m_tc) {
|
||||
goals const & gs = m_ps.get_goals();
|
||||
lean_assert(gs);
|
||||
m_g = head(gs);
|
||||
m_subst = m_ps.get_subst();
|
||||
}
|
||||
|
||||
proof_state_seq operator()(buffer<expr> const & elems) {
|
||||
std::cout << "rewrite_tactic\n";
|
||||
buffer<expr> new_elems;
|
||||
elaborate_elems(elems, new_elems);
|
||||
|
||||
lean_assert(elems.size() == new_elems.size());
|
||||
for (unsigned i = 0; i < new_elems.size(); i++) {
|
||||
flet<expr> set1(m_expr_loc, elems[i]);
|
||||
if (!process_step(new_elems[i])) {
|
||||
return proof_state_seq();
|
||||
}
|
||||
}
|
||||
|
||||
return proof_state_seq(m_ps);
|
||||
}
|
||||
};
|
||||
|
||||
tactic mk_rewrite_tactic(elaborate_fn const & elab, buffer<expr> const & elems) {
|
||||
return tactic([=](environment const & env, io_state const & ios, proof_state const & s) {
|
||||
goals const & gs = s.get_goals();
|
||||
if (empty(gs))
|
||||
return proof_state_seq();
|
||||
return rewrite_fn(env, ios, elab, s)(elems);
|
||||
});
|
||||
}
|
||||
|
||||
void initialize_rewrite_tactic() {
|
||||
|
@ -246,14 +603,14 @@ void initialize_rewrite_tactic() {
|
|||
return mk_macro(def, num, args);
|
||||
});
|
||||
register_tac(rewrite_tac_name,
|
||||
[](type_checker &, elaborate_fn const &, expr const & e, pos_info_provider const *) {
|
||||
[](type_checker &, elaborate_fn const & elab, expr const & e, pos_info_provider const *) {
|
||||
buffer<expr> args;
|
||||
get_tactic_expr_list_elements(app_arg(e), args, "invalid 'rewrite' tactic, invalid argument");
|
||||
for (expr const & arg : args) {
|
||||
if (!is_rewrite_step(arg) && !is_rewrite_unfold_step(arg))
|
||||
throw expr_to_tactic_exception(e, "invalid 'rewrite' tactic, invalid argument");
|
||||
}
|
||||
return mk_rewrite_tactic(args);
|
||||
return mk_rewrite_tactic(elab, args);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -21,9 +21,6 @@ bool is_rewrite_step(expr const & e);
|
|||
/** \brief Create a rewrite tactic expression, where elems was created using \c mk_rewrite_* procedures. */
|
||||
expr mk_rewrite_tactic_expr(buffer<expr> const & elems);
|
||||
|
||||
/** \brief Create rewrite tactic that applies the given rewrite elements */
|
||||
tactic mk_rewrite_tactic(buffer<expr> const & elems);
|
||||
|
||||
void initialize_rewrite_tactic();
|
||||
void finalize_rewrite_tactic();
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue