feat(library/tactic/rewrite_tactic): elaborate rewrite rule using unifier
This commit is contained in:
parent
49323ab598
commit
09818adf90
1 changed files with 117 additions and 28 deletions
|
@ -15,9 +15,13 @@ Author: Leonardo de Moura
|
||||||
#include "library/util.h"
|
#include "library/util.h"
|
||||||
#include "library/match.h"
|
#include "library/match.h"
|
||||||
#include "library/projection.h"
|
#include "library/projection.h"
|
||||||
|
#include "library/local_context.h"
|
||||||
|
#include "library/unifier.h"
|
||||||
|
#include "library/util.h"
|
||||||
#include "library/generic_exception.h"
|
#include "library/generic_exception.h"
|
||||||
#include "library/tactic/rewrite_tactic.h"
|
#include "library/tactic/rewrite_tactic.h"
|
||||||
#include "library/tactic/expr_to_tactic.h"
|
#include "library/tactic/expr_to_tactic.h"
|
||||||
|
#include "library/tactic/class_instance_synth.h"
|
||||||
|
|
||||||
// #define TRACE_MATCH_PLUGIN
|
// #define TRACE_MATCH_PLUGIN
|
||||||
|
|
||||||
|
@ -339,6 +343,7 @@ class rewrite_fn {
|
||||||
type_checker_ptr m_tc;
|
type_checker_ptr m_tc;
|
||||||
rewrite_match_plugin m_mplugin;
|
rewrite_match_plugin m_mplugin;
|
||||||
goal m_g;
|
goal m_g;
|
||||||
|
local_context m_ctx;
|
||||||
substitution m_subst;
|
substitution m_subst;
|
||||||
expr m_expr_loc; // auxiliary expression used for error localization
|
expr m_expr_loc; // auxiliary expression used for error localization
|
||||||
|
|
||||||
|
@ -353,6 +358,13 @@ class rewrite_fn {
|
||||||
throw_generic_exception(strm, m_expr_loc);
|
throw_generic_exception(strm, m_expr_loc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void update_goal(goal const & g) {
|
||||||
|
m_g = g;
|
||||||
|
buffer<expr> hyps;
|
||||||
|
g.get_hyps(hyps);
|
||||||
|
m_ctx = local_context(to_list(hyps));
|
||||||
|
}
|
||||||
|
|
||||||
expr mk_meta(expr const & type) {
|
expr mk_meta(expr const & type) {
|
||||||
return m_g.mk_meta(m_ngen.next(), type);
|
return m_g.mk_meta(m_ngen.next(), type);
|
||||||
}
|
}
|
||||||
|
@ -473,10 +485,87 @@ class rewrite_fn {
|
||||||
e = none_expr();
|
e = none_expr();
|
||||||
}
|
}
|
||||||
|
|
||||||
optional<expr> find_target(expr const & e, expr const & pattern) {
|
pair<expr, constraint> mk_class_instance_elaborator(expr const & type) {
|
||||||
optional<expr> found;
|
unifier_config cfg;
|
||||||
|
cfg.m_conservative = true;
|
||||||
|
bool use_local_instances = true;
|
||||||
|
bool is_strict = false;
|
||||||
|
return ::lean::mk_class_instance_elaborator(m_env, m_ios, m_ctx, m_ngen.next(), optional<name>(),
|
||||||
|
m_ps.relax_main_opaque(), use_local_instances, is_strict,
|
||||||
|
some_expr(type), m_expr_loc.get_tag(), cfg, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// rule, new_t
|
||||||
|
typedef optional<pair<expr, expr>> unify_result;
|
||||||
|
|
||||||
|
unify_result unify_target(expr const & t, expr const & pre_elem) {
|
||||||
|
try {
|
||||||
|
expr rule = get_rewrite_rule(pre_elem);
|
||||||
|
auto rcs = m_elab(m_g, m_ngen.mk_child(), rule, false);
|
||||||
|
rule = rcs.first;
|
||||||
|
buffer<constraint> cs;
|
||||||
|
to_buffer(rcs.second, cs);
|
||||||
|
constraint_seq cs_seq;
|
||||||
|
expr rule_type = m_tc->whnf(m_tc->infer(rule, cs_seq), cs_seq);
|
||||||
|
while (is_pi(rule_type)) {
|
||||||
|
expr meta;
|
||||||
|
if (binding_info(rule_type).is_inst_implicit()) {
|
||||||
|
auto mc = mk_class_instance_elaborator(binding_domain(rule_type));
|
||||||
|
meta = mc.first;
|
||||||
|
cs_seq += mc.second;
|
||||||
|
} else {
|
||||||
|
meta = mk_meta(binding_domain(rule_type));
|
||||||
|
}
|
||||||
|
rule_type = m_tc->whnf(instantiate(binding_body(rule_type), meta), cs_seq);
|
||||||
|
rule = mk_app(rule, meta);
|
||||||
|
}
|
||||||
|
lean_assert(is_eq(rule_type));
|
||||||
|
bool symm = get_rewrite_info(pre_elem).symm();
|
||||||
|
expr src;
|
||||||
|
if (symm)
|
||||||
|
src = app_arg(rule_type);
|
||||||
|
else
|
||||||
|
src = app_arg(app_fn(rule_type));
|
||||||
|
if (!m_tc->is_def_eq(t, src, justification(), cs_seq))
|
||||||
|
return unify_result();
|
||||||
|
cs_seq.linearize(cs);
|
||||||
|
unifier_config cfg;
|
||||||
|
cfg.m_conservative = false;
|
||||||
|
unify_result_seq rseq = unify(m_env, cs.size(), cs.data(), m_ngen.mk_child(), m_subst, cfg);
|
||||||
|
if (auto p = rseq.pull()) {
|
||||||
|
substitution new_subst = p->first.first;
|
||||||
|
constraints new_postponed = p->first.second;
|
||||||
|
if (new_postponed)
|
||||||
|
return unify_result(); // all constraints must be solved
|
||||||
|
rule = new_subst.instantiate_all(rule);
|
||||||
|
rule_type = new_subst.instantiate_all(rule_type);
|
||||||
|
if (has_expr_metavar_strict(rule) || has_expr_metavar_strict(rule_type))
|
||||||
|
return unify_result(); // rule was not completely instantiate.
|
||||||
|
m_subst = new_subst;
|
||||||
|
expr lhs = app_arg(app_fn(rule_type));
|
||||||
|
expr rhs = app_arg(rule_type);
|
||||||
|
if (symm) {
|
||||||
|
rule = mk_symm(*m_tc, rule);
|
||||||
|
return unify_result(rule, lhs);
|
||||||
|
} else {
|
||||||
|
return unify_result(rule, rhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (exception&) {}
|
||||||
|
return unify_result();
|
||||||
|
}
|
||||||
|
|
||||||
|
// target, rule, new_target : represents the rewrite rule : target = new_target
|
||||||
|
typedef optional<std::tuple<expr, expr, expr>> find_result;
|
||||||
|
|
||||||
|
// Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule
|
||||||
|
// in the rewrite step \c pre_elem with \c t.
|
||||||
|
// When successful, this method returns the target \c t, the fully elaborated rule \c r,
|
||||||
|
// and the new value \c new_t (i.e., the expression that will replace \c t).
|
||||||
|
find_result find_target(expr const & e, expr const & pattern, expr const & pre_elem) {
|
||||||
|
find_result result;
|
||||||
for_each(e, [&](expr const & t, unsigned) {
|
for_each(e, [&](expr const & t, unsigned) {
|
||||||
if (found)
|
if (result)
|
||||||
return false; // stop search
|
return false; // stop search
|
||||||
if (closed(t)) {
|
if (closed(t)) {
|
||||||
lean_assert(std::all_of(m_esubst.begin(), m_esubst.end(), [&](optional<expr> const & e) { return !e; }));
|
lean_assert(std::all_of(m_esubst.begin(), m_esubst.end(), [&](optional<expr> const & e) { return !e; }));
|
||||||
|
@ -485,52 +574,52 @@ class rewrite_fn {
|
||||||
if (assigned)
|
if (assigned)
|
||||||
reset_subst();
|
reset_subst();
|
||||||
if (r) {
|
if (r) {
|
||||||
found = t;
|
if (auto p = unify_target(t, pre_elem)) {
|
||||||
return false;
|
result = std::make_tuple(t, p->first, p->second);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
});
|
});
|
||||||
return found;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pattern) {
|
bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pre_elem, expr const & pattern) {
|
||||||
// TODO(Leo)
|
// TODO(Leo)
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_rewrite_goal(expr const & elem, expr const & pattern) {
|
bool process_rewrite_goal(expr const & elem, expr const & pre_elem, expr const & pattern) {
|
||||||
expr goal_type = m_g.get_type();
|
expr goal_type = m_g.get_type();
|
||||||
if (auto it = find_target(goal_type, pattern)) {
|
if (auto it = find_target(goal_type, pattern, pre_elem)) {
|
||||||
regular(m_env, m_ios) << "found: " << *it << "\n";
|
regular(m_env, m_ios) << "FOUND\n" << std::get<0>(*it) << "\n==\n" << std::get<1>(*it) << "\n==>\n" << std::get<2>(*it) << "\n";
|
||||||
// TODO(Leo)
|
// TODO(Leo)
|
||||||
return false;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_rewrite_single_step(expr const & elem, expr const & pattern) {
|
bool process_rewrite_single_step(expr const & elem, expr const & pre_elem, expr const & pattern) {
|
||||||
check_system("rewrite tactic");
|
check_system("rewrite tactic");
|
||||||
rewrite_info const & info = get_rewrite_info(elem);
|
rewrite_info const & info = get_rewrite_info(elem);
|
||||||
location const & loc = info.get_location();
|
location const & loc = info.get_location();
|
||||||
if (loc.is_goal_only())
|
if (loc.is_goal_only())
|
||||||
return process_rewrite_goal(elem, pattern);
|
return process_rewrite_goal(elem, pre_elem, pattern);
|
||||||
bool progress = false;
|
bool progress = false;
|
||||||
buffer<expr> hyps;
|
buffer<expr> hyps;
|
||||||
m_g.get_hyps(hyps);
|
m_g.get_hyps(hyps);
|
||||||
for (expr const & h : hyps) {
|
for (expr const & h : hyps) {
|
||||||
if (!loc.includes_hypothesis(local_pp_name(h)))
|
if (!loc.includes_hypothesis(local_pp_name(h)))
|
||||||
continue;
|
continue;
|
||||||
if (process_rewrite_hypothesis(h, elem, pattern))
|
if (process_rewrite_hypothesis(h, elem, pre_elem, pattern))
|
||||||
progress = true;
|
progress = true;
|
||||||
}
|
}
|
||||||
if (loc.includes_goal() && process_rewrite_goal(elem, pattern))
|
if (loc.includes_goal() && process_rewrite_goal(elem, pre_elem, pattern))
|
||||||
progress = true;
|
progress = true;
|
||||||
return progress;
|
return progress;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool process_rewrite_step(expr const & elem) {
|
bool process_rewrite_step(expr const & elem, expr const & pre_elem) {
|
||||||
lean_assert(is_rewrite_step(elem));
|
lean_assert(is_rewrite_step(elem));
|
||||||
expr pattern = get_pattern(elem);
|
expr pattern = get_pattern(elem);
|
||||||
regular(m_env, m_ios) << "pattern: " << pattern << "\n";
|
regular(m_env, m_ios) << "pattern: " << pattern << "\n";
|
||||||
|
@ -538,31 +627,31 @@ class rewrite_fn {
|
||||||
unsigned num;
|
unsigned num;
|
||||||
switch (info.get_multiplicity()) {
|
switch (info.get_multiplicity()) {
|
||||||
case rewrite_info::Once:
|
case rewrite_info::Once:
|
||||||
return process_rewrite_single_step(elem, pattern);
|
return process_rewrite_single_step(elem, pre_elem, pattern);
|
||||||
case rewrite_info::AtMostN:
|
case rewrite_info::AtMostN:
|
||||||
num = info.num();
|
num = info.num();
|
||||||
for (unsigned i = 0; i < num; i++) {
|
for (unsigned i = 0; i < num; i++) {
|
||||||
if (!process_rewrite_single_step(elem, pattern))
|
if (!process_rewrite_single_step(elem, pre_elem, pattern))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
case rewrite_info::ExactlyN:
|
case rewrite_info::ExactlyN:
|
||||||
num = info.num();
|
num = info.num();
|
||||||
for (unsigned i = 0; i < num; i++) {
|
for (unsigned i = 0; i < num; i++) {
|
||||||
if (!process_rewrite_single_step(elem, pattern))
|
if (!process_rewrite_single_step(elem, pre_elem, pattern))
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
case rewrite_info::ZeroOrMore:
|
case rewrite_info::ZeroOrMore:
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!process_rewrite_single_step(elem, pattern))
|
if (!process_rewrite_single_step(elem, pre_elem, pattern))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
case rewrite_info::OneOrMore:
|
case rewrite_info::OneOrMore:
|
||||||
if (!process_rewrite_single_step(elem, pattern))
|
if (!process_rewrite_single_step(elem, pre_elem, pattern))
|
||||||
return false;
|
return false;
|
||||||
while (true) {
|
while (true) {
|
||||||
if (!process_rewrite_single_step(elem, pattern))
|
if (!process_rewrite_single_step(elem, pre_elem, pattern))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -571,11 +660,11 @@ class rewrite_fn {
|
||||||
|
|
||||||
// Process the given rewrite element/step. This method destructively update
|
// Process the given rewrite element/step. This method destructively update
|
||||||
// m_g, m_subst, m_ngen. It returns true if it succeeded and false otherwise.
|
// m_g, m_subst, m_ngen. It returns true if it succeeded and false otherwise.
|
||||||
bool process_step(expr const & elem) {
|
bool process_step(expr const & elem, expr const & pre_elem) {
|
||||||
if (is_rewrite_unfold_step(elem)) {
|
if (is_rewrite_unfold_step(elem)) {
|
||||||
return process_unfold_step(elem);
|
return process_unfold_step(elem);
|
||||||
} else {
|
} else {
|
||||||
return process_rewrite_step(elem);
|
return process_rewrite_step(elem, pre_elem);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -586,7 +675,7 @@ public:
|
||||||
m_mplugin(m_ios, *m_tc) {
|
m_mplugin(m_ios, *m_tc) {
|
||||||
goals const & gs = m_ps.get_goals();
|
goals const & gs = m_ps.get_goals();
|
||||||
lean_assert(gs);
|
lean_assert(gs);
|
||||||
m_g = head(gs);
|
update_goal(head(gs));
|
||||||
m_subst = m_ps.get_subst();
|
m_subst = m_ps.get_subst();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -598,7 +687,7 @@ public:
|
||||||
lean_assert(elems.size() == new_elems.size());
|
lean_assert(elems.size() == new_elems.size());
|
||||||
for (unsigned i = 0; i < new_elems.size(); i++) {
|
for (unsigned i = 0; i < new_elems.size(); i++) {
|
||||||
flet<expr> set1(m_expr_loc, elems[i]);
|
flet<expr> set1(m_expr_loc, elems[i]);
|
||||||
if (!process_step(new_elems[i])) {
|
if (!process_step(new_elems[i], elems[i])) {
|
||||||
return proof_state_seq();
|
return proof_state_seq();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue