diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index 3f4544e0c..ca0a83d81 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -391,7 +391,7 @@ static void print_simp_sets(parser & p) { s = get_simp_rule_sets(p.env()); } name prev_eqv; - s.for_each([&](name const & eqv, simp_rule const & rw) { + s.for_each_simp([&](name const & eqv, simp_rule const & rw) { if (prev_eqv != eqv) { out << "simplification rules for " << eqv; if (!ns.is_anonymous()) diff --git a/src/library/simplifier/simp_rule_set.cpp b/src/library/simplifier/simp_rule_set.cpp index 5e569c55e..460e8df24 100644 --- a/src/library/simplifier/simp_rule_set.cpp +++ b/src/library/simplifier/simp_rule_set.cpp @@ -5,7 +5,9 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include +#include "util/sstream.h" #include "kernel/instantiate.h" +#include "kernel/find_fn.h" #include "kernel/error_msgs.h" #include "library/scoped_ext.h" #include "library/expr_pair.h" @@ -14,15 +16,19 @@ Author: Leonardo de Moura #include "library/simplifier/simp_rule_set.h" namespace lean { -bool operator==(simp_rule const & r1, simp_rule const & r2) { - return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; -} +simp_rule_core::simp_rule_core(name const & id, levels const & univ_metas, list const & metas, + expr const & lhs, expr const & rhs, expr const & proof): + m_id(id), m_univ_metas(univ_metas), m_metas(metas), m_lhs(lhs), m_rhs(rhs), m_proof(proof) {} simp_rule::simp_rule(name const & id, levels const & univ_metas, list const & metas, expr const & lhs, expr const & rhs, expr const & proof, bool is_perm): - m_id(id), m_univ_metas(univ_metas), m_metas(metas), m_lhs(lhs), m_rhs(rhs), m_proof(proof), + simp_rule_core(id, univ_metas, metas, lhs, rhs, proof), m_is_permutation(is_perm) {} +bool operator==(simp_rule const & r1, simp_rule const & r2) { + return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs; +} + format simp_rule::pp(formatter const & fmt) const { format r; r += format("#") + format(length(m_metas)); @@ -34,26 +40,67 @@ format simp_rule::pp(formatter const & fmt) const { return r; } +congr_rule::congr_rule(name const & id, levels const & univ_metas, list const & metas, + expr const & lhs, expr const & rhs, expr const & proof, list const & congr_hyps): + simp_rule_core(id, univ_metas, metas, lhs, rhs, proof), + m_congr_hyps(congr_hyps) {} + +bool operator==(congr_rule const & r1, congr_rule const & r2) { + return r1.m_lhs == r2.m_lhs && r1.m_rhs == r2.m_rhs && r1.m_congr_hyps == r2.m_congr_hyps; +} + +format congr_rule::pp(formatter const & fmt) const { + format r; + r += format("#") + format(length(m_metas)); + format r1; + for (expr const & h : m_congr_hyps) { + r1 += space() + paren(pp_indent_expr(fmt, mlocal_type(h))); + } + r += group(r1); + r += space() + format(":") + space(); + format r2 = fmt(m_lhs); + r2 += space() + format("↦") + pp_indent_expr(fmt, m_rhs); + r += group(r2); + return r; +} + simp_rule_set::simp_rule_set(name const & eqv): m_eqv(eqv) {} void simp_rule_set::insert(simp_rule const & r) { - m_set.insert(r.get_lhs(), r); -} - -list const * simp_rule_set::find(head_index const & h) const { - return m_set.find(h); -} - -void simp_rule_set::for_each(std::function const & fn) const { - m_set.for_each_entry([&](head_index const &, simp_rule const & r) { fn(r); }); + m_simp_set.insert(r.get_lhs(), r); } void simp_rule_set::erase(simp_rule const & r) { - m_set.erase(r.get_lhs(), r); + m_simp_set.erase(r.get_lhs(), r); } -void simp_rule_sets::insert(name const & eqv, simp_rule const & r) { +void simp_rule_set::insert(congr_rule const & r) { + m_congr_set.insert(r.get_lhs(), r); +} + +void simp_rule_set::erase(congr_rule const & r) { + m_congr_set.erase(r.get_lhs(), r); +} + +list const * simp_rule_set::find_simp(head_index const & h) const { + return m_simp_set.find(h); +} + +void simp_rule_set::for_each_simp(std::function const & fn) const { + m_simp_set.for_each_entry([&](head_index const &, simp_rule const & r) { fn(r); }); +} + +list const * simp_rule_set::find_congr(head_index const & h) const { + return m_congr_set.find(h); +} + +void simp_rule_set::for_each_congr(std::function const & fn) const { + m_congr_set.for_each_entry([&](head_index const &, congr_rule const & r) { fn(r); }); +} + +template +void simp_rule_sets::insert_core(name const & eqv, R const & r) { simp_rule_set s(eqv); if (auto const * curr = m_sets.find(eqv)) { s = *curr; @@ -62,7 +109,8 @@ void simp_rule_sets::insert(name const & eqv, simp_rule const & r) { m_sets.insert(eqv, s); } -void simp_rule_sets::erase(name const & eqv, simp_rule const & r) { +template +void simp_rule_sets::erase_core(name const & eqv, R const & r) { if (auto const * curr = m_sets.find(eqv)) { simp_rule_set s = *curr; s.erase(r); @@ -73,6 +121,22 @@ void simp_rule_sets::erase(name const & eqv, simp_rule const & r) { } } +void simp_rule_sets::insert(name const & eqv, simp_rule const & r) { + return insert_core(eqv, r); +} + +void simp_rule_sets::erase(name const & eqv, simp_rule const & r) { + return erase_core(eqv, r); +} + +void simp_rule_sets::insert(name const & eqv, congr_rule const & r) { + return insert_core(eqv, r); +} + +void simp_rule_sets::erase(name const & eqv, congr_rule const & r) { + return erase_core(eqv, r); +} + void simp_rule_sets::get_relations(buffer & rs) const { m_sets.for_each([&](name const & r, simp_rule_set const &) { rs.push_back(r); @@ -83,15 +147,29 @@ simp_rule_set const * simp_rule_sets::find(name const & eqv) const { return m_sets.find(eqv); } -list const * simp_rule_sets::find(name const & eqv, head_index const & h) const { +list const * simp_rule_sets::find_simp(name const & eqv, head_index const & h) const { if (auto const * s = m_sets.find(eqv)) - return s->find(h); + return s->find_simp(h); return nullptr; } -void simp_rule_sets::for_each(std::function const & fn) const { +list const * simp_rule_sets::find_congr(name const & eqv, head_index const & h) const { + if (auto const * s = m_sets.find(eqv)) + return s->find_congr(h); + return nullptr; +} + +void simp_rule_sets::for_each_simp(std::function const & fn) const { m_sets.for_each([&](name const & eqv, simp_rule_set const & s) { - s.for_each([&](simp_rule const & r) { + s.for_each_simp([&](simp_rule const & r) { + fn(eqv, r); + }); + }); +} + +void simp_rule_sets::for_each_congr(std::function const & fn) const { + m_sets.for_each([&](name const & eqv, simp_rule_set const & s) { + s.for_each_congr([&](congr_rule const & r) { fn(eqv, r); }); }); @@ -130,7 +208,7 @@ simp_rule_sets add(type_checker & tc, simp_rule_sets const & s, name const & id, simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2) { simp_rule_sets new_s1 = s1; - s2.for_each([&](name const & eqv, simp_rule const & r) { + s2.for_each_simp([&](name const & eqv, simp_rule const & r) { new_s1.insert(eqv, r); }); return new_s1; @@ -152,22 +230,179 @@ static simp_rule_sets add_core(type_checker & tc, simp_rule_sets const & s, name return add_core(tc, s, cname, ls, e, h); } + +// Return true iff lhs is of the form (B (x : ?m1), ?m2) or (B (x : ?m1), ?m2 x), +// where B is lambda or Pi +static bool is_valid_congr_rule_binding_lhs(expr const & lhs, name_set & found_mvars) { + lean_assert(is_binding(lhs)); + expr const & d = binding_domain(lhs); + expr const & b = binding_body(lhs); + if (!is_metavar(d)) + return false; + if (is_metavar(b) && b != d) { + found_mvars.insert(mlocal_name(b)); + found_mvars.insert(mlocal_name(d)); + return true; + } + if (is_app(b) && is_metavar(app_fn(b)) && is_var(app_arg(b), 0) && app_fn(b) != d) { + found_mvars.insert(mlocal_name(app_fn(b))); + found_mvars.insert(mlocal_name(d)); + return true; + } + return false; +} + +// Return true iff all metavariables in e are in found_mvars +static bool only_found_mvars(expr const & e, name_set const & found_mvars) { + return !find(e, [&](expr const & m, unsigned) { + return is_metavar(m) && !found_mvars.contains(mlocal_name(m)); + }); +} + +// Check whether rhs is of the form (mvar l_1 ... l_n) where mvar is a metavariable, +// and l_i's are local constants, and mvar does not occur in found_mvars. +// If it is return true and update found_mvars +static bool is_valid_congr_hyp_rhs(expr const & rhs, name_set & found_mvars) { + buffer rhs_args; + expr const & rhs_fn = get_app_args(rhs, rhs_args); + if (!is_metavar(rhs_fn) || found_mvars.contains(mlocal_name(rhs_fn))) + return false; + for (expr const & arg : rhs_args) + if (!is_local(arg)) + return false; + found_mvars.insert(mlocal_name(rhs_fn)); + return true; +} + +void add_congr_core(environment const & env, simp_rule_sets & s, name const & n) { + declaration const & d = env.get(n); + type_checker tc(env); + buffer us; + unsigned num_univs = d.get_num_univ_params(); + for (unsigned i = 0; i < num_univs; i++) { + us.push_back(mk_meta_univ(name(*g_prefix, i))); + } + levels ls = to_list(us); + expr pr = mk_constant(n, ls); + expr e = instantiate_type_univ_params(d, ls); + buffer explicit_args; + buffer metas; + unsigned idx = 0; + while (is_pi(e)) { + expr mvar = mk_metavar(name(*g_prefix, idx), binding_domain(e)); + idx++; + explicit_args.push_back(is_explicit(binding_info(e))); + metas.push_back(mvar); + e = instantiate(binding_body(e), mvar); + pr = mk_app(pr, mvar); + } + expr rel, lhs, rhs; + if (!is_simp_relation(env, e, rel, lhs, rhs) || !is_constant(rel)) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' resulting type is not of the form t ~ s, where '~' is a transitive and reflexive relation"); + } + name_set found_mvars; + buffer lhs_args, rhs_args; + expr const & lhs_fn = get_app_args(lhs, lhs_args); + expr const & rhs_fn = get_app_args(rhs, rhs_args); + if (is_constant(lhs_fn)) { + if (!is_constant(rhs_fn) || const_name(lhs_fn) != const_name(rhs_fn) || lhs_args.size() != rhs_args.size()) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' resulting type is not of the form (" << const_name(lhs_fn) << " ...) " + << " ~ (" << const_name(lhs_fn) << " ...), where ~ is '" << const_name(rel) << "'"); + } + for (expr const & lhs_arg : lhs_args) { + if (is_sort(lhs_arg)) + continue; + if (!is_metavar(lhs_arg) || found_mvars.contains(mlocal_name(lhs_arg))) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' the left-hand-side of the congruence resulting type must be of the form (" + << const_name(lhs_fn) << " x_1 ... x_n), where each x_i is a distinct variable or a sort"); + } + found_mvars.insert(mlocal_name(lhs_arg)); + } + } else if (is_binding(lhs)) { + if (lhs.kind() != rhs.kind()) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' kinds of the left-hand-side and right-hand-side of " + << "the congruence resulting type do not match"); + } + if (!is_valid_congr_rule_binding_lhs(lhs, found_mvars)) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' left-hand-side of the congruence resulting type must " + << "be of the form (fun/Pi (x : A), B x)"); + } + } else { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' left-hand-side is not an application nor a binding"); + } + + buffer congr_hyps; + lean_assert(metas.size() == explicit_args.size()); + for (unsigned i = 0; i < metas.size(); i++) { + expr const & mvar = metas[i]; + if (explicit_args[i] && !found_mvars.contains(mlocal_name(mvar))) { + buffer locals; + expr type = mlocal_type(mvar); + while (is_pi(type)) { + expr local = mk_local(tc.mk_fresh_name(), binding_domain(type)); + locals.push_back(local); + type = instantiate(binding_body(type), local); + } + expr h_rel, h_lhs, h_rhs; + if (!is_simp_relation(env, type, h_rel, h_lhs, h_rhs) || !is_constant(h_rel)) + continue; + unsigned j = 0; + for (expr const & local : locals) { + j++; + if (!only_found_mvars(mlocal_type(local), found_mvars)) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' argument #" << j << " is a congruence hypothesis, but it contains " + << "unresolved parameters"); + } + } + if (!only_found_mvars(h_lhs, found_mvars)) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' argument #" << j << " is not a valid hypothesis, the left-hand-side contains " + << "unresolved parameters"); + } + if (!is_valid_congr_hyp_rhs(h_rhs, found_mvars)) { + throw exception(sstream() << "invalid congruence rule, '" << n + << "' argument #" << j << " is not a valid hypothesis, the right-hand-side must be " + << "of the form (m l_1 ... l_n) where m is parameter that was not " + << "'assigned/resolved' yet and l_i's are locals"); + } + found_mvars.insert(mlocal_name(mvar)); + congr_hyps.push_back(mvar); + } + } + congr_rule rule(n, ls, to_list(metas), lhs, rhs, pr, to_list(congr_hyps)); + s.insert(const_name(rel), rule); +} + struct rrs_state { simp_rule_sets m_sets; name_set m_snames; - void add(environment const & env, name const & cname) { + void add_simp(environment const & env, name const & cname) { type_checker tc(env); m_sets = add_core(tc, m_sets, cname); m_snames.insert(cname); } + + void add_congr(environment const & env, name const & n) { + add_congr_core(env, m_sets, n); + } }; struct rrs_config { - typedef name entry; - typedef rrs_state state; + typedef pair entry; + typedef rrs_state state; static void add_entry(environment const & env, io_state const &, state & s, entry const & e) { - s.add(env, e); + if (e.first) + s.add_simp(env, e.second); + else + s.add_congr(env, e.second); } static name const & get_class_name() { return *g_class_name; @@ -176,13 +411,13 @@ struct rrs_config { return *g_key; } static void write_entry(serializer & s, entry const & e) { - s << e; + s << e.first << e.second; } static entry read_entry(deserializer & d) { - entry e; d >> e; return e; + entry e; d >> e.first >> e.second; return e; } static optional get_fingerprint(entry const & e) { - return some(e.hash()); + return some(hash(e.first ? 17 : 31, e.second.hash())); } }; @@ -190,12 +425,11 @@ template class scoped_ext; typedef scoped_ext rrs_ext; environment add_simp_rule(environment const & env, name const & n, bool persistent) { - return rrs_ext::add_entry(env, get_dummy_ios(), n, persistent); + return rrs_ext::add_entry(env, get_dummy_ios(), mk_pair(true, n), persistent); } environment add_congr_rule(environment const & env, name const & n, bool persistent) { - // TODO(Leo): - return env; + return rrs_ext::add_entry(env, get_dummy_ios(), mk_pair(false, n), persistent); } bool is_simp_rule(environment const & env, name const & n) { @@ -208,11 +442,11 @@ simp_rule_sets get_simp_rule_sets(environment const & env) { simp_rule_sets get_simp_rule_sets(environment const & env, name const & ns) { simp_rule_sets set; - list const * cnames = rrs_ext::get_entries(env, ns); + list> const * cnames = rrs_ext::get_entries(env, ns); if (cnames) { type_checker tc(env); - for (name const & cname : *cnames) { - set = add_core(tc, set, cname); + for (pair const & p : *cnames) { + set = add_core(tc, set, p.second); } } return set; diff --git a/src/library/simplifier/simp_rule_set.h b/src/library/simplifier/simp_rule_set.h index 62ea456ca..07d39d628 100644 --- a/src/library/simplifier/simp_rule_set.h +++ b/src/library/simplifier/simp_rule_set.h @@ -11,18 +11,16 @@ Author: Leonardo de Moura namespace lean { class simp_rule_sets; -class simp_rule { +class simp_rule_core { +protected: name m_id; levels m_univ_metas; list m_metas; expr m_lhs; expr m_rhs; expr m_proof; - bool m_is_permutation; - simp_rule(name const & id, levels const & univ_metas, list const & metas, - expr const & lhs, expr const & rhs, expr const & proof, bool is_perm); - friend simp_rule_sets add_core(type_checker & tc, simp_rule_sets const & s, name const & id, - levels const & univ_metas, expr const & e, expr const & h); + simp_rule_core(name const & id, levels const & univ_metas, list const & metas, + expr const & lhs, expr const & rhs, expr const & proof); public: name const & get_id() const { return m_id; } levels const & get_univ_metas() const { return m_univ_metas; } @@ -30,39 +28,74 @@ public: expr const & get_lhs() const { return m_lhs; } expr const & get_rhs() const { return m_rhs; } expr const & get_proof() const { return m_proof; } - bool is_perm() const { return m_is_permutation; } +}; + +class simp_rule : public simp_rule_core { + bool m_is_permutation; + simp_rule(name const & id, levels const & univ_metas, list const & metas, + expr const & lhs, expr const & rhs, expr const & proof, bool is_perm); + friend simp_rule_sets add_core(type_checker & tc, simp_rule_sets const & s, name const & id, + levels const & univ_metas, expr const & e, expr const & h); +public: friend bool operator==(simp_rule const & r1, simp_rule const & r2); + bool is_perm() const { return m_is_permutation; } format pp(formatter const & fmt) const; }; bool operator==(simp_rule const & r1, simp_rule const & r2); inline bool operator!=(simp_rule const & r1, simp_rule const & r2) { return !operator==(r1, r2); } +class congr_rule : public simp_rule_core { + list m_congr_hyps; + congr_rule(name const & id, levels const & univ_metas, list const & metas, + expr const & lhs, expr const & rhs, expr const & proof, list const & congr_hyps); + friend void add_congr_core(environment const & env, simp_rule_sets & s, name const & n); +public: + friend bool operator==(congr_rule const & r1, congr_rule const & r2); + list const & get_congr_hyps() const { return m_congr_hyps; } + format pp(formatter const & fmt) const; +}; + +bool operator==(congr_rule const & r1, congr_rule const & r2); +inline bool operator!=(congr_rule const & r1, congr_rule const & r2) { return !operator==(r1, r2); } + class simp_rule_set { - typedef head_map rule_set; - name m_eqv; - rule_set m_set; + typedef head_map simp_set; + typedef head_map congr_set; + name m_eqv; + simp_set m_simp_set; + congr_set m_congr_set; public: simp_rule_set() {} /** \brief Return the equivalence relation associated with this set */ simp_rule_set(name const & eqv); - bool empty() const { return m_set.empty(); } + bool empty() const { return m_simp_set.empty() && m_congr_set.empty(); } name const & get_eqv() const { return m_eqv; } void insert(simp_rule const & r); void erase(simp_rule const & r); - list const * find(head_index const & h) const; - void for_each(std::function const & fn) const; + void insert(congr_rule const & r); + void erase(congr_rule const & r); + list const * find_simp(head_index const & h) const; + void for_each_simp(std::function const & fn) const; + list const * find_congr(head_index const & h) const; + void for_each_congr(std::function const & fn) const; }; class simp_rule_sets { name_map m_sets; // mapping from relation name to simp_rule_set + template void insert_core(name const & eqv, R const & r); + template void erase_core(name const & eqv, R const & r); public: void insert(name const & eqv, simp_rule const & r); void erase(name const & eqv, simp_rule const & r); + void insert(name const & eqv, congr_rule const & r); + void erase(name const & eqv, congr_rule const & r); void get_relations(buffer & rs) const; simp_rule_set const * find(name const & eqv) const; - list const * find(name const & eqv, head_index const & h) const; - void for_each(std::function const & fn) const; + list const * find_simp(name const & eqv, head_index const & h) const; + list const * find_congr(name const & eqv, head_index const & h) const; + void for_each_simp(std::function const & fn) const; + void for_each_congr(std::function const & fn) const; }; simp_rule_sets add(type_checker & tc, simp_rule_sets const & s, name const & id, expr const & e, expr const & h); @@ -71,11 +104,11 @@ simp_rule_sets join(simp_rule_sets const & s1, simp_rule_sets const & s2); environment add_simp_rule(environment const & env, name const & n, bool persistent = true); environment add_congr_rule(environment const & env, name const & n, bool persistent = true); -/** \brief Return true if \c n is an active rewrite rule in \c env */ +/** \brief Return true if \c n is an active simplification rule in \c env */ bool is_simp_rule(environment const & env, name const & n); -/** \brief Get current rewrite rule sets */ +/** \brief Get current simplification rule sets */ simp_rule_sets get_simp_rule_sets(environment const & env); -/** \brief Get rewrite rule sets in the given namespace. */ +/** \brief Get simplification rule sets in the given namespace. */ simp_rule_sets get_simp_rule_sets(environment const & env, name const & ns); void initialize_simp_rule_set(); void finalize_simp_rule_set();