feat(kernel): add method 'may_reduce_later' to normalizer_extension, and improve unifier

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-07-03 22:31:05 -07:00
parent ce282a549a
commit 7fb2b0f6d8
8 changed files with 196 additions and 105 deletions

View file

@ -54,6 +54,11 @@ struct default_converter : public converter {
return m_env.norm_ext()(e, get_extension(c));
}
/** \brief Return true if \c e may be reduced later after metavariables are instantiated. */
bool may_reduce_later(expr const & e, type_checker & c) {
return m_env.norm_ext().may_reduce_later(e, get_extension(c));
}
/** \brief Try to apply eta-reduction to \c e. */
expr try_eta(expr const & e) {
lean_assert(is_lambda(e));
@ -462,6 +467,11 @@ struct default_converter : public converter {
}
}
if (may_reduce_later(t_n, c) || may_reduce_later(s_n, c)) {
add_cnstr(c, mk_eq_cnstr(t_n, s_n, jst.get()));
return true;
}
return false;
}

View file

@ -17,9 +17,8 @@ namespace lean {
*/
class noop_normalizer_extension : public normalizer_extension {
public:
virtual optional<expr> operator()(expr const &, extension_context &) const {
return none_expr();
}
virtual optional<expr> operator()(expr const &, extension_context &) const { return none_expr(); }
virtual bool may_reduce_later(expr const &, extension_context &) const { return false; }
};
environment_header::environment_header(unsigned trust_lvl, bool prop_proof_irrel, bool eta, bool impredicative,

View file

@ -38,6 +38,8 @@ class normalizer_extension {
public:
virtual ~normalizer_extension() {}
virtual optional<expr> operator()(expr const & e, extension_context & ctx) const = 0;
/** \brief Return true if the extension may reduce \c e after metavariables are instantiated. */
virtual bool may_reduce_later(expr const & e, extension_context & ctx) const = 0;
};
/**

View file

@ -803,6 +803,23 @@ optional<expr> inductive_normalizer_extension::operator()(expr const & e, extens
return some_expr(r);
}
// Return true if \c e is of the form (elim ... (?m ...))
bool inductive_normalizer_extension::may_reduce_later(expr const & e, extension_context & ctx) const {
inductive_env_ext const & ext = get_extension(ctx.env());
expr const & elim_fn = get_app_fn(e);
if (!is_constant(elim_fn))
return false;
auto it1 = ext.m_elim_info.find(const_name(elim_fn));
if (!it1)
return false;
buffer<expr> elim_args;
get_app_args(e, elim_args);
if (elim_args.size() != it1->m_num_ACe + it1->m_num_indices + 1)
return false;
expr intro_app = ctx.whnf(elim_args.back());
return is_meta(intro_app);
}
optional<inductive_decls> is_inductive_decl(environment const & env, name const & n) {
inductive_env_ext const & ext = get_extension(env);
if (auto it = ext.m_inductive_info.find(n))

View file

@ -17,6 +17,7 @@ namespace inductive {
class inductive_normalizer_extension : public normalizer_extension {
public:
virtual optional<expr> operator()(expr const & e, extension_context & ctx) const;
virtual bool may_reduce_later(expr const & e, extension_context & ctx) const;
};
/** \brief Introduction rule */

View file

@ -480,105 +480,9 @@ struct unifier_fn {
add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed);
}
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
bool is_elim_meta_app(expr const & e) {
if (!is_app(e))
return false;
expr const & f = get_app_fn(e);
if (!is_constant(f))
return false;
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
if (!it_name)
return false;
if (!is_meta(app_arg(e)))
return false;
if (is_pi(m_tc.whnf(m_tc.infer(e))))
return false;
return true;
}
/**
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
to the different constructors of decl.
*/
void mk_inductice_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
justification const & j) {
lean_assert(is_constant(elim));
levels elim_lvls = const_levels(elim);
unsigned elim_num_lvls = length(elim_lvls);
unsigned num_args = args.size();
expr meta = args[num_args - 1]; // save last argument, we will update it
lean_assert(is_meta(meta));
buffer<expr> margs;
expr const & m = get_app_args(meta, margs);
expr const & mtype = mlocal_type(m);
buffer<constraints> alts;
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
name const & intro_name = inductive::intro_rule_name(intro);
declaration intro_decl = m_env.get(intro_name);
levels intro_lvls;
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
intro_lvls = elim_lvls;
} else {
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
intro_lvls = tail(elim_lvls);
}
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
expr hint = intro_fn;
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
while (is_pi(intro_type)) {
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
intro_type = m_tc.whnf(binding_body(intro_type));
}
constraint c1 = mk_eq_cnstr(meta, hint, j);
args[num_args - 1] = hint;
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
alts.push_back(constraints({c1, c2}));
}
if (alts.empty()) {
set_conflict(j);
} else if (alts.size() == 1) {
process_constraints(alts[0], justification());
} else {
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
process_constraints(alts[0], a);
}
}
bool try_inductive_hint_core(expr const & t1, expr const & t2, justification const & j) {
if (!is_elim_meta_app(t1))
return false;
buffer<expr> args;
expr const & elim = get_app_args(t1, args);
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
auto decls = *inductive::is_inductive_decl(m_env, it_name);
for (auto const & d : std::get<2>(decls)) {
if (inductive::inductive_decl_name(d) == it_name) {
mk_inductice_cnstrs(d, elim, args, t2, j);
return true;
}
}
lean_unreachable(); // LCOV_EXCL_LINE
}
/**
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
associated with the eliminator \c elim.
*/
bool try_inductive_hint(expr const & t1, expr const & t2, justification const & j) {
return
try_inductive_hint_core(t1, t2, j) ||
try_inductive_hint_core(t2, t1, j);
}
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
if (m_tc.is_def_eq(t1, t2, j)) {
return true;
} else if (try_inductive_hint(t1, t2, j)) {
return true;
} else {
set_conflict(j);
return false;
@ -919,6 +823,107 @@ struct unifier_fn {
}
}
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
bool is_elim_meta_app(expr const & e) {
if (!is_app(e))
return false;
expr const & f = get_app_fn(e);
if (!is_constant(f))
return false;
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
if (!it_name)
return false;
if (!is_meta(app_arg(e)))
return false;
if (is_pi(m_tc.whnf(m_tc.infer(e))))
return false;
return true;
}
/** \brief Return true iff the lhs or rhs of the constraint c is of the form (elim ... (?m ...)) */
bool is_elim_meta_cnstr(constraint const & c) {
return is_eq_cnstr(c) && (is_elim_meta_app(cnstr_lhs_expr(c)) || is_elim_meta_app(cnstr_rhs_expr(c)));
}
/**
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
to the different constructors of decl.
*/
bool add_elim_meta_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
justification const & j) {
lean_assert(is_constant(elim));
levels elim_lvls = const_levels(elim);
unsigned elim_num_lvls = length(elim_lvls);
unsigned num_args = args.size();
expr meta = args[num_args - 1]; // save last argument, we will update it
lean_assert(is_meta(meta));
buffer<expr> margs;
expr const & m = get_app_args(meta, margs);
expr const & mtype = mlocal_type(m);
buffer<constraints> alts;
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
name const & intro_name = inductive::intro_rule_name(intro);
declaration intro_decl = m_env.get(intro_name);
levels intro_lvls;
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
intro_lvls = elim_lvls;
} else {
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
intro_lvls = tail(elim_lvls);
}
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
expr hint = intro_fn;
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
while (is_pi(intro_type)) {
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
intro_type = m_tc.whnf(binding_body(intro_type));
}
constraint c1 = mk_eq_cnstr(meta, hint, j);
args[num_args - 1] = hint;
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
alts.push_back(constraints({c1, c2}));
}
if (alts.empty()) {
set_conflict(j);
return false;
} else if (alts.size() == 1) {
return process_constraints(alts[0], justification());
} else {
justification a = mk_assumption_justification(m_next_assumption_idx);
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
return process_constraints(alts[0], a);
}
}
bool process_elim_meta_core(expr const & lhs, expr const & rhs, justification const & j) {
lean_assert(is_elim_meta_app(lhs));
buffer<expr> args;
expr const & elim = get_app_args(lhs, args);
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
auto decls = *inductive::is_inductive_decl(m_env, it_name);
for (auto const & d : std::get<2>(decls)) {
if (inductive::inductive_decl_name(d) == it_name)
return add_elim_meta_cnstrs(d, elim, args, rhs, j);
}
lean_unreachable(); // LCOV_EXCL_LINE
}
/**
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
associated with the eliminator \c elim.
*/
bool process_elim_meta_cnstr(constraint const & c) {
expr const & lhs = cnstr_lhs_expr(c);
expr const & rhs = cnstr_rhs_expr(c);
justification const & j = c.get_justification();
if (is_elim_meta_app(lhs))
return process_elim_meta_core(lhs, rhs, j);
else
return process_elim_meta_core(rhs, lhs, j);
}
bool next_plugin_case_split(plugin_case_split & cs) {
auto r = cs.m_tail.pull();
if (r) {
@ -1226,6 +1231,8 @@ struct unifier_fn {
return process_flex_rigid(c);
else if (is_flex_flex(c))
return process_flex_flex(c);
else if (is_elim_meta_cnstr(c))
return process_elim_meta_cnstr(c);
else
return process_plugin_constraint(c);
}
@ -1304,13 +1311,17 @@ lazy_list<substitution> unify(environment const & env, expr const & lhs, expr co
type_checker tc(env, new_ngen.mk_child());
expr _lhs = s.instantiate(lhs);
expr _rhs = s.instantiate(rhs);
auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, s, p, false, max_steps);
if (!u->is_def_eq(_lhs, _rhs, justification()) && !u->more_solutions())
if (!tc.is_def_eq(_lhs, _rhs))
return lazy_list<substitution>();
u->consume_tc_cnstrs();
if (!u->more_solutions())
return lazy_list<substitution>();
return unify(u);
buffer<constraint> cs;
while (auto c = tc.next_cnstr()) {
cs.push_back(*c);
}
if (cs.empty()) {
return lazy_list<substitution>(s);
} else {
return unify(std::make_shared<unifier_fn>(env, cs.size(), cs.data(), ngen, s, p, false, max_steps));
}
}
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen,

View file

@ -124,6 +124,7 @@ public:
// In a real implementation, we must check if proj1 and mk were defined in the environment.
return some_expr(app_arg(app_fn(a_n)));
}
virtual bool may_reduce_later(expr const &, extension_context &) const { return false; }
};
static void tst3() {

50
tests/lean/run/uni2.lean Normal file
View file

@ -0,0 +1,50 @@
import logic
inductive nat : Type :=
| zero : nat
| succ : nat → nat
variable f : nat → nat
check @nat_rec
(*
local env = get_env()
local nat_rec = Const("nat_rec", {1})
local nat = Const("nat")
local f = Const("f")
local n = Local("n", nat)
local C = Fun(n, Bool)
local p = Local("p", Bool)
local ff = Const("false")
local tt = Const("true")
local t = nat_rec(C, ff, Fun(n, p, tt))
local zero = Const("zero")
local succ = Const("succ")
local one = succ(zero)
local tc = type_checker(env)
print(env:whnf(t(one)))
print(env:whnf(t(zero)))
local m = mk_metavar("m", nat)
print(env:whnf(t(m)))
function test_unify(env, lhs, rhs, num_s)
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
local ss = unify(env, lhs, rhs, name_generator(), substitution(), options())
local n = 0
for s in ss do
print("solution: ")
s:for_each_expr(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
s:for_each_level(function(n, v, j)
print(" " .. tostring(n) .. " := " .. tostring(v))
end)
n = n + 1
end
if num_s ~= n then print("n: " .. n) end
assert(num_s == n)
end
test_unify(env, f(t(m)), f(tt), 1)
*)