refactor(library/blast/simplifier): use priority_queue to store simp/congr lemmas, use name convention used at forward/backward lemmas, normalize lemmas when blast starts, cache get_simp_lemmas

This commit is contained in:
Leonardo de Moura 2015-12-28 17:52:57 -08:00
parent 5b7dc31ad1
commit b117a10f82
21 changed files with 981 additions and 871 deletions

View file

@ -266,7 +266,7 @@ iff_empty_intro (not_not_intro star)
definition not_empty_iff [simp] : (¬ empty) ↔ unit :=
iff_unit_intro not_empty
definition not_congr [congr] (H : a ↔ b) : ¬a ↔ ¬b :=
definition not_congr (H : a ↔ b) : ¬a ↔ ¬b :=
iff.intro (λ H₁ H₂, H₁ (iff.mpr H H₂)) (λ H₁ H₂, H₁ (iff.mp H H₂))
definition ne_self_iff_empty [simp] {A : Type} (a : A) : (not (a = a)) ↔ empty :=

View file

@ -501,7 +501,7 @@ section foldr
variable lcomm : left_commutative f
include lcomm
theorem foldr_eq_of_perm [congr] : l₁ ~ l₂ → ∀ b, foldr f b l₁ = foldr f b l₂ :=
theorem foldr_eq_of_perm : l₁ ~ l₂ → ∀ b, foldr f b l₁ = foldr f b l₂ :=
assume p, perm_induction_on p
(λ b, by rewrite *foldl_nil)
(λ x t₁ t₂ p r b, calc

View file

@ -536,7 +536,7 @@ static environment accessible_cmd(parser & p) {
name const & n = d.get_name();
total++;
if ((d.is_theorem() || d.is_definition()) &&
!is_instance(env, n) && !is_simp_rule(env, n) && !is_congr_rule(env, n) &&
!is_instance(env, n) && !is_simp_lemma(env, n) && !is_congr_lemma(env, n) &&
!is_user_defined_recursor(env, n) && !is_aux_recursor(env, n) &&
!is_projection(env, n) && !is_private(env, n) &&
!is_user_defined_recursor(env, n) && !is_aux_recursor(env, n) &&
@ -799,12 +799,12 @@ static environment simplify_cmd(parser & p) {
std::tie(e, ls) = parse_local_expr(p);
blast::scope_debug scope(p.env(), p.ios());
simp_rule_sets srss;
blast::simp_lemmas srss;
if (ns == name("null")) {
} else if (ns == name("env")) {
srss = get_simp_rule_sets(p.env());
srss = blast::get_simp_lemmas();
} else {
srss = get_simp_rule_sets(p.env(), p.get_options(), ns);
srss = blast::get_simp_lemmas(ns);
}
blast::simp::result r = blast::simplify(rel, e, srss);

View file

@ -508,14 +508,15 @@ static void print_reducible_info(parser & p, reducible_status s1) {
static void print_simp_rules(parser & p) {
io_state_stream out = p.regular_stream();
simp_rule_sets s;
blast::scope_debug scope(p.env(), p.ios());
blast::simp_lemmas s;
name ns;
if (p.curr_is_identifier()) {
ns = p.get_name_val();
p.next();
s = get_simp_rule_sets(p.env(), p.get_options(), ns);
s = blast::get_simp_lemmas(ns);
} else {
s = get_simp_rule_sets(p.env());
s = blast::get_simp_lemmas();
}
format header;
if (!ns.is_anonymous())
@ -525,7 +526,8 @@ static void print_simp_rules(parser & p) {
static void print_congr_rules(parser & p) {
io_state_stream out = p.regular_stream();
simp_rule_sets s = get_simp_rule_sets(p.env());
blast::scope_debug scope(p.env(), p.ios());
blast::simp_lemmas s = blast::get_simp_lemmas();
out << s.pp_congr(out.get_formatter());
}

View file

@ -32,6 +32,7 @@ Author: Leonardo de Moura
#include "library/blast/trace.h"
#include "library/blast/options.h"
#include "library/blast/strategies/portfolio.h"
#include "library/blast/simplifier/simp_lemmas.h"
namespace lean {
namespace blast {
@ -1069,6 +1070,7 @@ struct scope_debug::imp {
scope_blastenv m_scope2;
scope_congruence_closure m_scope3;
scope_config m_scope4;
scope_simp m_scope5;
imp(environment const & env, io_state const & ios):
m_scope1(true),
m_benv(env, ios, list<name>(), list<name>()),
@ -1168,6 +1170,7 @@ optional<expr> blast_goal(environment const & env, io_state const & ios, list<na
blast::scope_congruence_closure scope3;
blast::scope_config scope4(ios.get_options());
scope_trace_env scope5(env, ios);
blast::scope_simp scope6;
return b(g);
}
void initialize_blast() {

View file

@ -9,7 +9,7 @@ Author: Leonardo de Moura
#include "kernel/abstract.h"
#include "library/trace.h"
#include "library/constants.h"
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/simplifier/simp_lemmas.h"
#include "library/blast/congruence_closure.h"
#include "library/blast/util.h"
#include "library/blast/blast.h"
@ -159,7 +159,7 @@ static bool all_distinct(buffer<expr> const & es) {
}
/* Try to convert user-defined congruence rule into an ext_congr_lemma */
static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs, congr_rule const & r) {
static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs, user_congr_lemma const & r) {
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(r.get_lhs(), lhs_args);
expr const & rhs_fn = get_app_args(r.get_rhs(), rhs_args);
@ -280,11 +280,11 @@ static optional<ext_congr_lemma> to_ext_congr_lemma(name const & R, expr const &
}
static optional<ext_congr_lemma> mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) {
simp_rule_set const * sr = get_simp_rule_sets(env()).find(R);
simp_lemmas_for const * sr = get_simp_lemmas().find(R);
if (sr) {
list<congr_rule> const * crs = sr->find_congr(fn);
list<user_congr_lemma> const * crs = sr->find_congr(fn);
if (crs) {
for (congr_rule const & r : *crs) {
for (user_congr_lemma const & r : *crs) {
if (auto lemma = to_ext_congr_lemma(R, fn, nargs, r))
return lemma;
}

View file

@ -1,2 +1,2 @@
add_library(simplifier OBJECT init_module.cpp simp_rule_set.cpp ceqv.cpp simplifier.cpp
add_library(simplifier OBJECT init_module.cpp ceqv.cpp simplifier.cpp simp_lemmas.cpp
simplifier_actions.cpp simplifier_strategies.cpp)

View file

@ -4,21 +4,21 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include "library/blast/simplifier/simplifier_actions.h"
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/simplifier/simp_lemmas.h"
#include "library/blast/simplifier/simplifier.h"
namespace lean {
namespace blast {
void initialize_simplifier_module() {
initialize_simp_lemmas();
initialize_simplifier();
initialize_simplifier_rule_set();
initialize_simplifier_actions();
}
void finalize_simplifier_module() {
finalize_simplifier_actions();
finalize_simplifier_rule_set();
finalize_simplifier();
finalize_simp_lemmas();
}
}}

View file

@ -0,0 +1,709 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <vector>
#include <string>
#include "util/priority_queue.h"
#include "util/sstream.h"
#include "kernel/error_msgs.h"
#include "kernel/find_fn.h"
#include "kernel/instantiate.h"
#include "library/trace.h"
#include "library/scoped_ext.h"
#include "library/attribute_manager.h"
#include "library/blast/blast.h"
#include "library/blast/simplifier/ceqv.h"
#include "library/blast/simplifier/simp_lemmas.h"
namespace lean {
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
struct simp_state {
priority_queue<name, name_quick_cmp> m_simp_lemmas;
priority_queue<name, name_quick_cmp> m_congr_lemmas;
};
typedef std::tuple<bool, unsigned, name> simp_entry;
struct simp_config {
typedef simp_entry entry;
typedef simp_state state;
static void add_entry(environment const &, io_state const &, state & s, entry const & e) {
bool is_simp; unsigned prio; name n;
std::tie(is_simp, prio, n) = e;
if (is_simp) {
s.m_simp_lemmas.insert(n, prio);
} else {
s.m_congr_lemmas.insert(n, prio);
}
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
bool is_simp; unsigned prio; name n;
std::tie(is_simp, prio, n) = e;
s << is_simp << prio << n;
}
static entry read_entry(deserializer & d) {
bool is_simp; unsigned prio; name n;
d >> is_simp >> prio >> n;
return entry(is_simp, prio, n);
}
static optional<unsigned> get_fingerprint(entry const & e) {
bool is_simp; unsigned prio; name n;
std::tie(is_simp, prio, n) = e;
return some(hash(hash(n.hash(), prio), is_simp ? 17u : 31u));
}
};
typedef scoped_ext<simp_config> simp_ext;
void validate_simp(environment const & env, io_state const & ios, name const & n);
void validate_congr(environment const & env, io_state const & ios, name const & n);
environment add_simp_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) {
validate_simp(env, ios, c);
return simp_ext::add_entry(env, ios, simp_entry(true, prio, c), ns, persistent);
}
environment add_congr_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) {
validate_congr(env, ios, c);
return simp_ext::add_entry(env, ios, simp_entry(false, prio, c), ns, persistent);
}
bool is_simp_lemma(environment const & env, name const & c) {
return simp_ext::get_state(env).m_simp_lemmas.contains(c);
}
bool is_congr_lemma(environment const & env, name const & c) {
return simp_ext::get_state(env).m_congr_lemmas.contains(c);
}
void get_simp_lemmas(environment const & env, buffer<name> & r) {
return simp_ext::get_state(env).m_simp_lemmas.to_buffer(r);
}
void get_congr_lemmas(environment const & env, buffer<name> & r) {
return simp_ext::get_state(env).m_congr_lemmas.to_buffer(r);
}
static std::vector<std::vector<name>> * g_simp_lemma_ns = nullptr;
unsigned register_simp_lemmas(std::initializer_list<name> const & nss) {
unsigned r = g_simp_lemma_ns->size();
g_simp_lemma_ns->push_back(std::vector<name>(nss));
return r;
}
namespace blast {
LEAN_THREAD_VALUE(bool, g_throw_ex, false);
static void report_failure(sstream const & strm) {
if (g_throw_ex){
throw exception(strm);
} else {
lean_trace(name({"simplifier", "failure"}),
tout() << strm.str() << "\n";);
}
}
simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s,
name const & id, levels const & univ_metas, expr const & e, expr const & h,
unsigned priority) {
list<expr_pair> ceqvs = to_ceqvs(tctx, e, h);
if (is_nil(ceqvs)) {
report_failure(sstream() << "invalid [simp] lemma '" << id << "'");
}
environment const & env = tctx.env();
simp_lemmas new_s = s;
for (expr_pair const & p : ceqvs) {
expr rule = normalizer(tctx)(p.first);
expr proof = tctx.whnf(p.second);
bool is_perm = is_permutation_ceqv(env, rule);
buffer<expr> emetas;
buffer<bool> instances;
while (is_pi(rule)) {
expr mvar = tctx.mk_mvar(binding_domain(rule));
emetas.push_back(mvar);
instances.push_back(binding_info(rule).is_inst_implicit());
rule = tctx.whnf(instantiate(binding_body(rule), mvar));
proof = mk_app(proof, mvar);
}
expr rel, lhs, rhs;
if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) {
new_s.insert(const_name(rel), simp_lemma(id, univ_metas, reverse_to_list(emetas),
reverse_to_list(instances), lhs, rhs, proof, is_perm, priority));
}
}
return new_s;
}
simp_lemmas add(tmp_type_context & tctx, simp_lemmas const & s, name const & id, expr const & e, expr const & h, unsigned priority) {
return add_core(tctx, s, id, list<level>(), e, h, priority);
}
simp_lemmas join(simp_lemmas const & s1, simp_lemmas const & s2) {
simp_lemmas new_s1 = s1;
s2.for_each_simp([&](name const & eqv, simp_lemma const & r) {
new_s1.insert(eqv, r);
});
return new_s1;
}
static simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s, name const & cname, unsigned priority) {
declaration const & d = tctx.env().get(cname);
buffer<level> us;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
us.push_back(tctx.mk_uvar());
}
levels ls = to_list(us);
expr e = tctx.whnf(instantiate_type_univ_params(d, ls));
expr h = mk_constant(cname, ls);
return add_core(tctx, s, cname, ls, e, h, priority);
}
// Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x),
// where B is lambda or Pi
static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) {
lean_assert(is_binding(lhs));
expr const & d = binding_domain(lhs);
expr const & b = binding_body(lhs);
if (!is_metavar(d))
return false;
if (is_metavar(b) && b != d) {
found_mvars.insert(mlocal_name(b));
found_mvars.insert(mlocal_name(d));
return true;
}
if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) {
found_mvars.insert(mlocal_name(app_fn(b)));
found_mvars.insert(mlocal_name(d));
return true;
}
return false;
}
// Return true iff all metavariables in e are in found_mvars
static bool only_found_mvars(expr const & e, name_set const & found_mvars) {
return !find(e, [&](expr const & m, unsigned) {
return is_metavar(m) && !found_mvars.contains(mlocal_name(m));
});
}
// Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable,
// and l_i's are local constants, and mvar does not occur in found_mvars.
// If it is return true and update found_mvars
static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) {
buffer<expr> rhs_args;
expr const & rhs_fn = get_app_args(rhs, rhs_args);
if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn)))
return false;
for (expr const & arg : rhs_args)
if (!is_local(arg))
return false;
found_mvars.insert(mlocal_name(rhs_fn));
return true;
}
simp_lemmas add_congr_core(tmp_type_context & tctx, simp_lemmas const & s, name const & n, unsigned prio) {
declaration const & d = tctx.env().get(n);
buffer<level> us;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
us.push_back(tctx.mk_uvar());
}
levels ls = to_list(us);
expr rule = normalizer(tctx)(instantiate_type_univ_params(d, ls));
expr proof = mk_constant(n, ls);
buffer<expr> emetas;
buffer<bool> instances, explicits;
while (is_pi(rule)) {
expr mvar = tctx.mk_mvar(binding_domain(rule));
emetas.push_back(mvar);
explicits.push_back(is_explicit(binding_info(rule)));
instances.push_back(binding_info(rule).is_inst_implicit());
rule = tctx.whnf(instantiate(binding_body(rule), mvar));
proof = mk_app(proof, mvar);
}
expr rel, lhs, rhs;
if (!is_simp_relation(tctx.env(), rule, rel, lhs, rhs) || !is_constant(rel)) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation");
}
name_set found_mvars;
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
expr const & rhs_fn = get_app_args(rhs, rhs_args);
if (is_constant(lhs_fn)) {
if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) "
<< "~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'");
}
for (expr const & lhs_arg : lhs_args) {
if (is_sort(lhs_arg))
continue;
if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' the left-hand-side of the congruence resulting type must be of the form ("
<< const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort");
}
found_mvars.insert(mlocal_name(lhs_arg));
}
} else if (is_binding(lhs)) {
if (lhs.kind() != rhs.kind()) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' kinds of the left-hand-side and right-hand-side of "
<< "the congruence resulting type do not match");
}
if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' left-hand-side of the congruence resulting type must "
<< "be of the form (fun/Pi (x : A), B x)");
}
} else {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' left-hand-side is not an application nor a binding");
}
buffer<expr> congr_hyps;
lean_assert(emetas.size() == explicits.size());
for (unsigned i = 0; i < emetas.size(); i++) {
expr const & mvar = emetas[i];
if (explicits[i] && !found_mvars.contains(mlocal_name(mvar))) {
buffer<expr> locals;
expr type = mlocal_type(mvar);
while (is_pi(type)) {
expr local = tctx.mk_tmp_local(binding_domain(type));
locals.push_back(local);
type = instantiate(binding_body(type), local);
}
expr h_rel, h_lhs, h_rhs;
if (!is_simp_relation(tctx.env(), type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel))
continue;
unsigned j = 0;
for (expr const & local : locals) {
j++;
if (!only_found_mvars(mlocal_type(local), found_mvars)) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' argument #" << j << " of parameter #" << (i+1) << " contains "
<< "unresolved parameters");
}
}
if (!only_found_mvars(h_lhs, found_mvars)) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' argument #" << (i+1) << " is not a valid hypothesis, the left-hand-side contains "
<< "unresolved parameters");
}
if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) {
report_failure(sstream() << "invalid [congr] lemma, '" << n
<< "' argument #" << (i+1) << " is not a valid hypothesis, the right-hand-side must be "
<< "of the form (m l_1 ... l_n) where m is parameter that was not "
<< "'assigned/resolved' yet and l_i's are locals");
}
found_mvars.insert(mlocal_name(mvar));
congr_hyps.push_back(mvar);
}
}
simp_lemmas new_s = s;
new_s.insert(const_name(rel), user_congr_lemma(n, ls, reverse_to_list(emetas),
reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio));
return new_s;
}
simp_lemma_core::simp_lemma_core(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
unsigned priority):
m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances),
m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {}
simp_lemma::simp_lemma(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
bool is_perm, unsigned priority):
simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority),
m_is_permutation(is_perm) {}
bool operator==(simp_lemma const & r1, simp_lemma const & r2) {
return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs;
}
format simp_lemma::pp(formatter const & fmt) const {
format r;
r += format("#") + format(get_num_emeta());
if (m_priority != LEAN_DEFAULT_PRIORITY)
r += space() + paren(format(m_priority));
if (m_is_permutation)
r += space() + format("perm");
format r1 = comma() + space() + fmt(get_lhs());
r1 += space() + format("") + pp_indent_expr(fmt, get_rhs());
r += group(r1);
return r;
}
user_congr_lemma::user_congr_lemma(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
list<expr> const & congr_hyps, unsigned priority):
simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority),
m_congr_hyps(congr_hyps) {}
bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2) {
return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps;
}
format user_congr_lemma::pp(formatter const & fmt) const {
format r;
r += format("#") + format(get_num_emeta());
if (m_priority != LEAN_DEFAULT_PRIORITY)
r += space() + paren(format(m_priority));
format r1;
for (expr const & h : m_congr_hyps) {
r1 += space() + paren(fmt(mlocal_type(h)));
}
r += group(r1);
r += space() + format(":") + space();
format r2 = paren(fmt(get_lhs()));
r2 += space() + format("") + space() + paren(fmt(get_rhs()));
r += group(r2);
return r;
}
simp_lemmas_for::simp_lemmas_for(name const & eqv):
m_eqv(eqv) {}
void simp_lemmas_for::insert(simp_lemma const & r) {
m_simp_set.insert(r.get_lhs(), r);
}
void simp_lemmas_for::erase(simp_lemma const & r) {
m_simp_set.erase(r.get_lhs(), r);
}
void simp_lemmas_for::insert(user_congr_lemma const & r) {
m_congr_set.insert(r.get_lhs(), r);
}
void simp_lemmas_for::erase(user_congr_lemma const & r) {
m_congr_set.erase(r.get_lhs(), r);
}
list<simp_lemma> const * simp_lemmas_for::find_simp(head_index const & h) const {
return m_simp_set.find(h);
}
void simp_lemmas_for::for_each_simp(std::function<void(simp_lemma const &)> const & fn) const {
m_simp_set.for_each_entry([&](head_index const &, simp_lemma const & r) { fn(r); });
}
list<user_congr_lemma> const * simp_lemmas_for::find_congr(head_index const & h) const {
return m_congr_set.find(h);
}
void simp_lemmas_for::for_each_congr(std::function<void(user_congr_lemma const &)> const & fn) const {
m_congr_set.for_each_entry([&](head_index const &, user_congr_lemma const & r) { fn(r); });
}
void simp_lemmas_for::erase_simp(name_set const & ids) {
// This method is not very smart and doesn't use any indexing or caching.
// So, it may be a bottleneck in the future
buffer<simp_lemma> to_delete;
for_each_simp([&](simp_lemma const & r) {
if (ids.contains(r.get_id())) {
to_delete.push_back(r);
}
});
for (simp_lemma const & r : to_delete) {
erase(r);
}
}
void simp_lemmas_for::erase_simp(buffer<name> const & ids) {
erase_simp(to_name_set(ids));
}
template<typename R>
void simp_lemmas::insert_core(name const & eqv, R const & r) {
simp_lemmas_for s(eqv);
if (auto const * curr = m_sets.find(eqv)) {
s = *curr;
}
s.insert(r);
m_sets.insert(eqv, s);
}
template<typename R>
void simp_lemmas::erase_core(name const & eqv, R const & r) {
if (auto const * curr = m_sets.find(eqv)) {
simp_lemmas_for s = *curr;
s.erase(r);
if (s.empty())
m_sets.erase(eqv);
else
m_sets.insert(eqv, s);
}
}
void simp_lemmas::insert(name const & eqv, simp_lemma const & r) {
return insert_core(eqv, r);
}
void simp_lemmas::erase(name const & eqv, simp_lemma const & r) {
return erase_core(eqv, r);
}
void simp_lemmas::insert(name const & eqv, user_congr_lemma const & r) {
return insert_core(eqv, r);
}
void simp_lemmas::erase(name const & eqv, user_congr_lemma const & r) {
return erase_core(eqv, r);
}
void simp_lemmas::get_relations(buffer<name> & rs) const {
m_sets.for_each([&](name const & r, simp_lemmas_for const &) {
rs.push_back(r);
});
}
void simp_lemmas::erase_simp(name_set const & ids) {
name_map<simp_lemmas_for> new_sets;
m_sets.for_each([&](name const & n, simp_lemmas_for const & s) {
simp_lemmas_for new_s = s;
new_s.erase_simp(ids);
new_sets.insert(n, new_s);
});
m_sets = new_sets;
}
void simp_lemmas::erase_simp(buffer<name> const & ids) {
erase_simp(to_name_set(ids));
}
simp_lemmas_for const * simp_lemmas::find(name const & eqv) const {
return m_sets.find(eqv);
}
list<simp_lemma> const * simp_lemmas::find_simp(name const & eqv, head_index const & h) const {
if (auto const * s = m_sets.find(eqv))
return s->find_simp(h);
return nullptr;
}
list<user_congr_lemma> const * simp_lemmas::find_congr(name const & eqv, head_index const & h) const {
if (auto const * s = m_sets.find(eqv))
return s->find_congr(h);
return nullptr;
}
void simp_lemmas::for_each_simp(std::function<void(name const &, simp_lemma const &)> const & fn) const {
m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) {
s.for_each_simp([&](simp_lemma const & r) {
fn(eqv, r);
});
});
}
void simp_lemmas::for_each_congr(std::function<void(name const &, user_congr_lemma const &)> const & fn) const {
m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) {
s.for_each_congr([&](user_congr_lemma const & r) {
fn(eqv, r);
});
});
}
format simp_lemmas::pp(formatter const & fmt, format const & header, bool simp, bool congr) const {
format r;
if (simp) {
name prev_eqv;
for_each_simp([&](name const & eqv, simp_lemma const & rw) {
if (prev_eqv != eqv) {
r += format("simplification rules for ") + format(eqv);
r += header;
r += line();
prev_eqv = eqv;
}
r += rw.pp(fmt) + line();
});
}
if (congr) {
name prev_eqv;
for_each_congr([&](name const & eqv, user_congr_lemma const & cr) {
if (prev_eqv != eqv) {
r += format("congruencec rules for ") + format(eqv) + line();
prev_eqv = eqv;
}
r += cr.pp(fmt) + line();
});
}
return r;
}
format simp_lemmas::pp_simp(formatter const & fmt, format const & header) const {
return pp(fmt, header, true, false);
}
format simp_lemmas::pp_simp(formatter const & fmt) const {
return pp(fmt, format(), true, false);
}
format simp_lemmas::pp_congr(formatter const & fmt) const {
return pp(fmt, format(), false, true);
}
format simp_lemmas::pp(formatter const & fmt) const {
return pp(fmt, format(), true, true);
}
struct simp_lemmas_cache {
simp_lemmas m_main_cache;
std::vector<optional<simp_lemmas>> m_key_cache;
};
LEAN_THREAD_PTR(simp_lemmas_cache, g_simp_lemmas_cache);
scope_simp::scope_simp() {
m_old_cache = g_simp_lemmas_cache;
g_simp_lemmas_cache = nullptr;
}
scope_simp::~scope_simp() {
delete g_simp_lemmas_cache;
g_simp_lemmas_cache = m_old_cache;
}
simp_lemmas get_simp_lemmas_core() {
simp_lemmas r;
buffer<name> simp_lemmas, congr_lemmas;
blast_tmp_type_context ctx;
auto const & s = simp_ext::get_state(env());
s.m_simp_lemmas.to_buffer(simp_lemmas);
s.m_congr_lemmas.to_buffer(congr_lemmas);
unsigned i = simp_lemmas.size();
while (i > 0) {
--i;
ctx->clear();
r = add_core(*ctx, r, simp_lemmas[i], *s.m_simp_lemmas.get_prio(simp_lemmas[i]));
}
i = congr_lemmas.size();
while (i > 0) {
--i;
ctx->clear();
r = add_congr_core(*ctx, r, congr_lemmas[i], *s.m_congr_lemmas.get_prio(congr_lemmas[i]));
}
return r;
}
static simp_lemmas_cache & get_cache() {
if (!g_simp_lemmas_cache) {
g_simp_lemmas_cache = new simp_lemmas_cache();
g_simp_lemmas_cache->m_main_cache = get_simp_lemmas_core();
}
return *g_simp_lemmas_cache;
}
simp_lemmas get_simp_lemmas() {
return get_cache().m_main_cache;
}
template<typename NSS>
simp_lemmas get_simp_lemmas_core(NSS const & nss) {
simp_lemmas r;
blast_tmp_type_context ctx;
for (name const & ns : nss) {
list<simp_entry> const * entries = simp_ext::get_entries(env(), ns);
if (entries) {
for (auto const & e : *entries) {
bool is_simp; unsigned prio; name n;
std::tie(is_simp, prio, n) = e;
if (is_simp) {
ctx->clear();
r = add_core(*ctx, r, n, prio);
}
}
}
}
return r;
}
simp_lemmas get_simp_lemmas(std::initializer_list<name> const & nss) {
return get_simp_lemmas_core(nss);
}
simp_lemmas get_simp_lemmas(name const & ns) {
return get_simp_lemmas({ns});
}
simp_lemmas get_simp_lemmas(unsigned key) {
simp_lemmas_cache & cache = get_cache();
if (key >= g_simp_lemma_ns->size())
throw exception("invalid simp lemma cache key");
if (key >= cache.m_key_cache.size())
cache.m_key_cache.resize(key+1);
if (!cache.m_key_cache[key])
cache.m_key_cache[key] = get_simp_lemmas_core((*g_simp_lemma_ns)[key]);
return *cache.m_key_cache[key];
}
}
void validate_simp(environment const & env, io_state const & ios, name const & n) {
blast::simp_lemmas s;
tmp_type_context ctx(env, ios.get_options());
flet<bool> set_ex(blast::g_throw_ex, true);
blast::add_core(ctx, s, n, LEAN_DEFAULT_PRIORITY);
}
void validate_congr(environment const & env, io_state const & ios, name const & n) {
blast::simp_lemmas s;
tmp_type_context ctx(env, ios.get_options());
flet<bool> set_ex(blast::g_throw_ex, true);
blast::add_congr_core(ctx, s, n, LEAN_DEFAULT_PRIORITY);
}
void initialize_simp_lemmas() {
g_class_name = new name("simp");
g_key = new std::string("SIMP");
g_simp_lemma_ns = new std::vector<std::vector<name>>();
simp_ext::initialize();
register_prio_attribute("simp", "simplification lemma",
add_simp_lemma,
is_simp_lemma,
[](environment const & env, name const & d) {
if (auto p = simp_ext::get_state(env).m_simp_lemmas.get_prio(d))
return *p;
else
return LEAN_DEFAULT_PRIORITY;
});
register_prio_attribute("congr", "congruence lemma",
add_congr_lemma,
is_congr_lemma,
[](environment const & env, name const & d) {
if (auto p = simp_ext::get_state(env).m_congr_lemmas.get_prio(d))
return *p;
else
return LEAN_DEFAULT_PRIORITY;
});
blast::g_simp_lemmas_cache = nullptr;
}
void finalize_simp_lemmas() {
simp_ext::finalize();
delete g_key;
delete g_class_name;
delete g_simp_lemma_ns;
}
}

View file

@ -0,0 +1,171 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/environment.h"
#include "library/io_state.h"
#include "library/tmp_type_context.h"
#include "library/head_map.h"
#include "library/blast/gexpr.h"
namespace lean {
environment add_simp_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent);
environment add_congr_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent);
bool is_simp_lemma(environment const & env, name const & n);
bool is_congr_lemma(environment const & env, name const & n);
void get_simp_lemmas(environment const & env, buffer<name> & r);
void get_congr_lemmas(environment const & env, buffer<name> & r);
void initialize_simp_lemmas();
void finalize_simp_lemmas();
/** Generate a unique id for a set of namespaces containing [simp] and [congr] lemmas */
unsigned register_simp_lemmas(std::initializer_list<name> const & nss);
namespace blast {
class simp_lemmas;
class simp_lemma_core {
protected:
name m_id;
levels m_umetas;
list<expr> m_emetas;
list<bool> m_instances;
expr m_lhs;
expr m_rhs;
expr m_proof;
unsigned m_priority;
simp_lemma_core(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
unsigned priority);
public:
name const & get_id() const { return m_id; }
unsigned get_num_umeta() const { return length(m_umetas); }
unsigned get_num_emeta() const { return length(m_emetas); }
/** \brief Return a list containing the expression metavariables in reverse order. */
list<expr> const & get_emetas() const { return m_emetas; }
/** \brief Return a list of bools indicating whether or not each expression metavariable
in <tt>get_emetas()</tt> is an instance. */
list<bool> const & get_instances() const { return m_instances; }
unsigned get_priority() const { return m_priority; }
expr const & get_lhs() const { return m_lhs; }
expr const & get_rhs() const { return m_rhs; }
expr const & get_proof() const { return m_proof; }
};
class simp_lemma : public simp_lemma_core {
bool m_is_permutation;
simp_lemma(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
bool is_perm, unsigned priority);
friend simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s, name const & id,
levels const & univ_metas, expr const & e, expr const & h, unsigned priority);
public:
friend bool operator==(simp_lemma const & r1, simp_lemma const & r2);
bool is_perm() const { return m_is_permutation; }
format pp(formatter const & fmt) const;
};
bool operator==(simp_lemma const & r1, simp_lemma const & r2);
inline bool operator!=(simp_lemma const & r1, simp_lemma const & r2) { return !operator==(r1, r2); }
// We use user_congr_lemma to avoid a confusion with ::lemma::congr_lemma
class user_congr_lemma : public simp_lemma_core {
list<expr> m_congr_hyps;
user_congr_lemma(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
list<expr> const & congr_hyps, unsigned priority);
friend simp_lemmas add_congr_core(tmp_type_context & tctx, simp_lemmas const & s, name const & n, unsigned priority);
public:
friend bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2);
list<expr> const & get_congr_hyps() const { return m_congr_hyps; }
format pp(formatter const & fmt) const;
};
struct simp_lemma_core_prio_fn { unsigned operator()(simp_lemma_core const & s) const { return s.get_priority(); } };
bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2);
inline bool operator!=(user_congr_lemma const & r1, user_congr_lemma const & r2) { return !operator==(r1, r2); }
/** \brief Simplification and congruence lemmas for a given equivalence relation */
class simp_lemmas_for {
typedef head_map_prio<simp_lemma, simp_lemma_core_prio_fn> simp_set;
typedef head_map_prio<user_congr_lemma, simp_lemma_core_prio_fn> congr_set;
name m_eqv;
simp_set m_simp_set;
congr_set m_congr_set;
public:
simp_lemmas_for() {}
/** \brief Return the equivalence relation associated with this set */
simp_lemmas_for(name const & eqv);
bool empty() const { return m_simp_set.empty() && m_congr_set.empty(); }
name const & get_eqv() const { return m_eqv; }
void insert(simp_lemma const & r);
void erase(simp_lemma const & r);
void insert(user_congr_lemma const & r);
void erase(user_congr_lemma const & r);
void erase_simp(name_set const & ids);
void erase_simp(buffer<name> const & ids);
list<simp_lemma> const * find_simp(head_index const & h) const;
void for_each_simp(std::function<void(simp_lemma const &)> const & fn) const;
list<user_congr_lemma> const * find_congr(head_index const & h) const;
void for_each_congr(std::function<void(user_congr_lemma const &)> const & fn) const;
};
class simp_lemmas {
name_map<simp_lemmas_for> m_sets; // mapping from relation name to simp_lemmas_for
template<typename R> void insert_core(name const & eqv, R const & r);
template<typename R> void erase_core(name const & eqv, R const & r);
public:
void insert(name const & eqv, simp_lemma const & r);
void erase(name const & eqv, simp_lemma const & r);
void insert(name const & eqv, user_congr_lemma const & r);
void erase(name const & eqv, user_congr_lemma const & r);
void erase_simp(name_set const & ids);
void erase_simp(buffer<name> const & ids);
void get_relations(buffer<name> & rs) const;
simp_lemmas_for const * find(name const & eqv) const;
list<simp_lemma> const * find_simp(name const & eqv, head_index const & h) const;
list<user_congr_lemma> const * find_congr(name const & eqv, head_index const & h) const;
void for_each_simp(std::function<void(name const &, simp_lemma const &)> const & fn) const;
void for_each_congr(std::function<void(name const &, user_congr_lemma const &)> const & fn) const;
format pp(formatter const & fmt, format const & header, bool simp, bool congr) const;
format pp_simp(formatter const & fmt, format const & header) const;
format pp_simp(formatter const & fmt) const;
format pp_congr(formatter const & fmt) const;
format pp(formatter const & fmt) const;
};
struct simp_lemmas_cache;
/** \brief Auxiliary class for initializing simp lemmas during blast initialization. */
class scope_simp {
simp_lemmas_cache * m_old_cache;
public:
scope_simp();
~scope_simp();
};
simp_lemmas add(tmp_type_context & tctx, simp_lemmas const & s, name const & id, expr const & e, expr const & h, unsigned priority);
simp_lemmas join(simp_lemmas const & s1, simp_lemmas const & s2);
/** \brief Get (active) simplification lemmas. */
simp_lemmas get_simp_lemmas();
/** \brief Get simplification lemmas in the given namespace. */
simp_lemmas get_simp_lemmas(name const & ns);
/** \brief Get simplification lemmas in the given namespaces. */
// simp_lemmas get_simp_lemmas(std::initializer_list<name> const & nss);
/** \brief Get simplification lemmas in the namespaces registered at key.
The key is created using procedure #register_simp_lemmas at initialization time.
This is more efficient than get_simp_lemmas(std::initializer_list<name> const & nss), because
results are cached. */
simp_lemmas get_simp_lemmas(unsigned key);
}
}

View file

@ -1,612 +0,0 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <string>
#include <vector>
#include "util/sstream.h"
#include "kernel/instantiate.h"
#include "kernel/find_fn.h"
#include "kernel/error_msgs.h"
#include "library/scoped_ext.h"
#include "library/expr_pair.h"
#include "library/attribute_manager.h"
#include "library/relation_manager.h"
#include "library/blast/simplifier/ceqv.h"
#include "library/blast/simplifier/simp_rule_set.h"
namespace lean {
simp_rule_core::simp_rule_core(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
unsigned priority):
m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances),
m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {}
simp_rule::simp_rule(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
bool is_perm, unsigned priority):
simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority),
m_is_permutation(is_perm) {}
bool operator==(simp_rule const & r1, simp_rule const & r2) {
return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs;
}
format simp_rule::pp(formatter const & fmt) const {
format r;
r += format("#") + format(get_num_emeta());
if (m_priority != LEAN_DEFAULT_PRIORITY)
r += space() + paren(format(m_priority));
if (m_is_permutation)
r += space() + format("perm");
format r1 = comma() + space() + fmt(get_lhs());
r1 += space() + format("") + pp_indent_expr(fmt, get_rhs());
r += group(r1);
return r;
}
congr_rule::congr_rule(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
list<expr> const & congr_hyps, unsigned priority):
simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority),
m_congr_hyps(congr_hyps) {}
bool operator==(congr_rule const & r1, congr_rule const & r2) {
return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps;
}
format congr_rule::pp(formatter const & fmt) const {
format r;
r += format("#") + format(get_num_emeta());
if (m_priority != LEAN_DEFAULT_PRIORITY)
r += space() + paren(format(m_priority));
format r1;
for (expr const & h : m_congr_hyps) {
r1 += space() + paren(fmt(mlocal_type(h)));
}
r += group(r1);
r += space() + format(":") + space();
format r2 = paren(fmt(get_lhs()));
r2 += space() + format("") + space() + paren(fmt(get_rhs()));
r += group(r2);
return r;
}
simp_rule_set::simp_rule_set(name const & eqv):
m_eqv(eqv) {}
void simp_rule_set::insert(simp_rule const & r) {
m_simp_set.insert(r.get_lhs(), r);
}
void simp_rule_set::erase(simp_rule const & r) {
m_simp_set.erase(r.get_lhs(), r);
}
void simp_rule_set::insert(congr_rule const & r) {
m_congr_set.insert(r.get_lhs(), r);
}
void simp_rule_set::erase(congr_rule const & r) {
m_congr_set.erase(r.get_lhs(), r);
}
list<simp_rule> const * simp_rule_set::find_simp(head_index const & h) const {
return m_simp_set.find(h);
}
void simp_rule_set::for_each_simp(std::function<void(simp_rule const &)> const & fn) const {
m_simp_set.for_each_entry([&](head_index const &, simp_rule const & r) { fn(r); });
}
list<congr_rule> const * simp_rule_set::find_congr(head_index const & h) const {
return m_congr_set.find(h);
}
void simp_rule_set::for_each_congr(std::function<void(congr_rule const &)> const & fn) const {
m_congr_set.for_each_entry([&](head_index const &, congr_rule const & r) { fn(r); });
}
void simp_rule_set::erase_simp(name_set const & ids) {
// This method is not very smart and doesn't use any indexing or caching.
// So, it may be a bottleneck in the future
buffer<simp_rule> to_delete;
for_each_simp([&](simp_rule const & r) {
if (ids.contains(r.get_id())) {
to_delete.push_back(r);
}
});
for (simp_rule const & r : to_delete) {
erase(r);
}
}
void simp_rule_set::erase_simp(buffer<name> const & ids) {
erase_simp(to_name_set(ids));
}
template<typename R>
void simp_rule_sets::insert_core(name const & eqv, R const & r) {
simp_rule_set s(eqv);
if (auto const * curr = m_sets.find(eqv)) {
s = *curr;
}
s.insert(r);
m_sets.insert(eqv, s);
}
template<typename R>
void simp_rule_sets::erase_core(name const & eqv, R const & r) {
if (auto const * curr = m_sets.find(eqv)) {
simp_rule_set s = *curr;
s.erase(r);
if (s.empty())
m_sets.erase(eqv);
else
m_sets.insert(eqv, s);
}
}
void simp_rule_sets::insert(name const & eqv, simp_rule const & r) {
return insert_core(eqv, r);
}
void simp_rule_sets::erase(name const & eqv, simp_rule const & r) {
return erase_core(eqv, r);
}
void simp_rule_sets::insert(name const & eqv, congr_rule const & r) {
return insert_core(eqv, r);
}
void simp_rule_sets::erase(name const & eqv, congr_rule const & r) {
return erase_core(eqv, r);
}
void simp_rule_sets::get_relations(buffer<name> & rs) const {
m_sets.for_each([&](name const & r, simp_rule_set const &) {
rs.push_back(r);
});
}
void simp_rule_sets::erase_simp(name_set const & ids) {
name_map<simp_rule_set> new_sets;
m_sets.for_each([&](name const & n, simp_rule_set const & s) {
simp_rule_set new_s = s;
new_s.erase_simp(ids);
new_sets.insert(n, new_s);
});
m_sets = new_sets;
}
void simp_rule_sets::erase_simp(buffer<name> const & ids) {
erase_simp(to_name_set(ids));
}
simp_rule_set const * simp_rule_sets::find(name const & eqv) const {
return m_sets.find(eqv);
}
list<simp_rule> const * simp_rule_sets::find_simp(name const & eqv, head_index const & h) const {
if (auto const * s = m_sets.find(eqv))
return s->find_simp(h);
return nullptr;
}
list<congr_rule> const * simp_rule_sets::find_congr(name const & eqv, head_index const & h) const {
if (auto const * s = m_sets.find(eqv))
return s->find_congr(h);
return nullptr;
}
void simp_rule_sets::for_each_simp(std::function<void(name const &, simp_rule const &)> const & fn) const {
m_sets.for_each([&](name const & eqv, simp_rule_set const & s) {
s.for_each_simp([&](simp_rule const & r) {
fn(eqv, r);
});
});
}
void simp_rule_sets::for_each_congr(std::function<void(name const &, congr_rule const &)> const & fn) const {
m_sets.for_each([&](name const & eqv, simp_rule_set const & s) {
s.for_each_congr([&](congr_rule const & r) {
fn(eqv, r);
});
});
}
format simp_rule_sets::pp(formatter const & fmt, format const & header, bool simp, bool congr) const {
format r;
if (simp) {
name prev_eqv;
for_each_simp([&](name const & eqv, simp_rule const & rw) {
if (prev_eqv != eqv) {
r += format("simplification rules for ") + format(eqv);
r += header;
r += line();
prev_eqv = eqv;
}
r += rw.pp(fmt) + line();
});
}
if (congr) {
name prev_eqv;
for_each_congr([&](name const & eqv, congr_rule const & cr) {
if (prev_eqv != eqv) {
r += format("congruencec rules for ") + format(eqv) + line();
prev_eqv = eqv;
}
r += cr.pp(fmt) + line();
});
}
return r;
}
format simp_rule_sets::pp_simp(formatter const & fmt, format const & header) const {
return pp(fmt, header, true, false);
}
format simp_rule_sets::pp_simp(formatter const & fmt) const {
return pp(fmt, format(), true, false);
}
format simp_rule_sets::pp_congr(formatter const & fmt) const {
return pp(fmt, format(), false, true);
}
format simp_rule_sets::pp(formatter const & fmt) const {
return pp(fmt, format(), true, true);
}
static name * g_prefix = nullptr;
simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s,
name const & id, levels const & univ_metas, expr const & e, expr const & h,
unsigned priority) {
list<expr_pair> ceqvs = to_ceqvs(tctx, e, h);
if (is_nil(ceqvs)) throw exception("[simp] rule invalid");
environment const & env = tctx.env();
simp_rule_sets new_s = s;
for (expr_pair const & p : ceqvs) {
expr rule = tctx.whnf(p.first);
expr proof = tctx.whnf(p.second);
bool is_perm = is_permutation_ceqv(env, rule);
buffer<expr> emetas;
buffer<bool> instances;
while (is_pi(rule)) {
expr mvar = tctx.mk_mvar(binding_domain(rule));
emetas.push_back(mvar);
instances.push_back(binding_info(rule).is_inst_implicit());
rule = tctx.whnf(instantiate(binding_body(rule), mvar));
proof = mk_app(proof, mvar);
}
expr rel, lhs, rhs;
if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) {
new_s.insert(const_name(rel), simp_rule(id, univ_metas, reverse_to_list(emetas),
reverse_to_list(instances), lhs, rhs, proof, is_perm, priority));
}
}
return new_s;
}
simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority) {
return add_core(tctx, s, id, list<level>(), e, h, priority);
}
simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2) {
simp_rule_sets new_s1 = s1;
s2.for_each_simp([&](name const & eqv, simp_rule const & r) {
new_s1.insert(eqv, r);
});
return new_s1;
}
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
static simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & cname, unsigned priority) {
declaration const & d = tctx.env().get(cname);
buffer<level> us;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
us.push_back(tctx.mk_uvar());
}
levels ls = to_list(us);
expr e = tctx.whnf(instantiate_type_univ_params(d, ls));
expr h = mk_constant(cname, ls);
return add_core(tctx, s, cname, ls, e, h, priority);
}
// Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x),
// where B is lambda or Pi
static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) {
lean_assert(is_binding(lhs));
expr const & d = binding_domain(lhs);
expr const & b = binding_body(lhs);
if (!is_metavar(d))
return false;
if (is_metavar(b) && b != d) {
found_mvars.insert(mlocal_name(b));
found_mvars.insert(mlocal_name(d));
return true;
}
if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) {
found_mvars.insert(mlocal_name(app_fn(b)));
found_mvars.insert(mlocal_name(d));
return true;
}
return false;
}
// Return true iff all metavariables in e are in found_mvars
static bool only_found_mvars(expr const & e, name_set const & found_mvars) {
return !find(e, [&](expr const & m, unsigned) {
return is_metavar(m) && !found_mvars.contains(mlocal_name(m));
});
}
// Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable,
// and l_i's are local constants, and mvar does not occur in found_mvars.
// If it is return true and update found_mvars
static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) {
buffer<expr> rhs_args;
expr const & rhs_fn = get_app_args(rhs, rhs_args);
if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn)))
return false;
for (expr const & arg : rhs_args)
if (!is_local(arg))
return false;
found_mvars.insert(mlocal_name(rhs_fn));
return true;
}
void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned prio) {
declaration const & d = tctx.env().get(n);
buffer<level> us;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
us.push_back(tctx.mk_uvar());
}
levels ls = to_list(us);
expr rule = instantiate_type_univ_params(d, ls);
expr proof = mk_constant(n, ls);
buffer<expr> emetas;
buffer<bool> instances, explicits;
while (is_pi(rule)) {
expr mvar = tctx.mk_mvar(binding_domain(rule));
emetas.push_back(mvar);
explicits.push_back(is_explicit(binding_info(rule)));
instances.push_back(binding_info(rule).is_inst_implicit());
rule = tctx.whnf(instantiate(binding_body(rule), mvar));
proof = mk_app(proof, mvar);
}
expr rel, lhs, rhs;
if (!is_simp_relation(tctx.env(), rule, rel, lhs, rhs) || !is_constant(rel)) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation");
}
name_set found_mvars;
buffer<expr> lhs_args, rhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
expr const & rhs_fn = get_app_args(rhs, rhs_args);
if (is_constant(lhs_fn)) {
if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) "
<< "~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'");
}
for (expr const & lhs_arg : lhs_args) {
if (is_sort(lhs_arg))
continue;
if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' the left-hand-side of the congruence resulting type must be of the form ("
<< const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort");
}
found_mvars.insert(mlocal_name(lhs_arg));
}
} else if (is_binding(lhs)) {
if (lhs.kind() != rhs.kind()) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' kinds of the left-hand-side and right-hand-side of "
<< "the congruence resulting type do not match");
}
if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' left-hand-side of the congruence resulting type must "
<< "be of the form (fun/Pi (x : A), B x)");
}
} else {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' left-hand-side is not an application nor a binding");
}
buffer<expr> congr_hyps;
lean_assert(emetas.size() == explicits.size());
for (unsigned i = 0; i < emetas.size(); i++) {
expr const & mvar = emetas[i];
if (explicits[i] && !found_mvars.contains(mlocal_name(mvar))) {
buffer<expr> locals;
expr type = mlocal_type(mvar);
while (is_pi(type)) {
expr local = tctx.mk_tmp_local(binding_domain(type));
locals.push_back(local);
type = instantiate(binding_body(type), local);
}
expr h_rel, h_lhs, h_rhs;
if (!is_simp_relation(tctx.env(), type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel))
continue;
unsigned j = 0;
for (expr const & local : locals) {
j++;
if (!only_found_mvars(mlocal_type(local), found_mvars)) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' argument #" << j << " of parameter #" << (i+1) << " contains "
<< "unresolved parameters");
}
}
if (!only_found_mvars(h_lhs, found_mvars)) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' argument #" << (i+1) << " is not a valid hypothesis, the left-hand-side contains "
<< "unresolved parameters");
}
if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) {
throw exception(sstream() << "invalid congruence rule, '" << n
<< "' argument #" << (i+1) << " is not a valid hypothesis, the right-hand-side must be "
<< "of the form (m l_1 ... l_n) where m is parameter that was not "
<< "'assigned/resolved' yet and l_i's are locals");
}
found_mvars.insert(mlocal_name(mvar));
congr_hyps.push_back(mvar);
}
}
s.insert(const_name(rel), congr_rule(n, ls, reverse_to_list(emetas),
reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio));
}
struct rrs_state {
simp_rule_sets m_sets;
name_set m_simp_names;
name_set m_congr_names;
void add_simp(environment const & env, io_state const & ios, name const & cname, unsigned prio) {
tmp_type_context tctx(env, ios.get_options());
m_sets = add_core(tctx, m_sets, cname, prio);
m_simp_names.insert(cname);
}
void add_congr(environment const & env, io_state const & ios, name const & n, unsigned prio) {
tmp_type_context tctx(env, ios.get_options());
add_congr_core(tctx, m_sets, n, prio);
m_congr_names.insert(n);
}
};
struct rrs_entry {
bool m_is_simp;
name m_name;
unsigned m_priority;
rrs_entry() {}
rrs_entry(bool is_simp, name const & n, unsigned prio):m_is_simp(is_simp), m_name(n), m_priority(prio) {}
};
struct rrs_config {
typedef rrs_entry entry;
typedef rrs_state state;
static void add_entry(environment const & env, io_state const & ios, state & s, entry const & e) {
if (e.m_is_simp)
s.add_simp(env, ios, e.m_name, e.m_priority);
else
s.add_congr(env, ios, e.m_name, e.m_priority);
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_is_simp << e.m_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
entry e; d >> e.m_is_simp >> e.m_name >> e.m_priority; return e;
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(hash(e.m_is_simp ? 17 : 31, e.m_name.hash()));
}
};
template class scoped_ext<rrs_config>;
typedef scoped_ext<rrs_config> rrs_ext;
environment add_simp_rule(environment const & env, name const & n, unsigned prio, name const & ns, bool persistent) {
return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(true, n, prio), ns, persistent);
}
environment add_congr_rule(environment const & env, name const & n, unsigned prio, name const & ns, bool persistent) {
return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(false, n, prio), ns, persistent);
}
bool is_simp_rule(environment const & env, name const & n) {
return rrs_ext::get_state(env).m_simp_names.contains(n);
}
bool is_congr_rule(environment const & env, name const & n) {
return rrs_ext::get_state(env).m_congr_names.contains(n);
}
simp_rule_sets get_simp_rule_sets(environment const & env) {
return rrs_ext::get_state(env).m_sets;
}
simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, name const & ns) {
simp_rule_sets set;
list<rrs_entry> const * entries = rrs_ext::get_entries(env, ns);
if (entries) {
for (auto const & e : *entries) {
tmp_type_context tctx(env, o);
set = add_core(tctx, set, e.m_name, e.m_priority);
}
}
return set;
}
simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, std::initializer_list<name> const & nss) {
simp_rule_sets set;
for (name const & ns : nss) {
list<rrs_entry> const * entries = rrs_ext::get_entries(env, ns);
if (entries) {
for (auto const & e : *entries) {
tmp_type_context tctx(env, o);
set = add_core(tctx, set, e.m_name, e.m_priority);
}
}
}
return set;
}
io_state_stream const & operator<<(io_state_stream const & out, simp_rule_sets const & s) {
options const & opts = out.get_options();
out.get_stream() << mk_pair(s.pp(out.get_formatter()), opts);
return out;
}
void initialize_simplifier_rule_set() {
g_prefix = new name(name::mk_internal_unique_name());
g_class_name = new name("simp");
g_key = new std::string("SIMP");
rrs_ext::initialize();
register_prio_attribute("simp", "simplification rule",
[](environment const & env, io_state const &, name const & d, unsigned prio, name const & ns, bool persistent) {
return add_simp_rule(env, d, prio, ns, persistent);
},
is_simp_rule,
[](environment const &, name const &) {
// TODO(Leo): fix it after we refactor simp_rule_set
return LEAN_DEFAULT_PRIORITY;
});
register_prio_attribute("congr", "congruence rule",
[](environment const & env, io_state const &, name const & d, unsigned prio, name const & ns, bool persistent) {
return add_congr_rule(env, d, prio, ns, persistent);
},
is_congr_rule,
[](environment const &, name const &) {
// TODO(Leo): fix it after we refactor simp_rule_set
return LEAN_DEFAULT_PRIORITY;
});
}
void finalize_simplifier_rule_set() {
rrs_ext::finalize();
delete g_key;
delete g_class_name;
delete g_prefix;
}
}

View file

@ -1,151 +0,0 @@
/*
Copyright (c) 2015 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "library/tmp_type_context.h"
#include "library/head_map.h"
#include "library/io_state_stream.h"
namespace lean {
class simp_rule_sets;
class simp_rule_core {
protected:
name m_id;
levels m_umetas;
list<expr> m_emetas;
list<bool> m_instances;
expr m_lhs;
expr m_rhs;
expr m_proof;
unsigned m_priority;
simp_rule_core(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
unsigned priority);
public:
name const & get_id() const { return m_id; }
unsigned get_num_umeta() const { return length(m_umetas); }
unsigned get_num_emeta() const { return length(m_emetas); }
/** \brief Return a list containing the expression metavariables in reverse order. */
list<expr> const & get_emetas() const { return m_emetas; }
/** \brief Return a list of bools indicating whether or not each expression metavariable
in <tt>get_emetas()</tt> is an instance. */
list<bool> const & get_instances() const { return m_instances; }
unsigned get_priority() const { return m_priority; }
expr const & get_lhs() const { return m_lhs; }
expr const & get_rhs() const { return m_rhs; }
expr const & get_proof() const { return m_proof; }
};
class simp_rule : public simp_rule_core {
bool m_is_permutation;
simp_rule(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
bool is_perm, unsigned priority);
friend simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & id,
levels const & univ_metas, expr const & e, expr const & h, unsigned priority);
public:
friend bool operator==(simp_rule const & r1, simp_rule const & r2);
bool is_perm() const { return m_is_permutation; }
format pp(formatter const & fmt) const;
};
bool operator==(simp_rule const & r1, simp_rule const & r2);
inline bool operator!=(simp_rule const & r1, simp_rule const & r2) { return !operator==(r1, r2); }
class congr_rule : public simp_rule_core {
list<expr> m_congr_hyps;
congr_rule(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, expr const & proof,
list<expr> const & congr_hyps, unsigned priority);
friend void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned priority);
public:
friend bool operator==(congr_rule const & r1, congr_rule const & r2);
list<expr> const & get_congr_hyps() const { return m_congr_hyps; }
format pp(formatter const & fmt) const;
};
struct simp_rule_core_prio_fn { unsigned operator()(simp_rule_core const & s) const { return s.get_priority(); } };
bool operator==(congr_rule const & r1, congr_rule const & r2);
inline bool operator!=(congr_rule const & r1, congr_rule const & r2) { return !operator==(r1, r2); }
class simp_rule_set {
typedef head_map_prio<simp_rule, simp_rule_core_prio_fn> simp_set;
typedef head_map_prio<congr_rule, simp_rule_core_prio_fn> congr_set;
name m_eqv;
simp_set m_simp_set;
congr_set m_congr_set;
public:
simp_rule_set() {}
/** \brief Return the equivalence relation associated with this set */
simp_rule_set(name const & eqv);
bool empty() const { return m_simp_set.empty() && m_congr_set.empty(); }
name const & get_eqv() const { return m_eqv; }
void insert(simp_rule const & r);
void erase(simp_rule const & r);
void insert(congr_rule const & r);
void erase(congr_rule const & r);
void erase_simp(name_set const & ids);
void erase_simp(buffer<name> const & ids);
list<simp_rule> const * find_simp(head_index const & h) const;
void for_each_simp(std::function<void(simp_rule const &)> const & fn) const;
list<congr_rule> const * find_congr(head_index const & h) const;
void for_each_congr(std::function<void(congr_rule const &)> const & fn) const;
};
class simp_rule_sets {
name_map<simp_rule_set> m_sets; // mapping from relation name to simp_rule_set
template<typename R> void insert_core(name const & eqv, R const & r);
template<typename R> void erase_core(name const & eqv, R const & r);
public:
void insert(name const & eqv, simp_rule const & r);
void erase(name const & eqv, simp_rule const & r);
void insert(name const & eqv, congr_rule const & r);
void erase(name const & eqv, congr_rule const & r);
void erase_simp(name_set const & ids);
void erase_simp(buffer<name> const & ids);
void get_relations(buffer<name> & rs) const;
simp_rule_set const * find(name const & eqv) const;
list<simp_rule> const * find_simp(name const & eqv, head_index const & h) const;
list<congr_rule> const * find_congr(name const & eqv, head_index const & h) const;
void for_each_simp(std::function<void(name const &, simp_rule const &)> const & fn) const;
void for_each_congr(std::function<void(name const &, congr_rule const &)> const & fn) const;
format pp(formatter const & fmt, format const & header, bool simp, bool congr) const;
format pp_simp(formatter const & fmt, format const & header) const;
format pp_simp(formatter const & fmt) const;
format pp_congr(formatter const & fmt) const;
format pp(formatter const & fmt) const;
};
simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority);
simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2);
environment add_simp_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent);
environment add_congr_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent);
/** \brief Return true if \c n is an active simplification rule in \c env */
bool is_simp_rule(environment const & env, name const & n);
/** \brief Return true if \c n is an active congruence rule in \c env */
bool is_congr_rule(environment const & env, name const & n);
/** \brief Get current simplification rule sets */
simp_rule_sets get_simp_rule_sets(environment const & env);
/** \brief Get simplification rule sets in the given namespace. */
simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, name const & ns);
/** \brief Get simplification rule sets in the given namespaces. */
simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, std::initializer_list<name> const & nss);
io_state_stream const & operator<<(io_state_stream const & out, simp_rule_sets const & s);
void initialize_simplifier_rule_set();
void finalize_simplifier_rule_set();
}

View file

@ -27,7 +27,7 @@ Author: Daniel Selsam
#include "library/blast/trace.h"
#include "library/blast/blast_exception.h"
#include "library/blast/simplifier/simplifier.h"
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/simplifier/simp_lemmas.h"
#include "library/blast/simplifier/ceqv.h"
#ifndef LEAN_DEFAULT_SIMPLIFY_MAX_STEPS
@ -60,14 +60,11 @@ namespace blast {
using simp::result;
/* Names */
/* Keys */
static name * g_simplify_prove_namespace = nullptr;
static name * g_simplify_neg_namespace = nullptr;
static name * g_simplify_unit_namespace = nullptr;
static name * g_simplify_ac_namespace = nullptr;
static name * g_simplify_distrib_namespace = nullptr;
static name * g_simplify_numeral_namespace = nullptr;
static unsigned g_ac_key;
static unsigned g_som_key;
static unsigned g_numeral_key;
/* Options */
@ -133,8 +130,8 @@ class simplifier {
name m_rel;
expr_predicate m_simp_pred;
simp_rule_sets m_srss;
simp_rule_sets m_ctx_srss;
simp_lemmas m_srss;
simp_lemmas m_ctx_srss;
/* Logging */
unsigned m_num_steps{0};
@ -184,8 +181,8 @@ class simplifier {
return has_free_vars(binding_body(f_type));
}
simp_rule_sets add_to_srss(simp_rule_sets const & _srss, buffer<expr> & ls) {
simp_rule_sets srss = _srss;
simp_lemmas add_to_srss(simp_lemmas const & _srss, buffer<expr> & ls) {
simp_lemmas srss = _srss;
for (unsigned i = 0; i < ls.size(); i++) {
expr & l = ls[i];
blast_tmp_type_context tctx;
@ -211,7 +208,7 @@ class simplifier {
result finalize(result const & r);
/* Simplification */
result simplify(expr const & e, simp_rule_sets const & srss);
result simplify(expr const & e, simp_lemmas const & srss);
result simplify(expr const & e, bool is_root);
result simplify_lambda(expr const & e);
result simplify_pi(expr const & e);
@ -221,12 +218,12 @@ class simplifier {
/* Proving */
optional<expr> prove(expr const & thm);
optional<expr> prove(expr const & thm, simp_rule_sets const & srss);
optional<expr> prove(expr const & thm, simp_lemmas const & srss);
/* Rewriting */
result rewrite(expr const & e);
result rewrite(expr const & e, simp_rule_sets const & srss);
result rewrite(expr const & e, simp_rule const & sr);
result rewrite(expr const & e, simp_lemmas const & srss);
result rewrite(expr const & e, simp_lemma const & sr);
/* Congruence */
result congr_fun_arg(result const & r_f, result const & r_arg);
@ -236,7 +233,7 @@ class simplifier {
result congr_funs(result const & r_f, buffer<expr> const & args);
result try_congrs(expr const & e);
result try_congr(expr const & e, congr_rule const & cr);
result try_congr(expr const & e, user_congr_lemma const & cr);
template<typename F>
optional<result> synth_congr(expr const & e, F && simp);
@ -254,7 +251,7 @@ class simplifier {
public:
simplifier(name const & rel, expr_predicate const & simp_pred): m_rel(rel), m_simp_pred(simp_pred) { }
result operator()(expr const & e, simp_rule_sets const & srss) { return simplify(e, srss); }
result operator()(expr const & e, simp_lemmas const & srss) { return simplify(e, srss); }
};
/* Cache */
@ -365,8 +362,8 @@ expr simplifier::whnf_eta(expr const & e) {
/* Simplification */
result simplifier::simplify(expr const & e, simp_rule_sets const & srss) {
flet<simp_rule_sets> set_srss(m_srss, srss);
result simplifier::simplify(expr const & e, simp_lemmas const & srss) {
flet<simp_lemmas> set_srss(m_srss, srss);
freset<simplify_cache> reset(m_cache);
return simplify(e, true);
}
@ -552,7 +549,7 @@ optional<expr> simplifier::prove(expr const & thm) {
return none_expr();
}
optional<expr> simplifier::prove(expr const & thm, simp_rule_sets const & srss) {
optional<expr> simplifier::prove(expr const & thm, simp_lemmas const & srss) {
flet<name> set_name(m_rel, get_iff_name());
result r_cond = simplify(thm, srss);
if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) {
@ -577,16 +574,16 @@ result simplifier::rewrite(expr const & e) {
return r;
}
result simplifier::rewrite(expr const & e, simp_rule_sets const & srss) {
result simplifier::rewrite(expr const & e, simp_lemmas const & srss) {
result r(e);
simp_rule_set const * sr = srss.find(m_rel);
simp_lemmas_for const * sr = srss.find(m_rel);
if (!sr) return r;
list<simp_rule> const * srs = sr->find_simp(e);
list<simp_lemma> const * srs = sr->find_simp(e);
if (!srs) return r;
for_each(*srs, [&](simp_rule const & sr) {
for_each(*srs, [&](simp_lemma const & sr) {
result r_new = rewrite(r.get_new(), sr);
if (!r_new.has_proof()) return;
r = join(r, r_new);
@ -594,7 +591,7 @@ result simplifier::rewrite(expr const & e, simp_rule_sets const & srss) {
return r;
}
result simplifier::rewrite(expr const & e, simp_rule const & sr) {
result simplifier::rewrite(expr const & e, simp_lemma const & sr) {
blast_tmp_type_context tmp_tctx(sr.get_num_umeta(), sr.get_num_emeta());
if (!tmp_tctx->is_def_eq(e, sr.get_lhs())) return result(e);
@ -672,21 +669,21 @@ result simplifier::congr_funs(result const & r_f, buffer<expr> const & args) {
}
result simplifier::try_congrs(expr const & e) {
simp_rule_set const * sr = get_simp_rule_sets(env()).find(m_rel);
simp_lemmas_for const * sr = get_simp_lemmas().find(m_rel);
if (!sr) return result(e);
list<congr_rule> const * crs = sr->find_congr(e);
list<user_congr_lemma> const * crs = sr->find_congr(e);
if (!crs) return result(e);
result r(e);
for_each(*crs, [&](congr_rule const & cr) {
for_each(*crs, [&](user_congr_lemma const & cr) {
if (r.has_proof()) return;
r = try_congr(e, cr);
});
return r;
}
result simplifier::try_congr(expr const & e, congr_rule const & cr) {
result simplifier::try_congr(expr const & e, user_congr_lemma const & cr) {
blast_tmp_type_context tmp_tctx(cr.get_num_umeta(), cr.get_num_emeta());
if (!tmp_tctx->is_def_eq(e, cr.get_lhs())) return result(e);
@ -720,7 +717,7 @@ result simplifier::try_congr(expr const & e, congr_rule const & cr) {
{
flet<name> set_name(m_rel, const_name(h_rel));
flet<simp_rule_sets> set_ctx_srss(m_ctx_srss, m_contextual ? add_to_srss(m_ctx_srss, ls) : m_ctx_srss);
flet<simp_lemmas> set_ctx_srss(m_ctx_srss, m_contextual ? add_to_srss(m_ctx_srss, ls) : m_ctx_srss);
h_lhs = tmp_tctx->instantiate_uvars_mvars(h_lhs);
lean_assert(!has_metavar(h_lhs));
@ -976,9 +973,8 @@ result simplifier::fuse(expr const & e) {
/* Prove (1) == (3) using simplify with [ac] */
flet<bool> no_simplify_numerals(m_numerals, false);
auto pf_1_3 = prove(get_app_builder().mk_eq(e, e_grp),
get_simp_rule_sets(env(), ios().get_options(),
{*g_simplify_prove_namespace, *g_simplify_unit_namespace,
*g_simplify_neg_namespace, *g_simplify_ac_namespace}));
get_simp_lemmas(g_ac_key));
if (!pf_1_3) {
diagnostic(env(), ios()) << ppb(e) << "\n\n =?=\n\n" << ppb(e_grp) << "\n";
throw blast_exception("Failed to prove (1) == (3) during fusion", e);
@ -986,10 +982,8 @@ result simplifier::fuse(expr const & e) {
/* Prove (4) == (5) using simplify with [som] */
auto pf_4_5 = prove(get_app_builder().mk_eq(e_grp_ls, e_fused_ls),
get_simp_rule_sets(env(), ios().get_options(),
{*g_simplify_prove_namespace, *g_simplify_unit_namespace,
*g_simplify_neg_namespace, *g_simplify_ac_namespace,
*g_simplify_distrib_namespace}));
get_simp_lemmas(g_som_key));
if (!pf_4_5) {
diagnostic(env(), ios()) << ppb(e_grp_ls) << "\n\n =?=\n\n" << ppb(e_fused_ls) << "\n";
throw blast_exception("Failed to prove (4) == (5) during fusion", e);
@ -997,9 +991,7 @@ result simplifier::fuse(expr const & e) {
/* Prove (5) == (6) using simplify with [numeral] */
flet<bool> simplify_numerals(m_numerals, true);
result r_simp_ls = simplify(e_fused_ls, get_simp_rule_sets(env(), ios().get_options(),
{*g_simplify_unit_namespace, *g_simplify_neg_namespace,
*g_simplify_ac_namespace}));
result r_simp_ls = simplify(e_fused_ls, get_simp_lemmas(g_numeral_key));
/* Prove (4) == (6) by transitivity of proofs (2) and (3) */
expr pf_4_6;
@ -1057,12 +1049,11 @@ void initialize_simplifier() {
register_trace_class(name({"simplifier", "congruence"}));
register_trace_class(name({"simplifier", "failure"}));
g_simplify_prove_namespace = new name{"simplifier", "prove"};
g_simplify_neg_namespace = new name{"simplifier", "neg"};
g_simplify_unit_namespace = new name{"simplifier", "unit"};
g_simplify_ac_namespace = new name{"simplifier", "ac"};
g_simplify_distrib_namespace = new name{"simplifier", "distrib"};
g_simplify_numeral_namespace = new name{"simplifier", "numeral"};
g_ac_key = register_simp_lemmas({name{"simplifier", "prove"}, name{"simplifier", "unit"},
name{"simplifier", "neg"}, name{"simplifier", "ac"}});
g_som_key = register_simp_lemmas({name{"simplifier", "prove"}, name{"simplifier", "unit"},
name{"simplifier", "neg"}, name{"simplifier", "ac"}, name{"simplifier", "distrib"}});
g_numeral_key = register_simp_lemmas({name{"simplifier", "unit"}, name{"simplifier", "neg"}, name{"simplifier", "ac"}});
g_simplify_max_steps = new name{"simplify", "max_steps"};
g_simplify_top_down = new name{"simplify", "top_down"};
@ -1100,23 +1091,16 @@ void finalize_simplifier() {
delete g_simplify_exhaustive;
delete g_simplify_top_down;
delete g_simplify_max_steps;
delete g_simplify_numeral_namespace;
delete g_simplify_distrib_namespace;
delete g_simplify_ac_namespace;
delete g_simplify_unit_namespace;
delete g_simplify_neg_namespace;
delete g_simplify_prove_namespace;
}
/* Entry points */
static bool simplify_all_pred(expr const &) { return true; }
result simplify(name const & rel, expr const & e, simp_rule_sets const & srss) {
result simplify(name const & rel, expr const & e, simp_lemmas const & srss) {
return simplifier(rel, simplify_all_pred)(e, srss);
}
result simplify(name const & rel, expr const & e, simp_rule_sets const & srss, expr_predicate const & simp_pred) {
result simplify(name const & rel, expr const & e, simp_lemmas const & srss, expr_predicate const & simp_pred) {
return simplifier(rel, simp_pred)(e, srss);
}

View file

@ -6,7 +6,7 @@ Author: Daniel Selsam
#pragma once
#include "kernel/expr_pair.h"
#include "library/blast/state.h"
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/simplifier/simp_lemmas.h"
namespace lean {
namespace blast {
@ -40,8 +40,8 @@ public:
// TODO(dhs): put this outside of blast module
typedef std::function<bool(expr const &)> expr_predicate; // NOLINT
simp::result simplify(name const & rel, expr const & e, simp_rule_sets const & srss);
simp::result simplify(name const & rel, expr const & e, simp_rule_sets const & srss, expr_predicate const & simp_pred);
simp::result simplify(name const & rel, expr const & e, simp_lemmas const & srss);
simp::result simplify(name const & rel, expr const & e, simp_lemmas const & srss, expr_predicate const & simp_pred);
void initialize_simplifier();
void finalize_simplifier();

View file

@ -16,18 +16,18 @@ namespace lean {
namespace blast {
static unsigned g_ext_id = 0;
struct simplifier_branch_extension : public branch_extension {
simp_rule_sets m_srss;
bool m_simp_target{false}; // true if target needs to be simplified again
simp_lemmas m_simp_lemmas;
bool m_simp_target{false}; // true if target needs to be simplified again
simplifier_branch_extension() {}
simplifier_branch_extension(simplifier_branch_extension const & b):
m_srss(b.m_srss) {}
m_simp_lemmas(b.m_simp_lemmas) {}
virtual ~simplifier_branch_extension() {}
virtual branch_extension * clone() override { return new simplifier_branch_extension(*this); }
virtual void initialized() override { m_srss = ::lean::get_simp_rule_sets(env()); }
virtual void initialized() override { m_simp_lemmas = ::lean::blast::get_simp_lemmas(); }
virtual void target_updated() override { m_simp_target = true; }
virtual void hypothesis_activated(hypothesis const &, hypothesis_idx) override { }
virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) override { }
simp_rule_sets const & get_simp_rule_sets() const { return m_srss; }
simp_lemmas const & get_simp_lemmas() const { return m_simp_lemmas; }
};
void initialize_simplifier_actions() {
@ -78,7 +78,7 @@ action_result simplify_target_action() {
expr target = s.get_target();
bool iff = use_iff(target);
name rname = iff ? get_iff_name() : get_eq_name();
auto r = simplify(rname, target, ext.get_simp_rule_sets());
auto r = simplify(rname, target, ext.get_simp_lemmas());
if (r.get_new() == target)
return action_result::failed(); // did nothing
if (r.has_proof()) {
@ -108,7 +108,7 @@ action_result simplify_hypothesis_action(hypothesis_idx hidx) {
return action_result::failed();
}
auto & ext = get_extension();
auto r = simplify(get_iff_name(), h.get_type(), ext.get_simp_rule_sets());
auto r = simplify(get_iff_name(), h.get_type(), ext.get_simp_lemmas());
if (r.get_new() == h.get_type())
return action_result::failed(); // did nothing
expr new_h_proof;
@ -123,7 +123,7 @@ action_result simplify_hypothesis_action(hypothesis_idx hidx) {
return action_result::new_branch();
}
action_result add_simp_rule_action(hypothesis_idx hidx) {
action_result add_simp_lemma_action(hypothesis_idx hidx) {
if (!get_config().m_simp)
return action_result::failed();
blast_tmp_type_context ctx;
@ -136,8 +136,8 @@ action_result add_simp_rule_action(hypothesis_idx hidx) {
bool added = false;
for (auto const & p : ps) {
try {
ext.m_srss = add(*ctx, ext.m_srss, h.get_name(), p.first, p.second, LEAN_DEFAULT_PRIORITY);
added = true;
ext.m_simp_lemmas = add(*ctx, ext.m_simp_lemmas, h.get_name(), p.first, p.second, LEAN_DEFAULT_PRIORITY);
added = true;
} catch (exception &) {
// TODO(Leo, Daniel): store event
}

View file

@ -11,7 +11,7 @@ namespace lean {
namespace blast {
action_result simplify_hypothesis_action(hypothesis_idx hidx);
action_result simplify_target_action();
action_result add_simp_rule_action(hypothesis_idx hidx);
action_result add_simp_lemma_action(hypothesis_idx hidx);
void initialize_simplifier_actions();
void finalize_simplifier_actions();

View file

@ -35,7 +35,7 @@ class simplifier_strategy_fn : public strategy_fn {
virtual action_result hypothesis_post_activation(hypothesis_idx hidx) override {
if (m_use_hyps)
add_simp_rule_action(hidx);
add_simp_lemma_action(hidx);
return action_result::new_branch();
}

View file

@ -15,16 +15,15 @@ Author: Daniel Selsam
#include "library/blast/choice_point.h"
#include "library/blast/hypothesis.h"
#include "library/blast/util.h"
#include "library/blast/simplifier/simp_rule_set.h"
#include "library/blast/simplifier/simp_lemmas.h"
#include "library/blast/simplifier/simplifier.h"
#include "library/expr_lt.h"
#include "util/rb_multi_map.h"
namespace lean {
namespace blast {
static name * g_simplify_unit_simp_namespace = nullptr;
static name * g_simplify_contextual = nullptr;
static unsigned g_unit_simp_key;
static name * g_simplify_contextual = nullptr;
static bool is_propositional(expr const & e) {
// TODO(dhs): This predicate will need to evolve, and will eventually be thrown out
@ -50,7 +49,7 @@ action_result unit_preprocess(unsigned hidx) {
return action_result::failed();
}
simp_rule_sets srss = get_simp_rule_sets(env(), ios().get_options(), *g_simplify_unit_simp_namespace);
simp_lemmas srss = get_simp_lemmas(g_unit_simp_key);
// TODO(dhs): disable contextual rewriting
auto r = simplify(get_iff_name(), h.get_type(), srss, is_propositional);
@ -68,13 +67,11 @@ action_result unit_preprocess(unsigned hidx) {
}
void initialize_unit_preprocess() {
g_simplify_unit_simp_namespace = new name{"simplifier", "unit_simp"};
g_unit_simp_key = register_simp_lemmas({name{"simplifier", "unit_simp"}});
g_simplify_contextual = new name{"simplify", "contextual"};
}
void finalize_unit_preprocess() {
delete g_simplify_contextual;
delete g_simplify_unit_simp_namespace;
}
}}

View file

@ -15,7 +15,7 @@ sorry
lemma C₃ [congr] (a b : nat) : R (g a b) (g 0 0) := -- ERROR
sorry
lemma C₄ [congr] (A B : Type) : (A → B) = (λ a : nat, B → A) 0 := -- ERROR
lemma C₄ [congr] (A B : Type) : (A → B) = (λ a : nat, B → A) 0 :=
sorry
lemma C₅ [congr] (A B : Type₁) : (A → nat) = (B → nat) := -- ERROR

View file

@ -1,10 +1,9 @@
congr_error_msg.lean:9:0: error: invalid congruence rule, 'C₁' the left-hand-side of the congruence resulting type must be of the form (g x_1 ... x_n), where each x_i is a distinct variable or a sort
congr_error_msg.lean:12:0: error: invalid congruence rule, 'C₂' resulting type is not of the form (g ...) ~ (g ...), where ~ is 'eq'
congr_error_msg.lean:15:0: error: invalid congruence rule, 'C₃' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation
congr_error_msg.lean:18:0: error: invalid congruence rule, 'C₄' kinds of the left-hand-side and right-hand-side of the congruence resulting type do not match
congr_error_msg.lean:21:0: error: invalid congruence rule, 'C₅' left-hand-side of the congruence resulting type must be of the form (fun/Pi (x : A), B x)
congr_error_msg.lean:24:0: error: invalid congruence rule, 'C₆' left-hand-side is not an application nor a binding
congr_error_msg.lean:27:0: error: invalid congruence rule, 'C₇' argument #2 of parameter #5 contains unresolved parameters
congr_error_msg.lean:30:0: error: invalid congruence rule, 'C₈' argument #5 is not a valid hypothesis, the left-hand-side contains unresolved parameters
congr_error_msg.lean:33:0: error: invalid congruence rule, 'C₉' argument #6 is not a valid hypothesis, the right-hand-side must be of the form (m l_1 ... l_n) where m is parameter that was not 'assigned/resolved' yet and l_i's are locals
congr_error_msg.lean:9:0: error: invalid [congr] lemma, 'C₁' the left-hand-side of the congruence resulting type must be of the form (g x_1 ... x_n), where each x_i is a distinct variable or a sort
congr_error_msg.lean:12:0: error: invalid [congr] lemma, 'C₂' resulting type is not of the form (g ...) ~ (g ...), where ~ is 'eq'
congr_error_msg.lean:15:0: error: invalid [congr] lemma, 'C₃' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation
congr_error_msg.lean:21:0: error: invalid [congr] lemma, 'C₅' left-hand-side of the congruence resulting type must be of the form (fun/Pi (x : A), B x)
congr_error_msg.lean:24:0: error: invalid [congr] lemma, 'C₆' left-hand-side is not an application nor a binding
congr_error_msg.lean:27:0: error: invalid [congr] lemma, 'C₇' argument #2 of parameter #5 contains unresolved parameters
congr_error_msg.lean:30:0: error: invalid [congr] lemma, 'C₈' argument #5 is not a valid hypothesis, the left-hand-side contains unresolved parameters
congr_error_msg.lean:33:0: error: invalid [congr] lemma, 'C₉' argument #6 is not a valid hypothesis, the right-hand-side must be of the form (m l_1 ... l_n) where m is parameter that was not 'assigned/resolved' yet and l_i's are locals
congr_error_msg.lean:33:0: error: unknown declaration 'C₁'

View file

@ -0,0 +1,8 @@
definition f : nat → nat := sorry
definition g (a : nat) := f a
lemma gax [simp] : ∀ a, g a = 0 := sorry
attribute g [reducible]
example (a : nat) : f (a + a) = 0 :=
by simp