feat(kernel/inductive): generate computational rules RHS for inductive datatypes

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-05-19 09:08:19 -07:00
parent eb409a9ce3
commit 2aacb769dd
2 changed files with 108 additions and 24 deletions

View file

@ -47,8 +47,7 @@ struct add_inductive_fn {
expr m_major_premise; // major premise for each inductive decl
buffer<expr> m_minor_premises; // minor premise for each introduction rule
};
buffer<elim_info> m_elim_info; // for each datatype being declared
buffer<elim_info> m_elim_info; // for each datatype being declared
add_inductive_fn(environment env,
level_param_names const & level_params,
@ -63,13 +62,13 @@ struct add_inductive_fn {
/** \brief Return the number of inductive datatypes being defined. */
unsigned get_num_its() const { return m_decls_sz; }
/** \brief Make sure the latest environment is being used by m_tc */
/** \brief Make sure the latest environment is being used by m_tc. */
void updt_type_checker() {
type_checker tc(m_env);
m_tc.swap(tc);
}
/** \brief Return a fresh name */
/** \brief Return a fresh name. */
name mk_fresh_name() { return m_ngen.next(); }
/** \brief Create a local constant for the given binding. */
@ -176,7 +175,7 @@ struct add_inductive_fn {
std::any_of(m_it_consts.begin(), m_it_consts.end(), [&](expr const & c) { return const_name(e) == const_name(c); });
}
/** \brief Return true if \c t does not contain any occurrence of a datatype being declared */
/** \brief Return true if \c t does not contain any occurrence of a datatype being declared. */
bool has_it_occ(expr const & t) {
return (bool)find(t, [&](expr const & e, unsigned) { return is_it_occ(e); }); // NOLINT
}
@ -275,7 +274,7 @@ struct add_inductive_fn {
}
}
/** \brief Add all introduction rules (aka constructors) to environment */
/** \brief Add all introduction rules (aka constructors) to environment. */
void declare_intro_rules() {
for (auto d : m_decls) {
for (auto ir : inductive_decl_intros(d))
@ -285,7 +284,7 @@ struct add_inductive_fn {
}
/** \brief Initialize m_elim_level */
/** \brief Initialize m_elim_level. */
void mk_elim_level() {
if (m_env.impredicative() && is_zero(m_it_levels[0]) && (get_num_its() > 1 || length(inductive_decl_intros(head(m_decls))) != 1)) {
// environment is impredicative, datatype maps to Bool/Prop, we have more than one introduction rule.
@ -301,7 +300,7 @@ struct add_inductive_fn {
}
}
/** \brief Initialize m_dep_elim flag */
/** \brief Initialize m_dep_elim flag. */
void set_dep_elim() {
if (m_env.impredicative() && is_zero(m_it_levels[0]))
m_dep_elim = false;
@ -324,7 +323,7 @@ struct add_inductive_fn {
return *r;
}
/** \brief Populate m_elim_info */
/** \brief Populate m_elim_info. */
void mk_elim_info() {
unsigned d_idx = 0;
// First, populate the fields, m_C, m_indices, m_major_premise
@ -368,7 +367,7 @@ struct add_inductive_fn {
if (i < m_num_params) {
t = instantiate(binding_body(t), m_param_consts[i]);
} else {
expr l = mk_local(mk_fresh_name(), binding_name(t), binding_domain(t));
expr l = mk_local_for(t);
if (!is_rec_argument(binding_domain(t)))
b.push_back(l);
else
@ -390,7 +389,7 @@ struct add_inductive_fn {
expr u_i_ty = m_tc.whnf(mlocal_type(u_i));
buffer<expr> xs;
while (is_pi(u_i_ty)) {
expr x = mk_local(mk_fresh_name(), binding_name(u_i_ty), binding_domain(u_i_ty));
expr x = mk_local_for(u_i_ty);
xs.push_back(x);
u_i_ty = m_tc.whnf(instantiate(binding_body(u_i_ty), x));
}
@ -414,7 +413,28 @@ struct add_inductive_fn {
}
}
/** \brief Declare elimination rule */
/** \brief Return the name of the eliminator/recursor for \c d. */
name get_elim_name(inductive_decl const & d) { return inductive_decl_name(d).append_after("_rec"); }
name get_elim_name(unsigned d_idx) { return get_elim_name(get_ith(m_decls, d_idx)); }
/** \brief Return the level parameter names for the eliminator. */
level_param_names get_elim_level_param_names() {
if (is_param(m_elim_level))
return level_param_names(param_id(m_elim_level), m_level_names);
else
return m_level_names;
}
/** \brief Return the levels for the eliminator application. */
levels get_elim_level_params() {
if (is_param(m_elim_level))
return levels(m_elim_level, m_levels);
else
return m_levels;
}
/** \brief Declare elimination rule. */
void declare_elim_rule(inductive_decl const & d, unsigned d_idx) {
elim_info const & info = m_elim_info[d_idx];
expr C_app = mk_app(info.m_C, info.m_indices);
@ -438,16 +458,10 @@ struct add_inductive_fn {
elim_ty = Pi(m_elim_info[i].m_C, elim_ty);
}
elim_ty = Pi(m_param_consts, elim_ty);
level_param_names ls;
if (is_param(m_elim_level))
ls = level_param_names(param_id(m_elim_level), m_level_names);
else
ls = m_level_names;
name elim_name = inductive_decl_name(d).append_after("_rec");
m_env = m_env.add(check(m_env, mk_var_decl(elim_name, ls, elim_ty)));
m_env = m_env.add(check(m_env, mk_var_decl(get_elim_name(d), get_elim_level_param_names(), elim_ty)));
}
/** \brief Declare the eliminator/recursor for each datatype */
/** \brief Declare the eliminator/recursor for each datatype. */
void declare_elim_rules() {
set_dep_elim();
mk_elim_level();
@ -457,6 +471,73 @@ struct add_inductive_fn {
declare_elim_rule(d, i);
i++;
}
updt_type_checker();
}
/** \brief Store all set formers in \c Cs */
void collect_Cs(buffer<expr> & Cs) {
for (unsigned i = 0; i < get_num_its(); i++)
Cs.push_back(m_elim_info[i].m_C);
}
/** \brief Store all minor premises in \c es. */
void collect_minor_premises(buffer<expr> & es) {
for (unsigned i = 0; i < get_num_its(); i++)
es.append(m_elim_info[i].m_minor_premises);
}
/** \brief Create computional rules RHS. They are used by the normalizer extension. */
void mk_comp_rules_rhs() {
unsigned d_idx = 0;
unsigned minor_idx = 0;
buffer<expr> C; collect_Cs(C);
buffer<expr> e; collect_minor_premises(e);
levels ls = get_elim_level_params();
for (auto d : m_decls) {
for (auto ir : inductive_decl_intros(d)) {
buffer<expr> b;
buffer<expr> u;
expr t = intro_rule_type(ir);
unsigned i = 0;
while (is_pi(t)) {
if (i < m_num_params) {
t = instantiate(binding_body(t), m_param_consts[i]);
} else {
expr l = mk_local_for(t);
if (!is_rec_argument(binding_domain(t)))
b.push_back(l);
else
u.push_back(l);
t = instantiate(binding_body(t), l);
}
i++;
}
buffer<expr> v;
if (m_dep_elim) {
for (unsigned i = 0; i < u.size(); i++) {
expr u_i = u[i];
expr u_i_ty = m_tc.whnf(mlocal_type(u_i));
buffer<expr> xs;
while (is_pi(u_i_ty)) {
expr x = mk_local_for(u_i_ty);
xs.push_back(x);
u_i_ty = m_tc.whnf(instantiate(binding_body(u_i_ty), x));
}
buffer<expr> it_indices;
unsigned it_idx = get_I_indices(u_i_ty, it_indices);
expr elim_app = mk_constant(get_elim_name(it_idx), ls);
elim_app = mk_app(mk_app(mk_app(mk_app(mk_app(elim_app, m_param_consts), C), e), it_indices), mk_app(u_i, xs));
v.push_back(Fun(xs, elim_app));
}
}
expr e_app = mk_app(mk_app(mk_app(e[minor_idx], b), u), v);
expr comp_rhs = Fun(m_param_consts, Fun(C, Fun(e, Fun(b, Fun(u, e_app)))));
m_tc.check(comp_rhs, get_elim_level_param_names());
// TODO(Leo): store computational rule RHS
minor_idx++;
}
d_idx++;
}
}
environment operator()() {
@ -467,6 +548,7 @@ struct add_inductive_fn {
check_intro_rules();
declare_intro_rules();
declare_elim_rules();
mk_comp_rules_rhs();
return m_env;
}
};

View file

@ -33,8 +33,7 @@ env = add_inductive(env,
"nil", Pi({{A, U_l, true}}, list_l(A)),
"cons", Pi({{A, U_l, true}}, mk_arrow(A, list_l(A), list_l(A))))
env = add_inductive(env,
"vec", {l}, 1,
mk_arrow(U_l, Nat, U_l1),
"vec", {l}, 1, Pi({{A, U_l}, {n, Nat}}, U_l1),
"vnil", Pi({{A, U_l, true}}, vec_l(A, zero)),
"vcons", Pi({{A, U_l, true}, {n, Nat, true}}, mk_arrow(A, vec_l(A, n), vec_l(A, succ(n)))))
@ -87,9 +86,9 @@ env = add_inductive(env, {},
local flist_l = Const("flist", {l})
env = add_inductive(env,
"flist", {l}, 1, mk_arrow(U_l, U_l1),
"flist", {l}, 1, Pi(A, U_l, U_l1),
"fnil", Pi({{A, U_l, true}}, flist_l(A)),
"fcons", Pi({{A, U_l, true}}, mk_arrow(A, mk_arrow(Nat, flist_l(A)), flist_l(A))))
"fcons", Pi({{A, U_l, true}}, mk_arrow(mk_arrow(Nat, A), mk_arrow(Nat, Bool, flist_l(A)), flist_l(A))))
local eq_l = Const("eq", {l})
env = add_inductive(env,
@ -100,3 +99,6 @@ display_type(env, Const("exists_rec", {v, u}))
display_type(env, Const("list_rec", {v, u}))
display_type(env, Const("Even_rec"))
display_type(env, Const("Odd_rec"))
display_type(env, Const("and_rec", {v}))
display_type(env, Const("vec_rec", {v, u}))
display_type(env, Const("flist_rec", {v, u}))