feat(kernel): add method 'may_reduce_later' to normalizer_extension, and improve unifier
Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
parent
ce282a549a
commit
7fb2b0f6d8
8 changed files with 196 additions and 105 deletions
|
@ -54,6 +54,11 @@ struct default_converter : public converter {
|
|||
return m_env.norm_ext()(e, get_extension(c));
|
||||
}
|
||||
|
||||
/** \brief Return true if \c e may be reduced later after metavariables are instantiated. */
|
||||
bool may_reduce_later(expr const & e, type_checker & c) {
|
||||
return m_env.norm_ext().may_reduce_later(e, get_extension(c));
|
||||
}
|
||||
|
||||
/** \brief Try to apply eta-reduction to \c e. */
|
||||
expr try_eta(expr const & e) {
|
||||
lean_assert(is_lambda(e));
|
||||
|
@ -462,6 +467,11 @@ struct default_converter : public converter {
|
|||
}
|
||||
}
|
||||
|
||||
if (may_reduce_later(t_n, c) || may_reduce_later(s_n, c)) {
|
||||
add_cnstr(c, mk_eq_cnstr(t_n, s_n, jst.get()));
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,8 @@ namespace lean {
|
|||
*/
|
||||
class noop_normalizer_extension : public normalizer_extension {
|
||||
public:
|
||||
virtual optional<expr> operator()(expr const &, extension_context &) const {
|
||||
return none_expr();
|
||||
}
|
||||
virtual optional<expr> operator()(expr const &, extension_context &) const { return none_expr(); }
|
||||
virtual bool may_reduce_later(expr const &, extension_context &) const { return false; }
|
||||
};
|
||||
|
||||
environment_header::environment_header(unsigned trust_lvl, bool prop_proof_irrel, bool eta, bool impredicative,
|
||||
|
|
|
@ -38,6 +38,8 @@ class normalizer_extension {
|
|||
public:
|
||||
virtual ~normalizer_extension() {}
|
||||
virtual optional<expr> operator()(expr const & e, extension_context & ctx) const = 0;
|
||||
/** \brief Return true if the extension may reduce \c e after metavariables are instantiated. */
|
||||
virtual bool may_reduce_later(expr const & e, extension_context & ctx) const = 0;
|
||||
};
|
||||
|
||||
/**
|
||||
|
|
|
@ -803,6 +803,23 @@ optional<expr> inductive_normalizer_extension::operator()(expr const & e, extens
|
|||
return some_expr(r);
|
||||
}
|
||||
|
||||
// Return true if \c e is of the form (elim ... (?m ...))
|
||||
bool inductive_normalizer_extension::may_reduce_later(expr const & e, extension_context & ctx) const {
|
||||
inductive_env_ext const & ext = get_extension(ctx.env());
|
||||
expr const & elim_fn = get_app_fn(e);
|
||||
if (!is_constant(elim_fn))
|
||||
return false;
|
||||
auto it1 = ext.m_elim_info.find(const_name(elim_fn));
|
||||
if (!it1)
|
||||
return false;
|
||||
buffer<expr> elim_args;
|
||||
get_app_args(e, elim_args);
|
||||
if (elim_args.size() != it1->m_num_ACe + it1->m_num_indices + 1)
|
||||
return false;
|
||||
expr intro_app = ctx.whnf(elim_args.back());
|
||||
return is_meta(intro_app);
|
||||
}
|
||||
|
||||
optional<inductive_decls> is_inductive_decl(environment const & env, name const & n) {
|
||||
inductive_env_ext const & ext = get_extension(env);
|
||||
if (auto it = ext.m_inductive_info.find(n))
|
||||
|
|
|
@ -17,6 +17,7 @@ namespace inductive {
|
|||
class inductive_normalizer_extension : public normalizer_extension {
|
||||
public:
|
||||
virtual optional<expr> operator()(expr const & e, extension_context & ctx) const;
|
||||
virtual bool may_reduce_later(expr const & e, extension_context & ctx) const;
|
||||
};
|
||||
|
||||
/** \brief Introduction rule */
|
||||
|
|
|
@ -480,105 +480,9 @@ struct unifier_fn {
|
|||
add_cnstr(c, mlvl_occs, mvar_occs, g_first_very_delayed);
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
|
||||
bool is_elim_meta_app(expr const & e) {
|
||||
if (!is_app(e))
|
||||
return false;
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_constant(f))
|
||||
return false;
|
||||
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
|
||||
if (!it_name)
|
||||
return false;
|
||||
if (!is_meta(app_arg(e)))
|
||||
return false;
|
||||
if (is_pi(m_tc.whnf(m_tc.infer(e))))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
|
||||
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
|
||||
to the different constructors of decl.
|
||||
*/
|
||||
void mk_inductice_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
|
||||
justification const & j) {
|
||||
lean_assert(is_constant(elim));
|
||||
levels elim_lvls = const_levels(elim);
|
||||
unsigned elim_num_lvls = length(elim_lvls);
|
||||
unsigned num_args = args.size();
|
||||
expr meta = args[num_args - 1]; // save last argument, we will update it
|
||||
lean_assert(is_meta(meta));
|
||||
buffer<expr> margs;
|
||||
expr const & m = get_app_args(meta, margs);
|
||||
expr const & mtype = mlocal_type(m);
|
||||
buffer<constraints> alts;
|
||||
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
|
||||
name const & intro_name = inductive::intro_rule_name(intro);
|
||||
declaration intro_decl = m_env.get(intro_name);
|
||||
levels intro_lvls;
|
||||
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
|
||||
intro_lvls = elim_lvls;
|
||||
} else {
|
||||
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
|
||||
intro_lvls = tail(elim_lvls);
|
||||
}
|
||||
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
|
||||
expr hint = intro_fn;
|
||||
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
|
||||
while (is_pi(intro_type)) {
|
||||
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
|
||||
intro_type = m_tc.whnf(binding_body(intro_type));
|
||||
}
|
||||
constraint c1 = mk_eq_cnstr(meta, hint, j);
|
||||
args[num_args - 1] = hint;
|
||||
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
|
||||
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
|
||||
alts.push_back(constraints({c1, c2}));
|
||||
}
|
||||
if (alts.empty()) {
|
||||
set_conflict(j);
|
||||
} else if (alts.size() == 1) {
|
||||
process_constraints(alts[0], justification());
|
||||
} else {
|
||||
justification a = mk_assumption_justification(m_next_assumption_idx);
|
||||
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
|
||||
process_constraints(alts[0], a);
|
||||
}
|
||||
}
|
||||
|
||||
bool try_inductive_hint_core(expr const & t1, expr const & t2, justification const & j) {
|
||||
if (!is_elim_meta_app(t1))
|
||||
return false;
|
||||
buffer<expr> args;
|
||||
expr const & elim = get_app_args(t1, args);
|
||||
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
|
||||
auto decls = *inductive::is_inductive_decl(m_env, it_name);
|
||||
for (auto const & d : std::get<2>(decls)) {
|
||||
if (inductive::inductive_decl_name(d) == it_name) {
|
||||
mk_inductice_cnstrs(d, elim, args, t2, j);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
|
||||
associated with the eliminator \c elim.
|
||||
*/
|
||||
bool try_inductive_hint(expr const & t1, expr const & t2, justification const & j) {
|
||||
return
|
||||
try_inductive_hint_core(t1, t2, j) ||
|
||||
try_inductive_hint_core(t2, t1, j);
|
||||
}
|
||||
|
||||
bool is_def_eq(expr const & t1, expr const & t2, justification const & j) {
|
||||
if (m_tc.is_def_eq(t1, t2, j)) {
|
||||
return true;
|
||||
} else if (try_inductive_hint(t1, t2, j)) {
|
||||
return true;
|
||||
} else {
|
||||
set_conflict(j);
|
||||
return false;
|
||||
|
@ -919,6 +823,107 @@ struct unifier_fn {
|
|||
}
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c e is of the form (elim ... (?m ...)) */
|
||||
bool is_elim_meta_app(expr const & e) {
|
||||
if (!is_app(e))
|
||||
return false;
|
||||
expr const & f = get_app_fn(e);
|
||||
if (!is_constant(f))
|
||||
return false;
|
||||
auto it_name = inductive::is_elim_rule(m_env, const_name(f));
|
||||
if (!it_name)
|
||||
return false;
|
||||
if (!is_meta(app_arg(e)))
|
||||
return false;
|
||||
if (is_pi(m_tc.whnf(m_tc.infer(e))))
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/** \brief Return true iff the lhs or rhs of the constraint c is of the form (elim ... (?m ...)) */
|
||||
bool is_elim_meta_cnstr(constraint const & c) {
|
||||
return is_eq_cnstr(c) && (is_elim_meta_app(cnstr_lhs_expr(c)) || is_elim_meta_app(cnstr_rhs_expr(c)));
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Given (elim args) =?= t, where elim is the eliminator/recursor for the inductive declaration \c decl,
|
||||
and the last argument of args is of the form (?m ...), we create a case split where we try to assign (?m ...)
|
||||
to the different constructors of decl.
|
||||
*/
|
||||
bool add_elim_meta_cnstrs(inductive::inductive_decl const & decl, expr const & elim, buffer<expr> & args, expr const & t,
|
||||
justification const & j) {
|
||||
lean_assert(is_constant(elim));
|
||||
levels elim_lvls = const_levels(elim);
|
||||
unsigned elim_num_lvls = length(elim_lvls);
|
||||
unsigned num_args = args.size();
|
||||
expr meta = args[num_args - 1]; // save last argument, we will update it
|
||||
lean_assert(is_meta(meta));
|
||||
buffer<expr> margs;
|
||||
expr const & m = get_app_args(meta, margs);
|
||||
expr const & mtype = mlocal_type(m);
|
||||
buffer<constraints> alts;
|
||||
for (auto const & intro : inductive::inductive_decl_intros(decl)) {
|
||||
name const & intro_name = inductive::intro_rule_name(intro);
|
||||
declaration intro_decl = m_env.get(intro_name);
|
||||
levels intro_lvls;
|
||||
if (length(intro_decl.get_univ_params()) == elim_num_lvls) {
|
||||
intro_lvls = elim_lvls;
|
||||
} else {
|
||||
lean_assert(length(intro_decl.get_univ_params()) == elim_num_lvls - 1);
|
||||
intro_lvls = tail(elim_lvls);
|
||||
}
|
||||
expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls);
|
||||
expr hint = intro_fn;
|
||||
expr intro_type = m_tc.whnf(inductive::intro_rule_type(intro));
|
||||
while (is_pi(intro_type)) {
|
||||
hint = mk_app(hint, mk_app(mk_aux_metavar_for(mtype), margs));
|
||||
intro_type = m_tc.whnf(binding_body(intro_type));
|
||||
}
|
||||
constraint c1 = mk_eq_cnstr(meta, hint, j);
|
||||
args[num_args - 1] = hint;
|
||||
expr reduce_elim = m_tc.whnf(mk_app(elim, args));
|
||||
constraint c2 = mk_eq_cnstr(reduce_elim, t, j);
|
||||
alts.push_back(constraints({c1, c2}));
|
||||
}
|
||||
if (alts.empty()) {
|
||||
set_conflict(j);
|
||||
return false;
|
||||
} else if (alts.size() == 1) {
|
||||
return process_constraints(alts[0], justification());
|
||||
} else {
|
||||
justification a = mk_assumption_justification(m_next_assumption_idx);
|
||||
add_case_split(std::unique_ptr<case_split>(new ho_case_split(*this, to_list(alts.begin() + 1, alts.end()))));
|
||||
return process_constraints(alts[0], a);
|
||||
}
|
||||
}
|
||||
|
||||
bool process_elim_meta_core(expr const & lhs, expr const & rhs, justification const & j) {
|
||||
lean_assert(is_elim_meta_app(lhs));
|
||||
buffer<expr> args;
|
||||
expr const & elim = get_app_args(lhs, args);
|
||||
auto it_name = *inductive::is_elim_rule(m_env, const_name(elim));
|
||||
auto decls = *inductive::is_inductive_decl(m_env, it_name);
|
||||
for (auto const & d : std::get<2>(decls)) {
|
||||
if (inductive::inductive_decl_name(d) == it_name)
|
||||
return add_elim_meta_cnstrs(d, elim, args, rhs, j);
|
||||
}
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
}
|
||||
|
||||
/**
|
||||
\brief Try to solve constraint of the form (elim ... (?m ...)) =?= t, by assigning (?m ...) to the introduction rules
|
||||
associated with the eliminator \c elim.
|
||||
*/
|
||||
bool process_elim_meta_cnstr(constraint const & c) {
|
||||
expr const & lhs = cnstr_lhs_expr(c);
|
||||
expr const & rhs = cnstr_rhs_expr(c);
|
||||
justification const & j = c.get_justification();
|
||||
if (is_elim_meta_app(lhs))
|
||||
return process_elim_meta_core(lhs, rhs, j);
|
||||
else
|
||||
return process_elim_meta_core(rhs, lhs, j);
|
||||
}
|
||||
|
||||
bool next_plugin_case_split(plugin_case_split & cs) {
|
||||
auto r = cs.m_tail.pull();
|
||||
if (r) {
|
||||
|
@ -1226,6 +1231,8 @@ struct unifier_fn {
|
|||
return process_flex_rigid(c);
|
||||
else if (is_flex_flex(c))
|
||||
return process_flex_flex(c);
|
||||
else if (is_elim_meta_cnstr(c))
|
||||
return process_elim_meta_cnstr(c);
|
||||
else
|
||||
return process_plugin_constraint(c);
|
||||
}
|
||||
|
@ -1304,13 +1311,17 @@ lazy_list<substitution> unify(environment const & env, expr const & lhs, expr co
|
|||
type_checker tc(env, new_ngen.mk_child());
|
||||
expr _lhs = s.instantiate(lhs);
|
||||
expr _rhs = s.instantiate(rhs);
|
||||
auto u = std::make_shared<unifier_fn>(env, 0, nullptr, ngen, s, p, false, max_steps);
|
||||
if (!u->is_def_eq(_lhs, _rhs, justification()) && !u->more_solutions())
|
||||
if (!tc.is_def_eq(_lhs, _rhs))
|
||||
return lazy_list<substitution>();
|
||||
u->consume_tc_cnstrs();
|
||||
if (!u->more_solutions())
|
||||
return lazy_list<substitution>();
|
||||
return unify(u);
|
||||
buffer<constraint> cs;
|
||||
while (auto c = tc.next_cnstr()) {
|
||||
cs.push_back(*c);
|
||||
}
|
||||
if (cs.empty()) {
|
||||
return lazy_list<substitution>(s);
|
||||
} else {
|
||||
return unify(std::make_shared<unifier_fn>(env, cs.size(), cs.data(), ngen, s, p, false, max_steps));
|
||||
}
|
||||
}
|
||||
|
||||
lazy_list<substitution> unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen,
|
||||
|
|
|
@ -124,6 +124,7 @@ public:
|
|||
// In a real implementation, we must check if proj1 and mk were defined in the environment.
|
||||
return some_expr(app_arg(app_fn(a_n)));
|
||||
}
|
||||
virtual bool may_reduce_later(expr const &, extension_context &) const { return false; }
|
||||
};
|
||||
|
||||
static void tst3() {
|
||||
|
|
50
tests/lean/run/uni2.lean
Normal file
50
tests/lean/run/uni2.lean
Normal file
|
@ -0,0 +1,50 @@
|
|||
import logic
|
||||
|
||||
inductive nat : Type :=
|
||||
| zero : nat
|
||||
| succ : nat → nat
|
||||
|
||||
variable f : nat → nat
|
||||
|
||||
check @nat_rec
|
||||
|
||||
(*
|
||||
local env = get_env()
|
||||
local nat_rec = Const("nat_rec", {1})
|
||||
local nat = Const("nat")
|
||||
local f = Const("f")
|
||||
local n = Local("n", nat)
|
||||
local C = Fun(n, Bool)
|
||||
local p = Local("p", Bool)
|
||||
local ff = Const("false")
|
||||
local tt = Const("true")
|
||||
local t = nat_rec(C, ff, Fun(n, p, tt))
|
||||
local zero = Const("zero")
|
||||
local succ = Const("succ")
|
||||
local one = succ(zero)
|
||||
local tc = type_checker(env)
|
||||
print(env:whnf(t(one)))
|
||||
print(env:whnf(t(zero)))
|
||||
local m = mk_metavar("m", nat)
|
||||
print(env:whnf(t(m)))
|
||||
|
||||
function test_unify(env, lhs, rhs, num_s)
|
||||
print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(num_s))
|
||||
local ss = unify(env, lhs, rhs, name_generator(), substitution(), options())
|
||||
local n = 0
|
||||
for s in ss do
|
||||
print("solution: ")
|
||||
s:for_each_expr(function(n, v, j)
|
||||
print(" " .. tostring(n) .. " := " .. tostring(v))
|
||||
end)
|
||||
s:for_each_level(function(n, v, j)
|
||||
print(" " .. tostring(n) .. " := " .. tostring(v))
|
||||
end)
|
||||
n = n + 1
|
||||
end
|
||||
if num_s ~= n then print("n: " .. n) end
|
||||
assert(num_s == n)
|
||||
end
|
||||
|
||||
test_unify(env, f(t(m)), f(tt), 1)
|
||||
*)
|
Loading…
Add table
Reference in a new issue