From 41f5e2a067a80fd246ad9b724431eff36765a5af Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 29 Jan 2014 18:32:40 -0800 Subject: [PATCH] feat(library/simplifier): statically check (conditional) equations (aka rewrite rules) to verify whether we can skip type checking when using them in the simplifier Signed-off-by: Leonardo de Moura --- src/library/simplifier/ceq.cpp | 113 ++++++++++++++++++++ src/library/simplifier/ceq.h | 39 +++++++ src/library/simplifier/rewrite_rule_set.cpp | 4 +- src/library/simplifier/simplifier.cpp | 53 +++------ src/library/simplifier/simplifier.h | 11 +- src/library/tactic/simplify_tactic.cpp | 2 +- tests/lean/rs.lean | 8 ++ tests/lean/rs.lean.expected.out | 23 ++++ 8 files changed, 209 insertions(+), 44 deletions(-) create mode 100644 tests/lean/rs.lean create mode 100644 tests/lean/rs.lean.expected.out diff --git a/src/library/simplifier/ceq.cpp b/src/library/simplifier/ceq.cpp index 8449bad2b..552f26137 100644 --- a/src/library/simplifier/ceq.cpp +++ b/src/library/simplifier/ceq.cpp @@ -256,6 +256,119 @@ bool is_permutation_ceq(expr e) { } } +// Quick approximate test for e == (Type U). +// If the result is true, then \c e is definitionally equal to TypeU. +// If the result is false, then it may or may not be. +static bool is_TypeU(ro_environment const & env, expr const & e) { + if (is_type(e)) { + return e == TypeU; + } else if (is_constant(e)) { + auto obj = env->find_object(const_name(e)); + return obj && obj->is_definition() && is_TypeU(obj->get_value()); + } else { + return false; + } +} + +bool is_safe_to_skip_check_ceq_types(ro_environment const & env, optional const & menv, expr ceq) { + lean_assert(is_ceq(env, menv, ceq)); + type_checker tc(env); + buffer args; + buffer skip; + unsigned next_idx = 0; + bool to_check = false; + while (is_pi(ceq)) { + expr d = abst_domain(ceq); + expr a = mk_constant(name(g_unique, next_idx), d); + args.push_back(a); + if (tc.is_proposition(d, context(), menv) || + is_TypeU(env, d)) { + // See comment at ceq.h + // 1- The argument has type (Type U). In Lean, (Type U) is the maximal universe. + // 2- The argument is a proposition. + skip.push_back(true); + } else { + skip.push_back(false); + to_check = true; + } + ceq = instantiate(abst_body(ceq), a); + next_idx++; + } + if (!to_check) + return true; + + expr lhs, rhs; + lean_verify(is_equality(ceq, lhs, rhs)); + + auto arg_idx_core_fn = [&](expr const & e) -> optional { + if (is_constant(e)) { + name const & n = const_name(e); + if (!n.is_atomic() && n.get_prefix() == g_unique) { + return some(n.get_numeral()); + } + } + return optional(); + }; + + auto arg_idx_fn = [&](expr const & e) -> optional { + if (is_app(e)) + return arg_idx_core_fn(arg(e, 0)); + else if (is_lambda(e)) + return arg_idx_core_fn(abst_body(e)); + else + return arg_idx_core_fn(e); + }; + + // Return true if the application \c e has an argument or an + // application (f ...) where f is an argument. + auto has_target_fn = [&](expr const & e) -> bool { + lean_assert(is_app(e)); + for (unsigned i = 1; i < num_args(e); i++) { + expr const & a = arg(e, i); + if (arg_idx_fn(a)) + return true; + } + return false; + }; + + // 3- There is an application (f x) in the left-hand-side, and + // the type expected by f is definitionally equal to the argument type. + // 4- There is an application (f (x ...)) in the left-hand-side, and + // the type expected by f is definitionally equal to the type of (x ...) + // 5- There is an application (f (fun y, x)) in the left-hand-side, + // and the type expected by f is definitionally equal to the type of (fun y, x) + std::function visit_fn = + [&](expr const & e, context const & ctx) { + if (is_app(e)) { + expr const & f = arg(e, 0); + if (has_target_fn(e)) { + expr f_type = tc.infer_type(f, ctx, menv); + for (unsigned i = 1; i < num_args(e); i++) { + f_type = tc.ensure_pi(f_type, ctx, menv); + expr const & a = arg(e, i); + auto arg_idx = arg_idx_fn(a); + if (arg_idx && !skip[*arg_idx]) { + expr const & expected_type = abst_domain(f_type); + expr const & given_type = tc.infer_type(a, ctx, menv); + if (tc.is_definitionally_equal(given_type, expected_type)) { + skip[*arg_idx] = true; + } + } + f_type = instantiate(abst_body(f_type), a); + } + } + for (expr const & a : ::lean::args(e)) + visit_fn(a, ctx); + } else if (is_abstraction(e)) { + visit_fn(abst_domain(e), ctx); + visit_fn(abst_body(e), extend(ctx, abst_name(e), abst_body(e))); + } + }; + + visit_fn(lhs, context()); + return std::all_of(skip.begin(), skip.end(), [](bool b) { return b; }); +} + static int to_ceqs(lua_State * L) { ro_shared_environment env(L, 1); optional menv; diff --git a/src/library/simplifier/ceq.h b/src/library/simplifier/ceq.h index 13bbbe8e9..4e10312aa 100644 --- a/src/library/simplifier/ceq.h +++ b/src/library/simplifier/ceq.h @@ -47,5 +47,44 @@ bool is_ceq(ro_environment const & env, optional const & menv, e permutation of the conditional equation arguments. */ bool is_permutation_ceq(expr e); +/* + Given a ceq C, in principle, whenever we want to create an application (C t1 ... tn), + we must check whether the types of t1 ... tn are convertible to the expected types by C. + + This check is needed because of universe cumulativity. + Here is an example that illustrates the issue: + + universe U >= 2 + variable f (A : (Type 1)) : (Type 1) + axiom Ax1 (a : Type) : f a = a + rewrite_set S + add_rewrite Ax1 eq_id : S + theorem T1 (A : (Type 1)) : f A = A + := by simp S + + In this example, Ax1 is a ceq. It has an argument of type Type. + Note that f expects an element of type (Type 1). So, the term (f a) is type correct. + + The axiom Ax1 is only for arguments convertible to Type (i.e., Type 0), but + argument A in T1 lives in (Type 1) + + Scenarios like the one above do not occur very frequently. Moveover, it is quite expensive + to check if the types are convertible for each application of a ceq. + + In most cases, we can statically determine that the checks are not needed when applying + a ceq. Here is a sufficient condition for skipping the test: if for all + arguments x of ceq, one of the following conditions must hold: + 1- The argument has type (Type U). In Lean, (Type U) is the maximal universe. + 2- The argument is a proposition. + 3- There is an application (f x) in the left-hand-side, and + the type expected by f is definitionally equal to the argument type. + 4- There is an application (f (x ...)) in the left-hand-side, and + the type expected by f is definitionally equal to the type of (x ...) + 5- There is an application (f (fun y, x)) in the left-hand-side, + and the type expected by f is definitionally equal to the type of (fun y, x) + \pre is_ceq(env, menv, ceq) +*/ +bool is_safe_to_skip_check_ceq_types(ro_environment const & env, optional const & menv, expr ceq); + void open_ceq(lua_State * L); } diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp index 90d8cdd85..1a476b0fb 100644 --- a/src/library/simplifier/rewrite_rule_set.cpp +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -39,7 +39,7 @@ void rewrite_rule_set::insert(name const & id, expr const & th, expr const & pro num++; } lean_assert(is_equality(eq)); - bool must_check = true; // TODO(Leo): call procedure to test whether we must check types or not. + bool must_check = !is_safe_to_skip_check_ceq_types(m_env, menv, ceq); m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1), ceq, proof, num, is_perm, must_check), m_rule_set); @@ -114,6 +114,8 @@ format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { r += format(rule.get_id()); if (!enabled) r += format(" [disabled]"); + if (rule.must_check_types()) + r += format(" [check]"); r += format{space(), colon(), space()}; r += nest(indent, fmt(rule.get_ceq(), opts)); }); diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 4043ac92b..784045787 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -132,7 +132,6 @@ class simplifier_cell::imp { type_checker m_tc; bool m_has_heq; bool m_has_cast; - context m_ctx; rule_sets m_rule_sets; cache m_cache; max_sharing_fn m_max_sharing; @@ -201,16 +200,16 @@ class simplifier_cell::imp { return mk_lambda(abst_name(abst), abst_domain(abst), new_body); } - bool is_proposition(expr const & e) { return m_tc.is_proposition(e, m_ctx, m_menv.to_some_menv()); } - bool is_convertible(expr const & t1, expr const & t2) { return m_tc.is_convertible(t1, t2, m_ctx, m_menv.to_some_menv()); } + bool is_proposition(expr const & e) { return m_tc.is_proposition(e, context(), m_menv.to_some_menv()); } + bool is_convertible(expr const & t1, expr const & t2) { return m_tc.is_convertible(t1, t2, context(), m_menv.to_some_menv()); } bool is_definitionally_equal(expr const & t1, expr const & t2) { - return m_tc.is_definitionally_equal(t1, t2, m_ctx, m_menv.to_some_menv()); + return m_tc.is_definitionally_equal(t1, t2, context(), m_menv.to_some_menv()); } - expr infer_type(expr const & e) { return m_tc.infer_type(e, m_ctx, m_menv.to_some_menv()); } - expr ensure_pi(expr const & e) { return m_tc.ensure_pi(e, m_ctx, m_menv.to_some_menv()); } + expr infer_type(expr const & e) { return m_tc.infer_type(e, context(), m_menv.to_some_menv()); } + expr ensure_pi(expr const & e) { return m_tc.ensure_pi(e, context(), m_menv.to_some_menv()); } expr normalize(expr const & e) { normalizer & proc = m_tc.get_normalizer(); - return proc(e, m_ctx, m_menv.to_some_menv(), true); + return proc(e, context(), m_menv.to_some_menv(), true); } expr lift_free_vars(expr const & e, unsigned s, unsigned d) { return ::lean::lift_free_vars(e, s, d, m_menv.to_some_menv()); } expr lower_free_vars(expr const & e, unsigned s, unsigned d) { return ::lean::lower_free_vars(e, s, d, m_menv.to_some_menv()); } @@ -1092,7 +1091,7 @@ class simplifier_cell::imp { } new_rhs = lower_free_vars(new_rhs, 1, 1); expr new_rhs_type = ensure_pi(infer_type(new_rhs)); - if (m_tc.is_definitionally_equal(abst_domain(new_rhs_type), abst_domain(rhs.m_expr), m_ctx)) { + if (is_definitionally_equal(abst_domain(new_rhs_type), abst_domain(rhs.m_expr))) { if (m_proofs_enabled) { expr new_proof = mk_eta_th(abst_domain(rhs.m_expr), mk_lambda(rhs.m_expr, abst_body(new_rhs_type)), @@ -1491,13 +1490,6 @@ class simplifier_cell::imp { } } - void set_ctx(context const & ctx) { - if (!is_eqp(m_ctx, ctx)) { - m_cache.clear(); - m_ctx = ctx; - } - } - void set_options(options const & o) { m_proofs_enabled = get_simplifier_proofs(o); m_contextual = get_simplifier_contextual(o); @@ -1528,8 +1520,7 @@ public: m_next_idx = 0; } - result operator()(expr const & e, context const & ctx, optional const & menv) { - set_ctx(ctx); + result operator()(expr const & e, optional const & menv) { if (m_menv.update(menv)) m_cache.clear(); m_num_steps = 0; @@ -1551,12 +1542,11 @@ simplifier_cell::simplifier_cell(ro_environment const & env, options const & o, m_ptr(new imp(env, o, num_rs, rs, monitor)) { } -simplifier_cell::result simplifier_cell::operator()(expr const & e, context const & ctx, optional const & menv) { - return m_ptr->operator()(e, ctx, menv); +simplifier_cell::result simplifier_cell::operator()(expr const & e, optional const & menv) { + return m_ptr->operator()(e, menv); } void simplifier_cell::clear() { return m_ptr->m_cache.clear(); } unsigned simplifier_cell::get_depth() const { return m_ptr->m_depth; } -context const & simplifier_cell::get_context() const { return m_ptr->m_ctx; } ro_environment const & simplifier_cell::get_environment() const { return m_ptr->m_env; } options const & simplifier_cell::get_options() const { return m_ptr->m_options; } @@ -1576,21 +1566,21 @@ ro_simplifier::ro_simplifier(weak_ref const & r) { m_ptr = r.lock(); } -simplifier::result simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, +simplifier::result simplify(expr const & e, ro_environment const & env, options const & opts, unsigned num_rs, rewrite_rule_set const * rs, optional const & menv, std::shared_ptr const & monitor) { - return simplifier(env, opts, num_rs, rs, monitor)(e, ctx, menv); + return simplifier(env, opts, num_rs, rs, monitor)(e, menv); } -simplifier::result simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, +simplifier::result simplify(expr const & e, ro_environment const & env, options const & opts, unsigned num_ns, name const * ns, optional const & menv, std::shared_ptr const & monitor) { buffer rules; for (unsigned i = 0; i < num_ns; i++) rules.push_back(get_rewrite_rule_set(env, ns[i])); - return simplify(e, env, ctx, opts, num_ns, rules.data(), menv, monitor); + return simplify(e, env, opts, num_ns, rules.data(), menv, monitor); } simplifier_stack_space_exception::simplifier_stack_space_exception():stack_space_exception("simplifier") {} @@ -1767,11 +1757,9 @@ static int simplifier_apply(lua_State * L) { int nargs = lua_gettop(L); simplifier::result r; if (nargs == 2) - r = to_simplifier(L, 1)(to_expr(L, 2), context(), none_ro_menv()); - else if (nargs == 3) - r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3), none_ro_menv()); + r = to_simplifier(L, 1)(to_expr(L, 2), none_ro_menv()); else - r = to_simplifier(L, 1)(to_expr(L, 2), to_context(L, 3), some_ro_menv(to_metavar_env(L, 4))); + r = to_simplifier(L, 1)(to_expr(L, 2), some_ro_menv(to_metavar_env(L, 3))); push_expr(L, r.get_expr()); push_optional_expr(L, r.get_proof()); lua_pushboolean(L, r.is_heq_proof()); @@ -1780,11 +1768,9 @@ static int simplifier_apply(lua_State * L) { static int simplifier_clear(lua_State * L) { to_simplifier(L, 1)->clear(); return 0; } static int simplifier_depth(lua_State * L) { lua_pushinteger(L, to_simplifier(L, 1)->get_depth()); return 1; } -static int simplifier_context(lua_State * L) { return push_context(L, to_simplifier(L, 1)->get_context()); } static int simplifier_environment(lua_State * L) { return push_environment(L, to_simplifier(L, 1)->get_environment()); } static int simplifier_options(lua_State * L) { return push_options(L, to_simplifier(L, 1)->get_options()); } static int ro_simplifier_depth(lua_State * L) { lua_pushinteger(L, to_ro_simplifier(L, 1)->get_depth()); return 1; } -static int ro_simplifier_context(lua_State * L) { return push_context(L, to_ro_simplifier(L, 1)->get_context()); } static int ro_simplifier_environment(lua_State * L) { return push_environment(L, to_ro_simplifier(L, 1)->get_environment()); } static int ro_simplifier_options(lua_State * L) { return push_options(L, to_ro_simplifier(L, 1)->get_options()); } @@ -1794,7 +1780,6 @@ static const struct luaL_Reg simplifier_m[] = { {"clear", safe_function}, {"depth", safe_function}, {"get_environment", safe_function}, - {"get_context", safe_function}, {"get_options", safe_function}, {0, 0} }; @@ -1803,7 +1788,6 @@ static const struct luaL_Reg ro_simplifier_m[] = { {"__gc", ro_simplifier_gc}, {"depth", safe_function}, {"get_environment", safe_function}, - {"get_context", safe_function}, {"get_options", safe_function}, {0, 0} }; @@ -1813,13 +1797,10 @@ static int simplify_core(lua_State * L, ro_shared_environment const & env) { expr const & e = to_expr(L, 1); buffer rules; get_rewrite_rule_set(L, 2, env, rules); - context ctx; options opts; if (nargs >= 3) opts = to_options(L, 3); - if (nargs >= 5) - ctx = to_context(L, 5); - auto r = simplify(e, env, ctx, opts, rules.size(), rules.data()); + auto r = simplify(e, env, opts, rules.size(), rules.data()); push_expr(L, r.get_expr()); push_optional_expr(L, r.get_proof()); lua_pushboolean(L, r.is_heq_proof()); diff --git a/src/library/simplifier/simplifier.h b/src/library/simplifier/simplifier.h index 5924e2922..3569a1cb7 100644 --- a/src/library/simplifier/simplifier.h +++ b/src/library/simplifier/simplifier.h @@ -44,11 +44,10 @@ public: simplifier_cell(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs, std::shared_ptr const & monitor); - result operator()(expr const & e, context const & ctx, optional const & menv); + result operator()(expr const & e, optional const & menv); void clear(); unsigned get_depth() const; - context const & get_context() const; ro_environment const & get_environment() const; options const & get_options() const; }; @@ -63,8 +62,8 @@ public: std::shared_ptr const & monitor); simplifier_cell * operator->() const { return m_ptr.get(); } simplifier_cell & operator*() const { return *(m_ptr.get()); } - result operator()(expr const & e, context const & ctx, optional const & menv) { - return (*m_ptr)(e, ctx, menv); + result operator()(expr const & e, optional const & menv) { + return (*m_ptr)(e, menv); } }; @@ -137,11 +136,11 @@ public: virtual void rethrow() const; }; -simplifier::result simplify(expr const & e, ro_environment const & env, context const & ctx, options const & pts, +simplifier::result simplify(expr const & e, ro_environment const & env, options const & pts, unsigned num_rs, rewrite_rule_set const * rs, optional const & menv = none_ro_menv(), std::shared_ptr const & monitor = std::shared_ptr()); -simplifier::result simplify(expr const & e, ro_environment const & env, context const & ctx, options const & opts, +simplifier::result simplify(expr const & e, ro_environment const & env, options const & opts, unsigned num_ns, name const * ns, optional const & menv = none_ro_menv(), std::shared_ptr const & monitor = std::shared_ptr()); diff --git a/src/library/tactic/simplify_tactic.cpp b/src/library/tactic/simplify_tactic.cpp index 5fe14af0c..279d2ec5e 100644 --- a/src/library/tactic/simplify_tactic.cpp +++ b/src/library/tactic/simplify_tactic.cpp @@ -53,7 +53,7 @@ static optional simplify_tactic(ro_environment const & env, io_stat } expr conclusion = g.get_conclusion(); - auto r = simplify(conclusion, env, context(), opts, rule_sets.size(), rule_sets.data(), some_ro_menv(menv)); + auto r = simplify(conclusion, env, opts, rule_sets.size(), rule_sets.data(), some_ro_menv(menv)); expr new_conclusion = r.get_expr(); if (new_conclusion == g.get_conclusion()) return optional(s); diff --git a/tests/lean/rs.lean b/tests/lean/rs.lean new file mode 100644 index 000000000..640534001 --- /dev/null +++ b/tests/lean/rs.lean @@ -0,0 +1,8 @@ +rewrite_set S +variable bracket : Type → Bool +axiom bracket_eq (a : Bool) : bracket a = a +add_rewrite bracket_eq : S +add_rewrite and_truer and_comm not_true not_neq not_and exists_or_distribute exists_and_distributel : S +add_rewrite exists_rem eq_id forall_rem : S +add_rewrite Nat::add_zeror Nat::add_comm Nat::add_assoc Nat::mul_comm not_true not_false : S +print rewrite_set S \ No newline at end of file diff --git a/tests/lean/rs.lean.expected.out b/tests/lean/rs.lean.expected.out new file mode 100644 index 000000000..9de9a95e3 --- /dev/null +++ b/tests/lean/rs.lean.expected.out @@ -0,0 +1,23 @@ + Set: pp::colors + Set: pp::unicode + Assumed: bracket + Assumed: bracket_eq +not_false : ¬ ⊥ ↔ ⊤ +not_true : ¬ ⊤ ↔ ⊥ +Nat::mul_comm : ∀ a b : ℕ, a * b = b * a +Nat::add_assoc : ∀ a b c : ℕ, a + b + c = a + (b + c) +Nat::add_comm : ∀ a b : ℕ, a + b = b + a +Nat::add_zeror : ∀ a : ℕ, a + 0 = a +forall_rem [check] : ∀ (A : TypeU) (H : nonempty A) (p : Bool), (A → p) ↔ p +eq_id : ∀ (A : TypeU) (a : A), a = a ↔ ⊤ +exists_rem : ∀ (A : TypeU) (H : nonempty A) (p : Bool), (∃ x : A, p) ↔ p +exists_and_distributel : ∀ (A : TypeU) (p : Bool) (φ : A → Bool), + (∃ x : A, φ x ∧ p) ↔ (∃ x : A, φ x) ∧ p +exists_or_distribute : ∀ (A : TypeU) (φ ψ : A → Bool), + (∃ x : A, φ x ∨ ψ x) ↔ (∃ x : A, φ x) ∨ (∃ x : A, ψ x) +not_and : ∀ a b : Bool, ¬ (a ∧ b) ↔ ¬ a ∨ ¬ b +not_neq : ∀ (A : TypeU) (a b : A), ¬ a ≠ b ↔ a = b +not_true : ¬ ⊤ ↔ ⊥ +and_comm : ∀ a b : Bool, a ∧ b ↔ b ∧ a +and_truer : ∀ a : Bool, a ∧ ⊤ ↔ a +bracket_eq [check] : ∀ a : Bool, bracket a = a