feat(library/congr_lemma_manager): add heterogeneous equality congruence lemmas

This commit is contained in:
Leonardo de Moura 2016-01-09 15:41:08 -08:00
parent 403966792d
commit 42cdda227a
6 changed files with 133 additions and 2 deletions

View file

@ -707,6 +707,27 @@ static environment congr_rel_cmd(parser & p) {
return congr_cmd_core(p, congr_kind::Rel);
}
static environment hcongr_cmd(parser & p) {
environment const & env = p.env();
auto pos = p.pos();
expr e; level_param_names ls;
std::tie(e, ls) = parse_local_expr(p);
tmp_type_context ctx(env, p.get_options());
app_builder b(ctx);
fun_info_manager infom(ctx);
congr_lemma_manager cm(b, infom);
optional<hcongr_lemma> r = cm.mk_hcongr(e);
if (!r)
throw parser_error("failed to generated heterogeneous congruence lemma", pos);
auto out = p.regular_stream();
out << r->get_proof() << "\n:\n" << r->get_type() << "\n";;
type_checker tc(env);
expr type = tc.check(r->get_proof(), ls).first;
if (!tc.is_def_eq(type, r->get_type()).first)
throw parser_error("heterogeneous congruence lemma reported type does not match given type", pos);
return env;
}
static environment simplify_cmd(parser & p) {
name rel = p.check_constant_next("invalid #simplify command, constant expected");
name ns = p.check_id_next("invalid #simplify command, id expected");
@ -815,6 +836,7 @@ void init_cmd_table(cmd_table & r) {
add_cmd(r, cmd_info("#symm", "(for debugging purposes)", symm_cmd));
add_cmd(r, cmd_info("#compile", "(for debugging purposes)", compile_cmd));
add_cmd(r, cmd_info("#congr", "(for debugging purposes)", congr_cmd));
add_cmd(r, cmd_info("#hcongr", "(for debugging purposes)", hcongr_cmd));
add_cmd(r, cmd_info("#congr_simp", "(for debugging purposes)", congr_simp_cmd));
add_cmd(r, cmd_info("#congr_rel", "(for debugging purposes)", congr_rel_cmd));
add_cmd(r, cmd_info("#normalizer", "(for debugging purposes)", normalizer_cmd));

View file

@ -120,7 +120,7 @@ void init_token_table(token_table & t) {
"multiple_instances", "find_decl", "attribute", "persistent",
"include", "omit", "migrate", "init_quotient", "init_hits", "#erase_cache", "#projections", "#telescope_eq",
"#compile", "#accessible", "#decl_stats", "#relevant_thms", "#simplify", "#app_builder", "#refl", "#symm",
"#trans", "#congr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", nullptr};
"#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", nullptr};
pair<char const *, char const *> aliases[] =
{{g_lambda_unicode, "fun"}, {"forall", "Pi"}, {g_forall_unicode, "Pi"}, {g_pi_unicode, "Pi"},

View file

@ -791,6 +791,10 @@ expr app_builder::mk_iff_refl(expr const & a) {
return m_ptr->mk_iff_refl(a);
}
expr app_builder::mk_heq_refl(expr const & a) {
return m_ptr->mk_heq_refl(a);
}
expr app_builder::mk_symm(name const & relname, expr const & H) {
return m_ptr->mk_symm(relname, H);
}

View file

@ -100,6 +100,7 @@ public:
expr mk_refl(name const & relname, expr const & a);
expr mk_eq_refl(expr const & a);
expr mk_iff_refl(expr const & a);
expr mk_heq_refl(expr const & a);
/** \brief Similar a symmetry proof for the given relation */
expr mk_symm(name const & relname, expr const & H);

View file

@ -437,6 +437,32 @@ struct congr_lemma_manager::imp {
return result(r.get_type(), r.get_proof(), new_arg_kinds);
}
expr mk_hcongr_proof(expr type) {
expr A, B, a, b;
if (is_eq(type, a, b)) {
return m_builder.mk_eq_refl(a);
} else if (is_heq(type, A, a, B, b)) {
return m_builder.mk_heq_refl(a);
} else {
lean_assert(is_pi(type) && is_pi(binding_body(type)) && is_pi(binding_body(binding_body(type))));
expr a = m_ctx.mk_tmp_local(binding_name(type), binding_domain(type));
type = instantiate(binding_body(type), a);
expr b = m_ctx.mk_tmp_local(binding_name(type), binding_domain(type));
expr motive = instantiate(binding_body(type), b);
type = instantiate(binding_body(type), a);
expr eq_pr = m_ctx.mk_tmp_local(binding_name(motive), binding_domain(motive));
type = binding_body(type);
motive = binding_body(motive);
lean_assert(closed(type) && closed(motive));
expr minor = mk_hcongr_proof(type);
expr major = eq_pr;
if (is_heq(mlocal_type(eq_pr)))
major = m_builder.mk_eq_of_heq(eq_pr);
motive = Fun(b, motive);
return Fun({a, b, eq_pr}, m_builder.mk_eq_rec(motive, minor, major));
}
}
public:
imp(app_builder & b, fun_info_manager & fm):
m_builder(b), m_fmanager(fm), m_ctx(fm.ctx()),
@ -497,6 +523,60 @@ public:
return optional<result>(new_r);
}
optional<hresult> mk_hcongr(expr const & fn, unsigned nargs) {
try {
expr fn_type_lhs = relaxed_whnf(infer(fn));
expr fn_type_rhs = fn_type_lhs;
name e_name("e");
buffer<expr> lhss;
buffer<expr> rhss;
buffer<expr> eqs;
buffer<expr> hyps; // contains lhss + rhss + eqs
buffer<hcongr_arg_kind> kinds;
for (unsigned i = 0; i < nargs; i++) {
if (!is_pi(fn_type_lhs)) {
trace_too_many_arguments(fn, nargs);
return optional<hresult>();
}
expr lhs = m_ctx.mk_tmp_local(binding_name(fn_type_lhs), binding_domain(fn_type_lhs));
lhss.push_back(lhs); hyps.push_back(lhs);
expr rhs = m_ctx.mk_tmp_local(binding_name(fn_type_rhs).append_after("'"), binding_domain(fn_type_rhs));
rhss.push_back(rhs); hyps.push_back(rhs);
expr eq_type;
if (binding_domain(fn_type_lhs) == binding_domain(fn_type_rhs)) {
eq_type = m_builder.mk_eq(lhs, rhs);
kinds.push_back(hcongr_arg_kind::Eq);
} else {
eq_type = m_builder.mk_heq(lhs, rhs);
kinds.push_back(hcongr_arg_kind::HEq);
}
expr h_eq = m_ctx.mk_tmp_local(e_name.append_after(i), eq_type);
eqs.push_back(h_eq); hyps.push_back(h_eq);
fn_type_lhs = relaxed_whnf(instantiate(binding_body(fn_type_lhs), lhs));
fn_type_rhs = relaxed_whnf(instantiate(binding_body(fn_type_rhs), rhs));
}
expr lhs = mk_app(fn, lhss);
expr rhs = mk_app(fn, rhss);
expr eq_type;
if (fn_type_lhs == fn_type_rhs) {
eq_type = m_builder.mk_eq(lhs, rhs);
} else {
eq_type = m_builder.mk_heq(lhs, rhs);
}
expr result_type = Pi(hyps, eq_type);
expr result_proof = mk_hcongr_proof(result_type);
return optional<hresult>(result_type, result_proof, to_list(kinds));
} catch (app_builder_exception &) {
trace_app_builder_failure(fn);
return optional<hresult>();
}
}
optional<hresult> mk_hcongr(expr const & fn) {
fun_info finfo = m_fmanager.get(fn);
return mk_hcongr(fn, finfo.get_arity());
}
/** \brief Given an equivalence relation \c R, create the congruence lemma
forall a1 a2 b1 b2, R a1 a2 -> R b1 b2 -> (R a1 b1 <-> R a2 b2)
@ -619,6 +699,12 @@ auto congr_lemma_manager::mk_congr(expr const & fn, unsigned nargs) -> optional<
auto congr_lemma_manager::mk_specialized_congr(expr const & fn) -> optional<result> {
return m_ptr->mk_specialized_congr(fn);
}
auto congr_lemma_manager::mk_hcongr(expr const & fn) -> optional<hresult> {
return m_ptr->mk_hcongr(fn);
}
auto congr_lemma_manager::mk_hcongr(expr const & fn, unsigned nargs) -> optional<hresult> {
return m_ptr->mk_hcongr(fn, nargs);
}
auto congr_lemma_manager::mk_rel_iff_congr(expr const & R) -> optional<result> {
return m_ptr->mk_rel_iff_congr(R);
}

View file

@ -37,13 +37,28 @@ public:
bool all_eq_kind() const;
};
enum class hcongr_arg_kind { Eq, HEq };
class hcongr_lemma {
expr m_type;
expr m_proof;
list<hcongr_arg_kind> m_arg_kinds;
public:
hcongr_lemma(expr const & type, expr const & proof, list<hcongr_arg_kind> const & ks):
m_type(type), m_proof(proof), m_arg_kinds(ks) {}
expr const & get_type() const { return m_type; }
expr const & get_proof() const { return m_proof; }
list<hcongr_arg_kind> const & get_arg_kinds() const { return m_arg_kinds; }
};
class congr_lemma_manager {
struct imp;
std::unique_ptr<imp> m_ptr;
public:
congr_lemma_manager(app_builder & b, fun_info_manager & fm);
~congr_lemma_manager();
typedef congr_lemma result;
typedef congr_lemma result;
typedef hcongr_lemma hresult;
type_context & ctx();
unsigned get_specialization_prefix_size(expr const & fn, unsigned nargs);
@ -58,6 +73,9 @@ public:
/* Create a specialized theorem using (a prefix of) the arguments of the given application. */
optional<result> mk_specialized_congr(expr const & a);
optional<hresult> mk_hcongr(expr const & fn);
optional<hresult> mk_hcongr(expr const & fn, unsigned nargs);
/** \brief If R is an equivalence relation, construct the congruence lemma
R a1 a2 -> R b1 b2 -> (R a1 b1) <-> (R a2 b2) */