diff --git a/src/library/simplifier/congr.h b/src/library/simplifier/congr.h index 6655237da..77851723a 100644 --- a/src/library/simplifier/congr.h +++ b/src/library/simplifier/congr.h @@ -106,6 +106,7 @@ public: unsigned get_pos_at_proof() const { return m_proof_arg_pos; } optional const & get_new_pos_at_proof() const { return m_proof_new_arg_pos; } optional const & get_proof_pos_at_proof() const { return m_proof_proof_pos; } + bool should_simplify() const { return static_cast(get_new_pos_at_proof()); } void display(std::ostream & out) const; }; diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp index b0a380a26..88ef760b2 100644 --- a/src/library/simplifier/rewrite_rule_set.cpp +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -22,7 +22,7 @@ rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {} rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other): - m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules) {} + m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules), m_congr_thms(other.m_congr_thms) {} rewrite_rule_set::~rewrite_rule_set() {} void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) { @@ -94,6 +94,12 @@ void rewrite_rule_set::for_each(visit_fn const & fn) const { } } +void rewrite_rule_set::for_each_congr(visit_congr_fn const & fn) const { + for (auto const & congr_th : m_congr_thms) { + fn(congr_th); + } +} + format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { format r; bool first = true; diff --git a/src/library/simplifier/rewrite_rule_set.h b/src/library/simplifier/rewrite_rule_set.h index 02c8a9aee..efc7fb533 100644 --- a/src/library/simplifier/rewrite_rule_set.h +++ b/src/library/simplifier/rewrite_rule_set.h @@ -92,6 +92,11 @@ public: /** \brief Execute fn(rule, enabled) for each rule in this rule set. */ void for_each(visit_fn const & fn) const; + typedef std::function visit_congr_fn; // NOLINT + + /** \brief Execute fn(congr_th) for each congruence theorem in this rule set. */ + void for_each_congr(visit_congr_fn const & fn) const; + /** \brief Pretty print this rule set. */ format pp(formatter const & fmt, options const & opts) const; }; diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 2e0eeafa3..02b79bd2e 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -12,6 +12,7 @@ Author: Leonardo de Moura #include "kernel/type_checker.h" #include "kernel/free_vars.h" #include "kernel/instantiate.h" +#include "kernel/abstract.h" #include "kernel/normalizer.h" #include "kernel/kernel.h" #include "kernel/max_sharing.h" @@ -85,36 +86,20 @@ RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simp RegisterBoolOption(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE, "(simplifier) memoize/cache intermediate results"); RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps"); -bool get_simplifier_proofs(options const & opts) { - return opts.get_bool(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS); -} -bool get_simplifier_contextual(options const & opts) { - return opts.get_bool(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL); -} -bool get_simplifier_single_pass(options const & opts) { - return opts.get_bool(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS); -} -bool get_simplifier_beta(options const & opts) { - return opts.get_bool(g_simplifier_beta, LEAN_SIMPLIFIER_BETA); -} -bool get_simplifier_eta(options const & opts) { - return opts.get_bool(g_simplifier_eta, LEAN_SIMPLIFIER_ETA); -} -bool get_simplifier_eval(options const & opts) { - return opts.get_bool(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL); -} -bool get_simplifier_unfold(options const & opts) { - return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD); -} -bool get_simplifier_conditional(options const & opts) { - return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL); -} -bool get_simplifier_memoize(options const & opts) { - return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE); -} -unsigned get_simplifier_max_steps(options const & opts) { - return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS); -} +bool get_simplifier_proofs(options const & opts) { return opts.get_bool(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS); } +bool get_simplifier_contextual(options const & opts) { return opts.get_bool(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL); } +bool get_simplifier_single_pass(options const & opts) { return opts.get_bool(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS); } +bool get_simplifier_beta(options const & opts) { return opts.get_bool(g_simplifier_beta, LEAN_SIMPLIFIER_BETA); } +bool get_simplifier_eta(options const & opts) { return opts.get_bool(g_simplifier_eta, LEAN_SIMPLIFIER_ETA); } +bool get_simplifier_eval(options const & opts) { return opts.get_bool(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL); } +bool get_simplifier_unfold(options const & opts) { return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD); } +bool get_simplifier_conditional(options const & opts) { return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL); } +bool get_simplifier_memoize(options const & opts) { return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE); } +unsigned get_simplifier_max_steps(options const & opts) { return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS); } + +static name g_local("local"); +static name g_C("C"); +static name g_unique = name::mk_internal_unique_name(); class simplifier_fn { struct result { @@ -131,6 +116,7 @@ class simplifier_fn { typedef std::vector rule_sets; typedef expr_map cache; + typedef std::vector congr_thms; ro_environment m_env; type_checker m_tc; bool m_has_heq; @@ -138,7 +124,8 @@ class simplifier_fn { rule_sets m_rule_sets; cache m_cache; max_sharing_fn m_max_sharing; - + congr_thms m_congr_thms; + unsigned m_contextual_depth; // number of contextual simplification steps in the current branch unsigned m_num_steps; // number of steps performed // Configuration @@ -159,6 +146,15 @@ class simplifier_fn { set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx), m_reset_cache(s.m_cache) {} }; + struct updt_rule_set { + rewrite_rule_set & m_rs; + rewrite_rule_set m_old; + updt_rule_set(rewrite_rule_set & rs, expr const & fact, expr const & proof):m_rs(rs), m_old(m_rs) { + m_rs.insert(g_local, fact, proof); + } + ~updt_rule_set() { m_rs = m_old; } + }; + /** \brief Return a lambda with body \c new_body, and name and domain from abst. */ @@ -170,6 +166,10 @@ class simplifier_fn { return m_tc.is_proposition(e, m_ctx); } + bool is_eq_convertible(expr const & t1, expr const & t2) { + return m_tc.is_eq_convertible(t1, t2, m_ctx); + } + expr infer_type(expr const & e) { return m_tc.infer_type(e, m_ctx); } @@ -300,6 +300,129 @@ class simplifier_fn { } result simplify_app(expr const & e) { + if (m_contextual) { + expr const & f = arg(e, 0); + for (auto congr_th : m_congr_thms) { + if (congr_th->get_fun() == f) + return simplify_app_congr(e, *congr_th); + } + } + return simplify_app_default(e); + } + + /** + \brief Make sure the proof in rhs is using homogeneous equality, and return true. + If it is not possible to transform it in a homogeneous equality proof, then return false. + */ + bool ensure_homogeneous(expr const & lhs, result & rhs) { + if (rhs.m_heq_proof) { + // try to convert back to homogeneous + lean_assert(rhs.m_proof); + expr lhs_type = infer_type(lhs); + expr rhs_type = infer_type(rhs.m_out); + if (is_eq_convertible(lhs_type, rhs_type)) { + // move back to homogeneous equality using to_eq + rhs.m_proof = mk_to_eq_th(lhs_type, lhs, rhs.m_out, *rhs.m_proof); + return true; + } else { + return false; + } + } else { + return true; + } + } + + expr get_proof(result const & rhs) { + if (rhs.m_proof) { + return *rhs.m_proof; + } else { + // lhs and rhs are definitionally equal + return mk_refl_th(infer_type(rhs.m_out), rhs.m_out); + } + } + + /** + \brief Simplify \c e using the given congruence theorem. + See congr.h for a description of congr_theorem_info. + */ + result simplify_app_congr(expr const & e, congr_theorem_info const & cg_thm) { + lean_assert(is_app(e)); + lean_assert(arg(e, 0) == cg_thm.get_fun()); + buffer new_args; + bool changed = false; + new_args.resize(num_args(e)); + new_args[0] = arg(e, 0); + buffer proof_args_buf; + expr * proof_args; + if (m_proofs_enabled) { + proof_args_buf.resize(cg_thm.get_num_proof_args() + 1); + proof_args_buf[0] = cg_thm.get_proof(); + proof_args = proof_args_buf.data()+1; + } + for (auto const & info : cg_thm.get_arg_info()) { + unsigned pos = info.get_arg_pos(); + expr const & a = arg(e, pos); + if (info.should_simplify()) { + optional const & ctx = info.get_context(); + if (!ctx) { + // argument does not have a context + result res_a = simplify(a); + new_args[pos] = res_a.m_out; + if (m_proofs_enabled) { + if (!ensure_homogeneous(a, res_a)) + return simplify_app_default(e); // fallback to default congruence + proof_args[info.get_pos_at_proof()] = a; + proof_args[*info.get_new_pos_at_proof()] = new_args[pos]; + proof_args[*info.get_proof_pos_at_proof()] = get_proof(res_a); + } + } else { + unsigned dep_pos = ctx->get_arg_pos(); + expr H = ctx->use_new_val() ? new_args[dep_pos] : arg(e, dep_pos); + if (!ctx->is_pos_dep()) + H = mk_not(H); + // We will simplify the \c a under the hypothesis H + if (!m_proofs_enabled) { + // Contextual reasoning without proofs. + expr dummy_proof; // we don't need a proof + updt_rule_set update(m_rule_sets[0], H, dummy_proof); + result res_a = simplify(a); + new_args[pos] = res_a.m_out; + } else { + // We have to introduce H in the context, so first we lift the free variables in \c a + flet set_depth(m_contextual_depth, m_contextual_depth+1); + expr H_proof = mk_constant(name(g_unique, m_contextual_depth)); + updt_rule_set update(m_rule_sets[0], H, H_proof); + freset m_reset_cache(m_cache); // must reset cache for the recursive call because we updated the rule_sets + result res_a = simplify(a); + if (!ensure_homogeneous(a, res_a)) + return simplify_app_default(e); // fallback to default congruence + new_args[pos] = res_a.m_out; + proof_args[info.get_pos_at_proof()] = a; + proof_args[*info.get_new_pos_at_proof()] = new_args[pos]; + name C_name(g_C, m_contextual_depth); // H_name is a cryptic unique name + proof_args[*info.get_proof_pos_at_proof()] = ::lean::mk_lambda(C_name, H, abstract(get_proof(res_a), H_proof)); + } + } + if (new_args[pos] != a) + changed = true; + } else { + // argument should not be simplified + new_args[pos] = arg(e, pos); + if (m_proofs_enabled) + proof_args[info.get_pos_at_proof()] = arg(e, pos); + } + } + + if (!changed) { + return rewrite_app(e, result(e)); + } else if (!m_proofs_enabled) { + return rewrite_app(e, result(mk_app(new_args))); + } else { + return rewrite_app(e, result(mk_app(new_args), mk_app(proof_args_buf))); + } + } + + result simplify_app_default(expr const & e) { lean_assert(is_app(e)); buffer new_args; buffer> proofs; // used only if m_proofs_enabled @@ -763,6 +886,20 @@ class simplifier_fn { lean_unreachable(); } + void collect_congr_thms() { + if (m_contextual) { + for (auto const & rs : m_rule_sets) { + rs.for_each_congr([&](congr_theorem_info const & info) { + if (std::all_of(m_congr_thms.begin(), m_congr_thms.end(), + [&](congr_theorem_info const * info2) { + return info2->get_fun() != info.get_fun(); })) { + m_congr_thms.push_back(&info); + } + }); + } + } + } + void set_options(options const & o) { m_proofs_enabled = get_simplifier_proofs(o); m_contextual = get_simplifier_contextual(o); @@ -786,17 +923,15 @@ public: m_rule_sets.push_back(rewrite_rule_set(env)); } m_rule_sets.insert(m_rule_sets.end(), rs, rs + num_rs); + collect_congr_thms(); + m_contextual_depth = 0; } expr_pair operator()(expr const & e, context const & ctx) { set_context set(*this, ctx); m_num_steps = 0; auto r = simplify(e); - if (r.m_proof) { - return mk_pair(r.m_out, *(r.m_proof)); - } else { - return mk_pair(r.m_out, mk_refl_th(infer_type(r.m_out), r.m_out)); - } + return mk_pair(r.m_out, get_proof(r)); } }; diff --git a/tests/lean/simp13.lean b/tests/lean/simp13.lean new file mode 100644 index 000000000..a890abea1 --- /dev/null +++ b/tests/lean/simp13.lean @@ -0,0 +1,15 @@ +rewrite_set simple +add_rewrite and_truer and_truel and_falser and_falsel : simple +(* + add_congr_theorem("simple", "and_congr") +*) + +variables a b c : Nat + +(* +local t = parse_lean([[a = 1 ∧ b = 0 ∧ b > a]]) +local s, pr = simplify(t, "simple") +print(s) +print(pr) +print(get_environment():type_check(pr)) +*) \ No newline at end of file diff --git a/tests/lean/simp13.lean.expected.out b/tests/lean/simp13.lean.expected.out new file mode 100644 index 000000000..7e82aa04b --- /dev/null +++ b/tests/lean/simp13.lean.expected.out @@ -0,0 +1,13 @@ + Set: pp::colors + Set: pp::unicode + Assumed: a + Assumed: b + Assumed: c +⊥ +trans (and_congr + (refl (a = 1)) + (λ C::1 : a = 1, + trans (and_congr (refl (b = 0)) (λ C::2 : b = 0, congr (congr2 Nat::gt C::2) C::1)) + (and_falsel (b = 0)))) + (and_falsel (a = 1)) +(a = 1 ∧ b = 0 ∧ b > a) = ⊥