feat(library/blast/forward/ematch): ematching skeleton

This commit is contained in:
Leonardo de Moura 2015-11-29 06:40:19 -07:00
parent 001f8084a9
commit 7fa2b7cace
8 changed files with 339 additions and 1 deletions

View file

@ -1256,6 +1256,14 @@ expr congruence_closure::get_next(name const & R, expr const & e) const {
}
}
unsigned congruence_closure::get_mt(name const & R, expr const & e) const {
if (auto n = m_entries.find(eqc_key(R, e))) {
return n->m_mt;
} else {
return m_gmt;
}
}
void congruence_closure::freeze_partitions() {
m_froze_partitions = true;
entries new_entries;

View file

@ -216,6 +216,9 @@ public:
void inc_gmt() { m_gmt++; }
unsigned get_gmt() const { return m_gmt; }
unsigned get_mt(name const & R, expr const & e) const;
/** \brief dump for debugging purposes. */
void display() const;
void display_eqcs() const;

View file

@ -1 +1,2 @@
add_library(forward OBJECT init_module.cpp forward_extension.cpp qcf.cpp pattern.cpp)
add_library(forward OBJECT init_module.cpp forward_extension.cpp qcf.cpp pattern.cpp
ematch.cpp)

View file

@ -0,0 +1,300 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "library/constants.h"
#include "library/blast/blast.h"
#include "library/blast/trace.h"
#include "library/blast/congruence_closure.h"
#include "library/blast/forward/pattern.h"
namespace lean {
namespace blast {
/*
When a hypothesis hidx is activated:
1- Traverse its type and for each f-application.
If it is the first f-application found, and f is a constant then
retrieve lemmas which contain a multi-pattern starting with f.
2- If hypothesis is a proposition and a quantifier,
try to create a hi-lemma for it, and add it to
set of recently activated hi_lemmas
E-match round action
1- For each active hi-lemma L, and mulit-pattern P,
If L has been recently activated, then we ematch ignoring
gmt.
If L has been processed before, we try to ematch starting
at each each element of the multi-pattern.
We only consider the head f-applications that have a mt
equal to gmt
*/
typedef rb_tree<expr, expr_quick_cmp> expr_set;
typedef rb_tree<hi_lemma, hi_lemma_cmp> hi_lemma_set;
static unsigned g_ext_id = 0;
struct ematch_branch_extension : public branch_extension {
hi_lemma_set m_lemmas;
hi_lemma_set m_new_lemmas;
rb_map<expr, expr_set, expr_quick_cmp> m_apps;
name_set m_initialized;
ematch_branch_extension() {}
ematch_branch_extension(ematch_branch_extension const &) {}
void collect_apps(expr const & e) {
switch (e.kind()) {
case expr_kind::Var: case expr_kind::Sort:
case expr_kind::Constant: case expr_kind::Meta:
case expr_kind::Local: case expr_kind::Lambda:
break;
case expr_kind::Pi:
if (is_arrow(e) && is_prop(e)) {
collect_apps(binding_domain(e));
collect_apps(binding_body(e));
}
break;
case expr_kind::Macro:
for (unsigned i = 0; i < macro_num_args(e); i++)
collect_apps(macro_arg(e, i));
break;
case expr_kind::App: {
buffer<expr> args;
expr const & f = get_app_args(e, args);
if (is_constant(f) && !m_initialized.contains(const_name(f))) {
m_initialized.insert(const_name(f));
if (auto lemmas = get_hi_lemma_index(env()).find(const_name(f))) {
for (hi_lemma const & lemma : *lemmas) {
m_new_lemmas.insert(lemma);
}
}
}
if ((is_constant(f) && !is_no_pattern(env(), const_name(f))) ||
(is_local(f))) {
expr_set s;
if (auto old_s = m_apps.find(f))
s = *old_s;
s.insert(e);
m_apps.insert(f, s);
}
for (expr const & arg : args) {
collect_apps(arg);
}
break;
}}
}
void register_lemma(hypothesis const & h) {
if (is_pi(h.get_type()) && !is_arrow(h.get_type())) {
blast_tmp_type_context ctx;
try {
m_new_lemmas.insert(mk_hi_lemma(*ctx, h.get_self()));
} catch (exception &) {}
}
}
virtual ~ematch_branch_extension() {}
virtual branch_extension * clone() override { return new ematch_branch_extension(*this); }
virtual void initialized() override {}
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
collect_apps(h.get_type());
register_lemma(h);
}
virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) override {}
virtual void target_updated() override { collect_apps(curr_state().get_target()); }
};
void initialize_ematch() {
g_ext_id = register_branch_extension(new ematch_branch_extension());
}
void finalize_ematch() {}
struct ematch_fn {
ematch_branch_extension & m_ext;
blast_tmp_type_context m_ctx;
congruence_closure & m_cc;
enum frame_kind { DefEqOnly, Match, Continue };
typedef std::tuple<name, frame_kind, expr, expr> entry;
typedef list<entry> state;
typedef list<state> choice;
state m_state;
buffer<choice> m_choice_stack;
bool m_new_instances;
ematch_fn():
m_ext(static_cast<ematch_branch_extension&>(curr_state().get_extension(g_ext_id))),
m_cc(get_cc()),
m_new_instances(false) {
}
bool is_done() const {
return !m_state;
}
bool is_eqv(name const & R, expr const & p, expr const & t) {
if (!has_expr_metavar(p))
return m_cc.is_eqv(R, p, t) || m_ctx->is_def_eq(p, t);
else
return m_ctx->is_def_eq(p, t);
}
bool process_match(name const & R, expr const & p, expr const & t) {
if (!is_app(p))
return is_eqv(R, p, t);
buffer<expr> p_args;
expr const & fn = get_app_args(p, p_args);
if (m_ctx->is_mvar(fn))
return is_eqv(R, p, t);
buffer<expr> candidates;
expr it = t;
do {
if (m_cc.is_congr_root(R, t) && m_ctx->is_def_eq(get_app_fn(it), fn) &&
get_app_num_args(it) == p_args.size()) {
candidates.push_back(it);
}
it = m_cc.get_next(R, it);
} while (it != t);
if (candidates.empty())
return false;
optional<ext_congr_lemma> lemma = mk_ext_congr_lemma(R, fn, p_args.size());
if (!lemma)
return false;
buffer<state> new_states;
for (expr const & c : candidates) {
buffer<expr> c_args;
get_app_args(c, c_args);
lean_assert(c_args.size() == p_args.size());
state new_state = m_state;
auto const * r_names = &lemma->m_rel_names;
for (unsigned i = 0; i < p_args.size(); i++) {
lean_assert(*r_names);
if (auto Rc = head(*r_names)) {
new_state = cons(entry(*Rc, Match, p_args[i], c_args[i]), new_state);
} else {
new_state = cons(entry(get_eq_name(), DefEqOnly, p_args[i], c_args[i]), new_state);
}
r_names = &tail(*r_names);
}
new_states.push_back(new_state);
}
lean_assert(candidates.size() == new_states.size());
if (candidates.size() == 1) {
m_state = new_states[0];
return true;
} else {
m_state = new_states.back();
new_states.pop_back();
choice c = to_list(new_states);
m_choice_stack.push_back(c);
m_ctx->push();
return true;
}
}
bool process_continue(expr const &) {
// TODO(Leo):
return false;
}
bool process_next() {
lean_assert(!is_done());
name R; frame_kind kind; expr p, t;
std::tie(R, kind, p, t) = head(m_state);
m_state = tail(m_state);
switch (kind) {
case DefEqOnly:
return m_ctx->is_def_eq(p, t);
case Match:
return process_match(R, p, t);
case Continue:
return process_continue(p);
}
lean_unreachable();
}
bool match() {
// TODO(Leo)
return false;
}
void instantiate_lemma_using(hi_lemma const & lemma, buffer<expr> const & ps, bool filter) {
expr const & p0 = ps[0];
expr const & f = get_app_fn(p0);
name const & R = is_prop(p0) ? get_iff_name() : get_eq_name();
unsigned gmt = m_cc.get_gmt();
if (auto s = m_ext.m_apps.find(f)) {
s->for_each([&](expr const & t) {
if (m_cc.is_congr_root(R, t) && (!filter || m_cc.get_mt(R, t) == gmt)) {
m_ctx->set_next_uvar_idx(lemma.m_num_uvars);
m_ctx->set_next_mvar_idx(lemma.m_num_mvars);
state s;
unsigned i = ps.size();
while (i > 1) {
--i;
s = cons(entry(name(), Continue, ps[i], expr()), s);
}
s = cons(entry(R, Match, p0, t), s);
diagnostic(env(), ios()) << "ematch " << ppb(p0) << " =?= " << ppb(t) << "\n";
if (match()) {
// TODO(Leo): add instance
}
}
});
}
}
void instantiate_lemma_using(hi_lemma const & lemma, multi_pattern const & mp, bool filter) {
buffer<expr> ps;
to_buffer(mp, ps);
if (filter) {
for (unsigned i = 0; i < ps.size(); i++) {
std::swap(ps[0], ps[i]);
instantiate_lemma_using(lemma, ps, filter);
std::swap(ps[0], ps[i]);
}
} else {
instantiate_lemma_using(lemma, ps, filter);
}
}
void instantiate_lemma(hi_lemma const & lemma, bool filter) {
for (multi_pattern const & mp : lemma.m_multi_patterns) {
instantiate_lemma_using(lemma, mp, filter);
}
}
/* (Try to) instantiate lemmas in \c s. If \c filter is true, then use gmt optimization. */
void instantiate_lemmas(hi_lemma_set const & s, bool filter) {
s.for_each([&](hi_lemma const & l) {
instantiate_lemma(l, filter);
});
}
action_result operator()() {
instantiate_lemmas(m_ext.m_new_lemmas, false);
instantiate_lemmas(m_ext.m_lemmas, true);
m_ext.m_lemmas.merge(m_ext.m_new_lemmas);
m_ext.m_new_lemmas = hi_lemma_set();
m_cc.inc_gmt();
if (m_new_instances) {
return action_result::new_branch();
} else {
return action_result::failed();
}
}
};
action_result ematch_action() {
return ematch_fn()();
}
}}

View file

@ -0,0 +1,15 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "library/blast/action_result.h"
namespace lean {
namespace blast {
action_result ematch_action();
void initialize_ematch();
void finalize_ematch();
}}

View file

@ -6,6 +6,7 @@ Author: Daniel Selsam
#include "library/blast/forward/init_module.h"
#include "library/blast/forward/forward_extension.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/ematch.h"
namespace lean {
namespace blast {
@ -13,9 +14,11 @@ namespace blast {
void initialize_forward_module() {
initialize_forward_extension();
initialize_pattern();
initialize_ematch();
}
void finalize_forward_module() {
finalize_ematch();
finalize_pattern();
finalize_forward_extension();
}

View file

@ -656,6 +656,10 @@ hi_lemma const * get_hi_lemma(environment const & env, name const & c) {
return hi_ext::get_state(env).m_name_to_lemma.find(c);
}
hi_lemmas get_hi_lemma_index(environment const & env) {
return hi_ext::get_state(env).m_lemmas;
}
void initialize_pattern() {
g_hi_name = new name("hi");
g_key = new std::string("HI");

View file

@ -7,6 +7,7 @@ Author: Leonardo de Moura
#pragma once
#include "util/rb_multi_map.h"
#include "kernel/environment.h"
#include "library/expr_lt.h"
#include "library/tmp_type_context.h"
#ifndef LEAN_HI_LEMMA_DEFAULT_PRIORITY
@ -45,6 +46,9 @@ struct hi_lemma {
inline bool operator==(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop == l2.m_prop; }
inline bool operator!=(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop != l2.m_prop; }
struct hi_lemma_cmp {
int operator()(hi_lemma const & l1, hi_lemma const & l2) const { return expr_quick_cmp()(l1.m_prop, l2.m_prop); }
};
/** \brief Mapping c -> S, where c is a constant name and S is a set of hi_lemmas that contain
a pattern where the head symbol is c. */