diff --git a/hott/init/logic.hlean b/hott/init/logic.hlean index 3f28fff11..bd1d19d3c 100644 --- a/hott/init/logic.hlean +++ b/hott/init/logic.hlean @@ -266,7 +266,7 @@ iff_empty_intro (not_not_intro star) definition not_empty_iff [simp] : (¬ empty) ↔ unit := iff_unit_intro not_empty -definition not_congr [congr] (H : a ↔ b) : ¬a ↔ ¬b := +definition not_congr (H : a ↔ b) : ¬a ↔ ¬b := iff.intro (λ H₁ H₂, H₁ (iff.mpr H H₂)) (λ H₁ H₂, H₁ (iff.mp H H₂)) definition ne_self_iff_empty [simp] {A : Type} (a : A) : (not (a = a)) ↔ empty := diff --git a/library/data/list/perm.lean b/library/data/list/perm.lean index 970917de4..fdc7b9051 100644 --- a/library/data/list/perm.lean +++ b/library/data/list/perm.lean @@ -501,7 +501,7 @@ section foldr variable lcomm : left_commutative f include lcomm - theorem foldr_eq_of_perm [congr] : l₁ ~ l₂ → ∀ b, foldr f b l₁ = foldr f b l₂ := + theorem foldr_eq_of_perm : l₁ ~ l₂ → ∀ b, foldr f b l₁ = foldr f b l₂ := assume p, perm_induction_on p (λ b, by rewrite *foldl_nil) (λ x t₁ t₂ p r b, calc diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index 6a1dc1469..640c2de56 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -536,7 +536,7 @@ static environment accessible_cmd(parser & p) { name const & n = d.get_name(); total++; if ((d.is_theorem() || d.is_definition()) && - !is_instance(env, n) && !is_simp_rule(env, n) && !is_congr_rule(env, n) && + !is_instance(env, n) && !is_simp_lemma(env, n) && !is_congr_lemma(env, n) && !is_user_defined_recursor(env, n) && !is_aux_recursor(env, n) && !is_projection(env, n) && !is_private(env, n) && !is_user_defined_recursor(env, n) && !is_aux_recursor(env, n) && @@ -799,12 +799,12 @@ static environment simplify_cmd(parser & p) { std::tie(e, ls) = parse_local_expr(p); blast::scope_debug scope(p.env(), p.ios()); - simp_rule_sets srss; + blast::simp_lemmas srss; if (ns == name("null")) { } else if (ns == name("env")) { - srss = get_simp_rule_sets(p.env()); + srss = blast::get_simp_lemmas(); } else { - srss = get_simp_rule_sets(p.env(), p.get_options(), ns); + srss = blast::get_simp_lemmas(ns); } blast::simp::result r = blast::simplify(rel, e, srss); diff --git a/src/frontends/lean/print_cmd.cpp b/src/frontends/lean/print_cmd.cpp index 713c97966..5d8c45afd 100644 --- a/src/frontends/lean/print_cmd.cpp +++ b/src/frontends/lean/print_cmd.cpp @@ -508,14 +508,15 @@ static void print_reducible_info(parser & p, reducible_status s1) { static void print_simp_rules(parser & p) { io_state_stream out = p.regular_stream(); - simp_rule_sets s; + blast::scope_debug scope(p.env(), p.ios()); + blast::simp_lemmas s; name ns; if (p.curr_is_identifier()) { ns = p.get_name_val(); p.next(); - s = get_simp_rule_sets(p.env(), p.get_options(), ns); + s = blast::get_simp_lemmas(ns); } else { - s = get_simp_rule_sets(p.env()); + s = blast::get_simp_lemmas(); } format header; if (!ns.is_anonymous()) @@ -525,7 +526,8 @@ static void print_simp_rules(parser & p) { static void print_congr_rules(parser & p) { io_state_stream out = p.regular_stream(); - simp_rule_sets s = get_simp_rule_sets(p.env()); + blast::scope_debug scope(p.env(), p.ios()); + blast::simp_lemmas s = blast::get_simp_lemmas(); out << s.pp_congr(out.get_formatter()); } diff --git a/src/library/blast/blast.cpp b/src/library/blast/blast.cpp index cfbc6208d..71b34a628 100644 --- a/src/library/blast/blast.cpp +++ b/src/library/blast/blast.cpp @@ -32,6 +32,7 @@ Author: Leonardo de Moura #include "library/blast/trace.h" #include "library/blast/options.h" #include "library/blast/strategies/portfolio.h" +#include "library/blast/simplifier/simp_lemmas.h" namespace lean { namespace blast { @@ -1069,6 +1070,7 @@ struct scope_debug::imp { scope_blastenv m_scope2; scope_congruence_closure m_scope3; scope_config m_scope4; + scope_simp m_scope5; imp(environment const & env, io_state const & ios): m_scope1(true), m_benv(env, ios, list(), list()), @@ -1168,6 +1170,7 @@ optional blast_goal(environment const & env, io_state const & ios, list const & es) { } /* Try to convert user-defined congruence rule into an ext_congr_lemma */ -static optional to_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs, congr_rule const & r) { +static optional to_ext_congr_lemma(name const & R, expr const & fn, unsigned nargs, user_congr_lemma const & r) { buffer lhs_args, rhs_args; expr const & lhs_fn = get_app_args(r.get_lhs(), lhs_args); expr const & rhs_fn = get_app_args(r.get_rhs(), rhs_args); @@ -280,11 +280,11 @@ static optional to_ext_congr_lemma(name const & R, expr const & } static optional mk_ext_congr_lemma_core(name const & R, expr const & fn, unsigned nargs) { - simp_rule_set const * sr = get_simp_rule_sets(env()).find(R); + simp_lemmas_for const * sr = get_simp_lemmas().find(R); if (sr) { - list const * crs = sr->find_congr(fn); + list const * crs = sr->find_congr(fn); if (crs) { - for (congr_rule const & r : *crs) { + for (user_congr_lemma const & r : *crs) { if (auto lemma = to_ext_congr_lemma(R, fn, nargs, r)) return lemma; } diff --git a/src/library/blast/simplifier/CMakeLists.txt b/src/library/blast/simplifier/CMakeLists.txt index 0e13294b4..4985882a8 100644 --- a/src/library/blast/simplifier/CMakeLists.txt +++ b/src/library/blast/simplifier/CMakeLists.txt @@ -1,2 +1,2 @@ -add_library(simplifier OBJECT init_module.cpp simp_rule_set.cpp ceqv.cpp simplifier.cpp +add_library(simplifier OBJECT init_module.cpp ceqv.cpp simplifier.cpp simp_lemmas.cpp simplifier_actions.cpp simplifier_strategies.cpp) diff --git a/src/library/blast/simplifier/init_module.cpp b/src/library/blast/simplifier/init_module.cpp index 20d36a44f..4841465f9 100644 --- a/src/library/blast/simplifier/init_module.cpp +++ b/src/library/blast/simplifier/init_module.cpp @@ -4,21 +4,21 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Daniel Selsam */ #include "library/blast/simplifier/simplifier_actions.h" -#include "library/blast/simplifier/simp_rule_set.h" +#include "library/blast/simplifier/simp_lemmas.h" #include "library/blast/simplifier/simplifier.h" namespace lean { namespace blast { void initialize_simplifier_module() { + initialize_simp_lemmas(); initialize_simplifier(); - initialize_simplifier_rule_set(); initialize_simplifier_actions(); } void finalize_simplifier_module() { finalize_simplifier_actions(); - finalize_simplifier_rule_set(); finalize_simplifier(); + finalize_simp_lemmas(); } }} diff --git a/src/library/blast/simplifier/simp_lemmas.cpp b/src/library/blast/simplifier/simp_lemmas.cpp new file mode 100644 index 000000000..8eb05ab80 --- /dev/null +++ b/src/library/blast/simplifier/simp_lemmas.cpp @@ -0,0 +1,709 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include +#include "util/priority_queue.h" +#include "util/sstream.h" +#include "kernel/error_msgs.h" +#include "kernel/find_fn.h" +#include "kernel/instantiate.h" +#include "library/trace.h" +#include "library/scoped_ext.h" +#include "library/attribute_manager.h" +#include "library/blast/blast.h" +#include "library/blast/simplifier/ceqv.h" +#include "library/blast/simplifier/simp_lemmas.h" + +namespace lean { +static name * g_class_name = nullptr; +static std::string * g_key = nullptr; + +struct simp_state { + priority_queue m_simp_lemmas; + priority_queue m_congr_lemmas; +}; + +typedef std::tuple simp_entry; + +struct simp_config { + typedef simp_entry entry; + typedef simp_state state; + + static void add_entry(environment const &, io_state const &, state & s, entry const & e) { + bool is_simp; unsigned prio; name n; + std::tie(is_simp, prio, n) = e; + if (is_simp) { + s.m_simp_lemmas.insert(n, prio); + } else { + s.m_congr_lemmas.insert(n, prio); + } + } + static name const & get_class_name() { + return *g_class_name; + } + static std::string const & get_serialization_key() { + return *g_key; + } + static void write_entry(serializer & s, entry const & e) { + bool is_simp; unsigned prio; name n; + std::tie(is_simp, prio, n) = e; + s << is_simp << prio << n; + } + static entry read_entry(deserializer & d) { + bool is_simp; unsigned prio; name n; + d >> is_simp >> prio >> n; + return entry(is_simp, prio, n); + } + static optional get_fingerprint(entry const & e) { + bool is_simp; unsigned prio; name n; + std::tie(is_simp, prio, n) = e; + return some(hash(hash(n.hash(), prio), is_simp ? 17u : 31u)); + } +}; + +typedef scoped_ext simp_ext; + +void validate_simp(environment const & env, io_state const & ios, name const & n); +void validate_congr(environment const & env, io_state const & ios, name const & n); + +environment add_simp_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) { + validate_simp(env, ios, c); + return simp_ext::add_entry(env, ios, simp_entry(true, prio, c), ns, persistent); +} + +environment add_congr_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent) { + validate_congr(env, ios, c); + return simp_ext::add_entry(env, ios, simp_entry(false, prio, c), ns, persistent); +} + +bool is_simp_lemma(environment const & env, name const & c) { + return simp_ext::get_state(env).m_simp_lemmas.contains(c); +} + +bool is_congr_lemma(environment const & env, name const & c) { + return simp_ext::get_state(env).m_congr_lemmas.contains(c); +} + +void get_simp_lemmas(environment const & env, buffer & r) { + return simp_ext::get_state(env).m_simp_lemmas.to_buffer(r); +} + +void get_congr_lemmas(environment const & env, buffer & r) { + return simp_ext::get_state(env).m_congr_lemmas.to_buffer(r); +} + +static std::vector> * g_simp_lemma_ns = nullptr; + +unsigned register_simp_lemmas(std::initializer_list const & nss) { + unsigned r = g_simp_lemma_ns->size(); + g_simp_lemma_ns->push_back(std::vector(nss)); + return r; +} + +namespace blast { +LEAN_THREAD_VALUE(bool, g_throw_ex, false); + +static void report_failure(sstream const & strm) { + if (g_throw_ex){ + throw exception(strm); + } else { + lean_trace(name({"simplifier", "failure"}), + tout() << strm.str() << "\n";); + } +} + +simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s, + name const & id, levels const & univ_metas, expr const & e, expr const & h, + unsigned priority) { + list ceqvs = to_ceqvs(tctx, e, h); + if (is_nil(ceqvs)) { + report_failure(sstream() << "invalid [simp] lemma '" << id << "'"); + } + environment const & env = tctx.env(); + simp_lemmas new_s = s; + for (expr_pair const & p : ceqvs) { + expr rule = normalizer(tctx)(p.first); + expr proof = tctx.whnf(p.second); + bool is_perm = is_permutation_ceqv(env, rule); + buffer emetas; + buffer instances; + while (is_pi(rule)) { + expr mvar = tctx.mk_mvar(binding_domain(rule)); + emetas.push_back(mvar); + instances.push_back(binding_info(rule).is_inst_implicit()); + rule = tctx.whnf(instantiate(binding_body(rule), mvar)); + proof = mk_app(proof, mvar); + } + expr rel, lhs, rhs; + if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) { + new_s.insert(const_name(rel), simp_lemma(id, univ_metas, reverse_to_list(emetas), + reverse_to_list(instances), lhs, rhs, proof, is_perm, priority)); + } + } + return new_s; +} + +simp_lemmas add(tmp_type_context & tctx, simp_lemmas const & s, name const & id, expr const & e, expr const & h, unsigned priority) { + return add_core(tctx, s, id, list(), e, h, priority); +} + +simp_lemmas join(simp_lemmas const & s1, simp_lemmas const & s2) { + simp_lemmas new_s1 = s1; + s2.for_each_simp([&](name const & eqv, simp_lemma const & r) { + new_s1.insert(eqv, r); + }); + return new_s1; +} + +static simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s, name const & cname, unsigned priority) { + declaration const & d = tctx.env().get(cname); + buffer us; + unsigned num_univs = d.get_num_univ_params(); + for (unsigned i = 0; i < num_univs; i++) { + us.push_back(tctx.mk_uvar()); + } + levels ls = to_list(us); + expr e = tctx.whnf(instantiate_type_univ_params(d, ls)); + expr h = mk_constant(cname, ls); + return add_core(tctx, s, cname, ls, e, h, priority); +} + +// Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x), +// where B is lambda or Pi +static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) { + lean_assert(is_binding(lhs)); + expr const & d = binding_domain(lhs); + expr const & b = binding_body(lhs); + if (!is_metavar(d)) + return false; + if (is_metavar(b) && b != d) { + found_mvars.insert(mlocal_name(b)); + found_mvars.insert(mlocal_name(d)); + return true; + } + if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) { + found_mvars.insert(mlocal_name(app_fn(b))); + found_mvars.insert(mlocal_name(d)); + return true; + } + return false; +} + +// Return true iff all metavariables in e are in found_mvars +static bool only_found_mvars(expr const & e, name_set const & found_mvars) { + return !find(e, [&](expr const & m, unsigned) { + return is_metavar(m) && !found_mvars.contains(mlocal_name(m)); + }); +} + +// Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable, +// and l_i's are local constants, and mvar does not occur in found_mvars. +// If it is return true and update found_mvars +static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) { + buffer rhs_args; + expr const & rhs_fn = get_app_args(rhs, rhs_args); + if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn))) + return false; + for (expr const & arg : rhs_args) + if (!is_local(arg)) + return false; + found_mvars.insert(mlocal_name(rhs_fn)); + return true; +} + +simp_lemmas add_congr_core(tmp_type_context & tctx, simp_lemmas const & s, name const & n, unsigned prio) { + declaration const & d = tctx.env().get(n); + buffer us; + unsigned num_univs = d.get_num_univ_params(); + for (unsigned i = 0; i < num_univs; i++) { + us.push_back(tctx.mk_uvar()); + } + levels ls = to_list(us); + expr rule = normalizer(tctx)(instantiate_type_univ_params(d, ls)); + expr proof = mk_constant(n, ls); + + buffer emetas; + buffer instances, explicits; + + while (is_pi(rule)) { + expr mvar = tctx.mk_mvar(binding_domain(rule)); + emetas.push_back(mvar); + explicits.push_back(is_explicit(binding_info(rule))); + instances.push_back(binding_info(rule).is_inst_implicit()); + rule = tctx.whnf(instantiate(binding_body(rule), mvar)); + proof = mk_app(proof, mvar); + } + expr rel, lhs, rhs; + if (!is_simp_relation(tctx.env(), rule, rel, lhs, rhs) || !is_constant(rel)) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation"); + } + name_set found_mvars; + buffer lhs_args, rhs_args; + expr const & lhs_fn = get_app_args(lhs, lhs_args); + expr const & rhs_fn = get_app_args(rhs, rhs_args); + if (is_constant(lhs_fn)) { + if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) " + << "~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'"); + } + for (expr const & lhs_arg : lhs_args) { + if (is_sort(lhs_arg)) + continue; + if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' the left-hand-side of the congruence resulting type must be of the form (" + << const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort"); + } + found_mvars.insert(mlocal_name(lhs_arg)); + } + } else if (is_binding(lhs)) { + if (lhs.kind() != rhs.kind()) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' kinds of the left-hand-side and right-hand-side of " + << "the congruence resulting type do not match"); + } + if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' left-hand-side of the congruence resulting type must " + << "be of the form (fun/Pi (x : A), B x)"); + } + } else { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' left-hand-side is not an application nor a binding"); + } + + buffer congr_hyps; + lean_assert(emetas.size() == explicits.size()); + for (unsigned i = 0; i < emetas.size(); i++) { + expr const & mvar = emetas[i]; + if (explicits[i] && !found_mvars.contains(mlocal_name(mvar))) { + buffer locals; + expr type = mlocal_type(mvar); + while (is_pi(type)) { + expr local = tctx.mk_tmp_local(binding_domain(type)); + locals.push_back(local); + type = instantiate(binding_body(type), local); + } + expr h_rel, h_lhs, h_rhs; + if (!is_simp_relation(tctx.env(), type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel)) + continue; + unsigned j = 0; + for (expr const & local : locals) { + j++; + if (!only_found_mvars(mlocal_type(local), found_mvars)) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' argument #" << j << " of parameter #" << (i+1) << " contains " + << "unresolved parameters"); + } + } + if (!only_found_mvars(h_lhs, found_mvars)) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' argument #" << (i+1) << " is not a valid hypothesis, the left-hand-side contains " + << "unresolved parameters"); + } + if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) { + report_failure(sstream() << "invalid [congr] lemma, '" << n + << "' argument #" << (i+1) << " is not a valid hypothesis, the right-hand-side must be " + << "of the form (m l_1 ... l_n) where m is parameter that was not " + << "'assigned/resolved' yet and l_i's are locals"); + } + found_mvars.insert(mlocal_name(mvar)); + congr_hyps.push_back(mvar); + } + } + simp_lemmas new_s = s; + new_s.insert(const_name(rel), user_congr_lemma(n, ls, reverse_to_list(emetas), + reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio)); + return new_s; +} + +simp_lemma_core::simp_lemma_core(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + unsigned priority): + m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances), + m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {} + +simp_lemma::simp_lemma(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + bool is_perm, unsigned priority): + simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), + m_is_permutation(is_perm) {} + +bool operator==(simp_lemma const & r1, simp_lemma const & r2) { + return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; +} + +format simp_lemma::pp(formatter const & fmt) const { + format r; + r += format("#") + format(get_num_emeta()); + if (m_priority != LEAN_DEFAULT_PRIORITY) + r += space() + paren(format(m_priority)); + if (m_is_permutation) + r += space() + format("perm"); + format r1 = comma() + space() + fmt(get_lhs()); + r1 += space() + format("↦") + pp_indent_expr(fmt, get_rhs()); + r += group(r1); + return r; +} + +user_congr_lemma::user_congr_lemma(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + list const & congr_hyps, unsigned priority): + simp_lemma_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), + m_congr_hyps(congr_hyps) {} + +bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2) { + return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps; +} + +format user_congr_lemma::pp(formatter const & fmt) const { + format r; + r += format("#") + format(get_num_emeta()); + if (m_priority != LEAN_DEFAULT_PRIORITY) + r += space() + paren(format(m_priority)); + format r1; + for (expr const & h : m_congr_hyps) { + r1 += space() + paren(fmt(mlocal_type(h))); + } + r += group(r1); + r += space() + format(":") + space(); + format r2 = paren(fmt(get_lhs())); + r2 += space() + format("↦") + space() + paren(fmt(get_rhs())); + r += group(r2); + return r; +} + +simp_lemmas_for::simp_lemmas_for(name const & eqv): + m_eqv(eqv) {} + +void simp_lemmas_for::insert(simp_lemma const & r) { + m_simp_set.insert(r.get_lhs(), r); +} + +void simp_lemmas_for::erase(simp_lemma const & r) { + m_simp_set.erase(r.get_lhs(), r); +} + +void simp_lemmas_for::insert(user_congr_lemma const & r) { + m_congr_set.insert(r.get_lhs(), r); +} + +void simp_lemmas_for::erase(user_congr_lemma const & r) { + m_congr_set.erase(r.get_lhs(), r); +} + +list const * simp_lemmas_for::find_simp(head_index const & h) const { + return m_simp_set.find(h); +} + +void simp_lemmas_for::for_each_simp(std::function const & fn) const { + m_simp_set.for_each_entry([&](head_index const &, simp_lemma const & r) { fn(r); }); +} + +list const * simp_lemmas_for::find_congr(head_index const & h) const { + return m_congr_set.find(h); +} + +void simp_lemmas_for::for_each_congr(std::function const & fn) const { + m_congr_set.for_each_entry([&](head_index const &, user_congr_lemma const & r) { fn(r); }); +} + +void simp_lemmas_for::erase_simp(name_set const & ids) { + // This method is not very smart and doesn't use any indexing or caching. + // So, it may be a bottleneck in the future + buffer to_delete; + for_each_simp([&](simp_lemma const & r) { + if (ids.contains(r.get_id())) { + to_delete.push_back(r); + } + }); + for (simp_lemma const & r : to_delete) { + erase(r); + } +} + +void simp_lemmas_for::erase_simp(buffer const & ids) { + erase_simp(to_name_set(ids)); +} + +template +void simp_lemmas::insert_core(name const & eqv, R const & r) { + simp_lemmas_for s(eqv); + if (auto const * curr = m_sets.find(eqv)) { + s = *curr; + } + s.insert(r); + m_sets.insert(eqv, s); +} + +template +void simp_lemmas::erase_core(name const & eqv, R const & r) { + if (auto const * curr = m_sets.find(eqv)) { + simp_lemmas_for s = *curr; + s.erase(r); + if (s.empty()) + m_sets.erase(eqv); + else + m_sets.insert(eqv, s); + } +} + +void simp_lemmas::insert(name const & eqv, simp_lemma const & r) { + return insert_core(eqv, r); +} + +void simp_lemmas::erase(name const & eqv, simp_lemma const & r) { + return erase_core(eqv, r); +} + +void simp_lemmas::insert(name const & eqv, user_congr_lemma const & r) { + return insert_core(eqv, r); +} + +void simp_lemmas::erase(name const & eqv, user_congr_lemma const & r) { + return erase_core(eqv, r); +} + +void simp_lemmas::get_relations(buffer & rs) const { + m_sets.for_each([&](name const & r, simp_lemmas_for const &) { + rs.push_back(r); + }); +} + +void simp_lemmas::erase_simp(name_set const & ids) { + name_map new_sets; + m_sets.for_each([&](name const & n, simp_lemmas_for const & s) { + simp_lemmas_for new_s = s; + new_s.erase_simp(ids); + new_sets.insert(n, new_s); + }); + m_sets = new_sets; +} + +void simp_lemmas::erase_simp(buffer const & ids) { + erase_simp(to_name_set(ids)); +} + +simp_lemmas_for const * simp_lemmas::find(name const & eqv) const { + return m_sets.find(eqv); +} + +list const * simp_lemmas::find_simp(name const & eqv, head_index const & h) const { + if (auto const * s = m_sets.find(eqv)) + return s->find_simp(h); + return nullptr; +} + +list const * simp_lemmas::find_congr(name const & eqv, head_index const & h) const { + if (auto const * s = m_sets.find(eqv)) + return s->find_congr(h); + return nullptr; +} + +void simp_lemmas::for_each_simp(std::function const & fn) const { + m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) { + s.for_each_simp([&](simp_lemma const & r) { + fn(eqv, r); + }); + }); +} + +void simp_lemmas::for_each_congr(std::function const & fn) const { + m_sets.for_each([&](name const & eqv, simp_lemmas_for const & s) { + s.for_each_congr([&](user_congr_lemma const & r) { + fn(eqv, r); + }); + }); +} + +format simp_lemmas::pp(formatter const & fmt, format const & header, bool simp, bool congr) const { + format r; + if (simp) { + name prev_eqv; + for_each_simp([&](name const & eqv, simp_lemma const & rw) { + if (prev_eqv != eqv) { + r += format("simplification rules for ") + format(eqv); + r += header; + r += line(); + prev_eqv = eqv; + } + r += rw.pp(fmt) + line(); + }); + } + + if (congr) { + name prev_eqv; + for_each_congr([&](name const & eqv, user_congr_lemma const & cr) { + if (prev_eqv != eqv) { + r += format("congruencec rules for ") + format(eqv) + line(); + prev_eqv = eqv; + } + r += cr.pp(fmt) + line(); + }); + } + return r; +} + +format simp_lemmas::pp_simp(formatter const & fmt, format const & header) const { + return pp(fmt, header, true, false); +} + +format simp_lemmas::pp_simp(formatter const & fmt) const { + return pp(fmt, format(), true, false); +} + +format simp_lemmas::pp_congr(formatter const & fmt) const { + return pp(fmt, format(), false, true); +} + +format simp_lemmas::pp(formatter const & fmt) const { + return pp(fmt, format(), true, true); +} + +struct simp_lemmas_cache { + simp_lemmas m_main_cache; + std::vector> m_key_cache; +}; + +LEAN_THREAD_PTR(simp_lemmas_cache, g_simp_lemmas_cache); + +scope_simp::scope_simp() { + m_old_cache = g_simp_lemmas_cache; + g_simp_lemmas_cache = nullptr; +} + +scope_simp::~scope_simp() { + delete g_simp_lemmas_cache; + g_simp_lemmas_cache = m_old_cache; +} + +simp_lemmas get_simp_lemmas_core() { + simp_lemmas r; + buffer simp_lemmas, congr_lemmas; + blast_tmp_type_context ctx; + auto const & s = simp_ext::get_state(env()); + s.m_simp_lemmas.to_buffer(simp_lemmas); + s.m_congr_lemmas.to_buffer(congr_lemmas); + unsigned i = simp_lemmas.size(); + while (i > 0) { + --i; + ctx->clear(); + r = add_core(*ctx, r, simp_lemmas[i], *s.m_simp_lemmas.get_prio(simp_lemmas[i])); + } + i = congr_lemmas.size(); + while (i > 0) { + --i; + ctx->clear(); + r = add_congr_core(*ctx, r, congr_lemmas[i], *s.m_congr_lemmas.get_prio(congr_lemmas[i])); + } + return r; +} + +static simp_lemmas_cache & get_cache() { + if (!g_simp_lemmas_cache) { + g_simp_lemmas_cache = new simp_lemmas_cache(); + g_simp_lemmas_cache->m_main_cache = get_simp_lemmas_core(); + } + return *g_simp_lemmas_cache; +} + +simp_lemmas get_simp_lemmas() { + return get_cache().m_main_cache; +} + +template +simp_lemmas get_simp_lemmas_core(NSS const & nss) { + simp_lemmas r; + blast_tmp_type_context ctx; + for (name const & ns : nss) { + list const * entries = simp_ext::get_entries(env(), ns); + if (entries) { + for (auto const & e : *entries) { + bool is_simp; unsigned prio; name n; + std::tie(is_simp, prio, n) = e; + if (is_simp) { + ctx->clear(); + r = add_core(*ctx, r, n, prio); + } + } + } + } + return r; +} + +simp_lemmas get_simp_lemmas(std::initializer_list const & nss) { + return get_simp_lemmas_core(nss); +} + +simp_lemmas get_simp_lemmas(name const & ns) { + return get_simp_lemmas({ns}); +} + +simp_lemmas get_simp_lemmas(unsigned key) { + simp_lemmas_cache & cache = get_cache(); + if (key >= g_simp_lemma_ns->size()) + throw exception("invalid simp lemma cache key"); + if (key >= cache.m_key_cache.size()) + cache.m_key_cache.resize(key+1); + if (!cache.m_key_cache[key]) + cache.m_key_cache[key] = get_simp_lemmas_core((*g_simp_lemma_ns)[key]); + return *cache.m_key_cache[key]; +} +} + +void validate_simp(environment const & env, io_state const & ios, name const & n) { + blast::simp_lemmas s; + tmp_type_context ctx(env, ios.get_options()); + flet set_ex(blast::g_throw_ex, true); + blast::add_core(ctx, s, n, LEAN_DEFAULT_PRIORITY); +} + +void validate_congr(environment const & env, io_state const & ios, name const & n) { + blast::simp_lemmas s; + tmp_type_context ctx(env, ios.get_options()); + flet set_ex(blast::g_throw_ex, true); + blast::add_congr_core(ctx, s, n, LEAN_DEFAULT_PRIORITY); +} + +void initialize_simp_lemmas() { + g_class_name = new name("simp"); + g_key = new std::string("SIMP"); + g_simp_lemma_ns = new std::vector>(); + simp_ext::initialize(); + + register_prio_attribute("simp", "simplification lemma", + add_simp_lemma, + is_simp_lemma, + [](environment const & env, name const & d) { + if (auto p = simp_ext::get_state(env).m_simp_lemmas.get_prio(d)) + return *p; + else + return LEAN_DEFAULT_PRIORITY; + }); + + register_prio_attribute("congr", "congruence lemma", + add_congr_lemma, + is_congr_lemma, + [](environment const & env, name const & d) { + if (auto p = simp_ext::get_state(env).m_congr_lemmas.get_prio(d)) + return *p; + else + return LEAN_DEFAULT_PRIORITY; + }); + + blast::g_simp_lemmas_cache = nullptr; +} + +void finalize_simp_lemmas() { + simp_ext::finalize(); + delete g_key; + delete g_class_name; + delete g_simp_lemma_ns; +} +} diff --git a/src/library/blast/simplifier/simp_lemmas.h b/src/library/blast/simplifier/simp_lemmas.h new file mode 100644 index 000000000..3dc238131 --- /dev/null +++ b/src/library/blast/simplifier/simp_lemmas.h @@ -0,0 +1,171 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/environment.h" +#include "library/io_state.h" +#include "library/tmp_type_context.h" +#include "library/head_map.h" +#include "library/blast/gexpr.h" + +namespace lean { +environment add_simp_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent); +environment add_congr_lemma(environment const & env, io_state const & ios, name const & c, unsigned prio, name const & ns, bool persistent); +bool is_simp_lemma(environment const & env, name const & n); +bool is_congr_lemma(environment const & env, name const & n); +void get_simp_lemmas(environment const & env, buffer & r); +void get_congr_lemmas(environment const & env, buffer & r); +void initialize_simp_lemmas(); +void finalize_simp_lemmas(); + +/** Generate a unique id for a set of namespaces containing [simp] and [congr] lemmas */ +unsigned register_simp_lemmas(std::initializer_list const & nss); + +namespace blast { +class simp_lemmas; +class simp_lemma_core { +protected: + name m_id; + levels m_umetas; + list m_emetas; + list m_instances; + + expr m_lhs; + expr m_rhs; + expr m_proof; + unsigned m_priority; + simp_lemma_core(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + unsigned priority); +public: + name const & get_id() const { return m_id; } + unsigned get_num_umeta() const { return length(m_umetas); } + unsigned get_num_emeta() const { return length(m_emetas); } + + /** \brief Return a list containing the expression metavariables in reverse order. */ + list const & get_emetas() const { return m_emetas; } + + /** \brief Return a list of bools indicating whether or not each expression metavariable + in get_emetas() is an instance. */ + list const & get_instances() const { return m_instances; } + + unsigned get_priority() const { return m_priority; } + + expr const & get_lhs() const { return m_lhs; } + expr const & get_rhs() const { return m_rhs; } + expr const & get_proof() const { return m_proof; } +}; + +class simp_lemma : public simp_lemma_core { + bool m_is_permutation; + simp_lemma(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + bool is_perm, unsigned priority); + + friend simp_lemmas add_core(tmp_type_context & tctx, simp_lemmas const & s, name const & id, + levels const & univ_metas, expr const & e, expr const & h, unsigned priority); +public: + friend bool operator==(simp_lemma const & r1, simp_lemma const & r2); + bool is_perm() const { return m_is_permutation; } + format pp(formatter const & fmt) const; +}; + +bool operator==(simp_lemma const & r1, simp_lemma const & r2); +inline bool operator!=(simp_lemma const & r1, simp_lemma const & r2) { return !operator==(r1, r2); } + +// We use user_congr_lemma to avoid a confusion with ::lemma::congr_lemma +class user_congr_lemma : public simp_lemma_core { + list m_congr_hyps; + user_congr_lemma(name const & id, levels const & umetas, list const & emetas, + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + list const & congr_hyps, unsigned priority); + friend simp_lemmas add_congr_core(tmp_type_context & tctx, simp_lemmas const & s, name const & n, unsigned priority); +public: + friend bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2); + list const & get_congr_hyps() const { return m_congr_hyps; } + format pp(formatter const & fmt) const; +}; + +struct simp_lemma_core_prio_fn { unsigned operator()(simp_lemma_core const & s) const { return s.get_priority(); } }; + +bool operator==(user_congr_lemma const & r1, user_congr_lemma const & r2); +inline bool operator!=(user_congr_lemma const & r1, user_congr_lemma const & r2) { return !operator==(r1, r2); } + +/** \brief Simplification and congruence lemmas for a given equivalence relation */ +class simp_lemmas_for { + typedef head_map_prio simp_set; + typedef head_map_prio congr_set; + name m_eqv; + simp_set m_simp_set; + congr_set m_congr_set; +public: + simp_lemmas_for() {} + /** \brief Return the equivalence relation associated with this set */ + simp_lemmas_for(name const & eqv); + bool empty() const { return m_simp_set.empty() && m_congr_set.empty(); } + name const & get_eqv() const { return m_eqv; } + void insert(simp_lemma const & r); + void erase(simp_lemma const & r); + void insert(user_congr_lemma const & r); + void erase(user_congr_lemma const & r); + void erase_simp(name_set const & ids); + void erase_simp(buffer const & ids); + list const * find_simp(head_index const & h) const; + void for_each_simp(std::function const & fn) const; + list const * find_congr(head_index const & h) const; + void for_each_congr(std::function const & fn) const; +}; + +class simp_lemmas { + name_map m_sets; // mapping from relation name to simp_lemmas_for + template void insert_core(name const & eqv, R const & r); + template void erase_core(name const & eqv, R const & r); +public: + void insert(name const & eqv, simp_lemma const & r); + void erase(name const & eqv, simp_lemma const & r); + void insert(name const & eqv, user_congr_lemma const & r); + void erase(name const & eqv, user_congr_lemma const & r); + void erase_simp(name_set const & ids); + void erase_simp(buffer const & ids); + void get_relations(buffer & rs) const; + simp_lemmas_for const * find(name const & eqv) const; + list const * find_simp(name const & eqv, head_index const & h) const; + list const * find_congr(name const & eqv, head_index const & h) const; + void for_each_simp(std::function const & fn) const; + void for_each_congr(std::function const & fn) const; + format pp(formatter const & fmt, format const & header, bool simp, bool congr) const; + format pp_simp(formatter const & fmt, format const & header) const; + format pp_simp(formatter const & fmt) const; + format pp_congr(formatter const & fmt) const; + format pp(formatter const & fmt) const; +}; + +struct simp_lemmas_cache; + +/** \brief Auxiliary class for initializing simp lemmas during blast initialization. */ +class scope_simp { + simp_lemmas_cache * m_old_cache; +public: + scope_simp(); + ~scope_simp(); +}; + +simp_lemmas add(tmp_type_context & tctx, simp_lemmas const & s, name const & id, expr const & e, expr const & h, unsigned priority); +simp_lemmas join(simp_lemmas const & s1, simp_lemmas const & s2); + +/** \brief Get (active) simplification lemmas. */ +simp_lemmas get_simp_lemmas(); +/** \brief Get simplification lemmas in the given namespace. */ +simp_lemmas get_simp_lemmas(name const & ns); +/** \brief Get simplification lemmas in the given namespaces. */ +// simp_lemmas get_simp_lemmas(std::initializer_list const & nss); +/** \brief Get simplification lemmas in the namespaces registered at key. + The key is created using procedure #register_simp_lemmas at initialization time. + This is more efficient than get_simp_lemmas(std::initializer_list const & nss), because + results are cached. */ +simp_lemmas get_simp_lemmas(unsigned key); +} +} diff --git a/src/library/blast/simplifier/simp_rule_set.cpp b/src/library/blast/simplifier/simp_rule_set.cpp deleted file mode 100644 index 020842fdc..000000000 --- a/src/library/blast/simplifier/simp_rule_set.cpp +++ /dev/null @@ -1,612 +0,0 @@ -/* -Copyright (c) 2015 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. - -Author: Leonardo de Moura -*/ -#include -#include -#include "util/sstream.h" -#include "kernel/instantiate.h" -#include "kernel/find_fn.h" -#include "kernel/error_msgs.h" -#include "library/scoped_ext.h" -#include "library/expr_pair.h" -#include "library/attribute_manager.h" -#include "library/relation_manager.h" -#include "library/blast/simplifier/ceqv.h" -#include "library/blast/simplifier/simp_rule_set.h" - -namespace lean { -simp_rule_core::simp_rule_core(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - unsigned priority): - m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances), - m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {} - -simp_rule::simp_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - bool is_perm, unsigned priority): - simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), - m_is_permutation(is_perm) {} - -bool operator==(simp_rule const & r1, simp_rule const & r2) { - return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; -} - -format simp_rule::pp(formatter const & fmt) const { - format r; - r += format("#") + format(get_num_emeta()); - if (m_priority != LEAN_DEFAULT_PRIORITY) - r += space() + paren(format(m_priority)); - if (m_is_permutation) - r += space() + format("perm"); - format r1 = comma() + space() + fmt(get_lhs()); - r1 += space() + format("↦") + pp_indent_expr(fmt, get_rhs()); - r += group(r1); - return r; -} - -congr_rule::congr_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - list const & congr_hyps, unsigned priority): - simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), - m_congr_hyps(congr_hyps) {} - -bool operator==(congr_rule const & r1, congr_rule const & r2) { - return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps; -} - -format congr_rule::pp(formatter const & fmt) const { - format r; - r += format("#") + format(get_num_emeta()); - if (m_priority != LEAN_DEFAULT_PRIORITY) - r += space() + paren(format(m_priority)); - format r1; - for (expr const & h : m_congr_hyps) { - r1 += space() + paren(fmt(mlocal_type(h))); - } - r += group(r1); - r += space() + format(":") + space(); - format r2 = paren(fmt(get_lhs())); - r2 += space() + format("↦") + space() + paren(fmt(get_rhs())); - r += group(r2); - return r; -} - -simp_rule_set::simp_rule_set(name const & eqv): - m_eqv(eqv) {} - -void simp_rule_set::insert(simp_rule const & r) { - m_simp_set.insert(r.get_lhs(), r); -} - -void simp_rule_set::erase(simp_rule const & r) { - m_simp_set.erase(r.get_lhs(), r); -} - -void simp_rule_set::insert(congr_rule const & r) { - m_congr_set.insert(r.get_lhs(), r); -} - -void simp_rule_set::erase(congr_rule const & r) { - m_congr_set.erase(r.get_lhs(), r); -} - -list const * simp_rule_set::find_simp(head_index const & h) const { - return m_simp_set.find(h); -} - -void simp_rule_set::for_each_simp(std::function const & fn) const { - m_simp_set.for_each_entry([&](head_index const &, simp_rule const & r) { fn(r); }); -} - -list const * simp_rule_set::find_congr(head_index const & h) const { - return m_congr_set.find(h); -} - -void simp_rule_set::for_each_congr(std::function const & fn) const { - m_congr_set.for_each_entry([&](head_index const &, congr_rule const & r) { fn(r); }); -} - -void simp_rule_set::erase_simp(name_set const & ids) { - // This method is not very smart and doesn't use any indexing or caching. - // So, it may be a bottleneck in the future - buffer to_delete; - for_each_simp([&](simp_rule const & r) { - if (ids.contains(r.get_id())) { - to_delete.push_back(r); - } - }); - for (simp_rule const & r : to_delete) { - erase(r); - } -} - -void simp_rule_set::erase_simp(buffer const & ids) { - erase_simp(to_name_set(ids)); -} - -template -void simp_rule_sets::insert_core(name const & eqv, R const & r) { - simp_rule_set s(eqv); - if (auto const * curr = m_sets.find(eqv)) { - s = *curr; - } - s.insert(r); - m_sets.insert(eqv, s); -} - -template -void simp_rule_sets::erase_core(name const & eqv, R const & r) { - if (auto const * curr = m_sets.find(eqv)) { - simp_rule_set s = *curr; - s.erase(r); - if (s.empty()) - m_sets.erase(eqv); - else - m_sets.insert(eqv, s); - } -} - -void simp_rule_sets::insert(name const & eqv, simp_rule const & r) { - return insert_core(eqv, r); -} - -void simp_rule_sets::erase(name const & eqv, simp_rule const & r) { - return erase_core(eqv, r); -} - -void simp_rule_sets::insert(name const & eqv, congr_rule const & r) { - return insert_core(eqv, r); -} - -void simp_rule_sets::erase(name const & eqv, congr_rule const & r) { - return erase_core(eqv, r); -} - -void simp_rule_sets::get_relations(buffer & rs) const { - m_sets.for_each([&](name const & r, simp_rule_set const &) { - rs.push_back(r); - }); -} - -void simp_rule_sets::erase_simp(name_set const & ids) { - name_map new_sets; - m_sets.for_each([&](name const & n, simp_rule_set const & s) { - simp_rule_set new_s = s; - new_s.erase_simp(ids); - new_sets.insert(n, new_s); - }); - m_sets = new_sets; -} - -void simp_rule_sets::erase_simp(buffer const & ids) { - erase_simp(to_name_set(ids)); -} - -simp_rule_set const * simp_rule_sets::find(name const & eqv) const { - return m_sets.find(eqv); -} - -list const * simp_rule_sets::find_simp(name const & eqv, head_index const & h) const { - if (auto const * s = m_sets.find(eqv)) - return s->find_simp(h); - return nullptr; -} - -list const * simp_rule_sets::find_congr(name const & eqv, head_index const & h) const { - if (auto const * s = m_sets.find(eqv)) - return s->find_congr(h); - return nullptr; -} - -void simp_rule_sets::for_each_simp(std::function const & fn) const { - m_sets.for_each([&](name const & eqv, simp_rule_set const & s) { - s.for_each_simp([&](simp_rule const & r) { - fn(eqv, r); - }); - }); -} - -void simp_rule_sets::for_each_congr(std::function const & fn) const { - m_sets.for_each([&](name const & eqv, simp_rule_set const & s) { - s.for_each_congr([&](congr_rule const & r) { - fn(eqv, r); - }); - }); -} - -format simp_rule_sets::pp(formatter const & fmt, format const & header, bool simp, bool congr) const { - format r; - if (simp) { - name prev_eqv; - for_each_simp([&](name const & eqv, simp_rule const & rw) { - if (prev_eqv != eqv) { - r += format("simplification rules for ") + format(eqv); - r += header; - r += line(); - prev_eqv = eqv; - } - r += rw.pp(fmt) + line(); - }); - } - - if (congr) { - name prev_eqv; - for_each_congr([&](name const & eqv, congr_rule const & cr) { - if (prev_eqv != eqv) { - r += format("congruencec rules for ") + format(eqv) + line(); - prev_eqv = eqv; - } - r += cr.pp(fmt) + line(); - }); - } - return r; -} - -format simp_rule_sets::pp_simp(formatter const & fmt, format const & header) const { - return pp(fmt, header, true, false); -} - -format simp_rule_sets::pp_simp(formatter const & fmt) const { - return pp(fmt, format(), true, false); -} - -format simp_rule_sets::pp_congr(formatter const & fmt) const { - return pp(fmt, format(), false, true); -} - -format simp_rule_sets::pp(formatter const & fmt) const { - return pp(fmt, format(), true, true); -} - -static name * g_prefix = nullptr; - -simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, - name const & id, levels const & univ_metas, expr const & e, expr const & h, - unsigned priority) { - list ceqvs = to_ceqvs(tctx, e, h); - if (is_nil(ceqvs)) throw exception("[simp] rule invalid"); - environment const & env = tctx.env(); - simp_rule_sets new_s = s; - for (expr_pair const & p : ceqvs) { - expr rule = tctx.whnf(p.first); - expr proof = tctx.whnf(p.second); - bool is_perm = is_permutation_ceqv(env, rule); - buffer emetas; - buffer instances; - while (is_pi(rule)) { - expr mvar = tctx.mk_mvar(binding_domain(rule)); - emetas.push_back(mvar); - instances.push_back(binding_info(rule).is_inst_implicit()); - rule = tctx.whnf(instantiate(binding_body(rule), mvar)); - proof = mk_app(proof, mvar); - } - expr rel, lhs, rhs; - if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) { - new_s.insert(const_name(rel), simp_rule(id, univ_metas, reverse_to_list(emetas), - reverse_to_list(instances), lhs, rhs, proof, is_perm, priority)); - } - } - return new_s; -} - -simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority) { - return add_core(tctx, s, id, list(), e, h, priority); -} - -simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2) { - simp_rule_sets new_s1 = s1; - s2.for_each_simp([&](name const & eqv, simp_rule const & r) { - new_s1.insert(eqv, r); - }); - return new_s1; -} - -static name * g_class_name = nullptr; -static std::string * g_key = nullptr; - -static simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & cname, unsigned priority) { - declaration const & d = tctx.env().get(cname); - buffer us; - unsigned num_univs = d.get_num_univ_params(); - for (unsigned i = 0; i < num_univs; i++) { - us.push_back(tctx.mk_uvar()); - } - levels ls = to_list(us); - expr e = tctx.whnf(instantiate_type_univ_params(d, ls)); - expr h = mk_constant(cname, ls); - return add_core(tctx, s, cname, ls, e, h, priority); -} - - -// Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x), -// where B is lambda or Pi -static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) { - lean_assert(is_binding(lhs)); - expr const & d = binding_domain(lhs); - expr const & b = binding_body(lhs); - if (!is_metavar(d)) - return false; - if (is_metavar(b) && b != d) { - found_mvars.insert(mlocal_name(b)); - found_mvars.insert(mlocal_name(d)); - return true; - } - if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) { - found_mvars.insert(mlocal_name(app_fn(b))); - found_mvars.insert(mlocal_name(d)); - return true; - } - return false; -} - -// Return true iff all metavariables in e are in found_mvars -static bool only_found_mvars(expr const & e, name_set const & found_mvars) { - return !find(e, [&](expr const & m, unsigned) { - return is_metavar(m) && !found_mvars.contains(mlocal_name(m)); - }); -} - -// Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable, -// and l_i's are local constants, and mvar does not occur in found_mvars. -// If it is return true and update found_mvars -static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) { - buffer rhs_args; - expr const & rhs_fn = get_app_args(rhs, rhs_args); - if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn))) - return false; - for (expr const & arg : rhs_args) - if (!is_local(arg)) - return false; - found_mvars.insert(mlocal_name(rhs_fn)); - return true; -} - -void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned prio) { - declaration const & d = tctx.env().get(n); - buffer us; - unsigned num_univs = d.get_num_univ_params(); - for (unsigned i = 0; i < num_univs; i++) { - us.push_back(tctx.mk_uvar()); - } - levels ls = to_list(us); - expr rule = instantiate_type_univ_params(d, ls); - expr proof = mk_constant(n, ls); - - buffer emetas; - buffer instances, explicits; - - while (is_pi(rule)) { - expr mvar = tctx.mk_mvar(binding_domain(rule)); - emetas.push_back(mvar); - explicits.push_back(is_explicit(binding_info(rule))); - instances.push_back(binding_info(rule).is_inst_implicit()); - rule = tctx.whnf(instantiate(binding_body(rule), mvar)); - proof = mk_app(proof, mvar); - } - expr rel, lhs, rhs; - if (!is_simp_relation(tctx.env(), rule, rel, lhs, rhs) || !is_constant(rel)) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation"); - } - name_set found_mvars; - buffer lhs_args, rhs_args; - expr const & lhs_fn = get_app_args(lhs, lhs_args); - expr const & rhs_fn = get_app_args(rhs, rhs_args); - if (is_constant(lhs_fn)) { - if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) " - << "~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'"); - } - for (expr const & lhs_arg : lhs_args) { - if (is_sort(lhs_arg)) - continue; - if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' the left-hand-side of the congruence resulting type must be of the form (" - << const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort"); - } - found_mvars.insert(mlocal_name(lhs_arg)); - } - } else if (is_binding(lhs)) { - if (lhs.kind() != rhs.kind()) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' kinds of the left-hand-side and right-hand-side of " - << "the congruence resulting type do not match"); - } - if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' left-hand-side of the congruence resulting type must " - << "be of the form (fun/Pi (x : A), B x)"); - } - } else { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' left-hand-side is not an application nor a binding"); - } - - buffer congr_hyps; - lean_assert(emetas.size() == explicits.size()); - for (unsigned i = 0; i < emetas.size(); i++) { - expr const & mvar = emetas[i]; - if (explicits[i] && !found_mvars.contains(mlocal_name(mvar))) { - buffer locals; - expr type = mlocal_type(mvar); - while (is_pi(type)) { - expr local = tctx.mk_tmp_local(binding_domain(type)); - locals.push_back(local); - type = instantiate(binding_body(type), local); - } - expr h_rel, h_lhs, h_rhs; - if (!is_simp_relation(tctx.env(), type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel)) - continue; - unsigned j = 0; - for (expr const & local : locals) { - j++; - if (!only_found_mvars(mlocal_type(local), found_mvars)) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' argument #" << j << " of parameter #" << (i+1) << " contains " - << "unresolved parameters"); - } - } - if (!only_found_mvars(h_lhs, found_mvars)) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' argument #" << (i+1) << " is not a valid hypothesis, the left-hand-side contains " - << "unresolved parameters"); - } - if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) { - throw exception(sstream() << "invalid congruence rule, '" << n - << "' argument #" << (i+1) << " is not a valid hypothesis, the right-hand-side must be " - << "of the form (m l_1 ... l_n) where m is parameter that was not " - << "'assigned/resolved' yet and l_i's are locals"); - } - found_mvars.insert(mlocal_name(mvar)); - congr_hyps.push_back(mvar); - } - } - s.insert(const_name(rel), congr_rule(n, ls, reverse_to_list(emetas), - reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio)); -} - -struct rrs_state { - simp_rule_sets m_sets; - name_set m_simp_names; - name_set m_congr_names; - - void add_simp(environment const & env, io_state const & ios, name const & cname, unsigned prio) { - tmp_type_context tctx(env, ios.get_options()); - m_sets = add_core(tctx, m_sets, cname, prio); - m_simp_names.insert(cname); - } - - void add_congr(environment const & env, io_state const & ios, name const & n, unsigned prio) { - tmp_type_context tctx(env, ios.get_options()); - add_congr_core(tctx, m_sets, n, prio); - m_congr_names.insert(n); - } -}; - -struct rrs_entry { - bool m_is_simp; - name m_name; - unsigned m_priority; - rrs_entry() {} - rrs_entry(bool is_simp, name const & n, unsigned prio):m_is_simp(is_simp), m_name(n), m_priority(prio) {} -}; - -struct rrs_config { - typedef rrs_entry entry; - typedef rrs_state state; - static void add_entry(environment const & env, io_state const & ios, state & s, entry const & e) { - if (e.m_is_simp) - s.add_simp(env, ios, e.m_name, e.m_priority); - else - s.add_congr(env, ios, e.m_name, e.m_priority); - } - static name const & get_class_name() { - return *g_class_name; - } - static std::string const & get_serialization_key() { - return *g_key; - } - static void write_entry(serializer & s, entry const & e) { - s << e.m_is_simp << e.m_name << e.m_priority; - } - static entry read_entry(deserializer & d) { - entry e; d >> e.m_is_simp >> e.m_name >> e.m_priority; return e; - } - static optional get_fingerprint(entry const & e) { - return some(hash(e.m_is_simp ? 17 : 31, e.m_name.hash())); - } -}; - -template class scoped_ext; -typedef scoped_ext rrs_ext; - -environment add_simp_rule(environment const & env, name const & n, unsigned prio, name const & ns, bool persistent) { - return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(true, n, prio), ns, persistent); -} - -environment add_congr_rule(environment const & env, name const & n, unsigned prio, name const & ns, bool persistent) { - return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(false, n, prio), ns, persistent); -} - -bool is_simp_rule(environment const & env, name const & n) { - return rrs_ext::get_state(env).m_simp_names.contains(n); -} - -bool is_congr_rule(environment const & env, name const & n) { - return rrs_ext::get_state(env).m_congr_names.contains(n); -} - -simp_rule_sets get_simp_rule_sets(environment const & env) { - return rrs_ext::get_state(env).m_sets; -} - -simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, name const & ns) { - simp_rule_sets set; - list const * entries = rrs_ext::get_entries(env, ns); - if (entries) { - for (auto const & e : *entries) { - tmp_type_context tctx(env, o); - set = add_core(tctx, set, e.m_name, e.m_priority); - } - } - return set; -} - -simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, std::initializer_list const & nss) { - simp_rule_sets set; - for (name const & ns : nss) { - list const * entries = rrs_ext::get_entries(env, ns); - if (entries) { - for (auto const & e : *entries) { - tmp_type_context tctx(env, o); - set = add_core(tctx, set, e.m_name, e.m_priority); - } - } - } - return set; -} - -io_state_stream const & operator<<(io_state_stream const & out, simp_rule_sets const & s) { - options const & opts = out.get_options(); - out.get_stream() << mk_pair(s.pp(out.get_formatter()), opts); - return out; -} - -void initialize_simplifier_rule_set() { - g_prefix = new name(name::mk_internal_unique_name()); - g_class_name = new name("simp"); - g_key = new std::string("SIMP"); - rrs_ext::initialize(); - register_prio_attribute("simp", "simplification rule", - [](environment const & env, io_state const &, name const & d, unsigned prio, name const & ns, bool persistent) { - return add_simp_rule(env, d, prio, ns, persistent); - }, - is_simp_rule, - [](environment const &, name const &) { - // TODO(Leo): fix it after we refactor simp_rule_set - return LEAN_DEFAULT_PRIORITY; - }); - - register_prio_attribute("congr", "congruence rule", - [](environment const & env, io_state const &, name const & d, unsigned prio, name const & ns, bool persistent) { - return add_congr_rule(env, d, prio, ns, persistent); - }, - is_congr_rule, - [](environment const &, name const &) { - // TODO(Leo): fix it after we refactor simp_rule_set - return LEAN_DEFAULT_PRIORITY; - }); -} - -void finalize_simplifier_rule_set() { - rrs_ext::finalize(); - delete g_key; - delete g_class_name; - delete g_prefix; -} -} diff --git a/src/library/blast/simplifier/simp_rule_set.h b/src/library/blast/simplifier/simp_rule_set.h deleted file mode 100644 index 860904583..000000000 --- a/src/library/blast/simplifier/simp_rule_set.h +++ /dev/null @@ -1,151 +0,0 @@ -/* -Copyright (c) 2015 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. - -Author: Leonardo de Moura -*/ -#pragma once -#include "library/tmp_type_context.h" -#include "library/head_map.h" -#include "library/io_state_stream.h" - -namespace lean { -class simp_rule_sets; - -class simp_rule_core { -protected: - name m_id; - levels m_umetas; - list m_emetas; - list m_instances; - - expr m_lhs; - expr m_rhs; - expr m_proof; - unsigned m_priority; - simp_rule_core(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - unsigned priority); -public: - name const & get_id() const { return m_id; } - unsigned get_num_umeta() const { return length(m_umetas); } - unsigned get_num_emeta() const { return length(m_emetas); } - - /** \brief Return a list containing the expression metavariables in reverse order. */ - list const & get_emetas() const { return m_emetas; } - - /** \brief Return a list of bools indicating whether or not each expression metavariable - in get_emetas() is an instance. */ - list const & get_instances() const { return m_instances; } - - unsigned get_priority() const { return m_priority; } - - expr const & get_lhs() const { return m_lhs; } - expr const & get_rhs() const { return m_rhs; } - expr const & get_proof() const { return m_proof; } -}; - -class simp_rule : public simp_rule_core { - bool m_is_permutation; - simp_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - bool is_perm, unsigned priority); - - friend simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, - levels const & univ_metas, expr const & e, expr const & h, unsigned priority); -public: - friend bool operator==(simp_rule const & r1, simp_rule const & r2); - bool is_perm() const { return m_is_permutation; } - format pp(formatter const & fmt) const; -}; - -bool operator==(simp_rule const & r1, simp_rule const & r2); -inline bool operator!=(simp_rule const & r1, simp_rule const & r2) { return !operator==(r1, r2); } - -class congr_rule : public simp_rule_core { - list m_congr_hyps; - congr_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - list const & congr_hyps, unsigned priority); - friend void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned priority); -public: - friend bool operator==(congr_rule const & r1, congr_rule const & r2); - list const & get_congr_hyps() const { return m_congr_hyps; } - format pp(formatter const & fmt) const; -}; - -struct simp_rule_core_prio_fn { unsigned operator()(simp_rule_core const & s) const { return s.get_priority(); } }; - -bool operator==(congr_rule const & r1, congr_rule const & r2); -inline bool operator!=(congr_rule const & r1, congr_rule const & r2) { return !operator==(r1, r2); } - -class simp_rule_set { - typedef head_map_prio simp_set; - typedef head_map_prio congr_set; - name m_eqv; - simp_set m_simp_set; - congr_set m_congr_set; -public: - simp_rule_set() {} - /** \brief Return the equivalence relation associated with this set */ - simp_rule_set(name const & eqv); - bool empty() const { return m_simp_set.empty() && m_congr_set.empty(); } - name const & get_eqv() const { return m_eqv; } - void insert(simp_rule const & r); - void erase(simp_rule const & r); - void insert(congr_rule const & r); - void erase(congr_rule const & r); - void erase_simp(name_set const & ids); - void erase_simp(buffer const & ids); - list const * find_simp(head_index const & h) const; - void for_each_simp(std::function const & fn) const; - list const * find_congr(head_index const & h) const; - void for_each_congr(std::function const & fn) const; -}; - -class simp_rule_sets { - name_map m_sets; // mapping from relation name to simp_rule_set - template void insert_core(name const & eqv, R const & r); - template void erase_core(name const & eqv, R const & r); -public: - void insert(name const & eqv, simp_rule const & r); - void erase(name const & eqv, simp_rule const & r); - void insert(name const & eqv, congr_rule const & r); - void erase(name const & eqv, congr_rule const & r); - void erase_simp(name_set const & ids); - void erase_simp(buffer const & ids); - void get_relations(buffer & rs) const; - simp_rule_set const * find(name const & eqv) const; - list const * find_simp(name const & eqv, head_index const & h) const; - list const * find_congr(name const & eqv, head_index const & h) const; - void for_each_simp(std::function const & fn) const; - void for_each_congr(std::function const & fn) const; - format pp(formatter const & fmt, format const & header, bool simp, bool congr) const; - format pp_simp(formatter const & fmt, format const & header) const; - format pp_simp(formatter const & fmt) const; - format pp_congr(formatter const & fmt) const; - format pp(formatter const & fmt) const; -}; - -simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority); -simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2); - -environment add_simp_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent); -environment add_congr_rule(environment const & env, name const & n, unsigned priority, name const & ns, bool persistent); - -/** \brief Return true if \c n is an active simplification rule in \c env */ -bool is_simp_rule(environment const & env, name const & n); -/** \brief Return true if \c n is an active congruence rule in \c env */ -bool is_congr_rule(environment const & env, name const & n); -/** \brief Get current simplification rule sets */ -simp_rule_sets get_simp_rule_sets(environment const & env); -/** \brief Get simplification rule sets in the given namespace. */ -simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, name const & ns); -/** \brief Get simplification rule sets in the given namespaces. */ -simp_rule_sets get_simp_rule_sets(environment const & env, options const & o, std::initializer_list const & nss); - -io_state_stream const & operator<<(io_state_stream const & out, simp_rule_sets const & s); - -void initialize_simplifier_rule_set(); -void finalize_simplifier_rule_set(); -} diff --git a/src/library/blast/simplifier/simplifier.cpp b/src/library/blast/simplifier/simplifier.cpp index 29defd928..2c2c03786 100644 --- a/src/library/blast/simplifier/simplifier.cpp +++ b/src/library/blast/simplifier/simplifier.cpp @@ -27,7 +27,7 @@ Author: Daniel Selsam #include "library/blast/trace.h" #include "library/blast/blast_exception.h" #include "library/blast/simplifier/simplifier.h" -#include "library/blast/simplifier/simp_rule_set.h" +#include "library/blast/simplifier/simp_lemmas.h" #include "library/blast/simplifier/ceqv.h" #ifndef LEAN_DEFAULT_SIMPLIFY_MAX_STEPS @@ -60,14 +60,11 @@ namespace blast { using simp::result; -/* Names */ +/* Keys */ -static name * g_simplify_prove_namespace = nullptr; -static name * g_simplify_neg_namespace = nullptr; -static name * g_simplify_unit_namespace = nullptr; -static name * g_simplify_ac_namespace = nullptr; -static name * g_simplify_distrib_namespace = nullptr; -static name * g_simplify_numeral_namespace = nullptr; +static unsigned g_ac_key; +static unsigned g_som_key; +static unsigned g_numeral_key; /* Options */ @@ -133,8 +130,8 @@ class simplifier { name m_rel; expr_predicate m_simp_pred; - simp_rule_sets m_srss; - simp_rule_sets m_ctx_srss; + simp_lemmas m_srss; + simp_lemmas m_ctx_srss; /* Logging */ unsigned m_num_steps{0}; @@ -184,8 +181,8 @@ class simplifier { return has_free_vars(binding_body(f_type)); } - simp_rule_sets add_to_srss(simp_rule_sets const & _srss, buffer & ls) { - simp_rule_sets srss = _srss; + simp_lemmas add_to_srss(simp_lemmas const & _srss, buffer & ls) { + simp_lemmas srss = _srss; for (unsigned i = 0; i < ls.size(); i++) { expr & l = ls[i]; blast_tmp_type_context tctx; @@ -211,7 +208,7 @@ class simplifier { result finalize(result const & r); /* Simplification */ - result simplify(expr const & e, simp_rule_sets const & srss); + result simplify(expr const & e, simp_lemmas const & srss); result simplify(expr const & e, bool is_root); result simplify_lambda(expr const & e); result simplify_pi(expr const & e); @@ -221,12 +218,12 @@ class simplifier { /* Proving */ optional prove(expr const & thm); - optional prove(expr const & thm, simp_rule_sets const & srss); + optional prove(expr const & thm, simp_lemmas const & srss); /* Rewriting */ result rewrite(expr const & e); - result rewrite(expr const & e, simp_rule_sets const & srss); - result rewrite(expr const & e, simp_rule const & sr); + result rewrite(expr const & e, simp_lemmas const & srss); + result rewrite(expr const & e, simp_lemma const & sr); /* Congruence */ result congr_fun_arg(result const & r_f, result const & r_arg); @@ -236,7 +233,7 @@ class simplifier { result congr_funs(result const & r_f, buffer const & args); result try_congrs(expr const & e); - result try_congr(expr const & e, congr_rule const & cr); + result try_congr(expr const & e, user_congr_lemma const & cr); template optional synth_congr(expr const & e, F && simp); @@ -254,7 +251,7 @@ class simplifier { public: simplifier(name const & rel, expr_predicate const & simp_pred): m_rel(rel), m_simp_pred(simp_pred) { } - result operator()(expr const & e, simp_rule_sets const & srss) { return simplify(e, srss); } + result operator()(expr const & e, simp_lemmas const & srss) { return simplify(e, srss); } }; /* Cache */ @@ -365,8 +362,8 @@ expr simplifier::whnf_eta(expr const & e) { /* Simplification */ -result simplifier::simplify(expr const & e, simp_rule_sets const & srss) { - flet set_srss(m_srss, srss); +result simplifier::simplify(expr const & e, simp_lemmas const & srss) { + flet set_srss(m_srss, srss); freset reset(m_cache); return simplify(e, true); } @@ -552,7 +549,7 @@ optional simplifier::prove(expr const & thm) { return none_expr(); } -optional simplifier::prove(expr const & thm, simp_rule_sets const & srss) { +optional simplifier::prove(expr const & thm, simp_lemmas const & srss) { flet set_name(m_rel, get_iff_name()); result r_cond = simplify(thm, srss); if (is_constant(r_cond.get_new()) && const_name(r_cond.get_new()) == get_true_name()) { @@ -577,16 +574,16 @@ result simplifier::rewrite(expr const & e) { return r; } -result simplifier::rewrite(expr const & e, simp_rule_sets const & srss) { +result simplifier::rewrite(expr const & e, simp_lemmas const & srss) { result r(e); - simp_rule_set const * sr = srss.find(m_rel); + simp_lemmas_for const * sr = srss.find(m_rel); if (!sr) return r; - list const * srs = sr->find_simp(e); + list const * srs = sr->find_simp(e); if (!srs) return r; - for_each(*srs, [&](simp_rule const & sr) { + for_each(*srs, [&](simp_lemma const & sr) { result r_new = rewrite(r.get_new(), sr); if (!r_new.has_proof()) return; r = join(r, r_new); @@ -594,7 +591,7 @@ result simplifier::rewrite(expr const & e, simp_rule_sets const & srss) { return r; } -result simplifier::rewrite(expr const & e, simp_rule const & sr) { +result simplifier::rewrite(expr const & e, simp_lemma const & sr) { blast_tmp_type_context tmp_tctx(sr.get_num_umeta(), sr.get_num_emeta()); if (!tmp_tctx->is_def_eq(e, sr.get_lhs())) return result(e); @@ -672,21 +669,21 @@ result simplifier::congr_funs(result const & r_f, buffer const & args) { } result simplifier::try_congrs(expr const & e) { - simp_rule_set const * sr = get_simp_rule_sets(env()).find(m_rel); + simp_lemmas_for const * sr = get_simp_lemmas().find(m_rel); if (!sr) return result(e); - list const * crs = sr->find_congr(e); + list const * crs = sr->find_congr(e); if (!crs) return result(e); result r(e); - for_each(*crs, [&](congr_rule const & cr) { + for_each(*crs, [&](user_congr_lemma const & cr) { if (r.has_proof()) return; r = try_congr(e, cr); }); return r; } -result simplifier::try_congr(expr const & e, congr_rule const & cr) { +result simplifier::try_congr(expr const & e, user_congr_lemma const & cr) { blast_tmp_type_context tmp_tctx(cr.get_num_umeta(), cr.get_num_emeta()); if (!tmp_tctx->is_def_eq(e, cr.get_lhs())) return result(e); @@ -720,7 +717,7 @@ result simplifier::try_congr(expr const & e, congr_rule const & cr) { { flet set_name(m_rel, const_name(h_rel)); - flet set_ctx_srss(m_ctx_srss, m_contextual ? add_to_srss(m_ctx_srss, ls) : m_ctx_srss); + flet set_ctx_srss(m_ctx_srss, m_contextual ? add_to_srss(m_ctx_srss, ls) : m_ctx_srss); h_lhs = tmp_tctx->instantiate_uvars_mvars(h_lhs); lean_assert(!has_metavar(h_lhs)); @@ -976,9 +973,8 @@ result simplifier::fuse(expr const & e) { /* Prove (1) == (3) using simplify with [ac] */ flet no_simplify_numerals(m_numerals, false); auto pf_1_3 = prove(get_app_builder().mk_eq(e, e_grp), - get_simp_rule_sets(env(), ios().get_options(), - {*g_simplify_prove_namespace, *g_simplify_unit_namespace, - *g_simplify_neg_namespace, *g_simplify_ac_namespace})); + get_simp_lemmas(g_ac_key)); + if (!pf_1_3) { diagnostic(env(), ios()) << ppb(e) << "\n\n =?=\n\n" << ppb(e_grp) << "\n"; throw blast_exception("Failed to prove (1) == (3) during fusion", e); @@ -986,10 +982,8 @@ result simplifier::fuse(expr const & e) { /* Prove (4) == (5) using simplify with [som] */ auto pf_4_5 = prove(get_app_builder().mk_eq(e_grp_ls, e_fused_ls), - get_simp_rule_sets(env(), ios().get_options(), - {*g_simplify_prove_namespace, *g_simplify_unit_namespace, - *g_simplify_neg_namespace, *g_simplify_ac_namespace, - *g_simplify_distrib_namespace})); + get_simp_lemmas(g_som_key)); + if (!pf_4_5) { diagnostic(env(), ios()) << ppb(e_grp_ls) << "\n\n =?=\n\n" << ppb(e_fused_ls) << "\n"; throw blast_exception("Failed to prove (4) == (5) during fusion", e); @@ -997,9 +991,7 @@ result simplifier::fuse(expr const & e) { /* Prove (5) == (6) using simplify with [numeral] */ flet simplify_numerals(m_numerals, true); - result r_simp_ls = simplify(e_fused_ls, get_simp_rule_sets(env(), ios().get_options(), - {*g_simplify_unit_namespace, *g_simplify_neg_namespace, - *g_simplify_ac_namespace})); + result r_simp_ls = simplify(e_fused_ls, get_simp_lemmas(g_numeral_key)); /* Prove (4) == (6) by transitivity of proofs (2) and (3) */ expr pf_4_6; @@ -1057,12 +1049,11 @@ void initialize_simplifier() { register_trace_class(name({"simplifier", "congruence"})); register_trace_class(name({"simplifier", "failure"})); - g_simplify_prove_namespace = new name{"simplifier", "prove"}; - g_simplify_neg_namespace = new name{"simplifier", "neg"}; - g_simplify_unit_namespace = new name{"simplifier", "unit"}; - g_simplify_ac_namespace = new name{"simplifier", "ac"}; - g_simplify_distrib_namespace = new name{"simplifier", "distrib"}; - g_simplify_numeral_namespace = new name{"simplifier", "numeral"}; + g_ac_key = register_simp_lemmas({name{"simplifier", "prove"}, name{"simplifier", "unit"}, + name{"simplifier", "neg"}, name{"simplifier", "ac"}}); + g_som_key = register_simp_lemmas({name{"simplifier", "prove"}, name{"simplifier", "unit"}, + name{"simplifier", "neg"}, name{"simplifier", "ac"}, name{"simplifier", "distrib"}}); + g_numeral_key = register_simp_lemmas({name{"simplifier", "unit"}, name{"simplifier", "neg"}, name{"simplifier", "ac"}}); g_simplify_max_steps = new name{"simplify", "max_steps"}; g_simplify_top_down = new name{"simplify", "top_down"}; @@ -1100,23 +1091,16 @@ void finalize_simplifier() { delete g_simplify_exhaustive; delete g_simplify_top_down; delete g_simplify_max_steps; - - delete g_simplify_numeral_namespace; - delete g_simplify_distrib_namespace; - delete g_simplify_ac_namespace; - delete g_simplify_unit_namespace; - delete g_simplify_neg_namespace; - delete g_simplify_prove_namespace; } /* Entry points */ static bool simplify_all_pred(expr const &) { return true; } -result simplify(name const & rel, expr const & e, simp_rule_sets const & srss) { +result simplify(name const & rel, expr const & e, simp_lemmas const & srss) { return simplifier(rel, simplify_all_pred)(e, srss); } -result simplify(name const & rel, expr const & e, simp_rule_sets const & srss, expr_predicate const & simp_pred) { +result simplify(name const & rel, expr const & e, simp_lemmas const & srss, expr_predicate const & simp_pred) { return simplifier(rel, simp_pred)(e, srss); } diff --git a/src/library/blast/simplifier/simplifier.h b/src/library/blast/simplifier/simplifier.h index 239a69df9..57c0863c2 100644 --- a/src/library/blast/simplifier/simplifier.h +++ b/src/library/blast/simplifier/simplifier.h @@ -6,7 +6,7 @@ Author: Daniel Selsam #pragma once #include "kernel/expr_pair.h" #include "library/blast/state.h" -#include "library/blast/simplifier/simp_rule_set.h" +#include "library/blast/simplifier/simp_lemmas.h" namespace lean { namespace blast { @@ -40,8 +40,8 @@ public: // TODO(dhs): put this outside of blast module typedef std::function expr_predicate; // NOLINT -simp::result simplify(name const & rel, expr const & e, simp_rule_sets const & srss); -simp::result simplify(name const & rel, expr const & e, simp_rule_sets const & srss, expr_predicate const & simp_pred); +simp::result simplify(name const & rel, expr const & e, simp_lemmas const & srss); +simp::result simplify(name const & rel, expr const & e, simp_lemmas const & srss, expr_predicate const & simp_pred); void initialize_simplifier(); void finalize_simplifier(); diff --git a/src/library/blast/simplifier/simplifier_actions.cpp b/src/library/blast/simplifier/simplifier_actions.cpp index b12169eed..8ff8d0819 100644 --- a/src/library/blast/simplifier/simplifier_actions.cpp +++ b/src/library/blast/simplifier/simplifier_actions.cpp @@ -16,18 +16,18 @@ namespace lean { namespace blast { static unsigned g_ext_id = 0; struct simplifier_branch_extension : public branch_extension { - simp_rule_sets m_srss; - bool m_simp_target{false}; // true if target needs to be simplified again + simp_lemmas m_simp_lemmas; + bool m_simp_target{false}; // true if target needs to be simplified again simplifier_branch_extension() {} simplifier_branch_extension(simplifier_branch_extension const & b): - m_srss(b.m_srss) {} + m_simp_lemmas(b.m_simp_lemmas) {} virtual ~simplifier_branch_extension() {} virtual branch_extension * clone() override { return new simplifier_branch_extension(*this); } - virtual void initialized() override { m_srss = ::lean::get_simp_rule_sets(env()); } + virtual void initialized() override { m_simp_lemmas = ::lean::blast::get_simp_lemmas(); } virtual void target_updated() override { m_simp_target = true; } virtual void hypothesis_activated(hypothesis const &, hypothesis_idx) override { } virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) override { } - simp_rule_sets const & get_simp_rule_sets() const { return m_srss; } + simp_lemmas const & get_simp_lemmas() const { return m_simp_lemmas; } }; void initialize_simplifier_actions() { @@ -78,7 +78,7 @@ action_result simplify_target_action() { expr target = s.get_target(); bool iff = use_iff(target); name rname = iff ? get_iff_name() : get_eq_name(); - auto r = simplify(rname, target, ext.get_simp_rule_sets()); + auto r = simplify(rname, target, ext.get_simp_lemmas()); if (r.get_new() == target) return action_result::failed(); // did nothing if (r.has_proof()) { @@ -108,7 +108,7 @@ action_result simplify_hypothesis_action(hypothesis_idx hidx) { return action_result::failed(); } auto & ext = get_extension(); - auto r = simplify(get_iff_name(), h.get_type(), ext.get_simp_rule_sets()); + auto r = simplify(get_iff_name(), h.get_type(), ext.get_simp_lemmas()); if (r.get_new() == h.get_type()) return action_result::failed(); // did nothing expr new_h_proof; @@ -123,7 +123,7 @@ action_result simplify_hypothesis_action(hypothesis_idx hidx) { return action_result::new_branch(); } -action_result add_simp_rule_action(hypothesis_idx hidx) { +action_result add_simp_lemma_action(hypothesis_idx hidx) { if (!get_config().m_simp) return action_result::failed(); blast_tmp_type_context ctx; @@ -136,8 +136,8 @@ action_result add_simp_rule_action(hypothesis_idx hidx) { bool added = false; for (auto const & p : ps) { try { - ext.m_srss = add(*ctx, ext.m_srss, h.get_name(), p.first, p.second, LEAN_DEFAULT_PRIORITY); - added = true; + ext.m_simp_lemmas = add(*ctx, ext.m_simp_lemmas, h.get_name(), p.first, p.second, LEAN_DEFAULT_PRIORITY); + added = true; } catch (exception &) { // TODO(Leo, Daniel): store event } diff --git a/src/library/blast/simplifier/simplifier_actions.h b/src/library/blast/simplifier/simplifier_actions.h index 600532ed7..30f7d07ac 100644 --- a/src/library/blast/simplifier/simplifier_actions.h +++ b/src/library/blast/simplifier/simplifier_actions.h @@ -11,7 +11,7 @@ namespace lean { namespace blast { action_result simplify_hypothesis_action(hypothesis_idx hidx); action_result simplify_target_action(); -action_result add_simp_rule_action(hypothesis_idx hidx); +action_result add_simp_lemma_action(hypothesis_idx hidx); void initialize_simplifier_actions(); void finalize_simplifier_actions(); diff --git a/src/library/blast/simplifier/simplifier_strategies.cpp b/src/library/blast/simplifier/simplifier_strategies.cpp index 4b6caa0f6..0b0237f05 100644 --- a/src/library/blast/simplifier/simplifier_strategies.cpp +++ b/src/library/blast/simplifier/simplifier_strategies.cpp @@ -35,7 +35,7 @@ class simplifier_strategy_fn : public strategy_fn { virtual action_result hypothesis_post_activation(hypothesis_idx hidx) override { if (m_use_hyps) - add_simp_rule_action(hidx); + add_simp_lemma_action(hidx); return action_result::new_branch(); } diff --git a/src/library/blast/unit/unit_preprocess.cpp b/src/library/blast/unit/unit_preprocess.cpp index ec685076b..805141d88 100644 --- a/src/library/blast/unit/unit_preprocess.cpp +++ b/src/library/blast/unit/unit_preprocess.cpp @@ -15,16 +15,15 @@ Author: Daniel Selsam #include "library/blast/choice_point.h" #include "library/blast/hypothesis.h" #include "library/blast/util.h" -#include "library/blast/simplifier/simp_rule_set.h" +#include "library/blast/simplifier/simp_lemmas.h" #include "library/blast/simplifier/simplifier.h" #include "library/expr_lt.h" #include "util/rb_multi_map.h" namespace lean { namespace blast { - -static name * g_simplify_unit_simp_namespace = nullptr; -static name * g_simplify_contextual = nullptr; +static unsigned g_unit_simp_key; +static name * g_simplify_contextual = nullptr; static bool is_propositional(expr const & e) { // TODO(dhs): This predicate will need to evolve, and will eventually be thrown out @@ -50,7 +49,7 @@ action_result unit_preprocess(unsigned hidx) { return action_result::failed(); } - simp_rule_sets srss = get_simp_rule_sets(env(), ios().get_options(), *g_simplify_unit_simp_namespace); + simp_lemmas srss = get_simp_lemmas(g_unit_simp_key); // TODO(dhs): disable contextual rewriting auto r = simplify(get_iff_name(), h.get_type(), srss, is_propositional); @@ -68,13 +67,11 @@ action_result unit_preprocess(unsigned hidx) { } void initialize_unit_preprocess() { - g_simplify_unit_simp_namespace = new name{"simplifier", "unit_simp"}; + g_unit_simp_key = register_simp_lemmas({name{"simplifier", "unit_simp"}}); g_simplify_contextual = new name{"simplify", "contextual"}; } void finalize_unit_preprocess() { delete g_simplify_contextual; - delete g_simplify_unit_simp_namespace; } - }} diff --git a/tests/lean/congr_error_msg.lean b/tests/lean/congr_error_msg.lean index 0436d186e..541e9f7a7 100644 --- a/tests/lean/congr_error_msg.lean +++ b/tests/lean/congr_error_msg.lean @@ -15,7 +15,7 @@ sorry lemma C₃ [congr] (a b : nat) : R (g a b) (g 0 0) := -- ERROR sorry -lemma C₄ [congr] (A B : Type) : (A → B) = (λ a : nat, B → A) 0 := -- ERROR +lemma C₄ [congr] (A B : Type) : (A → B) = (λ a : nat, B → A) 0 := sorry lemma C₅ [congr] (A B : Type₁) : (A → nat) = (B → nat) := -- ERROR diff --git a/tests/lean/congr_error_msg.lean.expected.out b/tests/lean/congr_error_msg.lean.expected.out index 391ca11ee..dc28eb4c9 100644 --- a/tests/lean/congr_error_msg.lean.expected.out +++ b/tests/lean/congr_error_msg.lean.expected.out @@ -1,10 +1,9 @@ -congr_error_msg.lean:9:0: error: invalid congruence rule, 'C₁' the left-hand-side of the congruence resulting type must be of the form (g x_1 ... x_n), where each x_i is a distinct variable or a sort -congr_error_msg.lean:12:0: error: invalid congruence rule, 'C₂' resulting type is not of the form (g ...) ~ (g ...), where ~ is 'eq' -congr_error_msg.lean:15:0: error: invalid congruence rule, 'C₃' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation -congr_error_msg.lean:18:0: error: invalid congruence rule, 'C₄' kinds of the left-hand-side and right-hand-side of the congruence resulting type do not match -congr_error_msg.lean:21:0: error: invalid congruence rule, 'C₅' left-hand-side of the congruence resulting type must be of the form (fun/Pi (x : A), B x) -congr_error_msg.lean:24:0: error: invalid congruence rule, 'C₆' left-hand-side is not an application nor a binding -congr_error_msg.lean:27:0: error: invalid congruence rule, 'C₇' argument #2 of parameter #5 contains unresolved parameters -congr_error_msg.lean:30:0: error: invalid congruence rule, 'C₈' argument #5 is not a valid hypothesis, the left-hand-side contains unresolved parameters -congr_error_msg.lean:33:0: error: invalid congruence rule, 'C₉' argument #6 is not a valid hypothesis, the right-hand-side must be of the form (m l_1 ... l_n) where m is parameter that was not 'assigned/resolved' yet and l_i's are locals +congr_error_msg.lean:9:0: error: invalid [congr] lemma, 'C₁' the left-hand-side of the congruence resulting type must be of the form (g x_1 ... x_n), where each x_i is a distinct variable or a sort +congr_error_msg.lean:12:0: error: invalid [congr] lemma, 'C₂' resulting type is not of the form (g ...) ~ (g ...), where ~ is 'eq' +congr_error_msg.lean:15:0: error: invalid [congr] lemma, 'C₃' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation +congr_error_msg.lean:21:0: error: invalid [congr] lemma, 'C₅' left-hand-side of the congruence resulting type must be of the form (fun/Pi (x : A), B x) +congr_error_msg.lean:24:0: error: invalid [congr] lemma, 'C₆' left-hand-side is not an application nor a binding +congr_error_msg.lean:27:0: error: invalid [congr] lemma, 'C₇' argument #2 of parameter #5 contains unresolved parameters +congr_error_msg.lean:30:0: error: invalid [congr] lemma, 'C₈' argument #5 is not a valid hypothesis, the left-hand-side contains unresolved parameters +congr_error_msg.lean:33:0: error: invalid [congr] lemma, 'C₉' argument #6 is not a valid hypothesis, the right-hand-side must be of the form (m l_1 ... l_n) where m is parameter that was not 'assigned/resolved' yet and l_i's are locals congr_error_msg.lean:33:0: error: unknown declaration 'C₁' diff --git a/tests/lean/run/blast_simp5.lean b/tests/lean/run/blast_simp5.lean new file mode 100644 index 000000000..91821100d --- /dev/null +++ b/tests/lean/run/blast_simp5.lean @@ -0,0 +1,8 @@ +definition f : nat → nat := sorry +definition g (a : nat) := f a +lemma gax [simp] : ∀ a, g a = 0 := sorry + +attribute g [reducible] + +example (a : nat) : f (a + a) = 0 := +by simp