fix(library/app_builder): many bugs, add use_cache option, add tests
This commit is contained in:
parent
e6f022e3f2
commit
e9e1f86b7f
3 changed files with 169 additions and 90 deletions
|
@ -66,7 +66,7 @@ struct app_builder::imp {
|
|||
// Make sure m_levels contains at least nlvls metavariable universe levels
|
||||
void ensure_levels(unsigned nlvls) {
|
||||
while (m_levels.size() <= nlvls) {
|
||||
level new_lvl = mk_idx_meta_univ(m_levels.size());
|
||||
level new_lvl = mk_idx_meta_univ(m_levels.size() - 1);
|
||||
levels new_lvls = append(m_levels.back(), levels(new_lvl));
|
||||
m_levels.push_back(new_lvls);
|
||||
}
|
||||
|
@ -91,57 +91,60 @@ struct app_builder::imp {
|
|||
}
|
||||
}
|
||||
|
||||
optional<expr> mk_app_core(declaration const & d, unsigned nargs, expr const * args) {
|
||||
optional<expr> mk_app_core(declaration const & d, unsigned nargs, expr const * args, bool use_cache) {
|
||||
unsigned num_univs = d.get_num_univ_params();
|
||||
ensure_levels(num_univs);
|
||||
expr type = instantiate_type_univ_params(d, m_levels[num_univs]);
|
||||
buffer<optional<level>> lsubst;
|
||||
buffer<optional<expr>> esubst;
|
||||
lsubst.resize(num_univs);
|
||||
lsubst.resize(num_univs, none_level());
|
||||
constraint_seq cs;
|
||||
buffer<unsigned> used_idxs;
|
||||
buffer<expr> used_types;
|
||||
buffer<bool> explicit_mask;
|
||||
unsigned idx = 0;
|
||||
unsigned arity = 0;
|
||||
bool has_unassigned_args = false;
|
||||
bool has_unassigned_lvls = num_univs > 0;
|
||||
buffer<expr> domain_types;
|
||||
while (is_pi(type)) {
|
||||
if (idx >= nargs)
|
||||
return none_expr();
|
||||
if (is_explicit(binding_info(type))) {
|
||||
explicit_mask.push_back(true);
|
||||
if (!has_unassigned_args && !has_unassigned_lvls) {
|
||||
esubst.push_back(some_expr(args[idx]));
|
||||
} else {
|
||||
expr arg_type = m_tc.infer(args[idx], cs);
|
||||
if (cs)
|
||||
return none_expr();
|
||||
bool assigned = false;
|
||||
if (!match(binding_domain(type), arg_type, esubst, lsubst,
|
||||
nullptr, nullptr, &m_plugin, &assigned))
|
||||
return none_expr();
|
||||
if (assigned) {
|
||||
used_idxs.push_back(idx);
|
||||
used_types.push_back(arg_type);
|
||||
has_unassigned_lvls = std::find(lsubst.begin(), lsubst.end(), none_level()) != lsubst.end();
|
||||
has_unassigned_args = std::find(esubst.begin(), esubst.end(), none_expr()) != esubst.end();
|
||||
}
|
||||
esubst.push_back(some_expr(args[idx]));
|
||||
}
|
||||
idx++;
|
||||
} else {
|
||||
explicit_mask.push_back(false);
|
||||
esubst.push_back(none_expr());
|
||||
has_unassigned_args = true;
|
||||
}
|
||||
arity++;
|
||||
explicit_mask.push_back(is_explicit(binding_info(type)));
|
||||
esubst.push_back(none_expr());
|
||||
domain_types.push_back(binding_domain(type));
|
||||
type = binding_body(type);
|
||||
}
|
||||
lean_assert(explicit_mask.size() == esubst.size());
|
||||
if (idx != nargs || has_unassigned_args || has_unassigned_lvls)
|
||||
unsigned i = domain_types.size();
|
||||
unsigned j = nargs;
|
||||
while (i > 0) {
|
||||
--i;
|
||||
if (explicit_mask[i]) {
|
||||
if (j == 0)
|
||||
return none_expr();
|
||||
--j;
|
||||
expr arg_type = m_tc.infer(args[j], cs);
|
||||
if (cs)
|
||||
return none_expr();
|
||||
bool assigned = false;
|
||||
if (!match(domain_types[i], arg_type, i, esubst.data(), lsubst.size(), lsubst.data(),
|
||||
nullptr, nullptr, &m_plugin, &assigned))
|
||||
return none_expr();
|
||||
if (assigned && use_cache) {
|
||||
used_idxs.push_back(j);
|
||||
used_types.push_back(arg_type);
|
||||
}
|
||||
esubst[i] = some_expr(args[j]);
|
||||
} else {
|
||||
if (!esubst[i])
|
||||
return none_expr();
|
||||
expr arg_type = m_tc.infer(*esubst[i], cs);
|
||||
if (cs)
|
||||
return none_expr();
|
||||
if (!match(domain_types[i], arg_type, i, esubst.data(), lsubst.size(), lsubst.data(),
|
||||
nullptr, nullptr, &m_plugin))
|
||||
return none_expr();
|
||||
}
|
||||
}
|
||||
bool has_unassigned_lvls = std::find(lsubst.begin(), lsubst.end(), none_level()) != lsubst.end();
|
||||
if (j > 0 || has_unassigned_lvls)
|
||||
return none_expr();
|
||||
save_decl_info(d, nargs, used_idxs);
|
||||
if (use_cache)
|
||||
save_decl_info(d, nargs, used_idxs);
|
||||
buffer<level> r_lvls;
|
||||
for (optional<level> const & l : lsubst)
|
||||
r_lvls.push_back(*l);
|
||||
|
@ -149,10 +152,14 @@ struct app_builder::imp {
|
|||
for (optional<expr> const & o : esubst)
|
||||
r_args.push_back(*o);
|
||||
lean_assert(explicit_mask.size() == r_args.size());
|
||||
cache_key k(d.get_name(), used_types.size(), used_types.data());
|
||||
if (is_simple_mask(explicit_mask)) {
|
||||
expr f = ::lean::mk_app(mk_constant(d.get_name(), to_list(r_lvls)), arity - nargs, r_args.data());
|
||||
m_cache.insert(k, f);
|
||||
if (!use_cache) {
|
||||
return some_expr(::lean::mk_app(mk_constant(d.get_name(), to_list(r_lvls)), r_args.size(), r_args.data()));
|
||||
} else if (is_simple_mask(explicit_mask)) {
|
||||
expr f = ::lean::mk_app(mk_constant(d.get_name(), to_list(r_lvls)), r_args.size() - nargs, r_args.data());
|
||||
if (use_cache) {
|
||||
cache_key k(d.get_name(), used_types.size(), used_types.data());
|
||||
m_cache.insert(k, f);
|
||||
}
|
||||
return some_expr(::lean::mk_app(f, nargs, r_args.end() - nargs));
|
||||
} else {
|
||||
buffer<expr> imp_args;
|
||||
|
@ -166,37 +173,41 @@ struct app_builder::imp {
|
|||
}
|
||||
}
|
||||
expr f = ::lean::mk_app(mk_constant(d.get_name(), to_list(r_lvls)), imp_args.size(), imp_args.data());
|
||||
m_cache.insert(k, f);
|
||||
return some_expr(instantiate_rev(f, expl_args.size(), expl_args.data()));
|
||||
if (use_cache) {
|
||||
cache_key k(d.get_name(), used_types.size(), used_types.data());
|
||||
m_cache.insert(k, f);
|
||||
}
|
||||
return some_expr(instantiate(f, expl_args.size(), expl_args.data()));
|
||||
}
|
||||
}
|
||||
|
||||
optional<expr> mk_app(declaration const & d, unsigned nargs, expr const * args) {
|
||||
if (auto info = m_decl_info.find(d.get_name())) {
|
||||
if (info->m_nargs != nargs)
|
||||
return none_expr();
|
||||
buffer<expr> arg_types;
|
||||
constraint_seq cs;
|
||||
for (unsigned idx : info->m_used_idxs) {
|
||||
lean_assert(idx < nargs);
|
||||
expr t = m_tc.infer(args[idx], cs);
|
||||
if (cs)
|
||||
return none_expr(); // constraint was generated
|
||||
arg_types.push_back(t);
|
||||
optional<expr> mk_app(declaration const & d, unsigned nargs, expr const * args, bool use_cache) {
|
||||
if (use_cache) {
|
||||
if (auto info = m_decl_info.find(d.get_name())) {
|
||||
if (info->m_nargs != nargs)
|
||||
return none_expr();
|
||||
buffer<expr> arg_types;
|
||||
constraint_seq cs;
|
||||
for (unsigned idx : info->m_used_idxs) {
|
||||
lean_assert(idx < nargs);
|
||||
expr t = m_tc.infer(args[idx], cs);
|
||||
if (cs)
|
||||
return none_expr(); // constraint was generated
|
||||
arg_types.push_back(t);
|
||||
}
|
||||
cache_key k(d.get_name(), arg_types.size(), arg_types.data());
|
||||
auto it = m_cache.find(k);
|
||||
if (it != m_cache.end()) {
|
||||
if (closed(it->second))
|
||||
return some_expr(::lean::mk_app(it->second, nargs, args));
|
||||
else
|
||||
return some_expr(instantiate(it->second, nargs, args));
|
||||
} else {
|
||||
return mk_app_core(d, nargs, args, use_cache);
|
||||
}
|
||||
}
|
||||
cache_key k(d.get_name(), arg_types.size(), arg_types.data());
|
||||
auto it = m_cache.find(k);
|
||||
if (it != m_cache.end()) {
|
||||
if (closed(it->second))
|
||||
return some_expr(::lean::mk_app(it->second, nargs, args));
|
||||
else
|
||||
return some_expr(instantiate_rev(it->second, nargs, args));
|
||||
} else {
|
||||
return mk_app_core(d, nargs, args);
|
||||
}
|
||||
} else {
|
||||
return mk_app_core(d, nargs, args);
|
||||
}
|
||||
return mk_app_core(d, nargs, args, use_cache);
|
||||
}
|
||||
|
||||
void push() {
|
||||
|
@ -209,24 +220,24 @@ struct app_builder::imp {
|
|||
};
|
||||
|
||||
app_builder::app_builder(type_checker & tc):m_ptr(new imp(tc)) {}
|
||||
optional<expr> app_builder::mk_app(declaration const & d, unsigned nargs, expr const * args) {
|
||||
return m_ptr->mk_app(d, nargs, args);
|
||||
optional<expr> app_builder::mk_app(declaration const & d, unsigned nargs, expr const * args, bool use_cache) {
|
||||
return m_ptr->mk_app(d, nargs, args, use_cache);
|
||||
}
|
||||
optional<expr> app_builder::mk_app(name const & n, unsigned nargs, expr const * args) {
|
||||
optional<expr> app_builder::mk_app(name const & n, unsigned nargs, expr const * args, bool use_cache) {
|
||||
declaration const & d = m_ptr->m_tc.env().get(n);
|
||||
return mk_app(d, nargs, args);
|
||||
return mk_app(d, nargs, args, use_cache);
|
||||
}
|
||||
optional<expr> app_builder::mk_app(name const & n, std::initializer_list<expr> const & args) {
|
||||
return mk_app(n, args.size(), args.begin());
|
||||
optional<expr> app_builder::mk_app(name const & n, std::initializer_list<expr> const & args, bool use_cache) {
|
||||
return mk_app(n, args.size(), args.begin(), use_cache);
|
||||
}
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1) {
|
||||
return mk_app(n, {a1});
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1, bool use_cache) {
|
||||
return mk_app(n, {a1}, use_cache);
|
||||
}
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1, expr const & a2) {
|
||||
return mk_app(n, {a1, a2});
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1, expr const & a2, bool use_cache) {
|
||||
return mk_app(n, {a1, a2}, use_cache);
|
||||
}
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1, expr const & a2, expr const & a3) {
|
||||
return mk_app(n, {a1, a2, a3});
|
||||
optional<expr> app_builder::mk_app(name const & n, expr const & a1, expr const & a2, expr const & a3, bool use_cache) {
|
||||
return mk_app(n, {a1, a2, a3}, use_cache);
|
||||
}
|
||||
void app_builder::push() { m_ptr->push(); }
|
||||
void app_builder::pop() { m_ptr->pop(); }
|
||||
|
@ -249,10 +260,15 @@ static int app_builder_mk_app(lua_State * L) {
|
|||
int nargs = lua_gettop(L);
|
||||
buffer<expr> args;
|
||||
app_builder & b = to_app_builder_ref(L, 1)->m_builder;
|
||||
bool use_cache = true;
|
||||
name n = to_name_ext(L, 2);
|
||||
for (int i = 3; i <= nargs; i++)
|
||||
args.push_back(to_expr(L, i));
|
||||
return push_optional_expr(L, b.mk_app(n, args.size(), args.data()));
|
||||
for (int i = 3; i <= nargs; i++) {
|
||||
if (i < nargs || is_expr(L, i))
|
||||
args.push_back(to_expr(L, i));
|
||||
else
|
||||
use_cache = lua_toboolean(L, i);
|
||||
}
|
||||
return push_optional_expr(L, b.mk_app(n, args.size(), args.data(), use_cache));
|
||||
}
|
||||
|
||||
static int app_builder_push(lua_State * L) {
|
||||
|
|
|
@ -37,12 +37,12 @@ public:
|
|||
|
||||
\remark This methods uses just higher-order pattern matching.
|
||||
*/
|
||||
optional<expr> mk_app(declaration const & d, unsigned nargs, expr const * args);
|
||||
optional<expr> mk_app(name const & n, unsigned nargs, expr const * args);
|
||||
optional<expr> mk_app(name const & n, std::initializer_list<expr> const & args);
|
||||
optional<expr> mk_app(name const & n, expr const & a1);
|
||||
optional<expr> mk_app(name const & n, expr const & a1, expr const & a2);
|
||||
optional<expr> mk_app(name const & n, expr const & a1, expr const & a2, expr const & a3);
|
||||
optional<expr> mk_app(declaration const & d, unsigned nargs, expr const * args, bool use_cache = true);
|
||||
optional<expr> mk_app(name const & n, unsigned nargs, expr const * args, bool use_cache = true);
|
||||
optional<expr> mk_app(name const & n, std::initializer_list<expr> const & args, bool use_cache = true);
|
||||
optional<expr> mk_app(name const & n, expr const & a1, bool use_cache = true);
|
||||
optional<expr> mk_app(name const & n, expr const & a1, expr const & a2, bool use_cache = true);
|
||||
optional<expr> mk_app(name const & n, expr const & a1, expr const & a2, expr const & a3, bool use_cache = true);
|
||||
/** \brief Create a backtracking point for cached information.
|
||||
\remark This method does not invoke tc->push()
|
||||
*/
|
||||
|
|
63
tests/lean/run/app_builder.lean
Normal file
63
tests/lean/run/app_builder.lean
Normal file
|
@ -0,0 +1,63 @@
|
|||
definition a := 10
|
||||
|
||||
constant b : num
|
||||
constant c : num
|
||||
constant H1 : a = b
|
||||
constant H2 : b = c
|
||||
constant d : nat
|
||||
constant f : nat → nat
|
||||
constant g : nat → nat
|
||||
|
||||
set_option pp.implicit true
|
||||
set_option pp.universes true
|
||||
set_option pp.notation false
|
||||
(*
|
||||
local env = get_env()
|
||||
local tc = non_irreducible_type_checker()
|
||||
local b = app_builder(tc)
|
||||
local a = Const("a")
|
||||
local c = Const("c")
|
||||
local d = Const("d")
|
||||
local f = Const("f")
|
||||
local g = Const("g")
|
||||
function tst(n, ...)
|
||||
local args = {...}
|
||||
local r = b:mk_app(n, unpack(args))
|
||||
print(tostring(r) .. " : " .. tostring(tc:check(r)))
|
||||
return r
|
||||
end
|
||||
tst("eq", a, c)
|
||||
tst("eq", a, c)
|
||||
tst("eq", c, a)
|
||||
tst("eq", a, a)
|
||||
tst("eq", d, d)
|
||||
tst({"eq", "refl"}, a)
|
||||
tst({"eq", "refl"}, a)
|
||||
tst({"eq", "refl"}, d)
|
||||
tst({"eq", "refl"}, d)
|
||||
tst({"eq", "refl"}, c)
|
||||
tst({"eq", "refl"}, c, false)
|
||||
tst({"eq", "refl"}, a)
|
||||
local H1 = Const("H1")
|
||||
local H2 = Const("H2")
|
||||
tst({"eq", "trans"}, H1, H2)
|
||||
H1sy = tst({"eq", "symm"}, H1)
|
||||
H2sy = tst({"eq", "symm"}, H2)
|
||||
tst({"eq", "trans"}, H2sy, H1sy)
|
||||
tst({"heq", "refl"}, a)
|
||||
H1h = tst({"heq", "of_eq"}, H1)
|
||||
H2h = tst({"heq", "of_eq"}, H2)
|
||||
tst({"heq", "trans"}, H1h, H2h)
|
||||
tst({"heq", "symm"}, H1h)
|
||||
tst({"heq", "symm"}, H1h)
|
||||
tst({"heq"}, a, c)
|
||||
tst({"heq"}, a, d)
|
||||
tst({"heq"}, d, a)
|
||||
tst({"heq"}, a, c)
|
||||
tst({"heq"}, a, d)
|
||||
tst({"heq"}, d, a)
|
||||
tst({"eq", "refl"}, f)
|
||||
tst({"eq", "refl"}, g)
|
||||
tst("eq", f, g)
|
||||
tst("eq", g, f)
|
||||
*)
|
Loading…
Reference in a new issue