feat(library/simplifier): congruence theorem compilation

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-21 21:16:23 -08:00
parent 029d74ec11
commit cca15f1390
6 changed files with 411 additions and 2 deletions

View file

@ -1,2 +1,2 @@
add_library(simplifier ceq.cpp simplifier.cpp rewrite_rule_set.cpp) add_library(simplifier ceq.cpp congr.cpp rewrite_rule_set.cpp simplifier.cpp)
target_link_libraries(simplifier ${LEAN_LIBS}) target_link_libraries(simplifier ${LEAN_LIBS})

View file

@ -0,0 +1,200 @@
/*
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "util/sstream.h"
#include "kernel/kernel.h"
#include "library/equality.h"
#include "library/simplifier/congr.h"
namespace lean {
typedef congr_theorem_info::app_arg_info app_arg_info;
/**
\brief Return true iff arg_info contains an entry s.t. m_proof_arg_pos or m_proof_new_arg_pos is pos.
*/
static bool contains_pos(buffer<app_arg_info> const & arg_info, unsigned pos) {
return std::any_of(arg_info.begin(), arg_info.end(),
[&](app_arg_info const & info) {
return
info.get_pos_at_proof() == pos ||
(info.get_new_pos_at_proof() && *info.get_new_pos_at_proof() == pos);
});
}
static void check_conclusion_lhs_rhs(expr const & lhs, expr const & rhs, unsigned num) {
if (!is_var(lhs) || !is_var(rhs))
throw exception("invalid congruence theorem, the arguments in the left and right-hand-sides must be variables");
if (var_idx(lhs) >= num)
throw exception("invalid congruence theorem, left-hand-side contains free variables");
if (var_idx(rhs) >= num)
throw exception("invalid congruence theorem, right-hand-side contains free variables");
}
static void check_arg_lhs_rhs(expr const & lhs, expr const & rhs, unsigned num) {
if (!is_var(lhs) || !is_var(rhs))
throw exception(sstream() << "invalid congruence theorem, type of argument #" << (num+1) << " is not an equality between variables");
if (var_idx(lhs) >= num)
throw exception(sstream() << "invalid congruence theorem, left-hand-side of argument #" << (num+1) << " contains free variables");
if (var_idx(rhs) >= num)
throw exception(sstream() << "invalid congruence theorem, right-hand-side of argument #" << (num+1) << " contains free variables");
}
static buffer<app_arg_info>::iterator find_arg_info(buffer<app_arg_info> & arg_infos, unsigned proof_arg_pos, unsigned proof_new_arg_pos) {
return std::find_if(arg_infos.begin(), arg_infos.end(), [&](app_arg_info const & info) {
return info.get_pos_at_proof() == proof_arg_pos && info.get_new_pos_at_proof() && *info.get_new_pos_at_proof() == proof_new_arg_pos;
});
}
static std::pair<unsigned, bool> find_hypothesis(buffer<app_arg_info> & arg_infos, unsigned vidx, unsigned num) {
for (auto const & info : arg_infos) {
if (vidx == info.get_pos_at_proof()) {
return mk_pair(info.get_arg_pos(), false);
} else if (info.get_new_pos_at_proof() && vidx == *info.get_new_pos_at_proof()) {
return mk_pair(info.get_arg_pos(), true);
}
}
throw exception(sstream() << "invalid congruence theorem, invalid hypothesis for argument #" << num
<< ", variable must occur in the left or right hand side of the conclusion of the theorem");
}
void congr_theorem_info::context::display(std::ostream & out) const {
if (!m_pos)
out << "!";
out << "#" << m_arg;
if (m_new)
out << "'";
}
void congr_theorem_info::app_arg_info::display(std::ostream & out) const {
out << "#" << m_arg_pos << ": ";
if (m_context) {
m_context->display(out);
out << " -> ";
}
out << "#" << m_proof_arg_pos;
if (m_proof_new_arg_pos)
out << " #" << *m_proof_new_arg_pos << " #" << *m_proof_proof_pos;
}
void congr_theorem_info::display(std::ostream & out) const {
out << m_fun << " " << m_num_proof_args << "\n" << m_proof << "\n";
for (auto const & info : m_arg_info) {
info.display(out);
out << "\n";
}
out << "\n";
}
congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e) {
expr t = env->infer_type(e);
expr b = t;
unsigned num = 0;
while (is_pi(b)) {
b = abst_body(b);
num++;
}
expr lhs, rhs;
if (!is_equality(b, lhs, rhs))
throw exception("invalid congruence theorem, conclusion is not an equality");
if (!is_app(lhs))
throw exception("invalid congruence theorem, left-hand-side of the conclusion is not a function application");
if (!is_app(rhs))
throw exception("invalid congruence theorem, right-hand-side of the conclusion is not a function application");
if (arg(lhs, 0) != arg(rhs, 0))
throw exception("invalid congruence theorem, the functions in the left and right-hand-sides are different");
if (num_args(lhs) != num_args(rhs))
throw exception("invalid congruence theorem, the number of arguments in the left and right-hand-sides is different");
congr_theorem_info r;
r.m_fun = arg(lhs, 0);
r.m_proof = e;
r.m_num_proof_args = num;
buffer<app_arg_info> arg_infos;
for (unsigned i = 1; i < num_args(lhs); i++) {
expr a = arg(lhs, i);
expr new_a = arg(rhs, i);
check_conclusion_lhs_rhs(a, new_a, num);
unsigned proof_arg_pos = num - var_idx(a) - 1;
unsigned proof_new_arg_pos = num - var_idx(new_a) - 1;
if (contains_pos(arg_infos, proof_arg_pos) ||
contains_pos(arg_infos, proof_new_arg_pos))
throw exception("invalid congruence theorem, variables can only occur once in the left and right-hand sides");
if (proof_arg_pos == proof_new_arg_pos) {
// this argument does not need proof, then add it directly to
// r.m_arg_info
r.m_arg_info.push_back(app_arg_info(i, proof_arg_pos));
} else {
// we have to find where the proof for this one is located
arg_infos.push_back(app_arg_info(i, proof_arg_pos, proof_new_arg_pos));
}
}
bool progress = true;
while (progress) {
progress = false;
expr b = t;
num = 0;
while (is_pi(b)) {
expr d = abst_domain(b);
expr lhs, rhs;
if (is_equality(d, lhs, rhs)) {
check_arg_lhs_rhs(lhs, rhs, num);
auto it = find_arg_info(arg_infos, num - var_idx(lhs) - 1, num - var_idx(rhs) - 1);
if (it == arg_infos.end())
throw exception(sstream() << "invalid congruence theorem, argument #" << num << " does not match conclusion of the theorem");
if (!it->m_proof_proof_pos) {
progress = true;
it->m_proof_proof_pos = num;
r.m_arg_info.push_back(*it);
}
} else if (is_pi(d) && is_equality(abst_body(d), lhs, rhs)) {
check_arg_lhs_rhs(lhs, rhs, num+1);
auto it = find_arg_info(arg_infos, num - var_idx(lhs), num - var_idx(rhs));
if (it == arg_infos.end())
throw exception(sstream() << "invalid congruence theorem, argument #" << num
<< " does not match conclusion of the theorem");
if (!it->m_proof_proof_pos) {
bool ctx_pos;
std::pair<unsigned, bool> p;
if (is_var(abst_domain(d))) {
ctx_pos = true;
p = find_hypothesis(arg_infos, num - var_idx(abst_domain(d)) - 1, num);
} else if (is_not(abst_domain(d)) && is_var(arg(abst_domain(d), 1))) {
ctx_pos = false;
p = find_hypothesis(arg_infos, num - var_idx(arg(abst_domain(d), 1)) - 1, num);
} else {
throw exception(sstream() << "invalid congruence theorem, hypothesis for argument #" << num
<< " must be a variable or the negation of variable");
}
progress = true;
unsigned ctx_arg = p.first;
bool ctx_new = p.second;
it->m_proof_proof_pos = num;
it->m_context = congr_theorem_info::context(ctx_arg, ctx_pos, ctx_new);
r.m_arg_info.push_back(*it);
}
}
b = abst_body(b);
num++;
}
}
buffer<bool> found_args;
found_args.resize(num, false);
for (auto const & info : r.m_arg_info) {
found_args[info.get_pos_at_proof()] = true;
if (info.get_new_pos_at_proof())
found_args[*info.get_new_pos_at_proof()] = true;
if (info.get_proof_pos_at_proof())
found_args[*info.get_proof_pos_at_proof()] = true;
}
for (unsigned i = 0; i < num; i++)
if (!found_args[i])
throw exception(sstream() << "invalid congruence theorem, cannot synthesize argument #" << i);
return r;
}
}

View file

@ -0,0 +1,143 @@
/*
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include <vector>
#include "kernel/environment.h"
namespace lean {
/*
By default, Lean's simplifier will use the standard congruence theorem.
To simplify (f s), it will simplify f and s, and obtain the new terms
f' and s', and proofs H_f and H_s
H_f : f = f'
H_s : s = s'
Then, it uses the congr theorem to obtain
congr H_f H_s : f s = f' s'
This behavior can be customize by providing specialized congruence rules
for specific operators.
For example, kernel.lean contains the theorem:
theorem or_congrr {a b c d : Bool} (H_ac : (H_nb : ¬ b), a = c) (H_bd : (H_nc : ¬ c), b = d) : a b c d
It tells us that we can simplify a b, by first simplifying a under the assumption that b is false,
and then simplifying b under the assumption that the result of a simplification is false.
We say or_congrr is a congruence theorem. This module is used to identify congruence theorems and
"compile" them into simple instructions that can be efficiently applied by the simplifier.
*/
class congr_theorem_info {
friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e);
public:
/**
\brief Each argument may or may not be simplified under a new context.
For example, in or_congrr, b is simplified under a context containing not c.
This class store this dependency.
*/
class context {
friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e);
/**
The position of the dependent argument. For or_congrr this field has value 0 for b,
since b depends on the new value c of a (after simplification).
*/
unsigned m_arg;
/**
Indicate whether is a positive or negative dependency.
For or_congrr, this field is false for b, since it depends negatively on c.
*/
bool m_pos;
/**
Indicate whether the dependecy is before/after simplification.
For or_congrr, this field is true for b, since it depends on the new value c of a (after simplification).
*/
bool m_new;
context(unsigned arg, bool pos, bool n):m_arg(arg), m_pos(pos), m_new(n) {}
public:
unsigned get_arg_pos() const { return m_arg; }
bool is_pos_dep() const { return m_pos; }
bool use_new_val() const { return m_new; }
void display(std::ostream & out) const;
};
/**
\brief This class indicates how to process an argument of the function application.
*/
class app_arg_info {
friend congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e);
/**
\brief Position of the argument to be processed.
For or_congrr, this field is 2 for b.
*/
unsigned m_arg_pos;
/**
\brief The context (if any) is used to simplify the argument
*/
optional<context> m_context;
/**
\brief Position where this argument goes in the proof term.
For or_congrr, this field is 1 for b.
*/
unsigned m_proof_arg_pos;
/**
\brief Position where the simplified argument goes in the proof term.
If the argument should not be simplified, then this field is none.
For or_congrr, this field is 3 for b.
*/
optional<unsigned> m_proof_new_arg_pos;
/**
\brief Position where the proof for new = old goes in the proof term.
For or_congrr, this field is 5 for b.
*/
optional<unsigned> m_proof_proof_pos;
app_arg_info(unsigned arg_pos, unsigned proof_arg_pos):m_arg_pos(arg_pos), m_proof_arg_pos(proof_arg_pos) {}
app_arg_info(unsigned arg_pos, unsigned proof_arg_pos, unsigned proof_new_arg_pos):
m_arg_pos(arg_pos), m_proof_arg_pos(proof_arg_pos), m_proof_new_arg_pos(proof_new_arg_pos) {}
public:
unsigned get_arg_pos() const { return m_arg_pos; }
optional<context> const & get_context() const { return m_context; }
unsigned get_pos_at_proof() const { return m_proof_arg_pos; }
optional<unsigned> const & get_new_pos_at_proof() const { return m_proof_new_arg_pos; }
optional<unsigned> const & get_proof_pos_at_proof() const { return m_proof_proof_pos; }
void display(std::ostream & out) const;
};
private:
/**
Indicate for which function this theorem is a congruence for.
*/
expr m_fun;
/**
Proof term for the theorem, in most cases is just a constant (e.g., or_congrr) that references a theorem in a Lean environment.
*/
expr m_proof;
/**
Number of arguments the theorem takes. For example, or_congrr has 6 arguments.
*/
unsigned m_num_proof_args;
/**
\brief Store the sequence the application arguments should be processed.
*/
std::vector<app_arg_info> m_arg_info;
public:
expr const & get_fun() const { return m_fun; }
expr const & get_proof() const { return m_proof; }
unsigned get_num_proof_args() const { return m_num_proof_args; }
std::vector<app_arg_info> const & get_arg_info() const { return m_arg_info; }
void display(std::ostream & out) const;
};
/**
\brief Check whether \c e is a congruence theorem in the given environment.
If it is, then returns a congr_theorem_info object. Otherwise, throws an exception.
*/
congr_theorem_info check_congr_theorem(ro_environment const & env, expr const & e);
}

View file

@ -69,6 +69,15 @@ void rewrite_rule_set::enable(name const & id, bool f) {
m_disabled_rules.insert(id); m_disabled_rules.insert(id);
} }
void rewrite_rule_set::insert_congr(expr const & e) {
ro_environment env(m_env);
m_congr_thms.emplace_front(check_congr_theorem(env, e));
}
void rewrite_rule_set::insert_congr(name const & th_name) {
insert_congr(mk_constant(th_name));
}
bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const { bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const {
auto l = m_rule_set; auto l = m_rule_set;
for (auto const & rule : l) { for (auto const & rule : l) {
@ -156,6 +165,22 @@ static void read_enable_rr(environment const & env, io_state const &, deserializ
} }
static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr); static object_cell::register_deserializer_fn enable_rr_ds("enable_rr", read_enable_rr);
class add_congr_theorem_obj : public neutral_object_cell {
name m_rule_set_id;
name m_th_name;
public:
add_congr_theorem_obj(name const & rsid, name const & th_name):m_rule_set_id(rsid), m_th_name(th_name) {}
virtual ~add_congr_theorem_obj() {}
virtual char const * keyword() const { return "add_congr_theorem"; }
virtual void write(serializer & s) const { s << "add_ct" << m_rule_set_id << m_th_name; }
};
static void read_ct(environment const & env, io_state const &, deserializer & d) {
name rsid = read_name(d);
name th = read_name(d);
add_congr_theorem(env, rsid, th);
}
static object_cell::register_deserializer_fn add_ct_ds("add_ct", read_ct);
/** /**
\brief Extension for managing rewrite rule sets. \brief Extension for managing rewrite rule sets.
*/ */
@ -218,6 +243,12 @@ struct rewrite_rule_set_extension : public environment_extension {
env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag)); env->add_neutral_object(new enable_rewrite_rules_obj(rule_set_id, rule_id, flag));
} }
void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name) {
auto & rs = find_rw(rule_set_id);
rs.insert_congr(th_name);
env->add_neutral_object(new add_congr_theorem_obj(rule_set_id, th_name));
}
rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const { rewrite_rule_set get_rewrite_rule_set(name const & rule_set_id) const {
return find_ro(rule_set_id); return find_ro(rule_set_id);
} }
@ -259,6 +290,10 @@ void enable_rewrite_rules(environment const & env, name const & rule_set_id, nam
to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag); to_ext(env).enable_rewrite_rules(env, rule_set_id, rule_id, flag);
} }
void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name) {
to_ext(env).add_congr_theorem(env, rule_set_id, th_name);
}
rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) { rewrite_rule_set get_rewrite_rule_set(ro_environment const & env, name const & rule_set_id) {
return to_ext(env).get_rewrite_rule_set(rule_set_id); return to_ext(env).get_rewrite_rule_set(rule_set_id);
} }
@ -296,6 +331,17 @@ static int enable_rewrite_rules(lua_State * L) {
return 0; return 0;
} }
static int add_congr_theorem(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs == 1)
add_congr_theorem(rw_shared_environment(L), to_name_ext(L, 1));
else if (nargs == 2)
add_congr_theorem(rw_shared_environment(L), to_name_ext(L, 1), to_name_ext(L, 2));
else
add_congr_theorem(rw_shared_environment(L, 3), to_name_ext(L, 1), to_name_ext(L, 2));
return 0;
}
static int show_rewrite_rules(lua_State * L) { static int show_rewrite_rules(lua_State * L) {
int nargs = lua_gettop(L); int nargs = lua_gettop(L);
formatter fmt = get_global_formatter(L); formatter fmt = get_global_formatter(L);
@ -319,6 +365,7 @@ void open_rewrite_rule_set(lua_State * L) {
SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set"); SET_GLOBAL_FUN(mk_rewrite_rule_set, "mk_rewrite_rule_set");
SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules"); SET_GLOBAL_FUN(add_rewrite_rules, "add_rewrite_rules");
SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules"); SET_GLOBAL_FUN(enable_rewrite_rules, "enable_rewrite_rules");
SET_GLOBAL_FUN(add_congr_theorem, "add_congr_theorem");
SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules"); SET_GLOBAL_FUN(show_rewrite_rules, "show_rewrite_rules");
} }
} }

View file

@ -14,6 +14,7 @@ Author: Leonardo de Moura
#include "kernel/environment.h" #include "kernel/environment.h"
#include "kernel/formatter.h" #include "kernel/formatter.h"
#include "library/io_state_stream.h" #include "library/io_state_stream.h"
#include "library/simplifier/congr.h"
namespace lean { namespace lean {
class rewrite_rule_set; class rewrite_rule_set;
@ -46,6 +47,7 @@ class rewrite_rule_set {
ro_environment::weak_ref m_env; ro_environment::weak_ref m_env;
list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees list<rewrite_rule> m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees
name_set m_disabled_rules; name_set m_disabled_rules;
list<congr_theorem_info> m_congr_thms; // This is probably ok since we usually have very few congruence theorems
bool enabled(rewrite_rule const & rule) const; bool enabled(rewrite_rule const & rule) const;
public: public:
@ -73,6 +75,10 @@ public:
/** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */ /** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */
void enable(name const & id, bool f); void enable(name const & id, bool f);
/** \brief Add a new congruence theorem. */
void insert_congr(expr const & e);
void insert_congr(name const & th_name);
typedef std::function<bool(rewrite_rule const &)> match_fn; // NOLINT typedef std::function<bool(rewrite_rule const &)> match_fn; // NOLINT
typedef std::function<void(rewrite_rule const &, bool)> visit_fn; typedef std::function<void(rewrite_rule const &, bool)> visit_fn;
@ -121,6 +127,14 @@ inline void enable_rewrite_rules(environment const & env, name const & rule_id,
enable_rewrite_rules(env, get_default_rewrite_rule_set_id(), rule_id, flag); enable_rewrite_rules(env, get_default_rewrite_rule_set_id(), rule_id, flag);
} }
/**
\brief Add a new congruence theorem to the given rewrite rule set.
*/
void add_congr_theorem(environment const & env, name const & rule_set_id, name const & th_name);
inline void add_congr_theorem(environment const & env, name const & th_name) {
add_congr_theorem(env, get_default_rewrite_rule_set_id(), th_name);
}
/** /**
\brief Return the rule set name \c rule_set_id in the given environment. \brief Return the rule set name \c rule_set_id in the given environment.

View file

@ -778,9 +778,14 @@ class simplifier_fn {
public: public:
simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs): simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs):
m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) { m_env(env), m_tc(env) {
m_has_heq = m_env->imported("heq"); m_has_heq = m_env->imported("heq");
set_options(o); set_options(o);
if (m_contextual) {
// add a set of rewrite rules for contextual rewriting
m_rule_sets.push_back(rewrite_rule_set(env));
}
m_rule_sets.insert(m_rule_sets.end(), rs, rs + num_rs);
} }
expr_pair operator()(expr const & e, context const & ctx) { expr_pair operator()(expr const & e, context const & ctx) {