diff --git a/src/kernel/metavar.cpp b/src/kernel/metavar.cpp index b898c5717..1724ba7a5 100644 --- a/src/kernel/metavar.cpp +++ b/src/kernel/metavar.cpp @@ -255,6 +255,20 @@ expr substitution::instantiate_metavars_wo_jst(expr const & e, bool inst_local_t return instantiate_metavars_fn(*this, false, inst_local_types)(e); } +auto substitution::expand_metavar_app(expr const & e) -> opt_expr_jst { + expr const & f = get_app_fn(e); + if (!is_metavar(f)) + return opt_expr_jst(); + name const & f_name = mlocal_name(f); + auto f_value = get_expr(f_name); + if (!f_value) + return opt_expr_jst(); + buffer args; + get_app_rev_args(e, args); + expr new_e = apply_beta(*f_value, args.size(), args.data()); + return opt_expr_jst(new_e, get_expr_jst(f_name)); +} + static name_set merge(name_set s1, name_set const & s2) { s2.for_each([&](name const & n) { s1.insert(n); }); return s1; diff --git a/src/kernel/metavar.h b/src/kernel/metavar.h index 2e73c083c..2be50bef4 100644 --- a/src/kernel/metavar.h +++ b/src/kernel/metavar.h @@ -15,6 +15,10 @@ Author: Leonardo de Moura namespace lean { class substitution { +public: + typedef optional> opt_expr_jst; + typedef optional> opt_level_jst; +private: typedef name_map expr_map; typedef name_map level_map; typedef name_map jst_map; @@ -37,19 +41,11 @@ class substitution { pair instantiate_metavars_core(expr const & e, bool inst_local_types); bool occurs_expr_core(name const & m, expr const & e, name_set & visited) const; name_set get_occs(name const & m, name_set & fresh); -public: - substitution(); - typedef optional> opt_expr_jst; - typedef optional> opt_level_jst; - bool is_expr_assigned(name const & m) const; opt_expr_jst get_expr_assignment(name const & m) const; - - bool is_level_assigned(name const & m) const; + optional get_expr(name const & m) const; opt_level_jst get_level_assignment(name const & m) const; - optional get_expr(name const & m) const; - optional get_level(name const & m) const; justification get_expr_jst(name const & m) const { if (auto it = m_expr_jsts.find(m)) return *it; else return justification(); } @@ -57,6 +53,13 @@ public: if (auto it = m_level_jsts.find(m)) return *it; else return justification(); } +public: + substitution(); + + optional get_level(name const & m) const; + bool is_expr_assigned(name const & m) const; + bool is_level_assigned(name const & m) const; + void assign(name const & m, expr const & t, justification const & j); void assign(name const & m, expr const & t) { assign(m, t, justification()); } void assign(expr const & m, expr const & t, justification const & j) { assign(mlocal_name(m), t, j); } @@ -66,6 +69,9 @@ public: void assign(level const & m, level const & t, justification const & j) { assign(meta_id(m), t, j); } void assign(level const & m, level const & t) { assign(m, t, justification ()); } + /** \brief Given e of the form ?m t1 ... t2, expand ?m and apply beta-reduction */ + opt_expr_jst expand_metavar_app(expr const & e); + pair instantiate_metavars(level const & l) { return instantiate_metavars(l, true); } level instantiate(level const & l) { return instantiate_metavars(l, false).first; } diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 620deb9e8..677b3c035 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -1646,18 +1646,6 @@ static void open_constraint(lua_State * L) { // Substitution DECL_UDATA(substitution) static int mk_substitution(lua_State * L) { return push_substitution(L, substitution()); } -static int subst_get_expr(lua_State * L) { - if (is_expr(L, 2)) - return push_optional_expr(L, to_substitution(L, 1).get_expr(to_expr(L, 2))); - else - return push_optional_expr(L, to_substitution(L, 1).get_expr(to_name_ext(L, 2))); -} -static int subst_get_level(lua_State * L) { - if (is_level(L, 2)) - return push_optional_level(L, to_substitution(L, 1).get_level(to_level(L, 2))); - else - return push_optional_level(L, to_substitution(L, 1).get_level(to_name_ext(L, 2))); -} static int subst_assign(lua_State * L) { int nargs = lua_gettop(L); if (nargs == 3) { @@ -1697,46 +1685,6 @@ static int subst_is_expr_assigned(lua_State * L) { return push_boolean(L, to_sub static int subst_is_level_assigned(lua_State * L) { return push_boolean(L, to_substitution(L, 1).is_level_assigned(to_name_ext(L, 2))); } static int subst_occurs(lua_State * L) { return push_boolean(L, to_substitution(L, 1).occurs(to_expr(L, 2), to_expr(L, 3))); } static int subst_occurs_expr(lua_State * L) { return push_boolean(L, to_substitution(L, 1).occurs_expr(to_name_ext(L, 2), to_expr(L, 3))); } -static int subst_get_expr_assignment(lua_State * L) { - auto r = to_substitution(L, 1).get_expr_assignment(to_name_ext(L, 2)); - if (r) { - push_expr(L, r->first); - push_justification(L, r->second); - } else { - push_nil(L); push_nil(L); - } - return 2; -} -static int subst_get_level_assignment(lua_State * L) { - auto r = to_substitution(L, 1).get_level_assignment(to_name_ext(L, 2)); - if (r) { - push_level(L, r->first); - push_justification(L, r->second); - } else { - push_nil(L); push_nil(L); - } - return 2; -} -static int subst_get_assignment(lua_State * L) { - if (is_expr(L, 2)) { - auto r = to_substitution(L, 1).get_assignment(to_expr(L, 2)); - if (r) { - push_expr(L, r->first); - push_justification(L, r->second); - } else { - push_nil(L); push_nil(L); - } - } else { - auto r = to_substitution(L, 1).get_assignment(to_level(L, 2)); - if (r) { - push_level(L, r->first); - push_justification(L, r->second); - } else { - push_nil(L); push_nil(L); - } - } - return 2; -} static int subst_instantiate(lua_State * L) { if (is_expr(L, 2)) { auto r = to_substitution(L, 1).instantiate_metavars(to_expr(L, 2)); @@ -1783,17 +1731,12 @@ static int subst_copy(lua_State * L) { static const struct luaL_Reg substitution_m[] = { {"__gc", substitution_gc}, {"copy", safe_function}, - {"get_expr", safe_function}, - {"get_level", safe_function}, {"assign", safe_function}, {"is_assigned", safe_function}, {"is_expr_assigned", safe_function}, {"is_level_assigned", safe_function}, {"occurs", safe_function}, {"occurs_expr", safe_function}, - {"get_expr_assignment", safe_function}, - {"get_level_assignment", safe_function}, - {"get_assignment", safe_function}, {"instantiate", safe_function}, {"instantiate_all", safe_function}, {"for_each_expr", safe_function}, diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index 927c1c758..9bedb486e 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -867,17 +867,12 @@ struct unifier_fn { expr instantiate_meta(expr e, justification & j) { while (true) { - expr const & f = get_app_fn(e); - if (!is_metavar(f)) + if (auto p = m_subst.expand_metavar_app(e)) { + e = p->first; + j = mk_composite1(j, p->second); + } else { return e; - name const & f_name = mlocal_name(f); - auto f_value = m_subst.get_expr(f_name); - if (!f_value) - return e; - j = mk_composite1(j, m_subst.get_expr_jst(f_name)); - buffer args; - get_app_rev_args(e, args); - e = apply_beta(*f_value, args.size(), args.data()); + } } } diff --git a/tests/lua/subst1.lua b/tests/lua/subst1.lua index 235ee3dad..3724e5cd1 100644 --- a/tests/lua/subst1.lua +++ b/tests/lua/subst1.lua @@ -20,21 +20,13 @@ s:assign(u, l) assert(s:is_assigned(u)) assert(s:is_level_assigned("u")) assert(not s:is_expr_assigned("u")) -assert(s:get_expr("m") == a) local m2 = mk_metavar("m2", Prop) s:assign(m2, f(m)) -print(s:get_expr("m2")) assert(s:is_expr_assigned("m")) -- m is assigned, so it is does not occur in f(m2) -- assert(s:occurs_expr("m", f(m2))) -print(s:get_level("u")) print(s:instantiate(mk_sort(u))) assert(s:instantiate(mk_sort(u)) == mk_sort(l)) -assert(s:get_assignment(m) == a) -assert(s:get_assignment(u) == l) -assert(s:get_expr_assignment("m") == a) -assert(s:get_level_assignment("u") == l) - local s = substitution() local m2 = mk_metavar("m2", Prop) s:assign(m2, f(m))