From 491c7c55e1df492dc9fd1bf4d2740fc1b2fbd771 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 16 Nov 2015 22:34:06 -0800 Subject: [PATCH] feat(library/simplifier/simp_rule_set): add priorities for simp and congr rules --- hott/algebra/category/functor/basic.hlean | 4 +- hott/function.hlean | 2 +- src/frontends/lean/decl_attributes.cpp | 42 ++++-------- src/frontends/lean/decl_attributes.h | 3 - src/frontends/lean/decl_cmds.cpp | 2 +- src/library/blast/simplifier.cpp | 3 +- src/library/simplifier/simp_rule_set.cpp | 83 +++++++++++++---------- src/library/simplifier/simp_rule_set.h | 34 +++++++--- tests/lean/704.lean | 2 +- tests/lean/run/priority_test2.lean | 4 +- 10 files changed, 96 insertions(+), 83 deletions(-) diff --git a/hott/algebra/category/functor/basic.hlean b/hott/algebra/category/functor/basic.hlean index 8a013a4ce..c7110b707 100644 --- a/hott/algebra/category/functor/basic.hlean +++ b/hott/algebra/category/functor/basic.hlean @@ -181,8 +181,8 @@ namespace functor q (respect_comp F g f) end qed section - local attribute precategory.is_hset_hom [priority 1001] - local attribute trunctype.struct [priority 1] -- remove after #842 is closed + local attribute precategory.is_hset_hom [instance] [priority 1001] + local attribute trunctype.struct [instance] [priority 1] -- remove after #842 is closed protected theorem is_hset_functor [instance] [HD : is_hset D] : is_hset (functor C D) := by apply is_trunc_equiv_closed; apply functor.sigma_char diff --git a/hott/function.hlean b/hott/function.hlean index 8c5fb1559..2ac85a9e4 100644 --- a/hott/function.hlean +++ b/hott/function.hlean @@ -252,7 +252,7 @@ namespace function section local attribute is_equiv_of_is_section_of_is_retraction [instance] [priority 10000] - local attribute trunctype.struct [priority 1] -- remove after #842 is closed + local attribute trunctype.struct [instance] [priority 1] -- remove after #842 is closed variable (f) definition is_hprop_is_retraction_prod_is_section : is_hprop (is_retraction f × is_section f) := begin diff --git a/src/frontends/lean/decl_attributes.cpp b/src/frontends/lean/decl_attributes.cpp index 68b16af48..0c8091e3f 100644 --- a/src/frontends/lean/decl_attributes.cpp +++ b/src/frontends/lean/decl_attributes.cpp @@ -43,7 +43,7 @@ decl_attributes::decl_attributes(bool is_abbrev, bool persistent): m_congr = false; } -void decl_attributes::parse(buffer const & ns, parser & p) { +void decl_attributes::parse(parser & p) { while (true) { auto pos = p.pos(); if (p.curr_is_token(get_instance_tk())) { @@ -87,17 +87,8 @@ void decl_attributes::parse(buffer const & ns, parser & p) { p.next(); } else if (auto it = parse_priority(p)) { m_priority = *it; - if (!m_is_instance) { - if (ns.empty()) { - throw parser_error("invalid '[priority]' attribute, declaration must be marked as an '[instance]'", pos); - } else { - for (name const & n : ns) { - if (!is_instance(p.env(), n)) - throw parser_error(sstream() << "invalid '[priority]' attribute, declaration '" << n - << "' must be marked as an '[instance]'", pos); - } - m_is_instance = true; - } + if (!m_is_instance && !m_simp && !m_congr) { + throw parser_error("invalid '[priority]' attribute, declaration must be marked as an '[instance]', '[simp]' or '[congr]'", pos); } } else if (p.curr_is_token(get_parsing_only_tk())) { if (!m_is_abbrev) @@ -158,17 +149,6 @@ void decl_attributes::parse(buffer const & ns, parser & p) { } } -void decl_attributes::parse(name const & n, parser & p) { - buffer ns; - ns.push_back(n); - parse(ns, p); -} - -void decl_attributes::parse(parser & p) { - buffer ns; - parse(ns, p); -} - environment decl_attributes::apply(environment env, io_state const & ios, name const & d) const { if (m_is_instance) { if (m_priority) { @@ -221,10 +201,18 @@ environment decl_attributes::apply(environment env, io_state const & ios, name c env = add_user_recursor(env, d, m_recursor_major_pos, m_persistent); if (m_is_class) env = add_class(env, d, m_persistent); - if (m_simp) - env = add_simp_rule(env, d, m_persistent); - if (m_congr) - env = add_congr_rule(env, d, m_persistent); + if (m_simp) { + if (m_priority) + env = add_simp_rule(env, d, *m_priority, m_persistent); + else + env = add_simp_rule(env, d, LEAN_SIMP_DEFAULT_PRIORITY, m_persistent); + } + if (m_congr) { + if (m_priority) + env = add_congr_rule(env, d, *m_priority, m_persistent); + else + env = add_congr_rule(env, d, LEAN_SIMP_DEFAULT_PRIORITY, m_persistent); + } if (m_has_multiple_instances) env = mark_multiple_instances(env, d, m_persistent); return env; diff --git a/src/frontends/lean/decl_attributes.h b/src/frontends/lean/decl_attributes.h index 507c43b62..d12e811ff 100644 --- a/src/frontends/lean/decl_attributes.h +++ b/src/frontends/lean/decl_attributes.h @@ -34,11 +34,8 @@ class decl_attributes { optional m_recursor_major_pos; optional m_priority; list m_unfold_hint; - - void parse(name const & n, parser & p); public: decl_attributes(bool is_abbrev = false, bool persistent = true); - void parse(buffer const & ns, parser & p); void parse(parser & p); environment apply(environment env, io_state const & ios, name const & d) const; bool is_parsing_only() const { return m_is_parsing_only; } diff --git a/src/frontends/lean/decl_cmds.cpp b/src/frontends/lean/decl_cmds.cpp index 41b138193..357a963e9 100644 --- a/src/frontends/lean/decl_cmds.cpp +++ b/src/frontends/lean/decl_cmds.cpp @@ -1259,7 +1259,7 @@ static environment attribute_cmd_core(parser & p, bool persistent) { } bool abbrev = false; decl_attributes attributes(abbrev, persistent); - attributes.parse(ds, p); + attributes.parse(p); environment env = p.env(); for (name const & d : ds) env = attributes.apply(env, p.ios(), d); diff --git a/src/library/blast/simplifier.cpp b/src/library/blast/simplifier.cpp index e00345a66..8e879bd12 100644 --- a/src/library/blast/simplifier.cpp +++ b/src/library/blast/simplifier.cpp @@ -197,7 +197,8 @@ class simplifier { } tmp_type_context tctx(env(), ios()); try { - srss = add(tctx, srss, mlocal_name(l), tctx.infer(l), l); + // TODO(Leo,Daniel): should we allow the user to set the priority of local lemmas + srss = add(tctx, srss, mlocal_name(l), tctx.infer(l), l, LEAN_SIMP_DEFAULT_PRIORITY); } catch (exception e) { } } diff --git a/src/library/simplifier/simp_rule_set.cpp b/src/library/simplifier/simp_rule_set.cpp index b56f68c82..1a75d0f02 100644 --- a/src/library/simplifier/simp_rule_set.cpp +++ b/src/library/simplifier/simp_rule_set.cpp @@ -18,13 +18,15 @@ Author: Leonardo de Moura namespace lean { simp_rule_core::simp_rule_core(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof): + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + unsigned priority): m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances), - m_lhs(lhs), m_rhs(rhs), m_proof(proof) {} + m_lhs(lhs), m_rhs(rhs), m_proof(proof), m_priority(priority) {} simp_rule::simp_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, bool is_perm): - simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof), + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + bool is_perm, unsigned priority): + simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), m_is_permutation(is_perm) {} bool operator==(simp_rule const & r1, simp_rule const & r2) { @@ -34,6 +36,8 @@ bool operator==(simp_rule const & r1, simp_rule const & r2) { format simp_rule::pp(formatter const & fmt) const { format r; r += format("#") + format(get_num_emeta()); + if (m_priority != LEAN_SIMP_DEFAULT_PRIORITY) + r += space() + paren(format(m_priority)); if (m_is_permutation) r += space() + format("perm"); format r1 = comma() + space() + fmt(get_lhs()); @@ -44,8 +48,8 @@ format simp_rule::pp(formatter const & fmt) const { congr_rule::congr_rule(name const & id, levels const & umetas, list const & emetas, list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - list const & congr_hyps): - simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof), + list const & congr_hyps, unsigned priority): + simp_rule_core(id, umetas, emetas, instances, lhs, rhs, proof, priority), m_congr_hyps(congr_hyps) {} bool operator==(congr_rule const & r1, congr_rule const & r2) { @@ -55,6 +59,8 @@ bool operator==(congr_rule const & r1, congr_rule const & r2) { format congr_rule::pp(formatter const & fmt) const { format r; r += format("#") + format(get_num_emeta()); + if (m_priority != LEAN_SIMP_DEFAULT_PRIORITY) + r += space() + paren(format(m_priority)); format r1; for (expr const & h : m_congr_hyps) { r1 += space() + paren(fmt(mlocal_type(h))); @@ -257,7 +263,8 @@ format simp_rule_sets::pp(formatter const & fmt) const { static name * g_prefix = nullptr; simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, - name const & id, levels const & univ_metas, expr const & e, expr const & h) { + name const & id, levels const & univ_metas, expr const & e, expr const & h, + unsigned priority) { list ceqvs = to_ceqvs(tctx, e, h); if (is_nil(ceqvs)) throw exception("[simp] rule invalid"); environment const & env = tctx.env(); @@ -278,14 +285,14 @@ simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, expr rel, lhs, rhs; if (is_simp_relation(env, rule, rel, lhs, rhs) && is_constant(rel)) { new_s.insert(const_name(rel), simp_rule(id, univ_metas, reverse_to_list(emetas), - reverse_to_list(instances), lhs, rhs, proof, is_perm)); + reverse_to_list(instances), lhs, rhs, proof, is_perm, priority)); } } return new_s; } -simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h) { - return add_core(tctx, s, id, list(), e, h); +simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority) { + return add_core(tctx, s, id, list(), e, h, priority); } simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2) { @@ -299,7 +306,7 @@ simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2) { static name * g_class_name = nullptr; static std::string * g_key = nullptr; -static simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & cname) { +static simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & cname, unsigned priority) { declaration const & d = tctx.env().get(cname); buffer us; unsigned num_univs = d.get_num_univ_params(); @@ -309,7 +316,7 @@ static simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s levels ls = to_list(us); expr e = tctx.whnf(instantiate_type_univ_params(d, ls)); expr h = mk_constant(cname, ls); - return add_core(tctx, s, cname, ls, e, h); + return add_core(tctx, s, cname, ls, e, h, priority); } @@ -356,7 +363,7 @@ static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) { return true; } -void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n) { +void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned prio) { declaration const & d = tctx.env().get(n); buffer us; unsigned num_univs = d.get_num_univ_params(); @@ -459,7 +466,7 @@ void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n) } } s.insert(const_name(rel), congr_rule(n, ls, reverse_to_list(emetas), - reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps))); + reverse_to_list(instances), lhs, rhs, proof, to_list(congr_hyps), prio)); } struct rrs_state { @@ -467,27 +474,35 @@ struct rrs_state { name_set m_simp_names; name_set m_congr_names; - void add_simp(environment const & env, io_state const & ios, name const & cname) { + void add_simp(environment const & env, io_state const & ios, name const & cname, unsigned prio) { tmp_type_context tctx{env, ios}; - m_sets = add_core(tctx, m_sets, cname); + m_sets = add_core(tctx, m_sets, cname, prio); m_simp_names.insert(cname); } - void add_congr(environment const & env, io_state const & ios, name const & n) { + void add_congr(environment const & env, io_state const & ios, name const & n, unsigned prio) { tmp_type_context tctx{env, ios}; - add_congr_core(tctx, m_sets, n); + add_congr_core(tctx, m_sets, n, prio); m_congr_names.insert(n); } }; +struct rrs_entry { + bool m_is_simp; + name m_name; + unsigned m_priority; + rrs_entry() {} + rrs_entry(bool is_simp, name const & n, unsigned prio):m_is_simp(is_simp), m_name(n), m_priority(prio) {} +}; + struct rrs_config { - typedef pair entry; - typedef rrs_state state; + typedef rrs_entry entry; + typedef rrs_state state; static void add_entry(environment const & env, io_state const & ios, state & s, entry const & e) { - if (e.first) - s.add_simp(env, ios, e.second); + if (e.m_is_simp) + s.add_simp(env, ios, e.m_name, e.m_priority); else - s.add_congr(env, ios, e.second); + s.add_congr(env, ios, e.m_name, e.m_priority); } static name const & get_class_name() { return *g_class_name; @@ -496,25 +511,25 @@ struct rrs_config { return *g_key; } static void write_entry(serializer & s, entry const & e) { - s << e.first << e.second; + s << e.m_is_simp << e.m_name << e.m_priority; } static entry read_entry(deserializer & d) { - entry e; d >> e.first >> e.second; return e; + entry e; d >> e.m_is_simp >> e.m_name >> e.m_priority; return e; } static optional get_fingerprint(entry const & e) { - return some(hash(e.first ? 17 : 31, e.second.hash())); + return some(hash(e.m_is_simp ? 17 : 31, e.m_name.hash())); } }; template class scoped_ext; typedef scoped_ext rrs_ext; -environment add_simp_rule(environment const & env, name const & n, bool persistent) { - return rrs_ext::add_entry(env, get_dummy_ios(), mk_pair(true, n), persistent); +environment add_simp_rule(environment const & env, name const & n, unsigned prio, bool persistent) { + return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(true, n, prio), persistent); } -environment add_congr_rule(environment const & env, name const & n, bool persistent) { - return rrs_ext::add_entry(env, get_dummy_ios(), mk_pair(false, n), persistent); +environment add_congr_rule(environment const & env, name const & n, unsigned prio, bool persistent) { + return rrs_ext::add_entry(env, get_dummy_ios(), rrs_entry(false, n, prio), persistent); } bool is_simp_rule(environment const & env, name const & n) { @@ -531,11 +546,11 @@ simp_rule_sets get_simp_rule_sets(environment const & env) { simp_rule_sets get_simp_rule_sets(environment const & env, io_state const & ios, name const & ns) { simp_rule_sets set; - list> const * cnames = rrs_ext::get_entries(env, ns); - if (cnames) { - for (pair const & p : *cnames) { + list const * entries = rrs_ext::get_entries(env, ns); + if (entries) { + for (auto const & e : *entries) { tmp_type_context tctx(env, ios); - set = add_core(tctx, set, p.second); + set = add_core(tctx, set, e.m_name, e.m_priority); } } return set; diff --git a/src/library/simplifier/simp_rule_set.h b/src/library/simplifier/simp_rule_set.h index 950b55542..9147d4fa4 100644 --- a/src/library/simplifier/simp_rule_set.h +++ b/src/library/simplifier/simp_rule_set.h @@ -10,6 +10,10 @@ Author: Leonardo de Moura #include "library/io_state_stream.h" #include +#ifndef LEAN_SIMP_DEFAULT_PRIORITY +#define LEAN_SIMP_DEFAULT_PRIORITY 1000 +#endif + namespace lean { class simp_rule_sets; @@ -23,8 +27,10 @@ protected: expr m_lhs; expr m_rhs; expr m_proof; + unsigned m_priority; simp_rule_core(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof); + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + unsigned priority); public: name const & get_id() const { return m_id; } unsigned get_num_umeta() const { return length(m_umetas); } @@ -33,10 +39,12 @@ public: /** \brief Return a list containing the expression metavariables in reverse order. */ list const & get_emetas() const { return m_emetas; } - /** \brief Return a list of bools indicating whether or not each expression metavariable + /** \brief Return a list of bools indicating whether or not each expression metavariable in get_emetas() is an instance. */ list const & get_instances() const { return m_instances; } + unsigned get_priority() const { return m_priority; } + expr const & get_lhs() const { return m_lhs; } expr const & get_rhs() const { return m_rhs; } expr const & get_proof() const { return m_proof; } @@ -45,9 +53,11 @@ public: class simp_rule : public simp_rule_core { bool m_is_permutation; simp_rule(name const & id, levels const & umetas, list const & emetas, - list const & instances, expr const & lhs, expr const & rhs, expr const & proof, bool is_perm); + list const & instances, expr const & lhs, expr const & rhs, expr const & proof, + bool is_perm, unsigned priority); + friend simp_rule_sets add_core(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, - levels const & univ_metas, expr const & e, expr const & h); + levels const & univ_metas, expr const & e, expr const & h, unsigned priority); public: friend bool operator==(simp_rule const & r1, simp_rule const & r2); bool is_perm() const { return m_is_permutation; } @@ -61,20 +71,22 @@ class congr_rule : public simp_rule_core { list m_congr_hyps; congr_rule(name const & id, levels const & umetas, list const & emetas, list const & instances, expr const & lhs, expr const & rhs, expr const & proof, - list const & congr_hyps); - friend void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n); + list const & congr_hyps, unsigned priority); + friend void add_congr_core(tmp_type_context & tctx, simp_rule_sets & s, name const & n, unsigned priority); public: friend bool operator==(congr_rule const & r1, congr_rule const & r2); list const & get_congr_hyps() const { return m_congr_hyps; } format pp(formatter const & fmt) const; }; +struct simp_rule_core_prio_fn { unsigned operator()(simp_rule_core const & s) const { return s.get_priority(); } }; + bool operator==(congr_rule const & r1, congr_rule const & r2); inline bool operator!=(congr_rule const & r1, congr_rule const & r2) { return !operator==(r1, r2); } class simp_rule_set { - typedef head_map simp_set; - typedef head_map congr_set; + typedef head_map_prio simp_set; + typedef head_map_prio congr_set; name m_eqv; simp_set m_simp_set; congr_set m_congr_set; @@ -120,11 +132,11 @@ public: format pp(formatter const & fmt) const; }; -simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h); +simp_rule_sets add(tmp_type_context & tctx, simp_rule_sets const & s, name const & id, expr const & e, expr const & h, unsigned priority); simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2); -environment add_simp_rule(environment const & env, name const & n, bool persistent = true); -environment add_congr_rule(environment const & env, name const & n, bool persistent = true); +environment add_simp_rule(environment const & env, name const & n, unsigned priority, bool persistent); +environment add_congr_rule(environment const & env, name const & n, unsigned priority, bool persistent); /** \brief Return true if \c n is an active simplification rule in \c env */ bool is_simp_rule(environment const & env, name const & n); diff --git a/tests/lean/704.lean b/tests/lean/704.lean index 0457fcfb7..1e4a305e7 100644 --- a/tests/lean/704.lean +++ b/tests/lean/704.lean @@ -1,4 +1,4 @@ open classical eval if true then 1 else (0:num) -attribute prop_decidable [priority 0] +attribute prop_decidable [instance] [priority 0] eval if true then 1 else (0:num) diff --git a/tests/lean/run/priority_test2.lean b/tests/lean/run/priority_test2.lean index 05f9121c4..d19d508a4 100644 --- a/tests/lean/run/priority_test2.lean +++ b/tests/lean/run/priority_test2.lean @@ -27,12 +27,12 @@ foo.mk 4 4 example : foo.a = 3 := rfl -attribute i4 [priority std.priority.default+2] +attribute i4 [instance] [priority std.priority.default+2] example : foo.a = 4 := rfl -attribute i1 [priority std.priority.default+3] +attribute i1 [instance] [priority std.priority.default+3] example : foo.a = 1 := rfl