diff --git a/src/kernel/inductive/inductive.cpp b/src/kernel/inductive/inductive.cpp index 8be087814..5bee4298a 100644 --- a/src/kernel/inductive/inductive.cpp +++ b/src/kernel/inductive/inductive.cpp @@ -47,8 +47,7 @@ struct add_inductive_fn { expr m_major_premise; // major premise for each inductive decl buffer m_minor_premises; // minor premise for each introduction rule }; - - buffer m_elim_info; // for each datatype being declared + buffer m_elim_info; // for each datatype being declared add_inductive_fn(environment env, level_param_names const & level_params, @@ -63,13 +62,13 @@ struct add_inductive_fn { /** \brief Return the number of inductive datatypes being defined. */ unsigned get_num_its() const { return m_decls_sz; } - /** \brief Make sure the latest environment is being used by m_tc */ + /** \brief Make sure the latest environment is being used by m_tc. */ void updt_type_checker() { type_checker tc(m_env); m_tc.swap(tc); } - /** \brief Return a fresh name */ + /** \brief Return a fresh name. */ name mk_fresh_name() { return m_ngen.next(); } /** \brief Create a local constant for the given binding. */ @@ -176,7 +175,7 @@ struct add_inductive_fn { std::any_of(m_it_consts.begin(), m_it_consts.end(), [&](expr const & c) { return const_name(e) == const_name(c); }); } - /** \brief Return true if \c t does not contain any occurrence of a datatype being declared */ + /** \brief Return true if \c t does not contain any occurrence of a datatype being declared. */ bool has_it_occ(expr const & t) { return (bool)find(t, [&](expr const & e, unsigned) { return is_it_occ(e); }); // NOLINT } @@ -275,7 +274,7 @@ struct add_inductive_fn { } } - /** \brief Add all introduction rules (aka constructors) to environment */ + /** \brief Add all introduction rules (aka constructors) to environment. */ void declare_intro_rules() { for (auto d : m_decls) { for (auto ir : inductive_decl_intros(d)) @@ -285,7 +284,7 @@ struct add_inductive_fn { } - /** \brief Initialize m_elim_level */ + /** \brief Initialize m_elim_level. */ void mk_elim_level() { if (m_env.impredicative() && is_zero(m_it_levels[0]) && (get_num_its() > 1 || length(inductive_decl_intros(head(m_decls))) != 1)) { // environment is impredicative, datatype maps to Bool/Prop, we have more than one introduction rule. @@ -301,7 +300,7 @@ struct add_inductive_fn { } } - /** \brief Initialize m_dep_elim flag */ + /** \brief Initialize m_dep_elim flag. */ void set_dep_elim() { if (m_env.impredicative() && is_zero(m_it_levels[0])) m_dep_elim = false; @@ -324,7 +323,7 @@ struct add_inductive_fn { return *r; } - /** \brief Populate m_elim_info */ + /** \brief Populate m_elim_info. */ void mk_elim_info() { unsigned d_idx = 0; // First, populate the fields, m_C, m_indices, m_major_premise @@ -368,7 +367,7 @@ struct add_inductive_fn { if (i < m_num_params) { t = instantiate(binding_body(t), m_param_consts[i]); } else { - expr l = mk_local(mk_fresh_name(), binding_name(t), binding_domain(t)); + expr l = mk_local_for(t); if (!is_rec_argument(binding_domain(t))) b.push_back(l); else @@ -390,7 +389,7 @@ struct add_inductive_fn { expr u_i_ty = m_tc.whnf(mlocal_type(u_i)); buffer xs; while (is_pi(u_i_ty)) { - expr x = mk_local(mk_fresh_name(), binding_name(u_i_ty), binding_domain(u_i_ty)); + expr x = mk_local_for(u_i_ty); xs.push_back(x); u_i_ty = m_tc.whnf(instantiate(binding_body(u_i_ty), x)); } @@ -414,7 +413,28 @@ struct add_inductive_fn { } } - /** \brief Declare elimination rule */ + /** \brief Return the name of the eliminator/recursor for \c d. */ + name get_elim_name(inductive_decl const & d) { return inductive_decl_name(d).append_after("_rec"); } + + name get_elim_name(unsigned d_idx) { return get_elim_name(get_ith(m_decls, d_idx)); } + + /** \brief Return the level parameter names for the eliminator. */ + level_param_names get_elim_level_param_names() { + if (is_param(m_elim_level)) + return level_param_names(param_id(m_elim_level), m_level_names); + else + return m_level_names; + } + + /** \brief Return the levels for the eliminator application. */ + levels get_elim_level_params() { + if (is_param(m_elim_level)) + return levels(m_elim_level, m_levels); + else + return m_levels; + } + + /** \brief Declare elimination rule. */ void declare_elim_rule(inductive_decl const & d, unsigned d_idx) { elim_info const & info = m_elim_info[d_idx]; expr C_app = mk_app(info.m_C, info.m_indices); @@ -438,16 +458,10 @@ struct add_inductive_fn { elim_ty = Pi(m_elim_info[i].m_C, elim_ty); } elim_ty = Pi(m_param_consts, elim_ty); - level_param_names ls; - if (is_param(m_elim_level)) - ls = level_param_names(param_id(m_elim_level), m_level_names); - else - ls = m_level_names; - name elim_name = inductive_decl_name(d).append_after("_rec"); - m_env = m_env.add(check(m_env, mk_var_decl(elim_name, ls, elim_ty))); + m_env = m_env.add(check(m_env, mk_var_decl(get_elim_name(d), get_elim_level_param_names(), elim_ty))); } - /** \brief Declare the eliminator/recursor for each datatype */ + /** \brief Declare the eliminator/recursor for each datatype. */ void declare_elim_rules() { set_dep_elim(); mk_elim_level(); @@ -457,6 +471,73 @@ struct add_inductive_fn { declare_elim_rule(d, i); i++; } + updt_type_checker(); + } + + /** \brief Store all set formers in \c Cs */ + void collect_Cs(buffer & Cs) { + for (unsigned i = 0; i < get_num_its(); i++) + Cs.push_back(m_elim_info[i].m_C); + } + + /** \brief Store all minor premises in \c es. */ + void collect_minor_premises(buffer & es) { + for (unsigned i = 0; i < get_num_its(); i++) + es.append(m_elim_info[i].m_minor_premises); + } + + /** \brief Create computional rules RHS. They are used by the normalizer extension. */ + void mk_comp_rules_rhs() { + unsigned d_idx = 0; + unsigned minor_idx = 0; + buffer C; collect_Cs(C); + buffer e; collect_minor_premises(e); + levels ls = get_elim_level_params(); + for (auto d : m_decls) { + for (auto ir : inductive_decl_intros(d)) { + buffer b; + buffer u; + expr t = intro_rule_type(ir); + unsigned i = 0; + while (is_pi(t)) { + if (i < m_num_params) { + t = instantiate(binding_body(t), m_param_consts[i]); + } else { + expr l = mk_local_for(t); + if (!is_rec_argument(binding_domain(t))) + b.push_back(l); + else + u.push_back(l); + t = instantiate(binding_body(t), l); + } + i++; + } + buffer v; + if (m_dep_elim) { + for (unsigned i = 0; i < u.size(); i++) { + expr u_i = u[i]; + expr u_i_ty = m_tc.whnf(mlocal_type(u_i)); + buffer xs; + while (is_pi(u_i_ty)) { + expr x = mk_local_for(u_i_ty); + xs.push_back(x); + u_i_ty = m_tc.whnf(instantiate(binding_body(u_i_ty), x)); + } + buffer it_indices; + unsigned it_idx = get_I_indices(u_i_ty, it_indices); + expr elim_app = mk_constant(get_elim_name(it_idx), ls); + elim_app = mk_app(mk_app(mk_app(mk_app(mk_app(elim_app, m_param_consts), C), e), it_indices), mk_app(u_i, xs)); + v.push_back(Fun(xs, elim_app)); + } + } + expr e_app = mk_app(mk_app(mk_app(e[minor_idx], b), u), v); + expr comp_rhs = Fun(m_param_consts, Fun(C, Fun(e, Fun(b, Fun(u, e_app))))); + m_tc.check(comp_rhs, get_elim_level_param_names()); + // TODO(Leo): store computational rule RHS + minor_idx++; + } + d_idx++; + } } environment operator()() { @@ -467,6 +548,7 @@ struct add_inductive_fn { check_intro_rules(); declare_intro_rules(); declare_elim_rules(); + mk_comp_rules_rhs(); return m_env; } }; diff --git a/tests/lua/ind1.lua b/tests/lua/ind1.lua index 742a439f0..53dea0666 100644 --- a/tests/lua/ind1.lua +++ b/tests/lua/ind1.lua @@ -33,8 +33,7 @@ env = add_inductive(env, "nil", Pi({{A, U_l, true}}, list_l(A)), "cons", Pi({{A, U_l, true}}, mk_arrow(A, list_l(A), list_l(A)))) env = add_inductive(env, - "vec", {l}, 1, - mk_arrow(U_l, Nat, U_l1), + "vec", {l}, 1, Pi({{A, U_l}, {n, Nat}}, U_l1), "vnil", Pi({{A, U_l, true}}, vec_l(A, zero)), "vcons", Pi({{A, U_l, true}, {n, Nat, true}}, mk_arrow(A, vec_l(A, n), vec_l(A, succ(n))))) @@ -87,9 +86,9 @@ env = add_inductive(env, {}, local flist_l = Const("flist", {l}) env = add_inductive(env, - "flist", {l}, 1, mk_arrow(U_l, U_l1), + "flist", {l}, 1, Pi(A, U_l, U_l1), "fnil", Pi({{A, U_l, true}}, flist_l(A)), - "fcons", Pi({{A, U_l, true}}, mk_arrow(A, mk_arrow(Nat, flist_l(A)), flist_l(A)))) + "fcons", Pi({{A, U_l, true}}, mk_arrow(mk_arrow(Nat, A), mk_arrow(Nat, Bool, flist_l(A)), flist_l(A)))) local eq_l = Const("eq", {l}) env = add_inductive(env, @@ -100,3 +99,6 @@ display_type(env, Const("exists_rec", {v, u})) display_type(env, Const("list_rec", {v, u})) display_type(env, Const("Even_rec")) display_type(env, Const("Odd_rec")) +display_type(env, Const("and_rec", {v})) +display_type(env, Const("vec_rec", {v, u})) +display_type(env, Const("flist_rec", {v, u}))