From 601dc544b69a6471681b26f8f6d316e6c805149a Mon Sep 17 00:00:00 2001 From: Daniel Selsam Date: Tue, 1 Dec 2015 09:57:12 -0800 Subject: [PATCH] feat(library/blast/imp_extension): imperative branch extensions --- src/library/blast/CMakeLists.txt | 2 +- src/library/blast/blast.cpp | 101 +++++++++++++++++++++++++++- src/library/blast/blast.h | 9 +++ src/library/blast/imp_extension.cpp | 44 ++++++++++++ src/library/blast/imp_extension.h | 42 ++++++++++++ 5 files changed, 195 insertions(+), 3 deletions(-) create mode 100644 src/library/blast/imp_extension.cpp create mode 100644 src/library/blast/imp_extension.h diff --git a/src/library/blast/CMakeLists.txt b/src/library/blast/CMakeLists.txt index 8242cfa8b..d7dbe38ae 100644 --- a/src/library/blast/CMakeLists.txt +++ b/src/library/blast/CMakeLists.txt @@ -3,4 +3,4 @@ add_library(blast OBJECT expr.cpp state.cpp blast.cpp blast_tactic.cpp options.cpp choice_point.cpp simple_strategy.cpp util.cpp gexpr.cpp revert.cpp subst_action.cpp no_confusion_action.cpp strategy.cpp recursor_action.cpp congruence_closure.cpp - trace.cpp assert_cc_action.cpp) + trace.cpp assert_cc_action.cpp imp_extension.cpp) diff --git a/src/library/blast/blast.cpp b/src/library/blast/blast.cpp index 2928902cb..7e047d505 100644 --- a/src/library/blast/blast.cpp +++ b/src/library/blast/blast.cpp @@ -38,11 +38,38 @@ namespace blast { static name * g_prefix = nullptr; static name * g_tmp_prefix = nullptr; +class imp_extension_manager { + std::vector > m_entries; +public: + std::vector > const & get_entries() { return m_entries; } + + unsigned register_imp_extension(ext_state_maker & state_maker) { + unsigned state_id = m_entries.size(); + unsigned ext_id = register_branch_extension(new imp_extension(state_id)); + m_entries.emplace_back(state_maker, ext_id); + return state_id; + } +}; + +static imp_extension_manager * g_imp_extension_manager = nullptr; +static imp_extension_manager & get_imp_extension_manager() { + return *g_imp_extension_manager; +} + +struct imp_extension_entry { + std::unique_ptr m_ext_state; + unsigned m_ext_id; + imp_extension * m_ext_of_ext_state; + imp_extension_entry(imp_extension_state * ext_state, unsigned ext_id, imp_extension * ext_of_ext_state): + m_ext_state(ext_state), m_ext_id(ext_id), m_ext_of_ext_state(ext_of_ext_state) {} +}; + class blastenv { friend class scope_assignment; friend class scope_unfold_macro_pred; typedef std::vector tmp_type_context_pool; typedef std::unique_ptr tmp_type_context_ptr; + typedef std::vector imp_extension_entries; environment m_env; io_state m_ios; @@ -66,6 +93,7 @@ class blastenv { congr_lemma_manager m_congr_lemma_manager; abstract_expr_manager m_abstract_expr_manager; light_lt_manager m_light_lt_manager; + imp_extension_entries m_imp_extension_entries; relation_info_getter m_rel_getter; refl_info_getter m_refl_getter; symm_info_getter m_symm_getter; @@ -461,6 +489,7 @@ public: } ~blastenv() { + finalize_imp_extension_entries(); for (auto ctx : m_tmp_ctx_pool) delete ctx; } @@ -480,6 +509,7 @@ public: void init_state(goal const & g) { init_curr_state(g); + init_imp_extension_entries(); save_initial_context(); m_tctx.set_local_instances(m_initial_context); m_tmp_ctx->set_local_instances(m_initial_context); @@ -568,6 +598,62 @@ public: return m_abstract_expr_manager.hash(e); } + void init_imp_extension_entries() { + for (auto & p : get_imp_extension_manager().get_entries()) { + branch_extension & b_ext = curr_state().get_extension(p.second); + b_ext.inc_ref(); + m_imp_extension_entries.emplace_back(p.first(), p.second, static_cast(&b_ext)); + } + } + + void finalize_imp_extension_entries() { + for (auto & e : m_imp_extension_entries) { + e.m_ext_of_ext_state->dec_ref(); + } + } + + list get_ext_path(imp_extension * _imp_ext) { + list path; + imp_extension * imp_ext = _imp_ext; + while (imp_ext != nullptr) { + path = cons(imp_ext, path); + imp_ext = imp_ext->m_parent; + } + return path; + } + + imp_extension_state & get_imp_extension_state(unsigned state_id) { + lean_assert(state_id < m_imp_extension_entries.size()); + imp_extension_entry & e = m_imp_extension_entries[state_id]; + imp_extension_state * ext_state = e.m_ext_state.get(); + imp_extension * ext_of_curr_state = static_cast(&curr_state().get_extension(e.m_ext_id)); + lean_assert(e.m_ext_of_ext_state); + imp_extension * ext_of_ext_state = e.m_ext_of_ext_state; + + list curr_state_path = get_ext_path(ext_of_curr_state); + list ext_state_path = get_ext_path(ext_of_ext_state); + + while (true) { + if (is_nil(curr_state_path) || is_nil(ext_state_path)) break; + if (head(curr_state_path) != head(ext_state_path)) break; + curr_state_path = tail(curr_state_path); + ext_state_path = tail(ext_state_path); + } + + for_each(reverse(ext_state_path), [&](imp_extension * imp_ext) { + ext_state->undo_actions(imp_ext); + }); + + for_each(curr_state_path, [&](imp_extension * imp_ext) { + ext_state->replay_actions(imp_ext); + }); + + ext_of_curr_state->inc_ref(); + ext_of_ext_state->dec_ref(); + e.m_ext_of_ext_state = ext_of_curr_state; + return *ext_state; + } + bool abstract_is_equal(expr const & e1, expr const & e2) { return m_abstract_expr_manager.is_equal(e1, e2); } @@ -795,6 +881,15 @@ unsigned abstract_hash(expr const & e) { return g_blastenv->abstract_hash(e); } +unsigned register_imp_extension(std::function & ext_state_maker) { + return get_imp_extension_manager().register_imp_extension(ext_state_maker); +} + +imp_extension_state & get_imp_extension_state(unsigned state_id) { + lean_assert(g_blastenv); + return g_blastenv->get_imp_extension_state(state_id); +} + bool abstract_is_equal(expr const & e1, expr const & e2) { lean_assert(g_blastenv); return g_blastenv->abstract_is_equal(e1, e2); @@ -963,10 +1058,12 @@ optional blast_goal(environment const & env, io_state const & ios, listm_state_id), m_parent(parent) { + parent->inc_ref(); +} + +imp_extension::~imp_extension() { if (m_parent) m_parent->dec_ref(); } + +imp_extension * imp_extension::clone() { + return new imp_extension(this); +} + +void imp_extension::hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) { + imp_extension_state & state = get_imp_extension_state(m_state_id); + if (is_nil(m_asserts)) state.push(); + m_asserts = cons(h, m_asserts); + state.assert(h); +} + +void imp_extension_state::replay_actions(imp_extension * imp_ext) { + list const & asserts = reverse(imp_ext->get_asserts()); + if (is_nil(asserts)) return; + push(); + for_each(asserts, [&](hypothesis const & h) { assert(h); }); +} + +void imp_extension_state::undo_actions(imp_extension * imp_ext) { + list const & asserts = imp_ext->get_asserts(); + if (is_nil(asserts)) return; + pop(); +} + +}} diff --git a/src/library/blast/imp_extension.h b/src/library/blast/imp_extension.h new file mode 100644 index 000000000..71cd6b1ec --- /dev/null +++ b/src/library/blast/imp_extension.h @@ -0,0 +1,42 @@ +/* +Copyright (c) 2015 Daniel Selsam. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Daniel Selsam +*/ +#pragma once +#include "library/blast/state.h" + +namespace lean { +namespace blast { + +struct imp_extension : branch_extension { + unsigned m_state_id; + imp_extension * m_parent; + list m_asserts; + + imp_extension(unsigned state_id); + imp_extension(imp_extension * parent); + ~imp_extension(); + + unsigned get_state_id() { return m_state_id; } + imp_extension * get_parent() { return m_parent; } + list const & get_asserts() { return m_asserts; } + + virtual imp_extension * clone() override; + virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx hidx) override; +}; + +struct imp_extension_state { + virtual void push() =0; + virtual void pop() =0; + virtual void assert(hypothesis const & h) =0; + + virtual ~imp_extension_state() {} + + void replay_actions(imp_extension * imp_ext); + void undo_actions(imp_extension * imp_ext); +}; + +typedef std::function ext_state_maker; + +}}