feat(library/unifier): add support for unification constraints of the form "(elim ... (?m ...)) =?= t", where elim is an eliminator
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
c854ad3d65
commit
110b622b83
2 changed files with 171 additions and 15 deletions
|
@ -16,6 +16,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/abstract.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/inductive/inductive.h"
|
||||
#include "library/occurs.h"
|
||||
#include "library/unifier.h"
|
||||
#include "library/kernel_bindings.h"
|
||||
|
@ -479,9 +480,105 @@ struct unifier_fn {
|
|||
add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed);
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
|
||||
bool is_elim_meta_app(expr const & e) {
|
||||
if (!is_app(e))
|
||||
return false;
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_constant(f))
|
||||
return false;
|
||||
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
|
||||
if (!it_name)
|
||||
return false;
|
||||
if (!is_meta(app_arg(e)))
|
||||
return false;
|
||||
if (is_pi(m_tc.whnf(m_tc.infer(e))))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
|
||||
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
|
||||
to the different constructors of decl.
|
||||
*/
|
||||
void mk_inductice_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
|
||||
justification const & j) {
|
||||
lean_assert(is_constant(elim));
|
||||
levels elim_lvls = const_levels(elim);
|
||||
unsigned elim_num_lvls = length(elim_lvls);
|
||||
unsigned num_args = args.size();
|
||||
expr meta = args[num_args - 1]; // save last argument, we will update it
|
||||
lean_assert(is_meta(meta));
|
||||
buffer<expr> margs;
|
||||
expr const & m = get_app_args(meta, margs);
|
||||
expr const & mtype = mlocal_type(m);
|
||||
buffer<constraints> alts;
|
||||
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
|
||||
name const & intro_name = inductive::intro_rule_name(intro);
|
||||
declaration intro_decl = m_env.get(intro_name);
|
||||
levels intro_lvls;
|
||||
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
|
||||
intro_lvls = elim_lvls;
|
||||
} else {
|
||||
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
|
||||
intro_lvls = tail(elim_lvls);
|
||||
}
|
||||
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
|
||||
expr hint = intro_fn;
|
||||
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
|
||||
while (is_pi(intro_type)) {
|
||||
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
|
||||
intro_type = m_tc.whnf(binding_body(intro_type));
|
||||
}
|
||||
constraint c1 = mk_eq_cnstr(meta, hint, j);
|
||||
args[num_args - 1] = hint;
|
||||
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
|
||||
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
|
||||
alts.push_back(constraints({c1, c2}));
|
||||
}
|
||||
if (alts.empty()) {
|
||||
set_conflict(j);
|
||||
} else if (alts.size() == 1) {
|
||||
process_constraints(alts[0], justification());
|
||||
} else {
|
||||
justification a = mk_assumption_justification(m_next_assumption_idx);
|
||||
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
|
||||
process_constraints(alts[0], a);
|
||||
}
|
||||
}
|
||||
|
||||
bool try_inductive_hint_core(expr const & t1, expr const & t2, justification const & j) {
|
||||
if (!is_elim_meta_app(t1))
|
||||
return false;
|
||||
buffer<expr> args;
|
||||
expr const & elim = get_app_args(t1, args);
|
||||
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
|
||||
auto decls = *inductive::is_inductive_decl(m_env, it_name);
|
||||
for (auto const & d : std::get<2>(decls)) {
|
||||
if (inductive::inductive_decl_name(d) == it_name) {
|
||||
mk_inductice_cnstrs(d, elim, args, t2, j);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
|
||||
associated with the eliminator \c elim.
|
||||
*/
|
||||
bool try_inductive_hint(expr const & t1, expr const & t2, justification const & j) {
|
||||
return
|
||||
try_inductive_hint_core(t1, t2, j) ||
|
||||
try_inductive_hint_core(t2, t1, j);
|
||||
}
|
||||
|
||||
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
|
||||
if (m_tc.is_def_eq(t1, t2, j)) {
|
||||
return true;
|
||||
} else if (try_inductive_hint(t1, t2, j)) {
|
||||
return true;
|
||||
} else {
|
||||
set_conflict(j);
|
||||
return false;
|
||||
|
@ -595,6 +692,12 @@ struct unifier_fn {
|
|||
rhs = m_tc.whnf(rhs);
|
||||
lhs = m_tc.whnf(lhs);
|
||||
|
||||
// We delay constraints where lhs or rhs are of the form (elim ... (?m ...))
|
||||
if (is_elim_meta_app(lhs) || is_elim_meta_app(rhs)) {
|
||||
add_very_delayed_cnstr(c, &unassigned_lvls, &unassigned_exprs);
|
||||
return true;
|
||||
}
|
||||
|
||||
// If lhs or rhs were updated, then invoke is_def_eq again.
|
||||
if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) {
|
||||
// some metavariables were instantiated, try is_def_eq again
|
||||
|
@ -1101,10 +1204,14 @@ struct unifier_fn {
|
|||
}
|
||||
|
||||
void consume_tc_cnstrs() {
|
||||
while (auto c = m_tc.next_cnstr()) {
|
||||
while (true) {
|
||||
if (in_conflict())
|
||||
return;
|
||||
process_constraint(*c);
|
||||
if (auto c = m_tc.next_cnstr()) {
|
||||
process_constraint(*c);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1123,11 +1230,16 @@ struct unifier_fn {
|
|||
return process_plugin_constraint(c);
|
||||
}
|
||||
|
||||
/** \brief Return true if unifier may be able to produce more solutions */
|
||||
bool more_solutions() const {
|
||||
return !in_conflict() || !m_case_splits.empty();
|
||||
}
|
||||
|
||||
/** \brief Produce the next solution */
|
||||
optional<substitution> next() {
|
||||
if (in_conflict())
|
||||
if (!more_solutions())
|
||||
return failure();
|
||||
if (!m_case_splits.empty()) {
|
||||
if (!m_first && !m_case_splits.empty()) {
|
||||
justification all_assumptions;
|
||||
for (auto const & cs : m_case_splits)
|
||||
all_assumptions = mk_composite1(all_assumptions, mk_assumption_justification(cs->m_assumption_idx));
|
||||
|
@ -1162,7 +1274,7 @@ unifier_plugin get_noop_unifier_plugin() {
|
|||
}
|
||||
|
||||
lazy_list<substitution> unify(std::shared_ptr<unifier_fn> u) {
|
||||
if (u->in_conflict()) {
|
||||
if (!u->more_solutions()) {
|
||||
u->failure(); // make sure exception is thrown if u->m_use_exception is true
|
||||
return lazy_list<substitution>();
|
||||
} else {
|
||||
|
@ -1192,17 +1304,13 @@ lazy_list<substitution> unify(environment const & env, expr const & lhs, expr co
|
|||
type_checker tc(env, new_ngen.mk_child());
|
||||
expr _lhs = s.instantiate(lhs);
|
||||
expr _rhs = s.instantiate(rhs);
|
||||
if (!tc.is_def_eq(_lhs, _rhs))
|
||||
auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, s, p, false, max_steps);
|
||||
if (!u->is_def_eq(_lhs, _rhs, justification()) && !u->more_solutions())
|
||||
return lazy_list<substitution>();
|
||||
buffer<constraint> cs;
|
||||
while (auto c = tc.next_cnstr()) {
|
||||
cs.push_back(*c);
|
||||
}
|
||||
if (cs.empty()) {
|
||||
return lazy_list<substitution>(s);
|
||||
} else {
|
||||
return unify(std::make_shared<unifier_fn>(env, cs.size(), cs.data(), ngen, s, p, false, max_steps));
|
||||
}
|
||||
u->consume_tc_cnstrs();
|
||||
if (!u->more_solutions())
|
||||
return lazy_list<substitution>();
|
||||
return unify(u);
|
||||
}
|
||||
|
||||
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen,
|
||||
|
|
48
tests/lean/run/uni.lean
Normal file
48
tests/lean/run/uni.lean
Normal file
|
@ -0,0 +1,48 @@
|
|||
import logic
|
||||
|
||||
inductive nat : Type :=
|
||||
| zero : nat
|
||||
| succ : nat → nat
|
||||
|
||||
check @nat_rec
|
||||
|
||||
(*
|
||||
local env = get_env()
|
||||
local nat_rec = Const("nat_rec", {1})
|
||||
local nat = Const("nat")
|
||||
local n = Local("n", nat)
|
||||
local C = Fun(n, Bool)
|
||||
local p = Local("p", Bool)
|
||||
local ff = Const("false")
|
||||
local tt = Const("true")
|
||||
local t = nat_rec(C, ff, Fun(n, p, tt))
|
||||
local zero = Const("zero")
|
||||
local succ = Const("succ")
|
||||
local one = succ(zero)
|
||||
local tc = type_checker(env)
|
||||
print(env:whnf(t(one)))
|
||||
print(env:whnf(t(zero)))
|
||||
local m = mk_metavar("m", nat)
|
||||
print(env:whnf(t(m)))
|
||||
|
||||
function test_unify(env, lhs, rhs, num_s)
|
||||
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
|
||||
local ss = unify(env, lhs, rhs, name_generator(), substitution(), options())
|
||||
local n = 0
|
||||
for s in ss do
|
||||
print("solution: ")
|
||||
s:for_each_expr(function(n, v, j)
|
||||
print(" " .. tostring(n) .. " := " .. tostring(v))
|
||||
end)
|
||||
s:for_each_level(function(n, v, j)
|
||||
print(" " .. tostring(n) .. " := " .. tostring(v))
|
||||
end)
|
||||
n = n + 1
|
||||
end
|
||||
if num_s ~= n then print("n: " .. n) end
|
||||
assert(num_s == n)
|
||||
end
|
||||
|
||||
test_unify(env, t(m), tt, 1)
|
||||
test_unify(env, t(m), ff, 1)
|
||||
*)
|
Loading…
Reference in a new issue