feat(library/unifier): add support for unification constraints of the form "(elim ... (?m ...)) =?= t", where elim is an eliminator

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-03 20:41:42 -07:00
parent c854ad3d65
commit 110b622b83
2 changed files with 171 additions and 15 deletions

View file

@ -16,6 +16,7 @@ Author: Leonardo de Moura
#include "kernel/abstract.h" #include "kernel/abstract.h"
#include "kernel/instantiate.h" #include "kernel/instantiate.h"
#include "kernel/type_checker.h" #include "kernel/type_checker.h"
#include "kernel/inductive/inductive.h"
#include "library/occurs.h" #include "library/occurs.h"
#include "library/unifier.h" #include "library/unifier.h"
#include "library/kernel_bindings.h" #include "library/kernel_bindings.h"
@ -479,9 +480,105 @@ struct unifier_fn {
add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed); add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed);
} }
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
bool is_elim_meta_app(expr const & e) {
if (!is_app(e))
return false;
expr const & f = get_app_fn(e);
if (!is_constant(f))
return false;
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
if (!it_name)
return false;
if (!is_meta(app_arg(e)))
return false;
if (is_pi(m_tc.whnf(m_tc.infer(e))))
return false;
return true;
}
/**
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
to the different constructors of decl.
*/
void mk_inductice_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
justification const & j) {
lean_assert(is_constant(elim));
levels elim_lvls = const_levels(elim);
unsigned elim_num_lvls = length(elim_lvls);
unsigned num_args = args.size();
expr meta = args[num_args - 1]; // save last argument, we will update it
lean_assert(is_meta(meta));
buffer<expr> margs;
expr const & m = get_app_args(meta, margs);
expr const & mtype = mlocal_type(m);
buffer<constraints> alts;
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
name const & intro_name = inductive::intro_rule_name(intro);
declaration intro_decl = m_env.get(intro_name);
levels intro_lvls;
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
intro_lvls = elim_lvls;
} else {
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
intro_lvls = tail(elim_lvls);
}
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
expr hint = intro_fn;
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
while (is_pi(intro_type)) {
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
intro_type = m_tc.whnf(binding_body(intro_type));
}
constraint c1 = mk_eq_cnstr(meta, hint, j);
args[num_args - 1] = hint;
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
alts.push_back(constraints({c1, c2}));
}
if (alts.empty()) {
set_conflict(j);
} else if (alts.size() == 1) {
process_constraints(alts[0], justification());
} else {
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
process_constraints(alts[0], a);
}
}
bool try_inductive_hint_core(expr const & t1, expr const & t2, justification const & j) {
if (!is_elim_meta_app(t1))
return false;
buffer<expr> args;
expr const & elim = get_app_args(t1, args);
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
auto decls = *inductive::is_inductive_decl(m_env, it_name);
for (auto const & d : std::get<2>(decls)) {
if (inductive::inductive_decl_name(d) == it_name) {
mk_inductice_cnstrs(d, elim, args, t2, j);
return true;
}
}
lean_unreachable(); // LCOV_EXCL_LINE
}
/**
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
associated with the eliminator \c elim.
*/
bool try_inductive_hint(expr const & t1, expr const & t2, justification const & j) {
return
try_inductive_hint_core(t1, t2, j) ||
try_inductive_hint_core(t2, t1, j);
}
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) { bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
if (m_tc.is_def_eq(t1, t2, j)) { if (m_tc.is_def_eq(t1, t2, j)) {
return true; return true;
} else if (try_inductive_hint(t1, t2, j)) {
return true;
} else { } else {
set_conflict(j); set_conflict(j);
return false; return false;
@ -595,6 +692,12 @@ struct unifier_fn {
rhs = m_tc.whnf(rhs); rhs = m_tc.whnf(rhs);
lhs = m_tc.whnf(lhs); lhs = m_tc.whnf(lhs);
// We delay constraints where lhs or rhs are of the form (elim ... (?m ...))
if (is_elim_meta_app(lhs) || is_elim_meta_app(rhs)) {
add_very_delayed_cnstr(c, &unassigned_lvls, &unassigned_exprs);
return true;
}
// If lhs or rhs were updated, then invoke is_def_eq again. // If lhs or rhs were updated, then invoke is_def_eq again.
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) { if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
// some metavariables were instantiated, try is_def_eq again // some metavariables were instantiated, try is_def_eq again
@ -1101,10 +1204,14 @@ struct unifier_fn {
} }
void consume_tc_cnstrs() { void consume_tc_cnstrs() {
while (auto c = m_tc.next_cnstr()) { while (true) {
if (in_conflict()) if (in_conflict())
return; return;
process_constraint(*c); if (auto c = m_tc.next_cnstr()) {
process_constraint(*c);
} else {
break;
}
} }
} }
@ -1123,11 +1230,16 @@ struct unifier_fn {
return process_plugin_constraint(c); return process_plugin_constraint(c);
} }
/** \brief Return true if unifier may be able to produce more solutions */
bool more_solutions() const {
return !in_conflict() || !m_case_splits.empty();
}
/** \brief Produce the next solution */ /** \brief Produce the next solution */
optional<substitution> next() { optional<substitution> next() {
if (in_conflict()) if (!more_solutions())
return failure(); return failure();
if (!m_case_splits.empty()) { if (!m_first && !m_case_splits.empty()) {
justification all_assumptions; justification all_assumptions;
for (auto const & cs : m_case_splits) for (auto const & cs : m_case_splits)
all_assumptions = mk_composite1(all_assumptions, mk_assumption_justification(cs->m_assumption_idx)); all_assumptions = mk_composite1(all_assumptions, mk_assumption_justification(cs->m_assumption_idx));
@ -1162,7 +1274,7 @@ unifier_plugin get_noop_unifier_plugin() {
} }
lazy_list<substitution> unify(std::shared_ptr<unifier_fn> u) { lazy_list<substitution> unify(std::shared_ptr<unifier_fn> u) {
if (u->in_conflict()) { if (!u->more_solutions()) {
u->failure(); // make sure exception is thrown if u->m_use_exception is true u->failure(); // make sure exception is thrown if u->m_use_exception is true
return lazy_list<substitution>(); return lazy_list<substitution>();
} else { } else {
@ -1192,17 +1304,13 @@ lazy_list<substitution> unify(environment const & env, expr const & lhs, expr co
type_checker tc(env, new_ngen.mk_child()); type_checker tc(env, new_ngen.mk_child());
expr _lhs = s.instantiate(lhs); expr _lhs = s.instantiate(lhs);
expr _rhs = s.instantiate(rhs); expr _rhs = s.instantiate(rhs);
if (!tc.is_def_eq(_lhs, _rhs)) auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, s, p, false, max_steps);
if (!u->is_def_eq(_lhs, _rhs, justification()) && !u->more_solutions())
return lazy_list<substitution>(); return lazy_list<substitution>();
buffer<constraint> cs; u->consume_tc_cnstrs();
while (auto c = tc.next_cnstr()) { if (!u->more_solutions())
cs.push_back(*c); return lazy_list<substitution>();
} return unify(u);
if (cs.empty()) {
return lazy_list<substitution>(s);
} else {
return unify(std::make_shared<unifier_fn>(env, cs.size(), cs.data(), ngen, s, p, false, max_steps));
}
} }
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen,

48
tests/lean/run/uni.lean Normal file
View file

@ -0,0 +1,48 @@
import logic
inductive nat : Type :=
| zero : nat
| succ : nat → nat
check @nat_rec
(*
local env = get_env()
local nat_rec = Const("nat_rec", {1})
local nat = Const("nat")
local n = Local("n", nat)
local C = Fun(n, Bool)
local p = Local("p", Bool)
local ff = Const("false")
local tt = Const("true")
local t = nat_rec(C, ff, Fun(n, p, tt))
local zero = Const("zero")
local succ = Const("succ")
local one = succ(zero)
local tc = type_checker(env)
print(env:whnf(t(one)))
print(env:whnf(t(zero)))
local m = mk_metavar("m", nat)
print(env:whnf(t(m)))
function test_unify(env, lhs, rhs, num_s)
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
local ss = unify(env, lhs, rhs, name_generator(), substitution(), options())
local n = 0
for s in ss do
print("solution: ")
s:for_each_expr(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
s:for_each_level(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
n = n + 1
end
if num_s ~= n then print("n: " .. n) end
assert(num_s == n)
end
test_unify(env, t(m), tt, 1)
test_unify(env, t(m), ff, 1)
*)