refactor(library/blast/backward): use priority_queue, make sure head is normalized when building index

This commit is contained in:
Leonardo de Moura 2015-12-28 12:10:56 -08:00
parent 26d0a62052
commit c8b9c98eb6
15 changed files with 280 additions and 264 deletions

View file

@ -29,7 +29,7 @@ Author: Leonardo de Moura
#include "library/definitional/projection.h"
#include "library/blast/blast.h"
#include "library/blast/simplifier/simplifier.h"
#include "library/blast/backward/backward_rule_set.h"
#include "library/blast/backward/backward_lemmas.h"
#include "library/blast/forward/forward_lemma_set.h"
#include "library/blast/forward/pattern.h"
#include "library/blast/grinder/intro_elim_lemmas.h"
@ -543,16 +543,19 @@ static void print_elim_lemmas(parser & p) {
}
static void print_intro_lemmas(parser & p) {
io_state_stream out = p.regular_stream();
buffer<name> lemmas;
get_intro_lemmas(p.env(), lemmas);
for (auto n : lemmas)
p.regular_stream() << n << "\n";
out << n << "\n";
}
static void print_backward_rules(parser & p) {
static void print_backward_lemmas(parser & p) {
io_state_stream out = p.regular_stream();
blast::backward_rule_set brs = get_backward_rule_set(p.env());
out << brs;
buffer<name> lemmas;
get_backward_lemmas(p.env(), lemmas);
for (auto n : lemmas)
out << n << "\n";
}
static void print_no_patterns(parser & p) {
@ -712,7 +715,7 @@ environment print_cmd(parser & p) {
print_light_rules(p);
} else if (p.curr_is_token(get_intro_attr_tk())) {
p.next();
print_backward_rules(p);
print_backward_lemmas(p);
} else if (print_polymorphic(p)) {
} else {
throw parser_error("invalid print command", p.pos());

View file

@ -1 +1 @@
add_library(backward OBJECT init_module.cpp backward_action.cpp backward_rule_set.cpp backward_strategy.cpp)
add_library(backward OBJECT init_module.cpp backward_action.cpp backward_lemmas.cpp backward_strategy.cpp)

View file

@ -0,0 +1,172 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <string>
#include "util/sstream.h"
#include "util/priority_queue.h"
#include "kernel/instantiate.h"
#include "library/trace.h"
#include "library/scoped_ext.h"
#include "library/user_recursors.h"
#include "library/tmp_type_context.h"
#include "library/attribute_manager.h"
#include "library/blast/blast.h"
#include "library/blast/backward/backward_lemmas.h"
namespace lean {
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
typedef priority_queue<name, name_quick_cmp> backward_state;
typedef std::tuple<unsigned, name> backward_entry;
struct backward_config {
typedef backward_entry entry;
typedef backward_state state;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
unsigned prio; name n;
std::tie(prio, n) = e;
s.insert(n, prio);
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
unsigned prio; name n;
std::tie(prio, n) = e;
s << prio << n;
}
static entry read_entry(deserializer & d) {
unsigned prio; name n;
d >> prio >> n;
return entry(prio, n);
}
static optional<unsigned> get_fingerprint(entry const & e) {
unsigned prio; name n;
std::tie(prio, n) = e;
return some(hash(n.hash(), prio));
}
};
typedef scoped_ext<backward_config> backward_ext;
static optional<head_index> get_backward_target(tmp_type_context & ctx, expr type) {
while (is_pi(type)) {
expr local = ctx.mk_tmp_local(binding_domain(type));
type = ctx.try_to_pi(instantiate(binding_body(type), local));
}
expr fn = get_app_fn(ctx.whnf(type));
if (is_constant(fn) || is_local(fn))
return optional<head_index>(fn);
else
return optional<head_index>();
}
static optional<head_index> get_backward_target(tmp_type_context & ctx, name const & c) {
declaration const & d = ctx.env().get(c);
buffer<level> us;
unsigned num_us = d.get_num_univ_params();
for (unsigned i = 0; i < num_us; i++)
us.push_back(ctx.mk_uvar());
expr type = ctx.try_to_pi(instantiate_type_univ_params(d, to_list(us)));
return get_backward_target(ctx, type);
}
environment add_backward_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) {
tmp_type_context ctx(env, ios.get_options());
auto index = get_backward_target(ctx, c);
if (!index || index->kind() != expr_kind::Constant)
throw exception(sstream() << "invalid [intro] attribute for '" << c << "', head symbol of resulting type must be a constant");
return backward_ext::add_entry(env, ios, backward_entry(prio, c), ns, persistent);
}
bool is_backward_lemma(environment const & env, name const & c) {
return backward_ext::get_state(env).contains(c);
}
void get_backward_lemmas(environment const & env, buffer<name> & r) {
return backward_ext::get_state(env).to_buffer(r);
}
void initialize_backward_lemmas() {
g_class_name = new name("backward");
g_key = new std::string("BWD");
backward_ext::initialize();
register_prio_attribute("intro", "introduction rule for backward chaining",
add_backward_lemma,
is_backward_lemma,
[](environment const & env, name const & d) {
if (auto p = backward_ext::get_state(env).get_prio(d))
return *p;
else
return LEAN_DEFAULT_PRIORITY;
});
}
void finalize_backward_lemmas() {
backward_ext::finalize();
delete g_key;
delete g_class_name;
}
namespace blast {
unsigned backward_lemma_prio_fn::operator()(backward_lemma const & r) const {
if (r.is_universe_polymorphic()) {
name const & n = r.to_name();
auto const & s = backward_ext::get_state(env());
if (auto prio = s.get_prio(n))
return *prio;
}
return LEAN_DEFAULT_PRIORITY;
}
void backward_lemma_index::init() {
m_index.clear();
buffer<name> lemmas;
blast_tmp_type_context ctx;
auto const & s = backward_ext::get_state(env());
s.to_buffer(lemmas);
unsigned i = lemmas.size();
while (i > 0) {
--i;
ctx->clear();
optional<head_index> target = get_backward_target(*ctx, lemmas[i]);
if (!target || target->kind() != expr_kind::Constant) {
lean_trace(name({"blast", "event"}),
tout() << "discarding [intro] lemma '" << lemmas[i] << "', failed to find target type\n";);
} else {
m_index.insert(*target, backward_lemma(lemmas[i]));
}
}
}
void backward_lemma_index::insert(expr const & href) {
blast_tmp_type_context ctx;
expr href_type = ctx->infer(href);
if (optional<head_index> target = get_backward_target(*ctx, href_type)) {
m_index.insert(*target, backward_lemma(gexpr(href)));
}
}
void backward_lemma_index::erase(expr const & href) {
blast_tmp_type_context ctx;
expr href_type = ctx->infer(href);
if (optional<head_index> target = get_backward_target(*ctx, href_type)) {
m_index.erase(*target, backward_lemma(gexpr(href)));
}
}
list<backward_lemma> backward_lemma_index::find(head_index const & h) const {
if (auto r = m_index.find(h))
return *r;
else
return list<backward_lemma>();
}
}}

View file

@ -0,0 +1,33 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/environment.h"
#include "library/attribute_manager.h"
#include "library/io_state.h"
#include "library/head_map.h"
#include "library/blast/gexpr.h"
namespace lean {
environment add_backward_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent);
bool is_backward_lemma(environment const & env, name const & n);
void get_backward_lemmas(environment const & env, buffer<name> & r);
void initialize_backward_lemmas();
void finalize_backward_lemmas();
namespace blast {
typedef gexpr backward_lemma;
struct backward_lemma_prio_fn { unsigned operator()(backward_lemma const & r) const; };
/* The following indices are based on blast current set of opaque/reducible constants. They
must be rebuilt whenever a key is "unfolded by blast */
class backward_lemma_index {
head_map_prio<backward_lemma, backward_lemma_prio_fn> m_index;
public:
void init();
void insert(expr const & href);
void erase(expr const & href);
list<backward_lemma> find(head_index const & h) const;
};
}}

View file

@ -1,171 +0,0 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include <string>
#include "util/sstream.h"
#include "kernel/error_msgs.h"
#include "kernel/instantiate.h"
#include "library/scoped_ext.h"
#include "library/attribute_manager.h"
#include "library/blast/backward/backward_rule_set.h"
namespace lean {
using blast::backward_rule;
using blast::backward_rule_set;
using blast::gexpr;
using std::function;
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
struct brs_state {
backward_rule_set m_backward_rule_set;
name_set m_names;
void add(environment const & env, options const & o, name const & cname, unsigned prio) {
default_type_context tctx(env, o);
m_backward_rule_set.insert(tctx, cname, prio);
m_names.insert(cname);
}
};
struct brs_entry {
name m_name;
unsigned m_priority;
brs_entry() {}
brs_entry(name const & n, unsigned prio): m_name(n), m_priority(prio) { }
};
struct brs_config {
typedef brs_entry entry;
typedef brs_state state;
static void add_entry(environment const & env, io_state const & ios, state & s, entry const & e) {
s.add(env, ios.get_options(), e.m_name, e.m_priority);
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
entry e; d >> e.m_name >> e.m_priority; return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(e.m_name.hash());
}
};
template class scoped_ext<brs_config>;
typedef scoped_ext<brs_config> brs_ext;
environment add_backward_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent) {
return brs_ext::add_entry(env, get_dummy_ios(), brs_entry(n, priority), ns, persistent);
}
bool is_backward_rule(environment const & env, name const & n) {
return brs_ext::get_state(env).m_names.contains(n);
}
backward_rule_set get_backward_rule_set(environment const & env) {
return brs_ext::get_state(env).m_backward_rule_set;
}
backward_rule_set get_backward_rule_sets(environment const & env, options const & o, name const & ns) {
backward_rule_set brs;
list<brs_entry> const * entries = brs_ext::get_entries(env, ns);
if (entries) {
for (auto const & e : *entries) {
default_type_context tctx(env, o);
brs.insert(tctx, e.m_name, e.m_priority);
}
}
return brs;
}
io_state_stream const & operator<<(io_state_stream const & out, backward_rule_set const & brs) {
out << "backward rules\n";
brs.for_each([&](head_index const & head_idx, backward_rule const & r) {
out << head_idx << " ==> " << r.get_proof().to_bare_expr() << "\n";
});
return out;
}
namespace blast {
bool operator==(backward_rule const & r1, backward_rule const & r2) {
return r1.get_proof() == r2.get_proof();
}
void backward_rule_set::insert(type_context & tctx, name const & id, gexpr const & proof, expr const & _thm, unsigned prio) {
expr thm = tctx.whnf(_thm);
while (is_pi(thm)) {
expr local = tctx.mk_tmp_local(binding_domain(thm), binding_info(thm));
thm = tctx.whnf(instantiate(binding_body(thm), local));
}
m_set.insert(head_index(thm), backward_rule(id, proof, prio));
}
void backward_rule_set::insert(type_context & tctx, name const & name, unsigned prio) {
gexpr proof(tctx.env(), name);
declaration const & d = tctx.env().get(name);
insert(tctx, name, proof, d.get_type(), prio);
}
void backward_rule_set::erase(name_set const & ids) {
// This method is not very smart and doesn't use any indexing or caching.
// So, it may be a bottleneck in the future
buffer<pair<head_index, backward_rule> > to_delete;
for_each([&](head_index const & h, backward_rule const & r) {
if (ids.contains(r.get_id())) {
to_delete.push_back(mk_pair(h, r));
}
});
for (auto const & hr : to_delete) {
m_set.erase(hr.first, hr.second);
}
}
void backward_rule_set::erase(name const & id) {
name_set ids;
ids.insert(id);
erase(ids);
}
void backward_rule_set::for_each(std::function<void(head_index const & h, backward_rule const & r)> const & fn) const {
m_set.for_each_entry(fn);
}
list<gexpr> backward_rule_set::find(head_index const & h) const {
list<backward_rule> const * rule_list = m_set.find(h);
if (!rule_list) return list<gexpr>();
return map2<gexpr, backward_rule, function<gexpr(backward_rule)>>(*rule_list, [&](backward_rule const & r) { return r.get_proof(); });
}
void initialize_backward_rule_set() {
g_class_name = new name("backward");
g_key = new std::string("BWD");
brs_ext::initialize();
register_prio_attribute("intro", "backward chaining",
[](environment const & env, io_state const &, name const & d, unsigned prio, name const & ns, bool persistent) {
return add_backward_rule(env, d, prio, ns, persistent);
},
is_backward_rule,
[](environment const &, name const &) {
// TODO(Leo): fix it after we refactor backward_rule_set
return LEAN_DEFAULT_PRIORITY;
});
}
void finalize_backward_rule_set() {
brs_ext::finalize();
delete g_key;
delete g_class_name;
}
}
}

View file

@ -1,64 +0,0 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#pragma once
#include <vector>
#include "library/type_context.h"
#include "library/head_map.h"
#include "library/io_state_stream.h"
#include "library/blast/gexpr.h"
namespace lean {
namespace blast {
class backward_rule {
name m_id;
gexpr m_proof;
unsigned m_priority;
public:
backward_rule(name const & id, gexpr const & proof, unsigned priority):
m_id(id), m_proof(proof), m_priority(priority) {}
name const & get_id() const { return m_id; }
gexpr const & get_proof() const { return m_proof; }
unsigned get_priority() const { return m_priority; }
};
bool operator==(backward_rule const & r1, backward_rule const & r2);
inline bool operator!=(backward_rule const & r1, backward_rule const & r2) { return !operator==(r1, r2); }
struct backward_rule_prio_fn { unsigned operator()(backward_rule const & r) const { return r.get_priority(); } };
class backward_rule_set {
head_map_prio<backward_rule, backward_rule_prio_fn> m_set;
public:
void insert(type_context & tctx, name const & id, gexpr const & proof, expr const & thm, unsigned prio);
void insert(type_context & tctx, name const & name, unsigned prio);
void erase(name_set const & ids);
void erase(name const & id);
void for_each(std::function<void(head_index const & h, backward_rule const & r)> const & fn) const;
list<gexpr> find(head_index const & h) const;
};
void initialize_backward_rule_set();
void finalize_backward_rule_set();
}
environment add_backward_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent);
/** \brief Return true if \c n is an active backward rule in \c env */
bool is_backward_rule(environment const & env, name const & n);
/** \brief Get current backward rule set */
blast::backward_rule_set get_backward_rule_set(environment const & env);
/** \brief Get backward rule set in the given namespace. */
blast::backward_rule_set get_backward_rule_sets(environment const & env, options const & o, name const & ns);
io_state_stream const & operator<<(io_state_stream const & out, blast::backward_rule_set const & r);
}

View file

@ -10,7 +10,7 @@ Author: Daniel Selsam
#include "library/blast/choice_point.h"
#include "library/blast/proof_expr.h"
#include "library/blast/strategy.h"
#include "library/blast/backward/backward_rule_set.h"
#include "library/blast/backward/backward_lemmas.h"
#include "library/blast/backward/backward_action.h"
#include "library/blast/actions/simple_actions.h"
#include "library/blast/actions/intros_action.h"
@ -21,21 +21,20 @@ namespace lean {
namespace blast {
static unsigned g_ext_id = 0;
struct backward_branch_extension : public branch_extension {
backward_rule_set m_backward_rule_set;
backward_lemma_index m_backward_lemmas;
backward_branch_extension() {}
backward_branch_extension(backward_branch_extension const & b):
m_backward_rule_set(b.m_backward_rule_set) {}
m_backward_lemmas(b.m_backward_lemmas) {}
virtual ~backward_branch_extension() {}
virtual branch_extension * clone() override { return new backward_branch_extension(*this); }
virtual void initialized() override { m_backward_rule_set = ::lean::get_backward_rule_set(env()); }
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) override {
m_backward_rule_set.insert(get_type_context(), h.get_name(), gexpr(mk_href(hidx)),
h.get_type(), LEAN_DEFAULT_PRIORITY);
virtual void initialized() override { m_backward_lemmas.init(); }
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override {
m_backward_lemmas.insert(h.get_self());
}
virtual void hypothesis_deleted(hypothesis const & h, hypothesis_idx) override {
m_backward_rule_set.erase(h.get_name());
m_backward_lemmas.erase(h.get_self());
}
backward_rule_set const & get_backward_rule_set() const { return m_backward_rule_set; }
backward_lemma_index const & get_backward_lemmas() const { return m_backward_lemmas; }
};
void initialize_backward_strategy() {
@ -70,7 +69,7 @@ class backward_strategy_fn : public strategy_fn {
Try(activate_hypothesis());
Try(trivial_action());
Try(assumption_action());
list<gexpr> backward_rules = get_extension().get_backward_rule_set().find(head_index(curr_state().get_target()));
list<gexpr> backward_rules = get_extension().get_backward_lemmas().find(head_index(curr_state().get_target()));
Try(backward_action(backward_rules, true));
return action_result::failed();
}

View file

@ -4,20 +4,20 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include "library/blast/backward/init_module.h"
#include "library/blast/backward/backward_rule_set.h"
#include "library/blast/backward/backward_lemmas.h"
#include "library/blast/backward/backward_strategy.h"
namespace lean {
namespace blast {
void initialize_backward_module() {
initialize_backward_rule_set();
initialize_backward_lemmas();
initialize_backward_strategy();
}
void finalize_backward_module() {
finalize_backward_strategy();
finalize_backward_rule_set();
finalize_backward_lemmas();
}
}
}

View file

@ -45,6 +45,8 @@ public:
/** \brief Return "bare" expression (without adding fresh metavariables if universe polymorphic) */
expr to_bare_expr() const { return m_expr; }
name const & to_name() const { lean_assert(is_universe_polymorphic()); return const_name(m_expr); }
friend bool operator==(gexpr const & ge1, gexpr const & ge2);
friend std::ostream const & operator<<(std::ostream const & out, gexpr const & ge);
};

View file

@ -82,7 +82,7 @@ optional<name> get_intro_target(tmp_type_context & ctx, name const & c) {
expr local = ctx.mk_tmp_local(binding_domain(type));
type = ctx.try_to_pi(instantiate(binding_body(type), local));
}
expr const & fn = get_app_fn(type);
expr const & fn = get_app_fn(ctx.whnf(type));
if (is_constant(fn))
return optional<name>(const_name(fn));
else
@ -92,7 +92,7 @@ optional<name> get_intro_target(tmp_type_context & ctx, name const & c) {
environment add_intro_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) {
tmp_type_context ctx(env, ios.get_options());
if (!get_intro_target(ctx, c))
throw exception(sstream() << "invalid [intro] attribute for '" << c << "', head symbol of resulting type must be a constant");
throw exception(sstream() << "invalid [intro!] attribute for '" << c << "', head symbol of resulting type must be a constant");
return intro_elim_ext::add_entry(env, ios, intro_elim_entry(false, prio, c), ns, persistent);
}
@ -157,7 +157,7 @@ head_map<gexpr> mk_intro_lemma_index() {
optional<name> target = get_intro_target(*ctx, lemmas[i]);
if (!target) {
lean_trace(name({"blast", "event"}),
tout() << "discarding [intro] lemma '" << lemmas[i] << "', failed to find target type\n";);
tout() << "discarding [intro!] lemma '" << lemmas[i] << "', failed to find target type\n";);
} else {
r.insert(head_index(*target), gexpr(lemmas[i]));
}

View file

@ -17,6 +17,7 @@ struct head_index {
explicit head_index(expr_kind k = expr_kind::Var):m_kind(k) {}
explicit head_index(name const & c):m_kind(expr_kind::Constant), m_name(c) {}
head_index(expr const & e);
expr_kind kind() const { return m_kind; }
struct cmp {
int operator()(head_index const & i1, head_index const & i2) const;
@ -53,6 +54,7 @@ class head_map_prio : private GetPrio {
public:
head_map_prio() {}
head_map_prio(GetPrio const & g):GetPrio(g) {}
void clear() { m_map = rb_map<head_index, list<V>, head_index::cmp>(); }
bool empty() const { return m_map.empty(); }
bool contains(head_index const & h) const { return m_map.contains(h); }
list<V> const * find(head_index const & h) const { return m_map.find(h); }

View file

@ -1,9 +1,8 @@
constant H [intro] : A → B
constant G [intro] : A → B → C
constant f [intro] : T → A
backward rules
exists_unique ==> exists_unique.intro
B ==> H
A ==> f
C ==> G
Exists ==> Exists.intro
f
G
H
exists_unique.intro
Exists.intro

View file

@ -0,0 +1,25 @@
constant r : nat → Prop
constant s : nat → Prop
constant p : nat → Prop
definition q (a : nat) := p a
lemma rq₁ [intro] [priority 20] : ∀ a, r a → q a :=
sorry
lemma rq₂ [intro] [priority 10] : ∀ a, s a → q a :=
sorry
attribute q [reducible]
definition lemma1 (a : nat) : r a → s a → p a :=
by blast
print lemma1
attribute rq₂ [intro] [priority 30]
definition lemma2 (a : nat) : r a → s a → p a :=
by blast
print lemma2

View file

@ -0,0 +1,4 @@
definition lemma1 : ∀ (a : ), r a → s a → p a :=
λ (a : ) (H.1 : r a) (H.2 : s a), rq₁ a H.1
definition lemma2 : ∀ (a : ), r a → s a → p a :=
λ (a : ) (H.1 : r a), rq₂ a

View file

@ -0,0 +1,12 @@
constant r : nat → Prop
constant p : nat → Prop
definition q (a : nat) := p a
lemma rq [intro] : ∀ a, r a → q a :=
sorry
attribute q [reducible]
example (a : nat) : r a → p a :=
by blast