refactor(library/blast,frontends/lean): forward pattern index

This commit is contained in:
Leonardo de Moura 2015-12-02 19:28:45 -07:00
parent 562d7b3e4a
commit 950f356d9a
14 changed files with 222 additions and 171 deletions

View file

@ -50,6 +50,7 @@ Author: Leonardo de Moura
#include "library/blast/simplifier/simplifier.h"
#include "library/blast/backward/backward_rule_set.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/forward_lemma_set.h"
#include "compiler/preprocess_rec.h"
#include "frontends/lean/util.h"
#include "frontends/lean/parser.h"
@ -231,20 +232,26 @@ static void print_metaclasses(parser const & p) {
p.regular_stream() << "[" << n << "]" << endl;
}
static void print_patterns(io_state_stream const & out, environment const & env, name const & n) {
if (auto lemma = get_hi_lemma(env, n)) {
if (lemma->m_multi_patterns) {
out << "(multi-)patterns:\n";
for (multi_pattern const & mp : lemma->m_multi_patterns) {
out << "{";
bool first = true;
for (expr const & p : mp) {
if (first) first = false; else out << ", ";
out << p;
static void print_patterns(parser const & p, name const & n) {
if (is_forward_lemma(p.env(), n)) {
blast::scope_debug scope(p.env(), p.ios());
try {
// we regenerate the patterns to make sure they reflect the current set of reducible constants
auto lemma = blast::mk_hi_lemma(n, LEAN_FORWARD_LEMMA_DEFAULT_PRIORITY);
if (lemma.m_multi_patterns) {
io_state_stream out = p.regular_stream();
out << "(multi-)patterns:\n";
for (multi_pattern const & mp : lemma.m_multi_patterns) {
out << "{";
bool first = true;
for (expr const & p : mp) {
if (first) first = false; else out << ", ";
out << p;
}
out << "}\n";
}
out << "}\n";
}
}
} catch (exception &) {}
}
}
@ -287,7 +294,7 @@ static void print_attributes(parser const & p, name const & n) {
out << " [backward]";
if (is_no_pattern(env, n))
out << " [no_pattern]";
if (get_hi_lemma(env, n))
if (is_forward_lemma(env, n))
out << " [forward]";
switch (get_reducible_status(env, n)) {
case reducible_status::Reducible: out << " [reducible]"; break;
@ -424,7 +431,7 @@ bool print_id_info(parser const & p, name const & id, bool show_value, pos_info
if (show_value)
print_definition(p, c, pos);
}
print_patterns(out, env, c);
print_patterns(p, c);
}
return true;
} catch (exception & ex) {}

View file

@ -14,6 +14,7 @@ Author: Leonardo de Moura
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/backward/backward_rule_set.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/forward_lemma_set.h"
#include "frontends/lean/decl_attributes.h"
#include "frontends/lean/parser.h"
#include "frontends/lean/tokens.h"
@ -241,9 +242,9 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c
}
if (forward) {
if (m_priority)
env = add_hi_lemma(env, ios.get_options(), d, *m_priority, m_persistent);
env = add_forward_lemma(env, d, *m_priority, m_persistent);
else
env = add_hi_lemma(env, ios.get_options(), d, LEAN_HI_LEMMA_DEFAULT_PRIORITY, m_persistent);
env = add_forward_lemma(env, d, LEAN_FORWARD_LEMMA_DEFAULT_PRIORITY, m_persistent);
}
if (m_no_pattern) {
env = add_no_pattern(env, d, m_persistent);

View file

@ -1,2 +1,2 @@
add_library(forward OBJECT init_module.cpp forward_extension.cpp qcf.cpp pattern.cpp
ematch.cpp)
ematch.cpp forward_lemma_set.cpp)

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/blast/options.h"
#include "library/blast/congruence_closure.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/forward_lemma_set.h"
namespace lean {
namespace blast {
@ -45,7 +46,6 @@ struct ematch_branch_extension : public branch_extension {
hi_lemma_set m_lemmas;
hi_lemma_set m_new_lemmas;
rb_map<head_index, expr_set, head_index::cmp> m_apps;
name_set m_initialized;
expr_set m_instances;
ematch_branch_extension() {}
@ -70,14 +70,6 @@ struct ematch_branch_extension : public branch_extension {
case expr_kind::App: {
buffer<expr> args;
expr const & f = get_app_args(e, args);
if (is_constant(f) && !m_initialized.contains(const_name(f))) {
m_initialized.insert(const_name(f));
if (auto lemmas = get_hi_lemma_index(env()).find(const_name(f))) {
for (hi_lemma const & lemma : *lemmas) {
m_new_lemmas.insert(lemma);
}
}
}
if ((is_constant(f) && !is_no_pattern(env(), const_name(f))) ||
(is_local(f))) {
expr_set s;
@ -95,16 +87,22 @@ struct ematch_branch_extension : public branch_extension {
void register_lemma(hypothesis const & h) {
if (is_pi(h.get_type()) && !is_arrow(h.get_type())) {
blast_tmp_type_context ctx;
try {
m_new_lemmas.insert(mk_hi_lemma(*ctx, h.get_self()));
m_new_lemmas.insert(mk_hi_lemma(h.get_self()));
} catch (exception &) {}
}
}
virtual ~ematch_branch_extension() {}
virtual branch_extension * clone() override { return new ematch_branch_extension(*this); }
virtual void initialized() override {}
virtual void initialized() override {
forward_lemma_set s = get_forward_lemma_set(env());
s.for_each([&](name const & n, unsigned prio) {
try {
m_new_lemmas.insert(mk_hi_lemma(n, prio));
} catch (exception &) {}
});
}
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
collect_apps(h.get_type());
register_lemma(h);

View file

@ -0,0 +1,78 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "library/scoped_ext.h"
#include "library/blast/forward/forward_lemma_set.h"
namespace lean {
static name * g_name = nullptr;
static std::string * g_key = nullptr;
struct forward_lemma {
name m_name;
unsigned m_priority;
forward_lemma() {}
forward_lemma(name const & n, unsigned p):m_name(n), m_priority(p) {}
};
struct forward_lemma_set_config {
typedef forward_lemma entry;
typedef forward_lemma_set state;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
s.insert(e.m_name, e.m_priority);
}
static name const & get_class_name() {
return *g_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
name n; unsigned p;
d >> n >> p;
return entry(n, p);
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(hash(e.m_name.hash(), e.m_priority));
}
};
template class scoped_ext<forward_lemma_set_config>;
typedef scoped_ext<forward_lemma_set_config> forward_lemma_set_ext;
environment add_forward_lemma(environment const & env, name const & n, unsigned priority, bool persistent) {
return forward_lemma_set_ext::add_entry(env, get_dummy_ios(), forward_lemma(n, priority), persistent);
}
bool is_forward_lemma(environment const & env, name const & n) {
return forward_lemma_set_ext::get_state(env).contains(n);
}
forward_lemma_set get_forward_lemma_set(environment const & env) {
return forward_lemma_set_ext::get_state(env);
}
void initialize_forward_lemma_set() {
g_name = new name("forward");
g_key = new std::string("FWD");
forward_lemma_set_ext::initialize();
}
void finalize_forward_lemma_set() {
forward_lemma_set_ext::finalize();
delete g_name;
delete g_key;
}
}

View file

@ -0,0 +1,25 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "util/rb_tree.h"
#include "kernel/expr.h"
#ifndef LEAN_FORWARD_LEMMA_DEFAULT_PRIORITY
#define LEAN_FORWARD_LEMMA_DEFAULT_PRIORITY 1000
#endif
namespace lean {
/** \brief The forward lemma set is actually a mapping from lemma name to priority */
typedef rb_map<name, unsigned, name_quick_cmp> forward_lemma_set;
environment add_forward_lemma(environment const & env, name const & n, unsigned priority, bool persistent);
bool is_forward_lemma(environment const & env, name const & n);
forward_lemma_set get_forward_lemma_set(environment const & env);
void initialize_forward_lemma_set();
void finalize_forward_lemma_set();
}

View file

@ -5,6 +5,7 @@ Author: Daniel Selsam
*/
#include "library/blast/forward/init_module.h"
#include "library/blast/forward/forward_extension.h"
#include "library/blast/forward/forward_lemma_set.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/forward/ematch.h"
@ -14,11 +15,13 @@ namespace blast {
void initialize_forward_module() {
initialize_forward_extension();
initialize_pattern();
initialize_forward_lemma_set();
initialize_ematch();
}
void finalize_forward_module() {
finalize_ematch();
finalize_forward_lemma_set();
finalize_pattern();
finalize_forward_extension();
}

View file

@ -6,7 +6,6 @@ Author: Leonardo de Moura
*/
#include <string>
#include "util/sstream.h"
#include "util/sexpr/option_declarations.h"
#include "library/expr_lt.h"
#include "kernel/find_fn.h"
#include "kernel/for_each_fn.h"
@ -21,19 +20,11 @@ Author: Leonardo de Moura
#include "library/scoped_ext.h"
#include "library/idx_metavar.h"
#include "library/blast/options.h"
#include "library/blast/blast.h"
#include "library/blast/forward/pattern.h"
#ifndef LEAN_DEFAULT_PATTERN_MAX_STEPS
#define LEAN_DEFAULT_PATTERN_MAX_STEPS 1024
#endif
#include "library/blast/forward/forward_lemma_set.h"
namespace lean {
static name * g_pattern_max_steps = nullptr;
unsigned get_pattern_max_steps(options const & o) {
return o.get_unsigned(*g_pattern_max_steps, LEAN_DEFAULT_PATTERN_MAX_STEPS);
}
/*
Step 1: Selecting which variables we should track.
@ -143,66 +134,19 @@ expr mk_pattern_hint(expr const & e) {
return mk_annotation(*g_pattern_hint, e);
}
static name * g_hi_name = nullptr;
static name * g_name = nullptr;
static std::string * g_key = nullptr;
// "Poor man" union type
struct hi_entry {
optional<name> m_no_pattern;
hi_lemma m_lemma;
hi_entry() {}
hi_entry(name const & n):m_no_pattern(n) {}
hi_entry(hi_lemma const & l):m_lemma(l) {}
};
struct hi_state {
name_set m_no_patterns;
name_map<hi_lemma> m_name_to_lemma;
hi_lemmas m_lemmas;
};
serializer & operator<<(serializer & s, multi_pattern const & mp) {
write_list(s, mp);
return s;
}
deserializer & operator>>(deserializer & d, multi_pattern & mp) {
mp = read_list<expr>(d);
return d;
}
static optional<name> get_hi_lemma_name(expr const & H) {
if (is_lambda(H))
return get_hi_lemma_name(binding_body(H));
expr const & f = get_app_fn(H);
if (is_constant(f))
return optional<name>(const_name(f));
else
return optional<name>();
}
struct hi_config {
typedef hi_entry entry;
typedef hi_state state;
struct no_pattern_config {
typedef name entry;
typedef name_set state;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
if (e.m_no_pattern) {
s.m_no_patterns.insert(*e.m_no_pattern);
} else {
if (auto n = get_hi_lemma_name(e.m_lemma.m_proof))
s.m_name_to_lemma.insert(*n, e.m_lemma);
for (multi_pattern const & mp : e.m_lemma.m_multi_patterns) {
for (expr const & p : mp) {
lean_assert(is_app(p));
lean_assert(is_constant(get_app_fn(p)));
s.m_lemmas.insert(const_name(get_app_fn(p)), e.m_lemma);
}
}
}
s.insert(e);
}
static name const & get_class_name() {
return *g_hi_name;
return *g_name;
}
static std::string const & get_serialization_key() {
@ -210,49 +154,36 @@ struct hi_config {
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_no_pattern;
if (!e.m_no_pattern) {
hi_lemma const & l = e.m_lemma;
s << l.m_num_uvars << l.m_num_mvars << l.m_priority << l.m_prop << l.m_proof;
write_list(s, l.m_mvars);
write_list(s, l.m_is_inst_implicit);
write_list(s, l.m_multi_patterns);
}
s << e;
}
static entry read_entry(deserializer & d) {
entry e;
d >> e.m_no_pattern;
if (!e.m_no_pattern) {
hi_lemma & l = e.m_lemma;
d >> l.m_num_uvars >> l.m_num_mvars >> l.m_priority >> l.m_prop >> l.m_proof;
l.m_mvars = read_list<expr>(d);
l.m_is_inst_implicit = read_list<bool>(d);
l.m_multi_patterns = read_list<multi_pattern>(d);
}
d >> e;
return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
return e.m_no_pattern ? some(e.m_no_pattern->hash()) : some(e.m_lemma.m_prop.hash());
return some(e.hash());
}
};
template class scoped_ext<hi_config>;
typedef scoped_ext<hi_config> hi_ext;
template class scoped_ext<no_pattern_config>;
typedef scoped_ext<no_pattern_config> no_pattern_ext;
bool is_no_pattern(environment const & env, name const & n) {
return hi_ext::get_state(env).m_no_patterns.contains(n);
return no_pattern_ext::get_state(env).contains(n);
}
environment add_no_pattern(environment const & env, name const & n, bool persistent) {
return hi_ext::add_entry(env, get_dummy_ios(), hi_entry(n), persistent);
return no_pattern_ext::add_entry(env, get_dummy_ios(), n, persistent);
}
name_set const & get_no_patterns(environment const & env) {
return hi_ext::get_state(env).m_no_patterns;
return no_pattern_ext::get_state(env);
}
namespace blast {
typedef rb_tree<unsigned, unsigned_cmp> idx_metavar_set;
static bool is_higher_order(tmp_type_context & ctx, expr const & e) {
@ -363,7 +294,7 @@ struct mk_hi_lemma_fn {
mk_hi_lemma_fn(tmp_type_context & ctx, expr const & H,
unsigned num_uvars, unsigned prio, unsigned max_steps):
m_ctx(ctx), m_no_patterns(hi_ext::get_state(ctx.env()).m_no_patterns),
m_ctx(ctx), m_no_patterns(no_pattern_ext::get_state(ctx.env())),
m_H(H), m_num_uvars(num_uvars), m_priority(prio), m_max_steps(max_steps) {}
struct candidate {
@ -599,7 +530,12 @@ struct mk_hi_lemma_fn {
}
hi_lemma operator()() {
expr H_type = m_ctx.infer(m_H);
expr H_type;
{
// preserve pattern hints
scope_unfold_macro_pred scope1([](expr const & e) { return !is_pattern_hint(e); });
H_type = normalize(m_ctx.infer(m_H));
}
buffer<bool> inst_implicit_flags;
expr B = extract_trackable(m_ctx, H_type, m_mvars, inst_implicit_flags, m_trackable, m_residue);
lean_assert(m_mvars.size() == inst_implicit_flags.size());
@ -652,52 +588,37 @@ hi_lemma mk_hi_lemma_core(tmp_type_context & ctx, expr const & H, unsigned num_u
return mk_hi_lemma_fn(ctx, H, num_uvars, priority, max_steps)();
}
namespace blast {
hi_lemma mk_hi_lemma(tmp_type_context & ctx, expr const & H) {
hi_lemma mk_hi_lemma(expr const & H) {
blast_tmp_type_context ctx;
unsigned max_steps = get_config().m_pattern_max_steps;
ctx.clear();
return mk_hi_lemma_core(ctx, H, 0, LEAN_HI_LEMMA_DEFAULT_PRIORITY, max_steps);
}}
return mk_hi_lemma_core(*ctx, H, 0, LEAN_FORWARD_LEMMA_DEFAULT_PRIORITY, max_steps);
}
environment add_hi_lemma(environment const & env, options const & o, name const & c, unsigned priority, bool persistent) {
tmp_type_context ctx(env, get_dummy_ios());
declaration const & d = env.get(c);
hi_lemma mk_hi_lemma(name const & c, unsigned priority) {
blast_tmp_type_context ctx;
unsigned max_steps = get_config().m_pattern_max_steps;
declaration const & d = env().get(c);
buffer<level> us;
unsigned num_us = d.get_num_univ_params();
for (unsigned i = 0; i < num_us; i++)
us.push_back(ctx.mk_uvar());
expr H = mk_constant(c, to_list(us));
unsigned max_steps = get_pattern_max_steps(o);
return hi_ext::add_entry(env, get_dummy_ios(), hi_entry(mk_hi_lemma_core(ctx, H, num_us, priority, max_steps)),
persistent);
us.push_back(ctx->mk_uvar());
expr H = mk_constant(c, to_list(us));
return mk_hi_lemma_core(*ctx, H, num_us, priority, max_steps);
}
hi_lemma const * get_hi_lemma(environment const & env, name const & c) {
return hi_ext::get_state(env).m_name_to_lemma.find(c);
}
hi_lemmas get_hi_lemma_index(environment const & env) {
return hi_ext::get_state(env).m_lemmas;
}
void initialize_pattern() {
g_hi_name = new name("hi");
g_key = new std::string("HI");
hi_ext::initialize();
g_name = new name("no_pattern");
g_key = new std::string("NOPAT");
no_pattern_ext::initialize();
g_pattern_hint = new name("pattern_hint");
register_annotation(*g_pattern_hint);
g_pattern_max_steps = new name{"pattern", "max_steps"};
register_unsigned_option(*g_pattern_max_steps, LEAN_DEFAULT_PATTERN_MAX_STEPS,
"(pattern) max number of steps performed by pattern inference procedure, "
"we have this threshold because in the worst case this procedure may take "
"an exponetial number of steps");
}
void finalize_pattern() {
hi_ext::finalize();
delete g_hi_name;
no_pattern_ext::finalize();
delete g_name;
delete g_key;
delete g_pattern_hint;
delete g_pattern_max_steps;
}
}

View file

@ -10,10 +10,6 @@ Author: Leonardo de Moura
#include "library/expr_lt.h"
#include "library/tmp_type_context.h"
#ifndef LEAN_HI_LEMMA_DEFAULT_PRIORITY
#define LEAN_HI_LEMMA_DEFAULT_PRIORITY 1000
#endif
namespace lean {
/** \brief Annotate \c e as a pattern hint */
expr mk_pattern_hint(expr const & e);
@ -52,27 +48,11 @@ struct hi_lemma_cmp {
int operator()(hi_lemma const & l1, hi_lemma const & l2) const { return expr_quick_cmp()(l1.m_prop, l2.m_prop); }
};
/** \brief Mapping c -> S, where c is a constant name and S is a set of hi_lemmas that contain
a pattern where the head symbol is c. */
typedef rb_multi_map<name, hi_lemma, name_quick_cmp> hi_lemmas;
/** \brief Add the given theorem as a heuristic instantiation lemma in the current environment. */
environment add_hi_lemma(environment const & env, options const & o, name const & c, unsigned priority, bool persistent);
/** \brief Return the heuristic instantiation lemma data associated with constant \c c */
hi_lemma const * get_hi_lemma(environment const & env, name const & c);
/** \brief Retrieve the active set of heuristic instantiation lemmas. */
hi_lemmas get_hi_lemma_index(environment const & env);
hi_lemma mk_hi_lemma(tmp_type_context & ctx, expr const & H, unsigned max_steps);
unsigned get_pattern_max_steps(options const & o);
namespace blast {
/** \brief Create a (local) heuristic instantiation lemma for \c H.
The maximum number of steps is extracted from the blast config object. */
hi_lemma mk_hi_lemma(tmp_type_context & ctx, expr const & H);
hi_lemma mk_hi_lemma(expr const & H);
hi_lemma mk_hi_lemma(name const & n, unsigned prio);
}
void initialize_pattern();

View file

@ -44,7 +44,9 @@ Author: Leonardo de Moura
#ifndef LEAN_DEFAULT_BLAST_BACKWARD
#define LEAN_DEFAULT_BLAST_BACKWARD true
#endif
#ifndef LEAN_DEFAULT_PATTERN_MAX_STEPS
#define LEAN_DEFAULT_PATTERN_MAX_STEPS 1024
#endif
namespace lean {
namespace blast {
@ -61,6 +63,7 @@ static name * g_blast_recursor = nullptr;
static name * g_blast_ematch = nullptr;
static name * g_blast_backward = nullptr;
static name * g_blast_show_failure = nullptr;
static name * g_pattern_max_steps = nullptr;
unsigned get_blast_max_depth(options const & o) {
return o.get_unsigned(*g_blast_max_depth, LEAN_DEFAULT_BLAST_MAX_DEPTH);
@ -98,6 +101,9 @@ bool get_blast_backward(options const & o) {
bool get_blast_show_failure(options const & o) {
return o.get_bool(*g_blast_show_failure, LEAN_DEFAULT_BLAST_SHOW_FAILURE);
}
unsigned get_pattern_max_steps(options const & o) {
return o.get_unsigned(*g_pattern_max_steps, LEAN_DEFAULT_PATTERN_MAX_STEPS);
}
config::config(options const & o) {
m_max_depth = get_blast_max_depth(o);
@ -145,6 +151,7 @@ void initialize_options() {
g_blast_ematch = new name{"blast", "ematch"};
g_blast_backward = new name{"blast", "backward"};
g_blast_show_failure = new name{"blast", "show_failure"};
g_pattern_max_steps = new name{"pattern", "max_steps"};
register_unsigned_option(*blast::g_blast_max_depth, LEAN_DEFAULT_BLAST_MAX_DEPTH,
"(blast) max search depth for blast");
@ -170,6 +177,10 @@ void initialize_options() {
"(blast) enable backward chaining");
register_bool_option(*blast::g_blast_show_failure, LEAN_DEFAULT_BLAST_SHOW_FAILURE,
"(blast) show failure state");
register_unsigned_option(*g_pattern_max_steps, LEAN_DEFAULT_PATTERN_MAX_STEPS,
"(pattern) max number of steps performed by pattern inference procedure, "
"we have this threshold because in the worst case this procedure may take "
"an exponetial number of steps");
}
void finalize_options() {
delete g_blast_max_depth;
@ -184,5 +195,6 @@ void finalize_options() {
delete g_blast_ematch;
delete g_blast_backward;
delete g_blast_show_failure;
delete g_pattern_max_steps;
}
}}

View file

@ -0,0 +1,9 @@
constants f g : nat → Prop
definition foo₁ [forward] : ∀ x, f x ∧ g x := sorry
definition foo₂ [forward] : ∀ x, (: f x :) ∧ g x := sorry
definition foo₃ [forward] : ∀ x, (: f (id x) :) ∧ g x := sorry
print foo₁
print foo₂
print foo₃ -- id is unfolded

View file

@ -0,0 +1,13 @@
definition foo₁ [forward] : ∀ (x : ), f x ∧ g x :=
sorry
(multi-)patterns:
{f ?M_1}
{g ?M_1}
definition foo₂ [forward] : ∀ (x : ), (:f x:) ∧ g x :=
sorry
(multi-)patterns:
{f ?M_1}
definition foo₃ [forward] : ∀ (x : ), (:f (id x):) ∧ g x :=
sorry
(multi-)patterns:
{f ?M_1}

View file

@ -0,0 +1,4 @@
definition foo [forward] : ∀ (m n k : ), P (f m) → P (g n) → P (f k) → P k ∧ R (g m) (f n) ∧ P (g m) ∧ P (f n) :=
λ (m n k : ), sorry
(multi-)patterns:
{P ?M_1, R (g ?M_2) (f ?M_3)}