feat(library/blast): add infer_type for blast tactic

This commit is contained in:
Leonardo de Moura 2015-10-02 13:11:17 -07:00
parent ad51339a28
commit aadac02bec
7 changed files with 165 additions and 11 deletions

View file

@ -247,6 +247,9 @@ public:
projection_info const * get_projection_info(name const & n) const {
return m_projection_info.find(n);
}
name mk_fresh_local_name() {
}
};
LEAN_THREAD_PTR(context, g_context);
@ -334,6 +337,11 @@ extension_context & ext_ctx() {
lean_assert(g_ext_context);
return *g_ext_context;
}
name mk_fresh_local_name() {
lean_assert(g_ext_context);
return g_ext_context->mk_fresh_name();
}
}
optional<expr> blast_goal(environment const & env, io_state const & ios, list<name> const & ls, list<name> const & ds,
goal const & g) {

View file

@ -24,6 +24,8 @@ io_state const & ios();
state & curr_state();
/** \brief Return the thead local extension context associated with the blast tactic. */
extension_context & ext_ctx();
/** \brief Return a thread local fresh name meant to be used to name local constants. */
name mk_fresh_local_name();
/** \brief Return true iff the given constant name is marked as reducible in env() */
bool is_reducible(name const & n);
/** \brief Return a nonnull projection_info object if \c n is the name of a projection in env() */

View file

@ -160,9 +160,9 @@ bool has_mref(expr const & e) {
return lean::has_expr_metavar(e);
}
expr mk_local(unsigned idx, expr const & t) {
expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi) {
lean_assert(is_cached(t));
return lean::mk_local(name(*g_prefix, idx), t);
return lean::mk_local(n, pp_n, t, bi);
}
bool is_local(expr const & e) {

View file

@ -42,7 +42,7 @@ expr mk_href(unsigned idx);
expr mk_mref(unsigned idx);
expr mk_sort(level const & l);
expr mk_constant(name const & n, levels const & ls);
expr mk_local(unsigned idx, expr const & t);
expr mk_local(name const & n, name const & pp_n, expr const & t, binder_info const & bi);
expr mk_app(expr const & f, expr const & a);
expr mk_app(expr const & f, unsigned num_args, expr const * args);
expr mk_app(unsigned num_args, expr const * args);

View file

@ -6,8 +6,10 @@ Author: Leonardo de Moura
*/
#include "util/interrupt.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "library/blast/infer_type.h"
#include "library/blast/blast_context.h"
#include "library/blast/blast_exception.h"
namespace lean {
namespace blast {
@ -108,8 +110,144 @@ bool is_def_eq(expr const & e1, expr const & e2) {
return e1 == e2;
}
static expr infer_constant(expr const & e) {
declaration d = env().get(const_name(e));
auto const & ps = d.get_univ_params();
auto const & ls = const_levels(e);
if (length(ps) != length(ls))
throw blast_exception("infer type failed, incorrect number of universe levels", e);
return instantiate_type_univ_params(d, ls);
}
static expr infer_macro(expr const & e) {
auto def = macro_def(e);
bool infer_only = true;
// Remark: we are ignoring constraints generated by the macro definition.
return def.check_type(e, ext_ctx(), infer_only).first;
}
static expr infer_lambda(expr e) {
buffer<expr> es, ds, ls;
while (is_lambda(e)) {
es.push_back(e);
ds.push_back(binding_domain(e));
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e));
ls.push_back(l);
e = binding_body(e);
}
expr t = infer_type(instantiate_rev(e, ls.size(), ls.data()));
expr r = abstract_locals(t, ls.size(), ls.data());
unsigned i = es.size();
while (i > 0) {
--i;
r = blast::mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i]));
}
return r;
}
/** \brief Make sure \c e is a sort, if it is not throw an exception using \c ref as a reference */
static void ensure_sort(expr const & e, expr const & ref) {
// Remark: for simplicity reasons, we just fail if \c e is not a sort.
if (!is_sort(e))
throw blast_exception("infer type failed, sort expected", ref);
}
static expr infer_pi(expr const & e0) {
buffer<expr> ls;
buffer<level> us;
expr e = e0;
while (is_pi(e)) {
expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data());
expr d_type = whnf(infer_type(d));
ensure_sort(d_type, e0);
us.push_back(sort_level(d_type));
expr l = blast::mk_local(mk_fresh_local_name(), binding_name(e), d, binding_info(e));
ls.push_back(l);
e = binding_body(e);
}
e = instantiate_rev(e, ls.size(), ls.data());
expr e_type = whnf(infer_type(e));
ensure_sort(e_type, e0);
level r = sort_level(e_type);
unsigned i = ls.size();
bool imp = env().impredicative();
while (i > 0) {
--i;
r = imp ? blast::mk_imax(us[i], r) : blast::mk_max(us[i], r);
}
return blast::mk_sort(r);
}
/** \brief Make sure \c e is a Pi-expression, if it is not throw an exception using \c ref as a reference */
static void ensure_pi(expr const & e, expr const & ref) {
// Remark: for simplicity reasons, we just fail if \c e is not a Pi.
if (!is_pi(e))
throw blast_exception("infer type failed, Pi expected", ref);
}
static expr infer_app(expr const & e) {
buffer<expr> args;
expr const & f = get_app_args(e, args);
expr f_type = infer_type(f);
unsigned j = 0;
unsigned nargs = args.size();
for (unsigned i = 0; i < nargs; i++) {
if (is_pi(f_type)) {
f_type = binding_body(f_type);
} else {
f_type = whnf(instantiate_rev(f_type, i-j, args.data()+j));
ensure_pi(f_type, e);
f_type = binding_body(f_type);
j = i;
}
}
return instantiate_rev(f_type, nargs-j, args.data()+j);
}
expr infer_type(expr const & e) {
// TODO(Leo)
return e;
lean_assert(!is_var(e));
lean_assert(closed(e));
check_system("infer_type");
expr r;
switch (e.kind()) {
case expr_kind::Local:
if (is_href(e)) {
if (hypothesis const * h = state().get_main_branch().get(e)) {
r = h->get_type();
} else {
throw blast_exception("infer type failed, unknown hypothesis", e);
}
} else {
r = mlocal_type(e);
}
break;
case expr_kind::Meta:
r = mlocal_type(e);
break;
case expr_kind::Var:
lean_unreachable(); // LCOV_EXCL_LINE
case expr_kind::Sort:
r = blast::mk_sort(blast::mk_succ(sort_level(e)));
break;
case expr_kind::Constant:
r = infer_constant(e);
break;
case expr_kind::Macro:
r = infer_macro(e);
break;
case expr_kind::Lambda:
r = infer_lambda(e);
break;
case expr_kind::Pi:
r = infer_pi(e);
break;
case expr_kind::App:
r = infer_app(e);
break;
}
// TODO(Leo): cache results if we have performance problems
return r;
}
}}

View file

@ -11,7 +11,7 @@ Author: Leonardo de Moura
namespace lean {
namespace blast {
state::state():m_next_mref_index(0) {}
state::state():m_next_uvar_index(0), m_next_mref_index(0) {}
/** \brief Mark that hypothesis h with index hidx is fixed by the meta-variable midx.
That is, `h` occurs in the type of `midx`. */

View file

@ -28,12 +28,15 @@ public:
class state {
friend class context;
typedef metavar_idx_map<metavar_decl> metavar_decls;
typedef metavar_idx_map<expr> assignment;
typedef metavar_idx_map<expr> eassignment;
typedef metavar_idx_map<level> uassignment;
typedef hypothesis_idx_map<metavar_idx_set> fixed_by;
unsigned m_next_mref_index;
metavar_decls m_metavar_decls;
assignment m_assignment;
branch m_main;
unsigned m_next_uvar_index; // index of the next universe metavariable
uassignment m_uassignment;
unsigned m_next_mref_index; // index of the next metavariable
metavar_decls m_metavar_decls;
eassignment m_eassignment;
branch m_main;
// In the following mapping, each entry (h -> {m_1 ... m_n}) means that hypothesis `h` cannot be cleared
// in any branch where the metavariables m_1 ... m_n have not been replaced with the values assigned to them.
// That is, to be able to clear `h` in a branch `B`, we first need to check whether it
@ -68,6 +71,9 @@ public:
return m_main.add_hypothesis(n, type, value, jst);
}
branch & get_main_branch() { return m_main; }
branch const & get_main_branch() const { return m_main; }
/** \brief Add a new hypothesis to the main branch */
expr add_hypothesis(expr const & type, optional<expr> const & value, optional<expr> const & jst) {
return m_main.add_hypothesis(type, value, jst);