feat(library/simplifier): contextual simplifications

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-01-23 12:23:22 -08:00
parent 1638a7bb02
commit d6692264e8
6 changed files with 212 additions and 37 deletions

View file

@ -106,6 +106,7 @@ public:
unsigned get_pos_at_proof() const { return m_proof_arg_pos; }
optional<unsigned> const & get_new_pos_at_proof() const { return m_proof_new_arg_pos; }
optional<unsigned> const & get_proof_pos_at_proof() const { return m_proof_proof_pos; }
bool should_simplify() const { return static_cast<bool>(get_new_pos_at_proof()); }
void display(std::ostream & out) const;
};

View file

@ -22,7 +22,7 @@ rewrite_rule::rewrite_rule(name const & id, expr const & lhs, expr const & rhs,
rewrite_rule_set::rewrite_rule_set(ro_environment const & env):m_env(env.to_weak_ref()) {}
rewrite_rule_set::rewrite_rule_set(rewrite_rule_set const & other):
m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules) {}
m_env(other.m_env), m_rule_set(other.m_rule_set), m_disabled_rules(other.m_disabled_rules), m_congr_thms(other.m_congr_thms) {}
rewrite_rule_set::~rewrite_rule_set() {}
void rewrite_rule_set::insert(name const & id, expr const & th, expr const & proof) {
@ -94,6 +94,12 @@ void rewrite_rule_set::for_each(visit_fn const & fn) const {
}
}
void rewrite_rule_set::for_each_congr(visit_congr_fn const & fn) const {
for (auto const & congr_th : m_congr_thms) {
fn(congr_th);
}
}
format rewrite_rule_set::pp(formatter const & fmt, options const & opts) const {
format r;
bool first = true;

View file

@ -92,6 +92,11 @@ public:
/** \brief Execute <tt>fn(rule, enabled)</tt> for each rule in this rule set. */
void for_each(visit_fn const & fn) const;
typedef std::function<void(congr_theorem_info const &)> visit_congr_fn; // NOLINT
/** \brief Execute <tt>fn(congr_th)</tt> for each congruence theorem in this rule set. */
void for_each_congr(visit_congr_fn const & fn) const;
/** \brief Pretty print this rule set. */
format pp(formatter const & fmt, options const & opts) const;
};

View file

@ -12,6 +12,7 @@ Author: Leonardo de Moura
#include "kernel/type_checker.h"
#include "kernel/free_vars.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "kernel/normalizer.h"
#include "kernel/kernel.h"
#include "kernel/max_sharing.h"
@ -85,36 +86,20 @@ RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simp
RegisterBoolOption(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE, "(simplifier) memoize/cache intermediate results");
RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps");
bool get_simplifier_proofs(options const & opts) {
return opts.get_bool(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS);
}
bool get_simplifier_contextual(options const & opts) {
return opts.get_bool(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL);
}
bool get_simplifier_single_pass(options const & opts) {
return opts.get_bool(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS);
}
bool get_simplifier_beta(options const & opts) {
return opts.get_bool(g_simplifier_beta, LEAN_SIMPLIFIER_BETA);
}
bool get_simplifier_eta(options const & opts) {
return opts.get_bool(g_simplifier_eta, LEAN_SIMPLIFIER_ETA);
}
bool get_simplifier_eval(options const & opts) {
return opts.get_bool(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL);
}
bool get_simplifier_unfold(options const & opts) {
return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD);
}
bool get_simplifier_conditional(options const & opts) {
return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL);
}
bool get_simplifier_memoize(options const & opts) {
return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE);
}
unsigned get_simplifier_max_steps(options const & opts) {
return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS);
}
bool get_simplifier_proofs(options const & opts) { return opts.get_bool(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS); }
bool get_simplifier_contextual(options const & opts) { return opts.get_bool(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL); }
bool get_simplifier_single_pass(options const & opts) { return opts.get_bool(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS); }
bool get_simplifier_beta(options const & opts) { return opts.get_bool(g_simplifier_beta, LEAN_SIMPLIFIER_BETA); }
bool get_simplifier_eta(options const & opts) { return opts.get_bool(g_simplifier_eta, LEAN_SIMPLIFIER_ETA); }
bool get_simplifier_eval(options const & opts) { return opts.get_bool(g_simplifier_eval, LEAN_SIMPLIFIER_EVAL); }
bool get_simplifier_unfold(options const & opts) { return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD); }
bool get_simplifier_conditional(options const & opts) { return opts.get_bool(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL); }
bool get_simplifier_memoize(options const & opts) { return opts.get_bool(g_simplifier_memoize, LEAN_SIMPLIFIER_MEMOIZE); }
unsigned get_simplifier_max_steps(options const & opts) { return opts.get_unsigned(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS); }
static name g_local("local");
static name g_C("C");
static name g_unique = name::mk_internal_unique_name();
class simplifier_fn {
struct result {
@ -131,6 +116,7 @@ class simplifier_fn {
typedef std::vector<rewrite_rule_set> rule_sets;
typedef expr_map<result> cache;
typedef std::vector<congr_theorem_info const *> congr_thms;
ro_environment m_env;
type_checker m_tc;
bool m_has_heq;
@ -138,7 +124,8 @@ class simplifier_fn {
rule_sets m_rule_sets;
cache m_cache;
max_sharing_fn m_max_sharing;
congr_thms m_congr_thms;
unsigned m_contextual_depth; // number of contextual simplification steps in the current branch
unsigned m_num_steps; // number of steps performed
// Configuration
@ -159,6 +146,15 @@ class simplifier_fn {
set_context(simplifier_fn & s, context const & new_ctx):m_set(s.m_ctx, new_ctx), m_reset_cache(s.m_cache) {}
};
struct updt_rule_set {
rewrite_rule_set & m_rs;
rewrite_rule_set m_old;
updt_rule_set(rewrite_rule_set & rs, expr const & fact, expr const & proof):m_rs(rs), m_old(m_rs) {
m_rs.insert(g_local, fact, proof);
}
~updt_rule_set() { m_rs = m_old; }
};
/**
\brief Return a lambda with body \c new_body, and name and domain from abst.
*/
@ -170,6 +166,10 @@ class simplifier_fn {
return m_tc.is_proposition(e, m_ctx);
}
bool is_eq_convertible(expr const & t1, expr const & t2) {
return m_tc.is_eq_convertible(t1, t2, m_ctx);
}
expr infer_type(expr const & e) {
return m_tc.infer_type(e, m_ctx);
}
@ -300,6 +300,129 @@ class simplifier_fn {
}
result simplify_app(expr const & e) {
if (m_contextual) {
expr const & f = arg(e, 0);
for (auto congr_th : m_congr_thms) {
if (congr_th->get_fun() == f)
return simplify_app_congr(e, *congr_th);
}
}
return simplify_app_default(e);
}
/**
\brief Make sure the proof in rhs is using homogeneous equality, and return true.
If it is not possible to transform it in a homogeneous equality proof, then return false.
*/
bool ensure_homogeneous(expr const & lhs, result & rhs) {
if (rhs.m_heq_proof) {
// try to convert back to homogeneous
lean_assert(rhs.m_proof);
expr lhs_type = infer_type(lhs);
expr rhs_type = infer_type(rhs.m_out);
if (is_eq_convertible(lhs_type, rhs_type)) {
// move back to homogeneous equality using to_eq
rhs.m_proof = mk_to_eq_th(lhs_type, lhs, rhs.m_out, *rhs.m_proof);
return true;
} else {
return false;
}
} else {
return true;
}
}
expr get_proof(result const & rhs) {
if (rhs.m_proof) {
return *rhs.m_proof;
} else {
// lhs and rhs are definitionally equal
return mk_refl_th(infer_type(rhs.m_out), rhs.m_out);
}
}
/**
\brief Simplify \c e using the given congruence theorem.
See congr.h for a description of congr_theorem_info.
*/
result simplify_app_congr(expr const & e, congr_theorem_info const & cg_thm) {
lean_assert(is_app(e));
lean_assert(arg(e, 0) == cg_thm.get_fun());
buffer<expr> new_args;
bool changed = false;
new_args.resize(num_args(e));
new_args[0] = arg(e, 0);
buffer<expr> proof_args_buf;
expr * proof_args;
if (m_proofs_enabled) {
proof_args_buf.resize(cg_thm.get_num_proof_args() + 1);
proof_args_buf[0] = cg_thm.get_proof();
proof_args = proof_args_buf.data()+1;
}
for (auto const & info : cg_thm.get_arg_info()) {
unsigned pos = info.get_arg_pos();
expr const & a = arg(e, pos);
if (info.should_simplify()) {
optional<congr_theorem_info::context> const & ctx = info.get_context();
if (!ctx) {
// argument does not have a context
result res_a = simplify(a);
new_args[pos] = res_a.m_out;
if (m_proofs_enabled) {
if (!ensure_homogeneous(a, res_a))
return simplify_app_default(e); // fallback to default congruence
proof_args[info.get_pos_at_proof()] = a;
proof_args[*info.get_new_pos_at_proof()] = new_args[pos];
proof_args[*info.get_proof_pos_at_proof()] = get_proof(res_a);
}
} else {
unsigned dep_pos = ctx->get_arg_pos();
expr H = ctx->use_new_val() ? new_args[dep_pos] : arg(e, dep_pos);
if (!ctx->is_pos_dep())
H = mk_not(H);
// We will simplify the \c a under the hypothesis H
if (!m_proofs_enabled) {
// Contextual reasoning without proofs.
expr dummy_proof; // we don't need a proof
updt_rule_set update(m_rule_sets[0], H, dummy_proof);
result res_a = simplify(a);
new_args[pos] = res_a.m_out;
} else {
// We have to introduce H in the context, so first we lift the free variables in \c a
flet<unsigned> set_depth(m_contextual_depth, m_contextual_depth+1);
expr H_proof = mk_constant(name(g_unique, m_contextual_depth));
updt_rule_set update(m_rule_sets[0], H, H_proof);
freset<cache> m_reset_cache(m_cache); // must reset cache for the recursive call because we updated the rule_sets
result res_a = simplify(a);
if (!ensure_homogeneous(a, res_a))
return simplify_app_default(e); // fallback to default congruence
new_args[pos] = res_a.m_out;
proof_args[info.get_pos_at_proof()] = a;
proof_args[*info.get_new_pos_at_proof()] = new_args[pos];
name C_name(g_C, m_contextual_depth); // H_name is a cryptic unique name
proof_args[*info.get_proof_pos_at_proof()] = ::lean::mk_lambda(C_name, H, abstract(get_proof(res_a), H_proof));
}
}
if (new_args[pos] != a)
changed = true;
} else {
// argument should not be simplified
new_args[pos] = arg(e, pos);
if (m_proofs_enabled)
proof_args[info.get_pos_at_proof()] = arg(e, pos);
}
}
if (!changed) {
return rewrite_app(e, result(e));
} else if (!m_proofs_enabled) {
return rewrite_app(e, result(mk_app(new_args)));
} else {
return rewrite_app(e, result(mk_app(new_args), mk_app(proof_args_buf)));
}
}
result simplify_app_default(expr const & e) {
lean_assert(is_app(e));
buffer<expr> new_args;
buffer<optional<expr>> proofs; // used only if m_proofs_enabled
@ -763,6 +886,20 @@ class simplifier_fn {
lean_unreachable();
}
void collect_congr_thms() {
if (m_contextual) {
for (auto const & rs : m_rule_sets) {
rs.for_each_congr([&](congr_theorem_info const & info) {
if (std::all_of(m_congr_thms.begin(), m_congr_thms.end(),
[&](congr_theorem_info const * info2) {
return info2->get_fun() != info.get_fun(); })) {
m_congr_thms.push_back(&info);
}
});
}
}
}
void set_options(options const & o) {
m_proofs_enabled = get_simplifier_proofs(o);
m_contextual = get_simplifier_contextual(o);
@ -786,17 +923,15 @@ public:
m_rule_sets.push_back(rewrite_rule_set(env));
}
m_rule_sets.insert(m_rule_sets.end(), rs, rs + num_rs);
collect_congr_thms();
m_contextual_depth = 0;
}
expr_pair operator()(expr const & e, context const & ctx) {
set_context set(*this, ctx);
m_num_steps = 0;
auto r = simplify(e);
if (r.m_proof) {
return mk_pair(r.m_out, *(r.m_proof));
} else {
return mk_pair(r.m_out, mk_refl_th(infer_type(r.m_out), r.m_out));
}
return mk_pair(r.m_out, get_proof(r));
}
};

15
tests/lean/simp13.lean Normal file
View file

@ -0,0 +1,15 @@
rewrite_set simple
add_rewrite and_truer and_truel and_falser and_falsel : simple
(*
add_congr_theorem("simple", "and_congr")
*)
variables a b c : Nat
(*
local t = parse_lean([[a = 1 ∧ b = 0 ∧ b > a]])
local s, pr = simplify(t, "simple")
print(s)
print(pr)
print(get_environment():type_check(pr))
*)

View file

@ -0,0 +1,13 @@
Set: pp::colors
Set: pp::unicode
Assumed: a
Assumed: b
Assumed: c
trans (and_congr
(refl (a = 1))
(λ C::1 : a = 1,
trans (and_congr (refl (b = 0)) (λ C::2 : b = 0, congr (congr2 Nat::gt C::2) C::1))
(and_falsel (b = 0))))
(and_falsel (a = 1))
(a = 1 ∧ b = 0 ∧ b > a) = ⊥