feat(library/blast): add infer_type for blast tactic
This commit is contained in:
parent
ad51339a28
commit
aadac02bec
7 changed files with 165 additions and 11 deletions
|
@ -247,6 +247,9 @@ public:
|
|||
projection_info const * get_projection_info(name const & n) const {
|
||||
return m_projection_info.find(n);
|
||||
}
|
||||
|
||||
name mk_fresh_local_name() {
|
||||
}
|
||||
};
|
||||
|
||||
LEAN_THREAD_PTR(context, g_context);
|
||||
|
@ -334,6 +337,11 @@ extension_context & ext_ctx() {
|
|||
lean_assert(g_ext_context);
|
||||
return *g_ext_context;
|
||||
}
|
||||
|
||||
name mk_fresh_local_name() {
|
||||
lean_assert(g_ext_context);
|
||||
return g_ext_context->mk_fresh_name();
|
||||
}
|
||||
}
|
||||
optional<expr> blast_goal(environment const & env, io_state const & ios, list<name> const & ls, list<name> const & ds,
|
||||
goal const & g) {
|
||||
|
|
|
@ -24,6 +24,8 @@ io_state const & ios();
|
|||
state & curr_state();
|
||||
/** \brief Return the thead local extension context associated with the blast tactic. */
|
||||
extension_context & ext_ctx();
|
||||
/** \brief Return a thread local fresh name meant to be used to name local constants. */
|
||||
name mk_fresh_local_name();
|
||||
/** \brief Return true iff the given constant name is marked as reducible in env() */
|
||||
bool is_reducible(name const & n);
|
||||
/** \brief Return a nonnull projection_info object if \c n is the name of a projection in env() */
|
||||
|
|
|
@ -160,9 +160,9 @@ bool has_mref(expr const & e) {
|
|||
return lean::has_expr_metavar(e);
|
||||
}
|
||||
|
||||
expr mk_local(unsigned idx, expr const & t) {
|
||||
expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi) {
|
||||
lean_assert(is_cached(t));
|
||||
return lean::mk_local(name(*g_prefix, idx), t);
|
||||
return lean::mk_local(n, pp_n, t, bi);
|
||||
}
|
||||
|
||||
bool is_local(expr const & e) {
|
||||
|
|
|
@ -42,7 +42,7 @@ expr mk_href(unsigned idx);
|
|||
expr mk_mref(unsigned idx);
|
||||
expr mk_sort(level const & l);
|
||||
expr mk_constant(name const & n, levels const & ls);
|
||||
expr mk_local(unsigned idx, expr const & t);
|
||||
expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi);
|
||||
expr mk_app(expr const & f, expr const & a);
|
||||
expr mk_app(expr const & f, unsigned num_args, expr const * args);
|
||||
expr mk_app(unsigned num_args, expr const * args);
|
||||
|
|
|
@ -6,8 +6,10 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include "util/interrupt.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "library/blast/infer_type.h"
|
||||
#include "library/blast/blast_context.h"
|
||||
#include "library/blast/blast_exception.h"
|
||||
|
||||
namespace lean {
|
||||
namespace blast {
|
||||
|
@ -108,8 +110,144 @@ bool is_def_eq(expr const & e1, expr const & e2) {
|
|||
return e1 == e2;
|
||||
}
|
||||
|
||||
static expr infer_constant(expr const & e) {
|
||||
declaration d = env().get(const_name(e));
|
||||
auto const & ps = d.get_univ_params();
|
||||
auto const & ls = const_levels(e);
|
||||
if (length(ps) != length(ls))
|
||||
throw blast_exception("infer type failed, incorrect number of universe levels", e);
|
||||
return instantiate_type_univ_params(d, ls);
|
||||
}
|
||||
|
||||
static expr infer_macro(expr const & e) {
|
||||
auto def = macro_def(e);
|
||||
bool infer_only = true;
|
||||
// Remark: we are ignoring constraints generated by the macro definition.
|
||||
return def.check_type(e, ext_ctx(), infer_only).first;
|
||||
}
|
||||
|
||||
static expr infer_lambda(expr e) {
|
||||
buffer<expr> es, ds, ls;
|
||||
while (is_lambda(e)) {
|
||||
es.push_back(e);
|
||||
ds.push_back(binding_domain(e));
|
||||
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
|
||||
expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e));
|
||||
ls.push_back(l);
|
||||
e = binding_body(e);
|
||||
}
|
||||
expr t = infer_type(instantiate_rev(e, ls.size(), ls.data()));
|
||||
expr r = abstract_locals(t, ls.size(), ls.data());
|
||||
unsigned i = es.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = blast::mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i]));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
/** \brief Make sure \c e is a sort, if it is not throw an exception using \c ref as a reference */
|
||||
static void ensure_sort(expr const & e, expr const & ref) {
|
||||
// Remark: for simplicity reasons, we just fail if \c e is not a sort.
|
||||
if (!is_sort(e))
|
||||
throw blast_exception("infer type failed, sort expected", ref);
|
||||
}
|
||||
|
||||
static expr infer_pi(expr const & e0) {
|
||||
buffer<expr> ls;
|
||||
buffer<level> us;
|
||||
expr e = e0;
|
||||
while (is_pi(e)) {
|
||||
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
|
||||
expr d_type = whnf(infer_type(d));
|
||||
ensure_sort(d_type, e0);
|
||||
us.push_back(sort_level(d_type));
|
||||
expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e));
|
||||
ls.push_back(l);
|
||||
e = binding_body(e);
|
||||
}
|
||||
e = instantiate_rev(e, ls.size(), ls.data());
|
||||
expr e_type = whnf(infer_type(e));
|
||||
ensure_sort(e_type, e0);
|
||||
level r = sort_level(e_type);
|
||||
unsigned i = ls.size();
|
||||
bool imp = env().impredicative();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
r = imp ? blast::mk_imax(us[i], r) : blast::mk_max(us[i], r);
|
||||
}
|
||||
return blast::mk_sort(r);
|
||||
}
|
||||
|
||||
/** \brief Make sure \c e is a Pi-expression, if it is not throw an exception using \c ref as a reference */
|
||||
static void ensure_pi(expr const & e, expr const & ref) {
|
||||
// Remark: for simplicity reasons, we just fail if \c e is not a Pi.
|
||||
if (!is_pi(e))
|
||||
throw blast_exception("infer type failed, Pi expected", ref);
|
||||
}
|
||||
|
||||
static expr infer_app(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr const & f = get_app_args(e, args);
|
||||
expr f_type = infer_type(f);
|
||||
unsigned j = 0;
|
||||
unsigned nargs = args.size();
|
||||
for (unsigned i = 0; i < nargs; i++) {
|
||||
if (is_pi(f_type)) {
|
||||
f_type = binding_body(f_type);
|
||||
} else {
|
||||
f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j));
|
||||
ensure_pi(f_type, e);
|
||||
f_type = binding_body(f_type);
|
||||
j = i;
|
||||
}
|
||||
}
|
||||
return instantiate_rev(f_type, nargs-j, args.data()+j);
|
||||
}
|
||||
|
||||
expr infer_type(expr const & e) {
|
||||
// TODO(Leo)
|
||||
return e;
|
||||
lean_assert(!is_var(e));
|
||||
lean_assert(closed(e));
|
||||
check_system("infer_type");
|
||||
|
||||
expr r;
|
||||
switch (e.kind()) {
|
||||
case expr_kind::Local:
|
||||
if (is_href(e)) {
|
||||
if (hypothesis const * h = state().get_main_branch().get(e)) {
|
||||
r = h->get_type();
|
||||
} else {
|
||||
throw blast_exception("infer type failed, unknown hypothesis", e);
|
||||
}
|
||||
} else {
|
||||
r = mlocal_type(e);
|
||||
}
|
||||
break;
|
||||
case expr_kind::Meta:
|
||||
r = mlocal_type(e);
|
||||
break;
|
||||
case expr_kind::Var:
|
||||
lean_unreachable(); // LCOV_EXCL_LINE
|
||||
case expr_kind::Sort:
|
||||
r = blast::mk_sort(blast::mk_succ(sort_level(e)));
|
||||
break;
|
||||
case expr_kind::Constant:
|
||||
r = infer_constant(e);
|
||||
break;
|
||||
case expr_kind::Macro:
|
||||
r = infer_macro(e);
|
||||
break;
|
||||
case expr_kind::Lambda:
|
||||
r = infer_lambda(e);
|
||||
break;
|
||||
case expr_kind::Pi:
|
||||
r = infer_pi(e);
|
||||
break;
|
||||
case expr_kind::App:
|
||||
r = infer_app(e);
|
||||
break;
|
||||
}
|
||||
// TODO(Leo): cache results if we have performance problems
|
||||
return r;
|
||||
}
|
||||
}}
|
||||
|
|
|
@ -11,7 +11,7 @@ Author: Leonardo de Moura
|
|||
|
||||
namespace lean {
|
||||
namespace blast {
|
||||
state::state():m_next_mref_index(0) {}
|
||||
state::state():m_next_uvar_index(0), m_next_mref_index(0) {}
|
||||
|
||||
/** \brief Mark that hypothesis h with index hidx is fixed by the meta-variable midx.
|
||||
That is, `h` occurs in the type of `midx`. */
|
||||
|
|
|
@ -28,12 +28,15 @@ public:
|
|||
class state {
|
||||
friend class context;
|
||||
typedef metavar_idx_map<metavar_decl> metavar_decls;
|
||||
typedef metavar_idx_map<expr> assignment;
|
||||
typedef metavar_idx_map<expr> eassignment;
|
||||
typedef metavar_idx_map<level> uassignment;
|
||||
typedef hypothesis_idx_map<metavar_idx_set> fixed_by;
|
||||
unsigned m_next_mref_index;
|
||||
metavar_decls m_metavar_decls;
|
||||
assignment m_assignment;
|
||||
branch m_main;
|
||||
unsigned m_next_uvar_index; // index of the next universe metavariable
|
||||
uassignment m_uassignment;
|
||||
unsigned m_next_mref_index; // index of the next metavariable
|
||||
metavar_decls m_metavar_decls;
|
||||
eassignment m_eassignment;
|
||||
branch m_main;
|
||||
// In the following mapping, each entry (h -> {m_1 ... m_n}) means that hypothesis `h` cannot be cleared
|
||||
// in any branch where the metavariables m_1 ... m_n have not been replaced with the values assigned to them.
|
||||
// That is, to be able to clear `h` in a branch `B`, we first need to check whether it
|
||||
|
@ -68,6 +71,9 @@ public:
|
|||
return m_main.add_hypothesis(n, type, value, jst);
|
||||
}
|
||||
|
||||
branch & get_main_branch() { return m_main; }
|
||||
branch const & get_main_branch() const { return m_main; }
|
||||
|
||||
/** \brief Add a new hypothesis to the main branch */
|
||||
expr add_hypothesis(expr const & type, optional<expr> const & value, optional<expr> const & jst) {
|
||||
return m_main.add_hypothesis(type, value, jst);
|
||||
|
|
Loading…
Reference in a new issue