diff --git a/src/library/simplifier/rewrite_rule_set.cpp b/src/library/simplifier/rewrite_rule_set.cpp index 5a4979865..5b5ff6bc3 100644 --- a/src/library/simplifier/rewrite_rule_set.cpp +++ b/src/library/simplifier/rewrite_rule_set.cpp @@ -15,15 +15,10 @@ Author: Leonardo de Moura #include "library/simplifier/rewrite_rule_set.h" namespace lean { -struct rewrite_rule_set::rewrite_rule { - name m_id; - expr m_lhs; - expr m_ceq; - expr m_proof; - bool m_is_permutation; - rewrite_rule(name const & id, expr const & lhs, expr const & ceq, expr const & proof, bool is_permutation): - m_id(id), m_lhs(lhs), m_ceq(ceq), m_proof(proof), m_is_permutation(is_permutation) {} -}; +rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof, + unsigned num_args, bool is_permutation): + m_id(id), m_lhs(lhs), m_rhs(rhs), m_ceq(ceq), m_proof(proof), m_num_args(num_args), m_is_permutation(is_permutation) { +} rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {} rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other): @@ -36,13 +31,16 @@ void rewrite_rule_set::insert(name const & id, expr const & th, expr const & pro expr const & ceq = p.first; expr const & proof = p.second; bool is_perm = is_permutation_ceq(ceq); - expr lhs = ceq; - while (is_pi(lhs)) { - lhs = abst_body(lhs); + expr eq = ceq; + unsigned num = 0; + while (is_pi(eq)) { + eq = abst_body(eq); + num++; } - lean_assert(is_equality(lhs)); - lhs = arg(lhs, num_args(lhs) - 2); - m_rule_set.emplace_front(id, lhs, ceq, proof, is_perm); + lean_assert(is_equality(eq)); + m_rule_set = cons(rewrite_rule(id, arg(eq, num_args(eq) - 2), arg(eq, num_args(eq) - 1), + ceq, proof, num, is_perm), + m_rule_set); } } @@ -57,7 +55,7 @@ void rewrite_rule_set::insert(name const & th_name) { } bool rewrite_rule_set::enabled(rewrite_rule const & rule) const { - return !m_disabled_rules.contains(rule.m_id); + return !m_disabled_rules.contains(rule.get_id()); } bool rewrite_rule_set::enabled(name const & id) const { @@ -71,18 +69,19 @@ void rewrite_rule_set::enable(name const & id, bool f) { m_disabled_rules.insert(id); } -void rewrite_rule_set::for_each_match_candidate(expr const &, match_fn const & fn) const { +bool rewrite_rule_set::find_match(expr const &, match_fn const & fn) const { auto l = m_rule_set; for (auto const & rule : l) { - if (enabled(rule) && fn(rule.m_lhs, rule.m_ceq, rule.m_is_permutation, rule.m_proof)) - return; + if (enabled(rule) && fn(rule)) + return true; } + return false; } void rewrite_rule_set::for_each(visit_fn const & fn) const { auto l = m_rule_set; for (auto const & rule : l) { - fn(rule.m_id, rule.m_ceq, rule.m_proof, enabled(rule)); + fn(rule, enabled(rule)); } } @@ -90,16 +89,16 @@ format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const { format r; bool first = true; unsigned indent = get_pp_indent(opts); - for_each([&](name const & name, expr const & ceq, expr const &, bool enabled) { + for_each([&](rewrite_rule const & rule, bool enabled) { if (first) first = false; else r += line(); - r += format(name); + r += format(rule.get_id()); if (!enabled) r += format(" [disabled]"); r += format{space(), colon(), space()}; - r += nest(indent, fmt(ceq, opts)); + r += nest(indent, fmt(rule.get_ceq(), opts)); }); return r; } diff --git a/src/library/simplifier/rewrite_rule_set.h b/src/library/simplifier/rewrite_rule_set.h index 50a6a4608..73b2deafa 100644 --- a/src/library/simplifier/rewrite_rule_set.h +++ b/src/library/simplifier/rewrite_rule_set.h @@ -15,11 +15,32 @@ Author: Leonardo de Moura #include "kernel/formatter.h" namespace lean { +class rewrite_rule_set; +class rewrite_rule { + friend class rewrite_rule_set; + name m_id; + expr m_lhs; + expr m_rhs; + expr m_ceq; + expr m_proof; + unsigned m_num_args; + bool m_is_permutation; + rewrite_rule(name const & id, expr const & lhs, expr const & rhs, expr const & ceq, expr const & proof, + unsigned num_args, bool is_permutation); +public: + name const & get_id() const { return m_id; } + expr const & get_lhs() const { return m_lhs; } + expr const & get_rhs() const { return m_rhs; } + expr const & get_ceq() const { return m_ceq; } + expr const & get_proof() const { return m_proof; } + unsigned get_num_args() const { return m_num_args; } + bool is_permutation() const { return m_is_permutation; } +}; + /** \brief Actual implementation of the \c rewrite_rule_set class. */ class rewrite_rule_set { - struct rewrite_rule; typedef splay_tree name_set; ro_environment::weak_ref m_env; list m_rule_set; // TODO(Leo): use better data-structure, e.g., discrimination trees @@ -51,22 +72,17 @@ public: /** \brief Enable/disable the conditional rewrite rules tagged with the given identifier. */ void enable(name const & id, bool f); - typedef std::function match_fn; // NOLINT - typedef std::function visit_fn; + typedef std::function match_fn; // NOLINT + typedef std::function visit_fn; /** - \brief Execute fn(lhs, ceq, is_perm, proof) for each (enabled) rule whose the left-hand-side may + \brief Execute fn(rule) for each (enabled) rule whose the left-hand-side may match \c e. The traversal is interrupted as soon as \c fn returns true. - - The redundant argument \c lhs is the left-hand-side of \c ceq. - The redundant argument \c is_perm is true iff \c ceq is a permutation rule. - - The argument \c proof is the proof for \c ceq. */ - void for_each_match_candidate(expr const &, match_fn const & fn) const; + bool find_match(expr const &, match_fn const & fn) const; - /** \brief Execute fn(id, ceq, proof, enabled) for each rule in this rule set. */ + /** \brief Execute fn(rule, enabled) for each rule in this rule set. */ void for_each(visit_fn const & fn) const; /** \brief Pretty print this rule set. */ diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 9f19fb066..119418573 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -92,6 +92,16 @@ class simplifier_fn { bool m_unfold; unsigned m_max_steps; + struct match_fn { + simplifier_fn & m_simp; + match_fn(simplifier_fn & s):m_simp(s) {} + bool operator()(rewrite_rule const & rule) const { + return m_simp.match(rule); + } + }; + + match_fn m_match_fn; + struct result { expr m_out; optional m_proof; @@ -126,20 +136,20 @@ class simplifier_fn { expr mk_congr1_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & Heq_f) { expr const & A = abst_domain(f_type); - expr const & B = lower_free_vars(abst_body(f_type), 1, 1); + expr B = lower_free_vars(abst_body(f_type), 1, 1); return ::lean::mk_congr1_th(A, B, f, new_f, a, Heq_f); } expr mk_congr2_th(expr const & f_type, expr const & a, expr const & new_a, expr const & f, expr const & Heq_a) { expr const & A = abst_domain(f_type); - expr const & B = lower_free_vars(abst_body(f_type), 1, 1); + expr B = lower_free_vars(abst_body(f_type), 1, 1); return ::lean::mk_congr2_th(A, B, a, new_a, f, Heq_a); } expr mk_congr_th(expr const & f_type, expr const & f, expr const & new_f, expr const & a, expr const & new_a, expr const & Heq_f, expr const & Heq_a) { expr const & A = abst_domain(f_type); - expr const & B = lower_free_vars(abst_body(f_type), 1, 1); + expr B = lower_free_vars(abst_body(f_type), 1, 1); return ::lean::mk_congr_th(A, B, f, new_f, a, new_a, Heq_f, Heq_a); } @@ -152,6 +162,65 @@ class simplifier_fn { f, new_f, a, new_a, Heq_f, Heq_a); } + result mk_trans_result(expr const & a, result const & b_res, expr const & c, expr const & H_bc) { + if (m_proofs_enabled) { + if (!b_res.m_proof) { + // The proof of a = b is reflexivity + return result(c, H_bc); + } else { + expr const & b = b_res.m_out; + expr new_proof; + bool heq_proof = false; + if (b_res.m_heq_proof) { + expr b_type = infer_type(b); + new_proof = ::lean::mk_htrans_th(infer_type(a), b_type, b_type, /* b and c must have the same type */ + a, b, c, *b_res.m_proof, mk_to_heq_th(b_type, b, c, H_bc)); + heq_proof = true; + } else { + new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, H_bc); + } + return result(c, new_proof, heq_proof); + } + } else { + return result(c); + } + } + + result mk_trans_result(expr const & a, result const & b_res, result const & c_res) { + if (m_proofs_enabled) { + if (!b_res.m_proof) { + // the proof of a == b is reflexivity + return c_res; + } else if (!c_res.m_proof) { + // the proof of b == c is reflexivity + return result(c_res.m_out, *b_res.m_proof, b_res.m_heq_proof); + } else { + bool heq_proof = b_res.m_heq_proof || c_res.m_heq_proof; + expr new_proof; + expr const & b = b_res.m_out; + expr const & c = c_res.m_out; + if (heq_proof) { + expr a_type = infer_type(a); + expr b_type = infer_type(b); + expr c_type = infer_type(c); + expr H_ab = *b_res.m_proof; + if (!b_res.m_heq_proof) + H_ab = mk_to_heq_th(a_type, a, b, H_ab); + expr H_bc = *c_res.m_proof; + if (!c_res.m_heq_proof) + H_bc = mk_to_heq_th(b_type, b, c, H_bc); + new_proof = ::lean::mk_htrans_th(a_type, b_type, c_type, a, b, c, H_ab, H_bc); + } else { + new_proof = ::lean::mk_trans_th(infer_type(a), a, b, c, *b_res.m_proof, *c_res.m_proof); + } + return result(c, new_proof, heq_proof); + } + } else { + // proof generation is disabled + return c_res; + } + } + expr mk_app_prefix(unsigned i, expr const & a) { lean_assert(i > 0); if (i == 1) @@ -225,9 +294,9 @@ class simplifier_fn { } if (!changed) { - return rewrite_app(result(e)); + return rewrite(e, result(e)); } else if (!m_proofs_enabled) { - return rewrite_app(result(mk_app(new_args))); + return rewrite(e, result(mk_app(new_args))); } else { expr out = mk_app(new_args); unsigned i = 0; @@ -272,11 +341,76 @@ class simplifier_fn { pr = mk_congr1_th(f_types[i-1], f, new_f, arg(e, i), pr); } } - return rewrite_app(result(out, pr, heq_proof)); + return rewrite(e, result(out, pr, heq_proof)); } } - result rewrite_app(result const & r) { + expr m_target; // temp field + buffer> m_subst; // temp field + buffer m_new_args; // temp field + expr m_new_rhs; // temp field + expr m_new_proof; // temp field + + void reset_subst(unsigned num_args) { + if (m_subst.size() < num_args) { + m_subst.resize(num_args); + m_new_args.resize(num_args+1); + } + for (unsigned i = 0; i < num_args; i++) + m_subst[i] = none_expr(); + } + + bool found_all_args(unsigned num_args) { + for (unsigned i = 0; i < num_args; i++) { + if (!m_subst[i]) + return false; + m_new_args[i+1] = *m_subst[i]; + } + return true; + } + + /** + \brief Auxiliary function used by m_match_fn, it tries to match the given rule and + the expression in the temporary field \c m_target. + If it succeeds, then the resultant expression is stored in the temporary field m_new_rhs, + and the proof in m_new_proof (if proofs are enabled). + */ + bool match(rewrite_rule const & rule) { + unsigned num = rule.get_num_args(); + reset_subst(num); + if (hop_match(rule.get_lhs(), m_target, m_subst)) { + if (found_all_args(num)) { + // easy case: all arguments found + m_new_rhs = instantiate(rule.get_rhs(), num, m_new_args.data() + 1); + if (m_proofs_enabled) { + if (num > 0) { + m_new_args[0] = rule.get_proof(); + m_new_proof = mk_app(m_new_args); + } else { + m_new_proof = rule.get_proof(); + } + } + return true; + } + // TODO(Leo): conditional rewriting + } + return false; + } + + result rewrite(expr const & e, result const & r) { + m_target = r.m_out; + for (rewrite_rule_set const & rs : m_rule_sets) { + if (rs.find_match(m_target, m_match_fn)) { + // the result is in m_new_rhs and proof at m_new_proof + result new_r1 = mk_trans_result(e, r, m_new_rhs, m_new_proof); + if (m_single_pass) { + return new_r1; + } else { + result new_r2 = simplify(new_r1.m_out); + return mk_trans_result(e, new_r1, new_r2); + } + } + } return r; } @@ -302,17 +436,7 @@ class simplifier_fn { } } } -#if 1 - if (const_name(e) == "a") { - auto obj = m_env->find_object("a_eq_0"); - if (obj) { - expr r = arg(obj->get_type(), 3); - return result(r, mk_constant("a_eq_0")); - } - } -#endif - - return result(e); + return rewrite(e, result(e)); } result simplify_lambda(expr const & e) { @@ -328,11 +452,11 @@ class simplifier_fn { if (is_eqp(new_body, abst_body(e))) return result(e); expr out = mk_lambda(e, new_body); - if (!m_proofs_enabled) + if (!m_proofs_enabled || !res_body.m_proof) return result(out); expr body_type = infer_type(abst_body(e)); expr pr = mk_funext_th(abst_domain(e), mk_lambda(e, body_type), e, out, - mk_lambda(e, *(res_body.m_proof))); + mk_lambda(e, *res_body.m_proof)); return result(out, pr); } } @@ -350,12 +474,12 @@ class simplifier_fn { if (is_eqp(new_body, abst_body(e))) return result(e); expr out = mk_pi(abst_name(e), abst_domain(e), new_body); - if (!m_proofs_enabled) + if (!m_proofs_enabled || !res_body.m_proof) return result(out); expr pr = mk_allext_th(abst_domain(e), mk_lambda(e, abst_body(e)), mk_lambda(e, abst_body(out)), - mk_lambda(e, *(res_body.m_proof))); + mk_lambda(e, *res_body.m_proof)); return result(out, pr); } else { // if the environment does not contain heq axioms, then we don't simplify Pi's that are not forall's @@ -384,13 +508,13 @@ class simplifier_fn { m_contextual = get_simplifier_contextual(o); m_single_pass = get_simplifier_single_pass(o); m_beta = get_simplifier_beta(o); - m_unfold = true; // get_simplifier_unfold(o); + m_unfold = get_simplifier_unfold(o); m_max_steps = get_simplifier_max_steps(o); } public: simplifier_fn(ro_environment const & env, options const & o, unsigned num_rs, rewrite_rule_set const * rs): - m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs) { + m_env(env), m_tc(env), m_rule_sets(rs, rs + num_rs), m_match_fn(*this) { m_has_heq = m_env->imported("heq"); set_options(o); } diff --git a/tests/lua/simp1.lua b/tests/lua/simp1.lua new file mode 100644 index 000000000..ec1a8fbec --- /dev/null +++ b/tests/lua/simp1.lua @@ -0,0 +1,30 @@ +mk_rewrite_rule_set() +add_rewrite_rules({"Nat", "add_zerol"}) +add_rewrite_rules({"Nat", "add_zeror"}) +parse_lean_cmds([[ + variable f : Nat -> Nat -> Nat + variable g : Nat -> Nat + variable b : Nat + definition a := 1 + theorem a_eq_1 : a = 1 + := refl a + definition c := 1 + set_opaque a true + + axiom f_id (x : Nat) : f x 1 = 2*x +]]) +add_rewrite_rules("a_eq_1") +add_rewrite_rules("f_id") +-- set_option({"lean", "pp", "implicit"}, true) +e, pr = simplify(parse_lean('fun x, f (f x (0 + a)) (g (b + 0))')) +print(e) +print(pr) +local env = get_environment() +print(env:type_check(pr)) + + +e, pr = simplify(parse_lean('forall x, let d := a + 1 in f x a >= d')) +print(e) +print(pr) +local env = get_environment() +print(env:type_check(pr))