From cca15f1390e11982349c2be8232c9231d26d9a9c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 21 Jan 2014 21:16:23 -0800 Subject: [PATCH] feat(library/simplifier): congruence theorem compilation Signed-off-by: Leonardo de Moura --- src/library/simplifier/CMakeLists.txt | 2 +- src/library/simplifier/congr.cpp | 200 ++++++++++++++++++++ src/library/simplifier/congr.h | 143 ++++++++++++++ src/library/simplifier/rewrite_rule_set.cpp | 47 +++++ src/library/simplifier/rewrite_rule_set.h | 14 ++ src/library/simplifier/simplifier.cpp | 7 +- 6 files changed, 411 insertions(+), 2 deletions(-) create mode 100644 src/library/simplifier/congr.cpp create mode 100644 src/library/simplifier/congr.h diff --git a/src/library/simplifier/CMakeLists.txt b/src/library/simplifier/CMakeLists.txt index 1264715a2..dceec2e0a 100644 --- a/src/library/simplifier/CMakeLists.txt +++ b/src/library/simplifier/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(simplifier ceq.cpp simplifier.cpp rewrite_rule_set.cpp) +add_library(simplifier ceq.cpp congr.cpp rewrite_rule_set.cpp simplifier.cpp) target_link_libraries(simplifier ${LEAN_LIBS}) diff --git a/src/library/simplifier/congr.cpp b/src/library/simplifier/congr.cpp new file mode 100644 index 000000000..f1cff1600 --- /dev/null +++ b/src/library/simplifier/congr.cpp @@ -0,0 +1,200 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "util/sstream.h" +#include "kernel/kernel.h" +#include "library/equality.h" +#include "library/simplifier/congr.h" + +namespace lean { +typedef congr_theorem_info::app_arg_info app_arg_info; +/** + \brief Return true iff arg_info contains an entry s.t. m_proof_arg_pos or m_proof_new_arg_pos is pos. +*/ +static bool contains_pos(buffer const & arg_info, unsigned pos) { + return std::any_of(arg_info.begin(), arg_info.end(), + [&](app_arg_info const & info) { + return + info.get_pos_at_proof() == pos || + (info.get_new_pos_at_proof() && *info.get_new_pos_at_proof() == pos); + }); +} + +static void check_conclusion_lhs_rhs(expr const & lhs, expr const & rhs, unsigned num) { + if (!is_var(lhs) || !is_var(rhs)) + throw exception("invalid congruence theorem, the arguments in the left and right-hand-sides must be variables"); + if (var_idx(lhs) >= num) + throw exception("invalid congruence theorem, left-hand-side contains free variables"); + if (var_idx(rhs) >= num) + throw exception("invalid congruence theorem, right-hand-side contains free variables"); +} + +static void check_arg_lhs_rhs(expr const & lhs, expr const & rhs, unsigned num) { + if (!is_var(lhs) || !is_var(rhs)) + throw exception(sstream() << "invalid congruence theorem, type of argument #" << (num+1) << " is not an equality between variables"); + if (var_idx(lhs) >= num) + throw exception(sstream() << "invalid congruence theorem, left-hand-side of argument #" << (num+1) << " contains free variables"); + if (var_idx(rhs) >= num) + throw exception(sstream() << "invalid congruence theorem, right-hand-side of argument #" << (num+1) << " contains free variables"); +} + +static buffer::iterator find_arg_info(buffer & arg_infos, unsigned proof_arg_pos, unsigned proof_new_arg_pos) { + return std::find_if(arg_infos.begin(), arg_infos.end(), [&](app_arg_info const & info) { + return info.get_pos_at_proof() == proof_arg_pos && info.get_new_pos_at_proof() && *info.get_new_pos_at_proof() == proof_new_arg_pos; + }); +} + +static std::pair find_hypothesis(buffer & arg_infos, unsigned vidx, unsigned num) { + for (auto const & info : arg_infos) { + if (vidx == info.get_pos_at_proof()) { + return mk_pair(info.get_arg_pos(), false); + } else if (info.get_new_pos_at_proof() && vidx == *info.get_new_pos_at_proof()) { + return mk_pair(info.get_arg_pos(), true); + } + } + throw exception(sstream() << "invalid congruence theorem, invalid hypothesis for argument #" << num + << ", variable must occur in the left or right hand side of the conclusion of the theorem"); +} + +void congr_theorem_info::context::display(std::ostream & out) const { + if (!m_pos) + out << "!"; + out << "#" << m_arg; + if (m_new) + out << "'"; +} + +void congr_theorem_info::app_arg_info::display(std::ostream & out) const { + out << "#" << m_arg_pos << ": "; + if (m_context) { + m_context->display(out); + out << " -> "; + } + out << "#" << m_proof_arg_pos; + if (m_proof_new_arg_pos) + out << " #" << *m_proof_new_arg_pos << " #" << *m_proof_proof_pos; +} + +void congr_theorem_info::display(std::ostream & out) const { + out << m_fun << " " << m_num_proof_args << "\n" << m_proof << "\n"; + for (auto const & info : m_arg_info) { + info.display(out); + out << "\n"; + } + out << "\n"; +} + +congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e) { + expr t = env->infer_type(e); + expr b = t; + unsigned num = 0; + while (is_pi(b)) { + b = abst_body(b); + num++; + } + expr lhs, rhs; + if (!is_equality(b, lhs, rhs)) + throw exception("invalid congruence theorem, conclusion is not an equality"); + if (!is_app(lhs)) + throw exception("invalid congruence theorem, left-hand-side of the conclusion is not a function application"); + if (!is_app(rhs)) + throw exception("invalid congruence theorem, right-hand-side of the conclusion is not a function application"); + if (arg(lhs, 0) != arg(rhs, 0)) + throw exception("invalid congruence theorem, the functions in the left and right-hand-sides are different"); + if (num_args(lhs) != num_args(rhs)) + throw exception("invalid congruence theorem, the number of arguments in the left and right-hand-sides is different"); + + congr_theorem_info r; + r.m_fun = arg(lhs, 0); + r.m_proof = e; + r.m_num_proof_args = num; + + buffer arg_infos; + for (unsigned i = 1; i < num_args(lhs); i++) { + expr a = arg(lhs, i); + expr new_a = arg(rhs, i); + check_conclusion_lhs_rhs(a, new_a, num); + unsigned proof_arg_pos = num - var_idx(a) - 1; + unsigned proof_new_arg_pos = num - var_idx(new_a) - 1; + + if (contains_pos(arg_infos, proof_arg_pos) || + contains_pos(arg_infos, proof_new_arg_pos)) + throw exception("invalid congruence theorem, variables can only occur once in the left and right-hand sides"); + + if (proof_arg_pos == proof_new_arg_pos) { + // this argument does not need proof, then add it directly to + // r.m_arg_info + r.m_arg_info.push_back(app_arg_info(i, proof_arg_pos)); + } else { + // we have to find where the proof for this one is located + arg_infos.push_back(app_arg_info(i, proof_arg_pos, proof_new_arg_pos)); + } + } + + bool progress = true; + while (progress) { + progress = false; + expr b = t; + num = 0; + while (is_pi(b)) { + expr d = abst_domain(b); + expr lhs, rhs; + if (is_equality(d, lhs, rhs)) { + check_arg_lhs_rhs(lhs, rhs, num); + auto it = find_arg_info(arg_infos, num - var_idx(lhs) - 1, num - var_idx(rhs) - 1); + if (it == arg_infos.end()) + throw exception(sstream() << "invalid congruence theorem, argument #" << num << " does not match conclusion of the theorem"); + if (!it->m_proof_proof_pos) { + progress = true; + it->m_proof_proof_pos = num; + r.m_arg_info.push_back(*it); + } + } else if (is_pi(d) && is_equality(abst_body(d), lhs, rhs)) { + check_arg_lhs_rhs(lhs, rhs, num+1); + auto it = find_arg_info(arg_infos, num - var_idx(lhs), num - var_idx(rhs)); + if (it == arg_infos.end()) + throw exception(sstream() << "invalid congruence theorem, argument #" << num + << " does not match conclusion of the theorem"); + if (!it->m_proof_proof_pos) { + bool ctx_pos; + std::pair p; + if (is_var(abst_domain(d))) { + ctx_pos = true; + p = find_hypothesis(arg_infos, num - var_idx(abst_domain(d)) - 1, num); + } else if (is_not(abst_domain(d)) && is_var(arg(abst_domain(d), 1))) { + ctx_pos = false; + p = find_hypothesis(arg_infos, num - var_idx(arg(abst_domain(d), 1)) - 1, num); + } else { + throw exception(sstream() << "invalid congruence theorem, hypothesis for argument #" << num + << " must be a variable or the negation of variable"); + } + progress = true; + unsigned ctx_arg = p.first; + bool ctx_new = p.second; + it->m_proof_proof_pos = num; + it->m_context = congr_theorem_info::context(ctx_arg, ctx_pos, ctx_new); + r.m_arg_info.push_back(*it); + } + } + b = abst_body(b); + num++; + } + } + buffer found_args; + found_args.resize(num, false); + for (auto const & info : r.m_arg_info) { + found_args[info.get_pos_at_proof()] = true; + if (info.get_new_pos_at_proof()) + found_args[*info.get_new_pos_at_proof()] = true; + if (info.get_proof_pos_at_proof()) + found_args[*info.get_proof_pos_at_proof()] = true; + } + for (unsigned i = 0; i < num; i++) + if (!found_args[i]) + throw exception(sstream() << "invalid congruence theorem, cannot synthesize argument #" << i); + return r; +} +} diff --git a/src/library/simplifier/congr.h b/src/library/simplifier/congr.h new file mode 100644 index 000000000..6655237da --- /dev/null +++ b/src/library/simplifier/congr.h @@ -0,0 +1,143 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include "kernel/environment.h" + +namespace lean { +/* + By default, Lean's simplifier will use the standard congruence theorem. + To simplify (f s), it will simplify f and s, and obtain the new terms + f' and s', and proofs H_f and H_s + H_f : f = f' + H_s : s = s' + Then, it uses the congr theorem to obtain + congr H_f H_s : f s = f' s' + + This behavior can be customize by providing specialized congruence rules + for specific operators. + + For example, kernel.lean contains the theorem: + + theorem or_congrr {a b c d : Bool} (H_ac : ∀ (H_nb : ¬ b), a = c) (H_bd : ∀ (H_nc : ¬ c), b = d) : a ∨ b ↔ c ∨ d + + It tells us that we can simplify a ∨ b, by first simplifying a under the assumption that b is false, + and then simplifying b under the assumption that the result of a simplification is false. + + We say or_congrr is a congruence theorem. This module is used to identify congruence theorems and + "compile" them into simple instructions that can be efficiently applied by the simplifier. +*/ +class congr_theorem_info { + friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e); +public: + /** + \brief Each argument may or may not be simplified under a new context. + For example, in or_congrr, b is simplified under a context containing not c. + + This class store this dependency. + */ + class context { + friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e); + /** + The position of the dependent argument. For or_congrr this field has value 0 for b, + since b depends on the new value c of a (after simplification). + */ + unsigned m_arg; + /** + Indicate whether is a positive or negative dependency. + For or_congrr, this field is false for b, since it depends negatively on c. + */ + bool m_pos; + /** + Indicate whether the dependecy is before/after simplification. + For or_congrr, this field is true for b, since it depends on the new value c of a (after simplification). + */ + bool m_new; + context(unsigned arg, bool pos, bool n):m_arg(arg), m_pos(pos), m_new(n) {} + public: + unsigned get_arg_pos() const { return m_arg; } + bool is_pos_dep() const { return m_pos; } + bool use_new_val() const { return m_new; } + void display(std::ostream & out) const; + }; + + /** + \brief This class indicates how to process an argument of the function application. + */ + class app_arg_info { + friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e); + /** + \brief Position of the argument to be processed. + For or_congrr, this field is 2 for b. + */ + unsigned m_arg_pos; + /** + \brief The context (if any) is used to simplify the argument + */ + optional m_context; + /** + \brief Position where this argument goes in the proof term. + For or_congrr, this field is 1 for b. + */ + unsigned m_proof_arg_pos; + /** + \brief Position where the simplified argument goes in the proof term. + If the argument should not be simplified, then this field is none. + + For or_congrr, this field is 3 for b. + */ + optional m_proof_new_arg_pos; + /** + \brief Position where the proof for new = old goes in the proof term. + + For or_congrr, this field is 5 for b. + */ + optional m_proof_proof_pos; + app_arg_info(unsigned arg_pos, unsigned proof_arg_pos):m_arg_pos(arg_pos), m_proof_arg_pos(proof_arg_pos) {} + app_arg_info(unsigned arg_pos, unsigned proof_arg_pos, unsigned proof_new_arg_pos): + m_arg_pos(arg_pos), m_proof_arg_pos(proof_arg_pos), m_proof_new_arg_pos(proof_new_arg_pos) {} + public: + unsigned get_arg_pos() const { return m_arg_pos; } + optional const & get_context() const { return m_context; } + unsigned get_pos_at_proof() const { return m_proof_arg_pos; } + optional const & get_new_pos_at_proof() const { return m_proof_new_arg_pos; } + optional const & get_proof_pos_at_proof() const { return m_proof_proof_pos; } + void display(std::ostream & out) const; + }; + +private: + /** + Indicate for which function this theorem is a congruence for. + */ + expr m_fun; + + /** + Proof term for the theorem, in most cases is just a constant (e.g., or_congrr) that references a theorem in a Lean environment. + */ + expr m_proof; + /** + Number of arguments the theorem takes. For example, or_congrr has 6 arguments. + */ + unsigned m_num_proof_args; + /** + \brief Store the sequence the application arguments should be processed. + */ + std::vector m_arg_info; +public: + expr const & get_fun() const { return m_fun; } + expr const & get_proof() const { return m_proof; } + unsigned get_num_proof_args() const { return m_num_proof_args; } + std::vector const & get_arg_info() const { return m_arg_info; } + void display(std::ostream & out) const; +}; + +/** + \brief Check whether \c e is a congruence theorem in the given environment. + If it is, then returns a congr_theorem_info object. Otherwise, throws an exception. +*/ +congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e); +} diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp index de65867ea..b0a380a26 100644 --- a/src/library/simplifier/rewrite_rule_set.cpp +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -69,6 +69,15 @@ void rewrite_rule_set::enable(name const & id, bool f) { m_disabled_rules.insert(id); } +void rewrite_rule_set::insert_congr(expr const & e) { + ro_environment env(m_env); + m_congr_thms.emplace_front(check_congr_theorem(env, e)); +} + +void rewrite_rule_set::insert_congr(name const & th_name) { + insert_congr(mk_constant(th_name)); +} + bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const { auto l = m_rule_set; for (auto const & rule : l) { @@ -156,6 +165,22 @@ static void read_enable_rr(environment const & env, io_state const &, deserializ } static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr); +class add_congr_theorem_obj : public neutral_object_cell { + name m_rule_set_id; + name m_th_name; +public: + add_congr_theorem_obj(name const & rsid, name const & th_name):m_rule_set_id(rsid), m_th_name(th_name) {} + virtual ~add_congr_theorem_obj() {} + virtual char const * keyword() const { return "add_congr_theorem"; } + virtual void write(serializer & s) const { s << "add_ct" << m_rule_set_id << m_th_name; } +}; +static void read_ct(environment const & env, io_state const &, deserializer & d) { + name rsid = read_name(d); + name th = read_name(d); + add_congr_theorem(env, rsid, th); +} +static object_cell::register_deserializer_fn add_ct_ds("add_ct", read_ct); + /** \brief Extension for managing rewrite rule sets. */ @@ -218,6 +243,12 @@ struct rewrite_rule_set_extension : public environment_extension { env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag)); } + void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name) { + auto & rs = find_rw(rule_set_id); + rs.insert_congr(th_name); + env->add_neutral_object(new add_congr_theorem_obj(rule_set_id, th_name)); + } + rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const { return find_ro(rule_set_id); } @@ -259,6 +290,10 @@ void enable_rewrite_rules(environment const & env, name const & rule_set_id, nam to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag); } +void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name) { + to_ext(env).add_congr_theorem(env, rule_set_id, th_name); +} + rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) { return to_ext(env).get_rewrite_rule_set(rule_set_id); } @@ -296,6 +331,17 @@ static int enable_rewrite_rules(lua_State * L) { return 0; } +static int add_congr_theorem(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 1) + add_congr_theorem(rw_shared_environment(L), to_name_ext(L, 1)); + else if (nargs == 2) + add_congr_theorem(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2)); + else + add_congr_theorem(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2)); + return 0; +} + static int show_rewrite_rules(lua_State * L) { int nargs = lua_gettop(L); formatter fmt = get_global_formatter(L); @@ -319,6 +365,7 @@ void open_rewrite_rule_set(lua_State * L) { SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set"); SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules"); SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules"); + SET_GLOBAL_FUN(add_congr_theorem, "add_congr_theorem"); SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules"); } } diff --git a/src/library/simplifier/rewrite_rule_set.h b/src/library/simplifier/rewrite_rule_set.h index 9849f3176..02c8a9aee 100644 --- a/src/library/simplifier/rewrite_rule_set.h +++ b/src/library/simplifier/rewrite_rule_set.h @@ -14,6 +14,7 @@ Author: Leonardo de Moura #include "kernel/environment.h" #include "kernel/formatter.h" #include "library/io_state_stream.h" +#include "library/simplifier/congr.h" namespace lean { class rewrite_rule_set; @@ -46,6 +47,7 @@ class rewrite_rule_set { ro_environment::weak_ref m_env; list m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees name_set m_disabled_rules; + list m_congr_thms; // This is probably ok since we usually have very few congruence theorems bool enabled(rewrite_rule const & rule) const; public: @@ -73,6 +75,10 @@ public: /** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */ void enable(name const & id, bool f); + /** \brief Add a new congruence theorem. */ + void insert_congr(expr const & e); + void insert_congr(name const & th_name); + typedef std::function match_fn; // NOLINT typedef std::function visit_fn; @@ -121,6 +127,14 @@ inline void enable_rewrite_rules(environment const & env, name const & rule_id, enable_rewrite_rules(env, get_default_rewrite_rule_set_id(), rule_id, flag); } +/** + \brief Add a new congruence theorem to the given rewrite rule set. +*/ +void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name); +inline void add_congr_theorem(environment const & env, name const & th_name) { + add_congr_theorem(env, get_default_rewrite_rule_set_id(), th_name); +} + /** \brief Return the rule set name \c rule_set_id in the given environment. diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 3ad3f401f..2e0eeafa3 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -778,9 +778,14 @@ class simplifier_fn { public: simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs): - m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) { + m_env(env), m_tc(env) { m_has_heq = m_env->imported("heq"); set_options(o); + if (m_contextual) { + // add a set of rewrite rules for contextual rewriting + m_rule_sets.push_back(rewrite_rule_set(env)); + } + m_rule_sets.insert(m_rule_sets.end(), rs, rs + num_rs); } expr_pair operator()(expr const & e, context const & ctx) {