feat(library/blast/imp_extension): imperative branch extensions

This commit is contained in:
Daniel Selsam 2015-12-01 09:57:12 -08:00 committed by Leonardo de Moura
parent 83b9769225
commit 601dc544b6
5 changed files with 195 additions and 3 deletions

View file

@ -3,4 +3,4 @@ add_library(blast OBJECT expr.cpp state.cpp blast.cpp blast_tactic.cpp
options.cpp choice_point.cpp simple_strategy.cpp util.cpp options.cpp choice_point.cpp simple_strategy.cpp util.cpp
gexpr.cpp revert.cpp subst_action.cpp no_confusion_action.cpp gexpr.cpp revert.cpp subst_action.cpp no_confusion_action.cpp
strategy.cpp recursor_action.cpp congruence_closure.cpp strategy.cpp recursor_action.cpp congruence_closure.cpp
trace.cpp assert_cc_action.cpp) trace.cpp assert_cc_action.cpp imp_extension.cpp)

View file

@ -38,11 +38,38 @@ namespace blast {
static name * g_prefix = nullptr; static name * g_prefix = nullptr;
static name * g_tmp_prefix = nullptr; static name * g_tmp_prefix = nullptr;
class imp_extension_manager {
std::vector<pair<ext_state_maker &, unsigned> > m_entries;
public:
std::vector<pair<ext_state_maker &, unsigned> > const & get_entries() { return m_entries; }
unsigned register_imp_extension(ext_state_maker & state_maker) {
unsigned state_id = m_entries.size();
unsigned ext_id = register_branch_extension(new imp_extension(state_id));
m_entries.emplace_back(state_maker, ext_id);
return state_id;
}
};
static imp_extension_manager * g_imp_extension_manager = nullptr;
static imp_extension_manager & get_imp_extension_manager() {
return *g_imp_extension_manager;
}
struct imp_extension_entry {
std::unique_ptr<imp_extension_state> m_ext_state;
unsigned m_ext_id;
imp_extension * m_ext_of_ext_state;
imp_extension_entry(imp_extension_state * ext_state, unsigned ext_id, imp_extension * ext_of_ext_state):
m_ext_state(ext_state), m_ext_id(ext_id), m_ext_of_ext_state(ext_of_ext_state) {}
};
class blastenv { class blastenv {
friend class scope_assignment; friend class scope_assignment;
friend class scope_unfold_macro_pred; friend class scope_unfold_macro_pred;
typedef std::vector<tmp_type_context *> tmp_type_context_pool; typedef std::vector<tmp_type_context *> tmp_type_context_pool;
typedef std::unique_ptr<tmp_type_context> tmp_type_context_ptr; typedef std::unique_ptr<tmp_type_context> tmp_type_context_ptr;
typedef std::vector<imp_extension_entry> imp_extension_entries;
environment m_env; environment m_env;
io_state m_ios; io_state m_ios;
@ -66,6 +93,7 @@ class blastenv {
congr_lemma_manager m_congr_lemma_manager; congr_lemma_manager m_congr_lemma_manager;
abstract_expr_manager m_abstract_expr_manager; abstract_expr_manager m_abstract_expr_manager;
light_lt_manager m_light_lt_manager; light_lt_manager m_light_lt_manager;
imp_extension_entries m_imp_extension_entries;
relation_info_getter m_rel_getter; relation_info_getter m_rel_getter;
refl_info_getter m_refl_getter; refl_info_getter m_refl_getter;
symm_info_getter m_symm_getter; symm_info_getter m_symm_getter;
@ -461,6 +489,7 @@ public:
} }
~blastenv() { ~blastenv() {
finalize_imp_extension_entries();
for (auto ctx : m_tmp_ctx_pool) for (auto ctx : m_tmp_ctx_pool)
delete ctx; delete ctx;
} }
@ -480,6 +509,7 @@ public:
void init_state(goal const & g) { void init_state(goal const & g) {
init_curr_state(g); init_curr_state(g);
init_imp_extension_entries();
save_initial_context(); save_initial_context();
m_tctx.set_local_instances(m_initial_context); m_tctx.set_local_instances(m_initial_context);
m_tmp_ctx->set_local_instances(m_initial_context); m_tmp_ctx->set_local_instances(m_initial_context);
@ -568,6 +598,62 @@ public:
return m_abstract_expr_manager.hash(e); return m_abstract_expr_manager.hash(e);
} }
void init_imp_extension_entries() {
for (auto & p : get_imp_extension_manager().get_entries()) {
branch_extension & b_ext = curr_state().get_extension(p.second);
b_ext.inc_ref();
m_imp_extension_entries.emplace_back(p.first(), p.second, static_cast<imp_extension*>(&b_ext));
}
}
void finalize_imp_extension_entries() {
for (auto & e : m_imp_extension_entries) {
e.m_ext_of_ext_state->dec_ref();
}
}
list<imp_extension*> get_ext_path(imp_extension * _imp_ext) {
list<imp_extension*> path;
imp_extension * imp_ext = _imp_ext;
while (imp_ext != nullptr) {
path = cons(imp_ext, path);
imp_ext = imp_ext->m_parent;
}
return path;
}
imp_extension_state & get_imp_extension_state(unsigned state_id) {
lean_assert(state_id < m_imp_extension_entries.size());
imp_extension_entry & e = m_imp_extension_entries[state_id];
imp_extension_state * ext_state = e.m_ext_state.get();
imp_extension * ext_of_curr_state = static_cast<imp_extension*>(&curr_state().get_extension(e.m_ext_id));
lean_assert(e.m_ext_of_ext_state);
imp_extension * ext_of_ext_state = e.m_ext_of_ext_state;
list<imp_extension*> curr_state_path = get_ext_path(ext_of_curr_state);
list<imp_extension*> ext_state_path = get_ext_path(ext_of_ext_state);
while (true) {
if (is_nil(curr_state_path) || is_nil(ext_state_path)) break;
if (head(curr_state_path) != head(ext_state_path)) break;
curr_state_path = tail(curr_state_path);
ext_state_path = tail(ext_state_path);
}
for_each(reverse(ext_state_path), [&](imp_extension * imp_ext) {
ext_state->undo_actions(imp_ext);
});
for_each(curr_state_path, [&](imp_extension * imp_ext) {
ext_state->replay_actions(imp_ext);
});
ext_of_curr_state->inc_ref();
ext_of_ext_state->dec_ref();
e.m_ext_of_ext_state = ext_of_curr_state;
return *ext_state;
}
bool abstract_is_equal(expr const & e1, expr const & e2) { bool abstract_is_equal(expr const & e1, expr const & e2) {
return m_abstract_expr_manager.is_equal(e1, e2); return m_abstract_expr_manager.is_equal(e1, e2);
} }
@ -795,6 +881,15 @@ unsigned abstract_hash(expr const & e) {
return g_blastenv->abstract_hash(e); return g_blastenv->abstract_hash(e);
} }
unsigned register_imp_extension(std::function<imp_extension_state*()> & ext_state_maker) {
return get_imp_extension_manager().register_imp_extension(ext_state_maker);
}
imp_extension_state & get_imp_extension_state(unsigned state_id) {
lean_assert(g_blastenv);
return g_blastenv->get_imp_extension_state(state_id);
}
bool abstract_is_equal(expr const & e1, expr const & e2) { bool abstract_is_equal(expr const & e1, expr const & e2) {
lean_assert(g_blastenv); lean_assert(g_blastenv);
return g_blastenv->abstract_is_equal(e1, e2); return g_blastenv->abstract_is_equal(e1, e2);
@ -963,10 +1058,12 @@ optional<expr> blast_goal(environment const & env, io_state const & ios, list<na
return b(g); return b(g);
} }
void initialize_blast() { void initialize_blast() {
blast::g_prefix = new name(name::mk_internal_unique_name()); blast::g_prefix = new name(name::mk_internal_unique_name());
blast::g_tmp_prefix = new name(name::mk_internal_unique_name()); blast::g_tmp_prefix = new name(name::mk_internal_unique_name());
blast::g_imp_extension_manager = new blast::imp_extension_manager();
} }
void finalize_blast() { void finalize_blast() {
delete blast::g_imp_extension_manager;
delete blast::g_prefix; delete blast::g_prefix;
delete blast::g_tmp_prefix; delete blast::g_tmp_prefix;
} }

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/congr_lemma_manager.h" #include "library/congr_lemma_manager.h"
#include "library/fun_info_manager.h" #include "library/fun_info_manager.h"
#include "library/blast/state.h" #include "library/blast/state.h"
#include "library/blast/imp_extension.h"
namespace lean { namespace lean {
struct projection_info; struct projection_info;
@ -120,6 +121,14 @@ bool is_light_lt(expr const & e1, expr const & e2);
/** \brief Whether [classical] namespace is open. */ /** \brief Whether [classical] namespace is open. */
bool classical(); bool classical();
/** \brief This procedure must be invoked at Lean initialization time for each imperative branch extension.
The unique id returned should be used to retrieve the extension state associated with the current state. */
unsigned register_imp_extension(ext_state_maker & state_maker);
/** \brief This procedure returns a reference to the extension state associated with the current state.
It handles all the bookeeping so that the returned extension state is guaranteed to be in synch with
the current blast state. */
imp_extension_state & get_imp_extension_state(unsigned state_id);
/** \brief Display the current state of the blast tactic in the diagnostic channel. */ /** \brief Display the current state of the blast tactic in the diagnostic channel. */
void display_curr_state(); void display_curr_state();

View file

@ -0,0 +1,44 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include "library/blast/imp_extension.h"
#include "library/blast/blast.h"
namespace lean {
namespace blast {
imp_extension::imp_extension(unsigned state_id): m_state_id(state_id), m_parent(nullptr) {}
imp_extension::imp_extension(imp_extension * parent):
m_state_id(parent->m_state_id), m_parent(parent) {
parent->inc_ref();
}
imp_extension::~imp_extension() { if (m_parent) m_parent->dec_ref(); }
imp_extension * imp_extension::clone() {
return new imp_extension(this);
}
void imp_extension::hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) {
imp_extension_state & state = get_imp_extension_state(m_state_id);
if (is_nil(m_asserts)) state.push();
m_asserts = cons(h, m_asserts);
state.assert(h);
}
void imp_extension_state::replay_actions(imp_extension * imp_ext) {
list<hypothesis> const & asserts = reverse(imp_ext->get_asserts());
if (is_nil(asserts)) return;
push();
for_each(asserts, [&](hypothesis const & h) { assert(h); });
}
void imp_extension_state::undo_actions(imp_extension * imp_ext) {
list<hypothesis> const & asserts = imp_ext->get_asserts();
if (is_nil(asserts)) return;
pop();
}
}}

View file

@ -0,0 +1,42 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#pragma once
#include "library/blast/state.h"
namespace lean {
namespace blast {
struct imp_extension : branch_extension {
unsigned m_state_id;
imp_extension * m_parent;
list<hypothesis> m_asserts;
imp_extension(unsigned state_id);
imp_extension(imp_extension * parent);
~imp_extension();
unsigned get_state_id() { return m_state_id; }
imp_extension * get_parent() { return m_parent; }
list<hypothesis> const & get_asserts() { return m_asserts; }
virtual imp_extension * clone() override;
virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) override;
};
struct imp_extension_state {
virtual void push() =0;
virtual void pop() =0;
virtual void assert(hypothesis const & h) =0;
virtual ~imp_extension_state() {}
void replay_actions(imp_extension * imp_ext);
void undo_actions(imp_extension * imp_ext);
};
typedef std::function<imp_extension_state*()> ext_state_maker;
}}