feat(library/kernel_bindings): improve Pi and Fun Lua APIs, and allow users to provide binder information

Signed-off-by: Leonardo de Moura <leonardo@microsoft.com>
This commit is contained in:
Leonardo de Moura 2014-05-16 14:09:00 -07:00
parent 862c5e354d
commit 193aa4a83f
3 changed files with 72 additions and 22 deletions

View file

@ -41,19 +41,26 @@ inline expr abstract_p(expr const & e, expr const & s) { return abstract_p(e, 1,
/**
\brief Create a lambda expression (lambda (x : t) b), the term b is abstracted using abstract(b, constant(x)).
*/
inline expr Fun(name const & n, expr const & t, expr const & b) { return mk_lambda(n, t, abstract(b, mk_constant(n))); }
inline expr Fun(expr const & n, expr const & t, expr const & b) { return mk_lambda(const_name(n), t, abstract(b, n)); }
inline expr Fun(name const & n, expr const & t, expr const & b, binder_info const & bi = binder_info()) {
return mk_lambda(n, t, abstract(b, mk_constant(n)), bi);
}
inline expr Fun(expr const & n, expr const & t, expr const & b, binder_info const & bi = binder_info()) {
return mk_lambda(const_name(n), t, abstract(b, n), bi);
}
inline expr Fun(std::pair<expr const &, expr const &> const & p, expr const & b) { return Fun(p.first, p.second, b); }
expr Fun(std::initializer_list<std::pair<expr const &, expr const &>> const & l, expr const & b);
expr Fun(std::initializer_list<std::pair<expr const &, expr const &>> const & l, expr const & b);
/**
\brief Create a Pi expression (pi (x : t) b), the term b is abstracted using abstract(b, constant(x)).
*/
inline expr Pi(name const & n, expr const & t, expr const & b) { return mk_pi(n, t, abstract(b, mk_constant(n))); }
inline expr Pi(expr const & n, expr const & t, expr const & b) { return mk_pi(const_name(n), t, abstract(b, n)); }
inline expr Pi(name const & n, expr const & t, expr const & b, binder_info const & bi = binder_info()) {
return mk_pi(n, t, abstract(b, mk_constant(n)), bi);
}
inline expr Pi(expr const & n, expr const & t, expr const & b, binder_info const & bi = binder_info()) {
return mk_pi(const_name(n), t, abstract(b, n), bi);
}
inline expr Pi(std::pair<expr const &, expr const &> const & p, expr const & b) { return Pi(p.first, p.second, b); }
expr Pi(std::initializer_list<std::pair<expr const &, expr const &>> const & l, expr const & b);
expr Pi(std::initializer_list<std::pair<expr const &, expr const &>> const & l, expr const & b);
/**
- \brief Create a Let expression (Let x := v in b), the term b is abstracted using abstract(b, x).

View file

@ -300,24 +300,42 @@ static expr get_expr_from_table(lua_State * L, int t, int i) {
return r;
}
// t is a table of pairs {{a1, b1}, ..., {ak, bk}}
static void throw_invalid_binder_table(int t) {
throw exception(sstream() << "arg #" << t << " must be a table {e_1, ..., e_k} where each entry e_i is of the form: {expr, expr}, {expr, expr, bool}, or {expr, expr, binder_info}, each entry represents a binder, the first expression in each entry must be a (local) constant, the second expression is the type, the optional Boolean can be used to mark implicit arguments.");
}
// t is a table of tuples {{a1, b1}, ..., {ak, bk}}
// Each ai and bi is an expression
static std::pair<expr, expr> get_expr_pair_from_table(lua_State * L, int t, int i) {
// Each tuple represents a binder
static std::tuple<expr, expr, binder_info> get_binder_from_table(lua_State * L, int t, int i) {
lua_pushvalue(L, t); // push table on the top
lua_pushinteger(L, i);
lua_gettable(L, -2); // now table {ai, bi} is on the top
if (!lua_istable(L, -1) || objlen(L, -1) != 2)
throw exception(sstream() << "arg #" << t << " must be of the form '{{expr, expr}, ...}'");
int tuple_sz = objlen(L, -1);
if (!lua_istable(L, -1) || (tuple_sz != 2 && tuple_sz != 3))
throw_invalid_binder_table(t);
expr ai = get_expr_from_table(L, -1, 1);
if (!is_constant(ai) && !is_local(ai))
throw_invalid_binder_table(t);
expr bi = get_expr_from_table(L, -1, 2);
lua_pop(L, 2); // pop table {ai, bi} and t from stack
return mk_pair(ai, bi);
binder_info ii;
if (tuple_sz == 3) {
lua_pushinteger(L, 3);
lua_gettable(L, -2);
if (lua_isboolean(L, -1))
ii = binder_info(lua_toboolean(L, -1));
else
ii = to_binder_info(L, -1);
lua_pop(L, 1);
}
lua_pop(L, 2); // pop tuple and t from stack
return std::make_tuple(ai, bi, ii);
}
typedef expr (*MkAbst1)(expr const & n, expr const & t, expr const & b);
typedef expr (*MkAbst2)(name const & n, expr const & t, expr const & b);
template<MkAbst1 F1, MkAbst2 F2>
template<bool pi>
static int expr_abst(lua_State * L) {
int nargs = lua_gettop(L);
if (nargs < 2)
@ -330,8 +348,11 @@ static int expr_abst(lua_State * L) {
throw exception("function expects arg #1 to be a non-empty table");
expr r = to_expr(L, 2);
for (int i = len; i >= 1; i--) {
auto p = get_expr_pair_from_table(L, 1, i);
r = F1(p.first, p.second, r);
auto p = get_binder_from_table(L, 1, i);
if (pi)
r = Pi(std::get<0>(p), std::get<1>(p), r, std::get<2>(p));
else
r = Fun(std::get<0>(p), std::get<1>(p), r, std::get<2>(p));
}
return push_expr(L, r);
} else {
@ -339,17 +360,24 @@ static int expr_abst(lua_State * L) {
throw exception("function must have an odd number of arguments");
expr r = to_expr(L, nargs);
for (int i = nargs - 1; i >= 1; i-=2) {
if (is_expr(L, i - 1))
r = F1(to_expr(L, i - 1), to_expr(L, i), r);
else
r = F2(to_name_ext(L, i - 1), to_expr(L, i), r);
if (is_expr(L, i - 1)) {
if (pi)
r = Pi(to_expr(L, i - 1), to_expr(L, i), r);
else
r = Fun(to_expr(L, i - 1), to_expr(L, i), r);
} else {
if (pi)
r = Pi(to_name_ext(L, i - 1), to_expr(L, i), r);
else
r = Fun(to_name_ext(L, i - 1), to_expr(L, i), r);
}
}
return push_expr(L, r);
}
}
static int expr_fun(lua_State * L) { return expr_abst<Fun, Fun>(L); }
static int expr_pi(lua_State * L) { return expr_abst<Pi, Pi>(L); }
static int expr_fun(lua_State * L) { return expr_abst<false>(L); }
static int expr_pi(lua_State * L) { return expr_abst<true>(L); }
static int expr_mk_sort(lua_State * L) { return push_expr(L, mk_sort(to_level(L, 1))); }
static int expr_mk_metavar(lua_State * L) { return push_expr(L, mk_metavar(to_name_ext(L, 1), to_expr(L, 2))); }
static int expr_mk_local(lua_State * L) {

15
tests/lua/expr4.lua Normal file
View file

@ -0,0 +1,15 @@
local a = Const("a")
local b = Const("b")
local f = Const("f")
local vec = Const("vec")
print(Pi({{a, Type}, {b, vec(a), true}}, vec(b)))
print(Pi({{a, Type, binder_info(true, true)}, {b, vec(a), true}}, vec(b)))
assert(not pcall(function()
print(Pi({{a, Type}, {f(b), vec(a), true}}, vec(b)))
end
))
assert(not pcall(function()
print(Pi({{a, Type, a}, {b, vec(a), true}}, vec(b)))
end
))