From 09818adf9018c24d0a5e9f3f2029a1853f6bb584 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 4 Feb 2015 13:51:32 -0800 Subject: [PATCH] feat(library/tactic/rewrite_tactic): elaborate rewrite rule using unifier --- src/library/tactic/rewrite_tactic.cpp | 145 +++++++++++++++++++++----- 1 file changed, 117 insertions(+), 28 deletions(-) diff --git a/src/library/tactic/rewrite_tactic.cpp b/src/library/tactic/rewrite_tactic.cpp index 16bc6d25c..5be24ec37 100644 --- a/src/library/tactic/rewrite_tactic.cpp +++ b/src/library/tactic/rewrite_tactic.cpp @@ -15,9 +15,13 @@ Author: Leonardo de Moura #include "library/util.h" #include "library/match.h" #include "library/projection.h" +#include "library/local_context.h" +#include "library/unifier.h" +#include "library/util.h" #include "library/generic_exception.h" #include "library/tactic/rewrite_tactic.h" #include "library/tactic/expr_to_tactic.h" +#include "library/tactic/class_instance_synth.h" // #define TRACE_MATCH_PLUGIN @@ -339,6 +343,7 @@ class rewrite_fn { type_checker_ptr m_tc; rewrite_match_plugin m_mplugin; goal m_g; + local_context m_ctx; substitution m_subst; expr m_expr_loc; // auxiliary expression used for error localization @@ -353,6 +358,13 @@ class rewrite_fn { throw_generic_exception(strm, m_expr_loc); } + void update_goal(goal const & g) { + m_g = g; + buffer hyps; + g.get_hyps(hyps); + m_ctx = local_context(to_list(hyps)); + } + expr mk_meta(expr const & type) { return m_g.mk_meta(m_ngen.next(), type); } @@ -473,10 +485,87 @@ class rewrite_fn { e = none_expr(); } - optional find_target(expr const & e, expr const & pattern) { - optional found; + pair mk_class_instance_elaborator(expr const & type) { + unifier_config cfg; + cfg.m_conservative = true; + bool use_local_instances = true; + bool is_strict = false; + return ::lean::mk_class_instance_elaborator(m_env, m_ios, m_ctx, m_ngen.next(), optional(), + m_ps.relax_main_opaque(), use_local_instances, is_strict, + some_expr(type), m_expr_loc.get_tag(), cfg, nullptr); + } + + // rule, new_t + typedef optional> unify_result; + + unify_result unify_target(expr const & t, expr const & pre_elem) { + try { + expr rule = get_rewrite_rule(pre_elem); + auto rcs = m_elab(m_g, m_ngen.mk_child(), rule, false); + rule = rcs.first; + buffer cs; + to_buffer(rcs.second, cs); + constraint_seq cs_seq; + expr rule_type = m_tc->whnf(m_tc->infer(rule, cs_seq), cs_seq); + while (is_pi(rule_type)) { + expr meta; + if (binding_info(rule_type).is_inst_implicit()) { + auto mc = mk_class_instance_elaborator(binding_domain(rule_type)); + meta = mc.first; + cs_seq += mc.second; + } else { + meta = mk_meta(binding_domain(rule_type)); + } + rule_type = m_tc->whnf(instantiate(binding_body(rule_type), meta), cs_seq); + rule = mk_app(rule, meta); + } + lean_assert(is_eq(rule_type)); + bool symm = get_rewrite_info(pre_elem).symm(); + expr src; + if (symm) + src = app_arg(rule_type); + else + src = app_arg(app_fn(rule_type)); + if (!m_tc->is_def_eq(t, src, justification(), cs_seq)) + return unify_result(); + cs_seq.linearize(cs); + unifier_config cfg; + cfg.m_conservative = false; + unify_result_seq rseq = unify(m_env, cs.size(), cs.data(), m_ngen.mk_child(), m_subst, cfg); + if (auto p = rseq.pull()) { + substitution new_subst = p->first.first; + constraints new_postponed = p->first.second; + if (new_postponed) + return unify_result(); // all constraints must be solved + rule = new_subst.instantiate_all(rule); + rule_type = new_subst.instantiate_all(rule_type); + if (has_expr_metavar_strict(rule) || has_expr_metavar_strict(rule_type)) + return unify_result(); // rule was not completely instantiate. + m_subst = new_subst; + expr lhs = app_arg(app_fn(rule_type)); + expr rhs = app_arg(rule_type); + if (symm) { + rule = mk_symm(*m_tc, rule); + return unify_result(rule, lhs); + } else { + return unify_result(rule, rhs); + } + } + } catch (exception&) {} + return unify_result(); + } + + // target, rule, new_target : represents the rewrite rule : target = new_target + typedef optional> find_result; + + // Search for \c pattern in \c e. If \c t is a match, then try to unify the type of the rule + // in the rewrite step \c pre_elem with \c t. + // When successful, this method returns the target \c t, the fully elaborated rule \c r, + // and the new value \c new_t (i.e., the expression that will replace \c t). + find_result find_target(expr const & e, expr const & pattern, expr const & pre_elem) { + find_result result; for_each(e, [&](expr const & t, unsigned) { - if (found) + if (result) return false; // stop search if (closed(t)) { lean_assert(std::all_of(m_esubst.begin(), m_esubst.end(), [&](optional const & e) { return !e; })); @@ -485,52 +574,52 @@ class rewrite_fn { if (assigned) reset_subst(); if (r) { - found = t; - return false; + if (auto p = unify_target(t, pre_elem)) { + result = std::make_tuple(t, p->first, p->second); + return false; + } } } return true; }); - return found; + return result; } - bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pattern) { + bool process_rewrite_hypothesis(expr const & hyp, expr const & elem, expr const & pre_elem, expr const & pattern) { // TODO(Leo) return false; } - bool process_rewrite_goal(expr const & elem, expr const & pattern) { + bool process_rewrite_goal(expr const & elem, expr const & pre_elem, expr const & pattern) { expr goal_type = m_g.get_type(); - if (auto it = find_target(goal_type, pattern)) { - regular(m_env, m_ios) << "found: " << *it << "\n"; + if (auto it = find_target(goal_type, pattern, pre_elem)) { + regular(m_env, m_ios) << "FOUND\n" << std::get<0>(*it) << "\n==\n" << std::get<1>(*it) << "\n==>\n" << std::get<2>(*it) << "\n"; // TODO(Leo) - return false; - } else { - return false; } + return false; } - bool process_rewrite_single_step(expr const & elem, expr const & pattern) { + bool process_rewrite_single_step(expr const & elem, expr const & pre_elem, expr const & pattern) { check_system("rewrite tactic"); rewrite_info const & info = get_rewrite_info(elem); location const & loc = info.get_location(); if (loc.is_goal_only()) - return process_rewrite_goal(elem, pattern); + return process_rewrite_goal(elem, pre_elem, pattern); bool progress = false; buffer hyps; m_g.get_hyps(hyps); for (expr const & h : hyps) { if (!loc.includes_hypothesis(local_pp_name(h))) continue; - if (process_rewrite_hypothesis(h, elem, pattern)) + if (process_rewrite_hypothesis(h, elem, pre_elem, pattern)) progress = true; } - if (loc.includes_goal() && process_rewrite_goal(elem, pattern)) + if (loc.includes_goal() && process_rewrite_goal(elem, pre_elem, pattern)) progress = true; return progress; } - bool process_rewrite_step(expr const & elem) { + bool process_rewrite_step(expr const & elem, expr const & pre_elem) { lean_assert(is_rewrite_step(elem)); expr pattern = get_pattern(elem); regular(m_env, m_ios) << "pattern: " << pattern << "\n"; @@ -538,31 +627,31 @@ class rewrite_fn { unsigned num; switch (info.get_multiplicity()) { case rewrite_info::Once: - return process_rewrite_single_step(elem, pattern); + return process_rewrite_single_step(elem, pre_elem, pattern); case rewrite_info::AtMostN: num = info.num(); for (unsigned i = 0; i < num; i++) { - if (!process_rewrite_single_step(elem, pattern)) + if (!process_rewrite_single_step(elem, pre_elem, pattern)) return true; } return true; case rewrite_info::ExactlyN: num = info.num(); for (unsigned i = 0; i < num; i++) { - if (!process_rewrite_single_step(elem, pattern)) + if (!process_rewrite_single_step(elem, pre_elem, pattern)) return false; } return true; case rewrite_info::ZeroOrMore: while (true) { - if (!process_rewrite_single_step(elem, pattern)) + if (!process_rewrite_single_step(elem, pre_elem, pattern)) return true; } case rewrite_info::OneOrMore: - if (!process_rewrite_single_step(elem, pattern)) + if (!process_rewrite_single_step(elem, pre_elem, pattern)) return false; while (true) { - if (!process_rewrite_single_step(elem, pattern)) + if (!process_rewrite_single_step(elem, pre_elem, pattern)) return true; } } @@ -571,11 +660,11 @@ class rewrite_fn { // Process the given rewrite element/step. This method destructively update // m_g, m_subst, m_ngen. It returns true if it succeeded and false otherwise. - bool process_step(expr const & elem) { + bool process_step(expr const & elem, expr const & pre_elem) { if (is_rewrite_unfold_step(elem)) { return process_unfold_step(elem); } else { - return process_rewrite_step(elem); + return process_rewrite_step(elem, pre_elem); } } @@ -586,7 +675,7 @@ public: m_mplugin(m_ios, *m_tc) { goals const & gs = m_ps.get_goals(); lean_assert(gs); - m_g = head(gs); + update_goal(head(gs)); m_subst = m_ps.get_subst(); } @@ -598,7 +687,7 @@ public: lean_assert(elems.size() == new_elems.size()); for (unsigned i = 0; i < new_elems.size(); i++) { flet set1(m_expr_loc, elems[i]); - if (!process_step(new_elems[i])) { + if (!process_step(new_elems[i], elems[i])) { return proof_state_seq(); } }