feat(library/blast/forward/pattern): basic indexing for heuristic instantiation
This commit is contained in:
parent
30214af15c
commit
87c31acf8c
2 changed files with 132 additions and 22 deletions
|
@ -5,8 +5,11 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#include <string>
|
||||
#include "util/rb_multi_map.h"
|
||||
#include "kernel/find_fn.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "library/kernel_serializer.h"
|
||||
#include "library/tmp_type_context.h"
|
||||
#include "library/fun_info_manager.h"
|
||||
#include "library/annotation.h"
|
||||
|
@ -113,26 +116,64 @@ That is, the user should provide pattern-hints.
|
|||
*/
|
||||
static name * g_pattern_hint = nullptr;
|
||||
|
||||
expr mk_pattern_hint(expr const & e) { return mk_annotation(*g_pattern_hint, e); }
|
||||
bool is_pattern_hint(expr const & e) { return is_annotation(e, *g_pattern_hint); }
|
||||
expr const & get_pattern_hint_arg(expr const & e) { lean_assert(is_pattern_hint(e)); return get_annotation_arg(e); }
|
||||
bool has_pattern_hints(expr const & e) {
|
||||
return static_cast<bool>(find(e, [](expr const & e, unsigned) { return is_pattern_hint(e); }));
|
||||
}
|
||||
expr mk_pattern_hint(expr const & e) {
|
||||
if (has_pattern_hints(e))
|
||||
throw exception("invalid pattern hint, nested patterns hints are not allowed");
|
||||
return mk_annotation(*g_pattern_hint, e);
|
||||
}
|
||||
|
||||
static name * g_no_pattern_name = nullptr;
|
||||
static std::string * g_key = nullptr;
|
||||
static name * g_hi_name = nullptr;
|
||||
static std::string * g_key = nullptr;
|
||||
|
||||
struct no_pattern_config {
|
||||
typedef name_set state;
|
||||
typedef name entry;
|
||||
// "Poor man" union type
|
||||
struct hi_entry {
|
||||
optional<name> m_no_pattern;
|
||||
hi_lemma m_lemma;
|
||||
hi_entry() {}
|
||||
hi_entry(name const & n):m_no_pattern(n) {}
|
||||
hi_entry(hi_lemma const & l):m_lemma(l) {}
|
||||
};
|
||||
|
||||
struct hi_state {
|
||||
name_set m_no_patterns;
|
||||
hi_lemmas m_lemmas;
|
||||
};
|
||||
|
||||
serializer & operator<<(serializer & s, multi_pattern const & mp) {
|
||||
write_list(s, mp);
|
||||
return s;
|
||||
}
|
||||
|
||||
deserializer & operator>>(deserializer & d, multi_pattern & mp) {
|
||||
mp = read_list<expr>(d);
|
||||
return d;
|
||||
}
|
||||
|
||||
struct hi_config {
|
||||
typedef hi_entry entry;
|
||||
typedef hi_state state;
|
||||
|
||||
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
|
||||
s.insert(e);
|
||||
if (e.m_no_pattern) {
|
||||
s.m_no_patterns.insert(*e.m_no_pattern);
|
||||
} else {
|
||||
for (multi_pattern const & mp : e.m_lemma.m_multi_patterns) {
|
||||
for (expr const & p : mp) {
|
||||
lean_assert(is_app(p));
|
||||
lean_assert(is_constant(get_app_fn(p)));
|
||||
s.m_lemmas.insert(const_name(p), e.m_lemma);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static name const & get_class_name() {
|
||||
return *g_no_pattern_name;
|
||||
return *g_hi_name;
|
||||
}
|
||||
|
||||
static std::string const & get_serialization_key() {
|
||||
|
@ -140,33 +181,43 @@ struct no_pattern_config {
|
|||
}
|
||||
|
||||
static void write_entry(serializer & s, entry const & e) {
|
||||
s << e;
|
||||
s << e.m_no_pattern;
|
||||
if (!e.m_no_pattern) {
|
||||
hi_lemma const & l = e.m_lemma;
|
||||
s << l.m_num_uvars << l.m_num_mvars << l.m_priority << l.m_prop << l.m_proof;
|
||||
write_list(s, l.m_multi_patterns);
|
||||
}
|
||||
}
|
||||
|
||||
static entry read_entry(deserializer & d) {
|
||||
entry e;
|
||||
d >> e;
|
||||
d >> e.m_no_pattern;
|
||||
if (!e.m_no_pattern) {
|
||||
hi_lemma & l = e.m_lemma;
|
||||
d >> l.m_num_uvars >> l.m_num_mvars >> l.m_priority >> l.m_prop >> l.m_proof;
|
||||
l.m_multi_patterns = read_list<multi_pattern>(d);
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
static optional<unsigned> get_fingerprint(entry const & e) {
|
||||
return some(e.hash());
|
||||
return e.m_no_pattern ? some(e.m_no_pattern->hash()) : some(e.m_lemma.m_prop.hash());
|
||||
}
|
||||
};
|
||||
|
||||
template class scoped_ext<no_pattern_config>;
|
||||
typedef scoped_ext<no_pattern_config> no_pattern_ext;
|
||||
template class scoped_ext<hi_config>;
|
||||
typedef scoped_ext<hi_config> hi_ext;
|
||||
|
||||
bool is_no_pattern(environment const & env, name const & n) {
|
||||
return no_pattern_ext::get_state(env).contains(n);
|
||||
return hi_ext::get_state(env).m_no_patterns.contains(n);
|
||||
}
|
||||
|
||||
environment add_no_pattern(environment const & env, name const & n, bool persistent) {
|
||||
return no_pattern_ext::add_entry(env, get_dummy_ios(), n, persistent);
|
||||
return hi_ext::add_entry(env, get_dummy_ios(), hi_entry(n), persistent);
|
||||
}
|
||||
|
||||
name_set const & get_no_patterns(environment const & env) {
|
||||
return no_pattern_ext::get_state(env);
|
||||
return hi_ext::get_state(env).m_no_patterns;
|
||||
}
|
||||
|
||||
typedef rb_tree<unsigned, unsigned_cmp> idx_metavar_set;
|
||||
|
@ -261,6 +312,30 @@ expr extract_trackable(tmp_type_context & ctx, expr const & type,
|
|||
return B;
|
||||
}
|
||||
|
||||
void collect_pattern_hints(expr const & e, buffer<expr> & hints) {
|
||||
for_each(e, [&](expr const & e, unsigned) {
|
||||
if (is_pattern_hint(e)) {
|
||||
hints.push_back(get_pattern_hint_arg(e));
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
hi_lemma mk_hi_lemma_core(tmp_type_context & ctx, fun_info_manager & fm, expr const & H, unsigned num_uvars) {
|
||||
// TODO(Leo):
|
||||
return hi_lemma();
|
||||
}
|
||||
|
||||
hi_lemma mk_hi_lemma(tmp_type_context & ctx, fun_info_manager & fm, expr const & H) {
|
||||
return mk_hi_lemma_core(ctx, fm, H, 0);
|
||||
}
|
||||
|
||||
environment add_hi_lemma(environment const & env, name const & c, unsigned priority, bool persistent) {
|
||||
// TODO(Leo):
|
||||
return env;
|
||||
}
|
||||
|
||||
/** pattern_le */
|
||||
struct pattern_le_fn {
|
||||
tmp_type_context & m_ctx;
|
||||
|
@ -404,21 +479,21 @@ struct collect_pattern_candidates_fn {
|
|||
typedef rb_tree<expr, expr_quick_cmp> candidates;
|
||||
|
||||
collect_pattern_candidates_fn(tmp_type_context & ctx, fun_info_manager & fm):
|
||||
m_ctx(ctx), m_fm(fm), m_no_patterns(no_pattern_ext::get_state(ctx.env())) {}
|
||||
m_ctx(ctx), m_fm(fm), m_no_patterns(get_no_patterns(ctx.env())) {}
|
||||
// TODO(Leo):
|
||||
};
|
||||
|
||||
void initialize_pattern() {
|
||||
g_no_pattern_name = new name("no_pattern");
|
||||
g_key = new std::string("no_pattern");
|
||||
no_pattern_ext::initialize();
|
||||
g_hi_name = new name("hi");
|
||||
g_key = new std::string("HI");
|
||||
hi_ext::initialize();
|
||||
g_pattern_hint = new name("pattern_hint");
|
||||
register_annotation(*g_pattern_hint);
|
||||
}
|
||||
|
||||
void finalize_pattern() {
|
||||
no_pattern_ext::finalize();
|
||||
delete g_no_pattern_name;
|
||||
hi_ext::finalize();
|
||||
delete g_hi_name;
|
||||
delete g_key;
|
||||
delete g_pattern_hint;
|
||||
}
|
||||
|
|
|
@ -5,7 +5,14 @@ Released under Apache 2.0 license as described in the file LICENSE.
|
|||
Author: Leonardo de Moura
|
||||
*/
|
||||
#pragma once
|
||||
#include "util/rb_multi_map.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "library/tmp_type_context.h"
|
||||
#include "library/fun_info_manager.h"
|
||||
|
||||
#ifndef LEAN_HI_LEMMA_DEFAULT_PRIORITY
|
||||
#define LEAN_HI_LEMMA_DEFAULT_PRIORITY 1000
|
||||
#endif
|
||||
|
||||
namespace lean {
|
||||
/** \brief Annotate \c e as a pattern hint */
|
||||
|
@ -25,6 +32,34 @@ bool is_no_pattern(environment const & env, name const & n);
|
|||
/** \brief Return the set of constants marked as no-patterns */
|
||||
name_set const & get_no_patterns(environment const & env);
|
||||
|
||||
typedef list<expr> multi_pattern;
|
||||
|
||||
/** Heuristic instantiation lemma */
|
||||
struct hi_lemma {
|
||||
unsigned m_num_uvars;
|
||||
unsigned m_num_mvars;
|
||||
unsigned m_priority;
|
||||
list<multi_pattern> m_multi_patterns;
|
||||
expr m_prop;
|
||||
expr m_proof;
|
||||
};
|
||||
|
||||
inline bool operator==(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop == l2.m_prop; }
|
||||
inline bool operator!=(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop != l2.m_prop; }
|
||||
|
||||
/** \brief Mapping c -> S, where c is a constant name and S is a set of hi_lemmas that contain
|
||||
a pattern where the head symbol is c. */
|
||||
typedef rb_multi_map<name, hi_lemma, name_quick_cmp> hi_lemmas;
|
||||
|
||||
/** \brief Add the given theorem as a heuristic instantiation lemma in the current environment. */
|
||||
environment add_hi_lemma(environment const & env, name const & c, unsigned priority, bool persistent);
|
||||
|
||||
/** \brief Retrieve the active set of heuristic instantiation lemmas. */
|
||||
hi_lemmas get_hi_lemma_index(environment const & env);
|
||||
|
||||
/** \brief Create a (local) heuristic instantiation lemma for \c H. */
|
||||
hi_lemma mk_hi_lemma(tmp_type_context & ctx, fun_info_manager & fm, expr const & H);
|
||||
|
||||
void initialize_pattern();
|
||||
void finalize_pattern();
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue