From 1b5366cfb72daa88dee9d9bff209526e03423e09 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 30 May 2014 21:05:47 -0700 Subject: [PATCH] feat(library): add module for implementing aliases and 'using' command Signed-off-by: Leonardo de Moura --- src/kernel/abstract.h | 4 +- src/library/CMakeLists.txt | 2 +- src/library/aliases.cpp | 186 ++++++++++++++++++++++++++++++++++ src/library/aliases.h | 42 ++++++++ src/library/register_module.h | 2 + tests/lua/alias1.lua | 61 +++++++++++ 6 files changed, 295 insertions(+), 2 deletions(-) create mode 100644 src/library/aliases.cpp create mode 100644 src/library/aliases.h create mode 100644 tests/lua/alias1.lua diff --git a/src/kernel/abstract.h b/src/kernel/abstract.h index d4887bbd4..1836d9ee3 100644 --- a/src/kernel/abstract.h +++ b/src/kernel/abstract.h @@ -46,7 +46,9 @@ expr Fun(std::initializer_list> const & l, /** \brief Create a lambda-expression by abstracting the given local constants over b */ expr Fun(unsigned num, expr const * locals, expr const & b); template expr Fun(T const & locals, expr const & b) { return Fun(locals.size(), locals.data(), b); } -inline expr Fun(expr const & local, expr const & b) { return Fun(1, &local, b); } +inline expr Fun(expr const & local, expr const & b, binder_info const & bi = binder_info()) { + return Fun(local_pp_name(local), mlocal_type(local), abstract(b, local), bi); +} /** \brief Create a Pi expression (pi (x : t) b), the term b is abstracted using abstract(b, constant(x)). diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt index f048504da..04b69d3a0 100644 --- a/src/library/CMakeLists.txt +++ b/src/library/CMakeLists.txt @@ -2,7 +2,7 @@ add_library(library deep_copy.cpp expr_lt.cpp io_state.cpp occurs.cpp kernel_bindings.cpp io_state_stream.cpp bin_app.cpp resolve_macro.cpp kernel_serializer.cpp max_sharing.cpp normalize.cpp shared_environment.cpp module.cpp coercion.cpp - private.cpp placeholder.cpp) + private.cpp placeholder.cpp aliases.cpp) # fo_unify.cpp hop_match.cpp) target_link_libraries(library ${LEAN_LIBS}) diff --git a/src/library/aliases.cpp b/src/library/aliases.cpp new file mode 100644 index 000000000..e4ad75c75 --- /dev/null +++ b/src/library/aliases.cpp @@ -0,0 +1,186 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include +#include "util/rb_map.h" +#include "util/name_generator.h" +#include "kernel/abstract.h" +#include "kernel/instantiate.h" +#include "library/expr_lt.h" +#include "library/kernel_bindings.h" +#include "library/aliases.h" +#include "library/io_state_stream.h" +#include "library/placeholder.h" + +namespace lean { +struct aliases_ext : public environment_extension { + rb_map m_aliases; + rb_map m_inv_aliases; +}; + +struct aliases_ext_reg { + unsigned m_ext_id; + aliases_ext_reg() { m_ext_id = environment::register_extension(std::make_shared()); } +}; +static aliases_ext_reg g_ext; +static aliases_ext const & get_extension(environment const & env) { + return static_cast(env.get_extension(g_ext.m_ext_id)); +} +static environment update(environment const & env, aliases_ext const & ext) { + return env.update(g_ext.m_ext_id, std::make_shared(ext)); +} + +static void check_name(environment const & env, name const & a, io_state const & ios) { + if (get_extension(env).m_aliases.contains(a)) + diagnostic(env, ios) << "alias '" << a << "' shadows existing alias\n"; + if (env.find(a)) + diagnostic(env, ios) << "alias '" << a << "' shadows existing declaration\n"; +} + +environment add_alias(environment const & env, name const & a, expr const & e, io_state const & ios) { + check_name(env, a, ios); + aliases_ext ext = get_extension(env); + ext.m_aliases.insert(a, e); + ext.m_inv_aliases.insert(e, a); + return update(env, ext); +} + +environment add_aliases(environment const & env, name const & prefix, optional const & new_prefix, io_state const & ios) { + return add_aliases(env, prefix, new_prefix, 0, nullptr, ios); +} + +static name replace_prefix(name const & n, name const & prefix, optional const & new_prefix) { + if (n == prefix) + return new_prefix ? *new_prefix : name(); + name p = replace_prefix(n.get_prefix(), prefix, new_prefix); + if (n.is_string()) + return name(p, n.get_string()); + else + return name(p, n.get_numeral()); +} + +static optional get_fix_param(unsigned num_fix_params, std::pair const * fix_params, name const & n) { + for (unsigned i = 0; i < num_fix_params; i++) { + if (fix_params[i].first == n) + return some_expr(fix_params[i].second); + } + return none_expr(); +} + +static name g_local_name = name::mk_internal_unique_name(); + +environment add_aliases(environment const & env, name const & prefix, optional const & new_prefix, + unsigned num_fix_params, std::pair const * fix_params, io_state const & ios) { + aliases_ext ext = get_extension(env); + env.for_each([&](declaration const & d) { + if (is_prefix_of(prefix, d.get_name())) { + name a = replace_prefix(d.get_name(), prefix, new_prefix); + check_name(env, a, ios); + levels ls = map2(d.get_params(), [](name const &) { return mk_level_placeholder(); }); + expr c = mk_constant(d.get_name(), ls); + if (num_fix_params > 0) { + expr t = d.get_type(); + buffer locals; + buffer infos; + buffer args; + name_generator ngen(g_local_name); + bool found_free = false; + bool found_fix = false; + bool easy = true; + while (is_pi(t)) { + if (auto p = get_fix_param(num_fix_params, fix_params, binding_name(t))) { + args.push_back(*p); + if (found_free) + easy = false; + found_fix = true; + t = instantiate(binding_body(t), *p); + } else { + found_free = true; + expr l = mk_local(ngen.next(), binding_name(t), binding_domain(t)); + infos.push_back(binding_info(t)); + locals.push_back(l); + args.push_back(l); + t = instantiate(binding_body(t), l); + } + } + if (found_fix) { + if (easy) { + args.shrink(args.size() - locals.size()); + c = mk_app(c, args); + } else { + c = mk_app(c, args); + unsigned i = locals.size(); + while (i > 0) { + --i; + c = Fun(locals[i], c, infos[i]); + } + } + } + } + ext.m_aliases.insert(a, c); + ext.m_inv_aliases.insert(c, a); + } + }); + return update(env, ext); +} + +optional is_aliased(environment const & env, expr const & t) { + auto it = get_extension(env).m_inv_aliases.find(t); + return it ? optional(*it) : optional(); +} + +optional get_alias(environment const & env, name const & n) { + auto it = get_extension(env).m_aliases.find(n); + return it ? optional(*it) : optional(); +} + +static int add_alias(lua_State * L) { + return push_environment(L, add_alias(to_environment(L, 1), to_name_ext(L, 2), to_expr(L, 3), to_io_state_ext(L, 4))); +} + +static int add_aliases(lua_State * L) { + int nargs = lua_gettop(L); + if (nargs == 2) { + return push_environment(L, add_aliases(to_environment(L, 1), to_name_ext(L, 2), optional(), get_io_state(L))); + } else if (nargs == 3) { + return push_environment(L, add_aliases(to_environment(L, 1), to_name_ext(L, 2), to_optional_name(L, 3), get_io_state(L))); + } else if (nargs == 4 && is_io_state(L, 4)) { + return push_environment(L, add_aliases(to_environment(L, 1), to_name_ext(L, 2), to_optional_name(L, 3), to_io_state(L, 4))); + } else { + buffer> fix_params; + luaL_checktype(L, 4, LUA_TTABLE); + int sz = objlen(L, 4); + for (int i = 1; i <= sz; i++) { + lua_rawgeti(L, 4, i); + luaL_checktype(L, -1, LUA_TTABLE); + lua_rawgeti(L, -1, 1); + name n = to_name_ext(L, -1); + lua_pop(L, 1); + lua_rawgeti(L, -1, 2); + expr e = to_expr(L, -1); + lua_pop(L, 2); + fix_params.emplace_back(n, e); + } + return push_environment(L, add_aliases(to_environment(L, 1), to_name_ext(L, 2), to_optional_name(L, 3), + fix_params.size(), fix_params.data(), to_io_state_ext(L, 5))); + } +} + +static int is_aliased(lua_State * L) { + return push_optional_name(L, is_aliased(to_environment(L, 1), to_expr(L, 2))); +} + +static int get_alias(lua_State * L) { + return push_optional_expr(L, get_alias(to_environment(L, 1), to_name_ext(L, 2))); +} + +void open_aliases(lua_State * L) { + SET_GLOBAL_FUN(add_alias, "add_alias"); + SET_GLOBAL_FUN(add_aliases, "add_aliases"); + SET_GLOBAL_FUN(is_aliased, "is_aliased"); + SET_GLOBAL_FUN(get_alias, "get_alias"); +} +} diff --git a/src/library/aliases.h b/src/library/aliases.h new file mode 100644 index 000000000..9f923e17e --- /dev/null +++ b/src/library/aliases.h @@ -0,0 +1,42 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include "util/lua.h" +#include "kernel/environment.h" +#include "library/io_state.h" + +namespace lean { +/** + \brief Add the alias \c a for expression \c e. \c e must not have + free variables. Warning messages are generated if the new alias shadows + existing aliases and/or declarations. +*/ +environment add_alias(environment const & env, name const & a, expr const & e, io_state const & ios); +/** + \brief Create an alias for each declaration named prefix.rest. If \c new_prefix is not none, + then the alias for prefix.rest is new_prefix.rest. Otherwise, it is just \c rest. + Warning messages are generated if the new aliases shadow existing aliases and/or declarations. +*/ +environment add_aliases(environment const & env, name const & prefix, optional const & new_prefix, io_state const & ios); +/** + \brief Create an alias for each declaration named prefix.rest, the alias will also fix the value of parameters + in \c fix_params. The argument \c fix_params is a sequence of pairs (name, expr), where the \c name is the + name of the parameter to be fixed. + Warning messages are generated if the new aliases shadow existing aliases and/or declarations. +*/ +environment add_aliases(environment const & env, name const & prefix, optional const & new_prefix, + unsigned num_fix_params, std::pair const * fix_params, io_state const & ios); + +/** \brief If \c t is aliased in \c env, then return its name. Otherwise, return none. */ +optional is_aliased(environment const & env, expr const & t); + +/** \brief Return expression associated with the given alias. */ +optional get_alias(environment const & env, name const & n); + +void open_aliases(lua_State * L); +} diff --git a/src/library/register_module.h b/src/library/register_module.h index 02b5ad228..c3e43f5a3 100644 --- a/src/library/register_module.h +++ b/src/library/register_module.h @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "library/coercion.h" #include "library/private.h" #include "library/placeholder.h" +#include "library/aliases.h" // #include "library/fo_unify.h" // #include "library/hop_match.h" @@ -21,6 +22,7 @@ inline void open_core_module(lua_State * L) { open_coercion(L); open_private(L); open_placeholder(L); + open_aliases(L); // open_fo_unify(L); // open_hop_match(L); } diff --git a/tests/lua/alias1.lua b/tests/lua/alias1.lua new file mode 100644 index 000000000..eced6897c --- /dev/null +++ b/tests/lua/alias1.lua @@ -0,0 +1,61 @@ +local env = environment() + +local env = environment() +local l = mk_param_univ("l") +local A = Const("A") +local U_l = mk_sort(l) +local U_l1 = mk_sort(max_univ(l, 1)) -- Make sure U_l1 is not Bool/Prop +local nat = Const({"nat", "nat"}) +local vec_l = Const({"vec", "vec"}, {l}) -- vec.{l} +local zero = Const({"nat", "zero"}) +local succ = Const({"nat", "succ"}) +local one = succ(zero) +local list_l = Const({"list", "list"}, {l}) -- list.{l} + +env = add_inductive(env, + name("nat", "nat"), Type, + name("nat", "zero"), nat, + name("nat", "succ"), mk_arrow(nat, nat)) + +env:for_each(function(d) print(d:name()) end) +env = add_aliases(env, "nat", "natural") +assert(get_alias(env, {"natural", "zero"}) == zero) +assert(get_alias(env, {"natural", "nat"}) == nat) +assert(is_aliased(env, nat) == name("natural", "nat")) + +env = add_inductive(env, + name("list", "list"), {l}, 1, Pi(A, U_l, U_l1), + name("list", "nil"), Pi({{A, U_l, true}}, list_l(A)), + name("list", "cons"), Pi({{A, U_l, true}}, mk_arrow(A, list_l(A), list_l(A)))) + +env = add_aliases(env, "list", "lst") +print(get_alias(env, {"lst", "list_rec"})) +env = add_aliases(env, "list") +print(get_alias(env, "list_rec")) +assert(get_alias(env, "list_rec")) +assert(get_alias(env, {"lst", "list_rec"})) + +env = add_aliases(env, "list", "lnat", {{"A", nat}}) +print(get_alias(env, {"lnat", "list"})) +print(get_alias(env, {"lnat", "cons"})) +assert(get_alias(env, {"lnat", "cons"}) == Const({"list", "cons"}, { mk_level_placeholder() })(nat)) + +local A = Local("A", mk_sort(1)) +local R = Local("R", mk_arrow(A, A, Bool)) +local a = Local("a", A) +local b = Local("b", A) + +env = add_decl(env, mk_var_decl({"foo", "pred"}, Pi({A, R, a, b}, Bool))) +env = add_aliases(env, "foo", nil, {{"A", nat}, {"a", zero}}) +local Rn = Local("R", mk_arrow(nat, nat, Bool)) +local bn = Local("b", nat) +local a1 = get_alias(env, "pred") +local a2 = Fun({Rn, bn}, Const({"foo", "pred"})(nat, Rn, zero, bn)) +print(a1) +print(a2) +assert(a1 == a2) +env = add_aliases(env, "foo", nil, {{"A", nat}, {"a", zero}, {"b", one}}) +print(get_alias(env, "pred")) +env = add_alias(env, "z", zero) +assert(get_alias(env, "z") == zero) +assert(not get_alias(env, "zz"))