diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 988066a0b..6d1ac55c1 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -89,7 +89,7 @@ public: return some_level(*it); } else { level new_p = mk_new_univ_param(); - m_subst.d_assign(l, new_p); + m_subst.assign(l, new_p); return some_level(new_p); } } @@ -307,9 +307,9 @@ class elaborator { substitution subst = next->first.get_subst(); buffer cs; expr const & mvar = get_app_fn(m_meta); - cs.push_back(mk_eq_cnstr(mvar, subst.d_instantiate(mvar), m_jst)); + cs.push_back(mk_eq_cnstr(mvar, subst.instantiate(mvar), m_jst)); for (auto const & mvar : m_mvars_in_meta_type) - cs.push_back(mk_eq_cnstr(mvar, subst.d_instantiate(mvar), m_jst)); + cs.push_back(mk_eq_cnstr(mvar, subst.instantiate(mvar), m_jst)); return optional(to_list(cs.begin(), cs.end())); } return optional(); @@ -414,9 +414,7 @@ public: if (!m_accumulated.is_none()) c = update_justification(c, mk_composite1(c.get_justification(), m_accumulated)); add_cnstr_core(c); - auto ss = unify_simple(m_subst, c); - m_subst = ss.second; - if (ss.first == unify_status::Failed) + if (unify_simple(m_subst, c) == unify_status::Failed) throw unifier_exception(c.get_justification(), m_subst); } @@ -424,7 +422,7 @@ public: \remark We update \c m_accumulated with any justifications used. */ expr instantiate_metavars(expr const & e) { - auto e_j = m_subst.d_instantiate_metavars(e); + auto e_j = m_subst.instantiate_metavars(e); m_accumulated = mk_composite1(m_accumulated, e_j.second); return e_j.first; } @@ -538,7 +536,7 @@ public: return ::lean::is_class(m_env, cls_name) || !empty(get_tactic_hints(m_env, cls_name)); } - static expr instantiate_meta(expr const & meta, substitution const & subst) { + static expr instantiate_meta(expr const & meta, substitution & subst) { buffer locals; expr mvar = get_app_args(meta, locals); mvar = update_mlocal(mvar, subst.instantiate(mlocal_type(mvar))); @@ -551,7 +549,8 @@ public: justification mk_failed_to_synthesize_jst(expr const & m) { environment env = m_env; return mk_justification(m, [=](formatter const & fmt, substitution const & subst) { - expr new_m = instantiate_meta(m, subst); + substitution tmp(subst); + expr new_m = instantiate_meta(m, tmp); expr new_type = type_checker(env).infer(new_m); proof_state ps(goals(goal(new_m, new_type)), substitution(), name_generator("dontcare")); return format({format("failed to synthesize placeholder"), line(), ps.pp(fmt)}); @@ -1065,7 +1064,7 @@ public: } else { subst = r->first.get_subst(); expr v = subst.instantiate(mvar); - subst = subst.assign(mlocal_name(mvar), v); + subst.assign(mlocal_name(mvar), v); return true; } } catch (tactic_exception & ex) { @@ -1110,12 +1109,13 @@ public: void display_unassigned_mvars(expr const & e, substitution const & s) { if (m_check_unassigned && has_metavar(e)) { + substitution tmp_s(s); for_each(e, [&](expr const & e, unsigned) { if (!is_metavar(e)) return has_metavar(e); if (auto it = m_mvar2meta.find(mlocal_name(e))) { - expr meta = s.instantiate(*it); - expr meta_type = s.instantiate(type_checker(m_env).infer(meta)); + expr meta = tmp_s.instantiate(*it); + expr meta_type = tmp_s.instantiate(type_checker(m_env).infer(meta)); goal g(meta, meta_type); display_unsolved_proof_state(e, proof_state(goals(g), substitution(), m_ngen), "don't know how to synthesize it"); @@ -1127,7 +1127,7 @@ public: /** \brief Apply substitution and solve remaining metavariables using tactics. */ expr apply(substitution & s, expr const & e, name_set & univ_params, buffer & new_params) { - expr r = s.d_instantiate(e); + expr r = s.instantiate(e); if (has_univ_metavar(r)) r = univ_metavars_to_params_fn(m_env, m_lls, s, univ_params, new_params)(r); r = solve_unassigned_mvars(s, r); @@ -1158,7 +1158,8 @@ public: expr r_v = visit(v); expr r_v_type = infer_type(r_v); justification j = mk_justification(v, [=](formatter const & fmt, substitution const & subst) { - return pp_def_type_mismatch(fmt, n, subst.instantiate(r_t), subst.instantiate(r_v_type)); + substitution s(subst); + return pp_def_type_mismatch(fmt, n, s.instantiate(r_t), s.instantiate(r_v_type)); }); if (!m_tc->is_def_eq(r_v_type, r_t, j)) { throw_kernel_exception(m_env, v, [=](formatter const & fmt) { return pp_def_type_mismatch(fmt, n, r_t, r_v_type); }); diff --git a/src/kernel/metavar.cpp b/src/kernel/metavar.cpp index c1e3efc43..17b6ef0f4 100644 --- a/src/kernel/metavar.cpp +++ b/src/kernel/metavar.cpp @@ -50,40 +50,20 @@ optional substitution::get_level(name const & m) const { return it ? some_level(*it) : none_level(); } -void substitution::d_assign(name const & m, expr const & t, justification const & j) { +void substitution::assign(name const & m, expr const & t, justification const & j) { lean_assert(closed(t)); m_expr_subst.insert(m, t); if (!j.is_none()) m_expr_jsts.insert(m, j); } -void substitution::d_assign(name const & m, level const & l, justification const & j) { +void substitution::assign(name const & m, level const & l, justification const & j) { m_level_subst.insert(m, l); if (!j.is_none()) m_level_jsts.insert(m, j); } -substitution substitution::assign(name const & m, expr const & t, justification const & j) const { - substitution s(*this); - s.d_assign(m, t, j); - return s; -} - -substitution substitution::assign(name const & m, expr const & t) const { - return assign(m, t, justification()); -} - -substitution substitution::assign(name const & m, level const & l, justification const & j) const { - substitution s(*this); - s.d_assign(m, l, j); - return s; -} - -substitution substitution::assign(name const & m, level const & l) const { - return assign(m, l, justification()); -} - -std::pair substitution::d_instantiate_metavars(level const & l, bool use_jst, bool updt) { +std::pair substitution::instantiate_metavars(level const & l, bool use_jst) { if (!has_meta(l)) return mk_pair(l, justification()); justification j; @@ -94,14 +74,13 @@ std::pair substitution::d_instantiate_metavars(level const } else if (is_meta(l)) { auto p1 = get_assignment(l); if (p1) { - auto p2 = d_instantiate_metavars(p1->first, use_jst, updt); + auto p2 = instantiate_metavars(p1->first, use_jst); if (use_jst) { justification new_jst = mk_composite1(p1->second, p2.second); - if (updt) - d_assign(meta_id(l), p2.first, new_jst); + assign(meta_id(l), p2.first, new_jst); save_jst(new_jst); - } else if (updt) { - d_assign(meta_id(l), p2.first); + } else { + assign(meta_id(l), p2.first); } return some_level(p2.first); } @@ -116,13 +95,11 @@ protected: substitution & m_subst; justification m_jst; bool m_use_jst; - bool m_update; - bool m_only_expr; void save_jst(justification const & j) { m_jst = mk_composite1(m_jst, j); } level visit_level(level const & l) { - auto p1 = m_subst.d_instantiate_metavars(l, m_use_jst, m_update); + auto p1 = m_subst.instantiate_metavars(l, m_use_jst); if (m_use_jst) save_jst(p1.second); return p1.first; @@ -133,12 +110,10 @@ protected: } virtual expr visit_sort(expr const & s) { - lean_assert(!m_only_expr); return update_sort(s, visit_level(sort_level(s))); } virtual expr visit_constant(expr const & c) { - lean_assert(!m_only_expr); return update_constant(c, visit_levels(const_levels(c))); } @@ -151,25 +126,15 @@ protected: save_jst(p1->second); return p1->first; } else if (m_use_jst) { - if (m_update) { - auto p2 = m_subst.d_instantiate_metavars(p1->first); - justification new_jst = mk_composite1(p1->second, p2.second); - m_subst.d_assign(m_name, p2.first, new_jst); - save_jst(new_jst); - return p2.first; - } else { - auto p2 = m_subst.instantiate_metavars(p1->first); - save_jst(mk_composite1(p1->second, p2.second)); - return p2.first; - } + auto p2 = m_subst.instantiate_metavars(p1->first); + justification new_jst = mk_composite1(p1->second, p2.second); + m_subst.assign(m_name, p2.first, new_jst); + save_jst(new_jst); + return p2.first; } else { - if (m_update) { - auto p2 = m_subst.d_instantiate_metavars(p1->first); - m_subst.d_assign(m_name, p2.first, mk_composite1(p1->second, p2.second)); - return p2.first; - } else { - return m_subst.instantiate(p1->first); - } + auto p2 = m_subst.instantiate_metavars(p1->first); + m_subst.assign(m_name, p2.first, mk_composite1(p1->second, p2.second)); + return p2.first; } } else { return m; @@ -194,7 +159,7 @@ protected: } virtual expr visit(expr const & e) { - if ((m_only_expr && !has_expr_metavar(e)) || (!m_only_expr && !has_metavar(e))) { + if (!has_metavar(e)) { return e; } else { return replace_visitor::visit(e); @@ -202,57 +167,23 @@ protected: } public: - instantiate_metavars_fn(substitution & s, bool use_jst, bool updt, bool only_expr = false): - m_subst(s), m_use_jst(use_jst), m_update(updt), m_only_expr(only_expr) {} + instantiate_metavars_fn(substitution & s, bool use_jst): + m_subst(s), m_use_jst(use_jst) {} justification const & get_justification() const { return m_jst; } }; -std::pair substitution::instantiate_metavars(expr const & e) const { - substitution s(*this); - instantiate_metavars_fn fn(s, true, false); - expr r = fn(e); - return mk_pair(r, fn.get_justification()); -} - -std::pair substitution::d_instantiate_metavars(expr const & e, bool only_expr) { - if ((only_expr && !has_expr_metavar(e)) || (!only_expr && !has_metavar(e))) { +std::pair substitution::instantiate_metavars(expr const & e) { + if (!has_metavar(e)) { return mk_pair(e, justification()); } else { - instantiate_metavars_fn fn(*this, true, true, only_expr); + instantiate_metavars_fn fn(*this, true); expr r = fn(e); return mk_pair(r, fn.get_justification()); } } -std::tuple substitution::updt_instantiate_metavars(expr const & e) const { - substitution s(*this); - instantiate_metavars_fn fn(s, true, true); - expr r = fn(e); - return std::make_tuple(r, fn.get_justification(), s); -} - -std::pair substitution::instantiate_metavars(level const & l) const { - substitution s(*this); - return s.d_instantiate_metavars(l, true, false); -} - -expr substitution::instantiate_metavars_wo_jst(expr const & e) const { - substitution s(*this); - return instantiate_metavars_fn(s, false, false)(e); -} - -expr substitution::d_instantiate_metavars_wo_jst(expr const & e) { - return instantiate_metavars_fn(*this, false, true)(e); -} - -std::pair substitution::updt_instantiate_metavars_wo_jst(expr const & e) const { - substitution s(*this); - return mk_pair(instantiate_metavars_fn(s, false, true)(e), s); -} - -level substitution::instantiate_metavars_wo_jst(level const & l) const { - substitution s(*this); - return s.d_instantiate_metavars(l, false, false).first; +expr substitution::instantiate_metavars_wo_jst(expr const & e) { + return instantiate_metavars_fn(*this, false)(e); } bool substitution::occurs_expr(name const & m, expr const & e) const { diff --git a/src/kernel/metavar.h b/src/kernel/metavar.h index 92e7fab66..c1663cf06 100644 --- a/src/kernel/metavar.h +++ b/src/kernel/metavar.h @@ -24,7 +24,8 @@ class substitution { jst_map m_level_jsts; friend class instantiate_metavars_fn; - std::pair d_instantiate_metavars(level const & l, bool use_jst, bool updt); + std::pair instantiate_metavars(level const & l, bool use_jst); + expr instantiate_metavars_wo_jst(expr const & e); public: substitution(); @@ -42,27 +43,19 @@ public: justification get_expr_jst(name const & m) const { if (auto it = m_expr_jsts.find(m)) return *it; else return justification(); } justification get_level_jst(name const & m) const { if (auto it = m_level_jsts.find(m)) return *it; else return justification(); } - substitution assign(name const & m, expr const & t, justification const & j) const; - substitution assign(name const & m, expr const & t) const; + void assign(name const & m, expr const & t, justification const & j); + void assign(name const & m, expr const & t) { assign(m, t, justification()); } + void assign(expr const & m, expr const & t, justification const & j) { assign(mlocal_name(m), t, j); } + void assign(expr const & m, expr const & t) { assign(m, t, justification()); } + void assign(name const & m, level const & t, justification const & j); + void assign(name const & m, level const & t) { assign(m, t, justification ()); } + void assign(level const & m, level const & t, justification const & j) { assign(meta_id(m), t, j); } + void assign(level const & m, level const & t) { assign(m, t, justification ()); } - substitution assign(name const & m, level const & t, justification const & j) const; - substitution assign(name const & m, level const & t) const; + std::pair instantiate_metavars(expr const & e); + std::pair instantiate_metavars(level const & l) { return instantiate_metavars(l, true); } - void d_assign(name const & m, expr const & t, justification const & j); - void d_assign(name const & m, expr const & t) { d_assign(m, t, justification()); } - void d_assign(expr const & m, expr const & t, justification const & j) { d_assign(mlocal_name(m), t, j); } - void d_assign(expr const & m, expr const & t) { d_assign(m, t, justification()); } - void d_assign(name const & m, level const & t, justification const & j); - void d_assign(name const & m, level const & t) { d_assign(m, t, justification ()); } - void d_assign(level const & m, level const & t, justification const & j) { d_assign(meta_id(m), t, j); } - void d_assign(level const & m, level const & t) { d_assign(m, t, justification ()); } - - std::pair d_instantiate_metavars(expr const & e, bool only_expr = false); - expr d_instantiate_metavars_wo_jst(expr const & e); - std::pair d_instantiate_metavars(level const & l) { return d_instantiate_metavars(l, true, true); } - - void d_forget_justifications() { m_expr_jsts = jst_map(); m_level_jsts = jst_map(); } - substitution forget_justifications() const { substitution s(*this); s.d_forget_justifications(); return s; } + void forget_justifications() { m_expr_jsts = jst_map(); m_level_jsts = jst_map(); } template void for_each_expr(F && fn) const { @@ -77,45 +70,12 @@ public: bool is_assigned(expr const & m) const { lean_assert(is_metavar(m)); return is_expr_assigned(mlocal_name(m)); } opt_expr_jst get_assignment(expr const & m) const { lean_assert(is_metavar(m)); return get_expr_assignment(mlocal_name(m)); } optional get_expr(expr const & m) const { lean_assert(is_metavar(m)); return get_expr(mlocal_name(m)); } - substitution assign(expr const & m, expr const & t, justification const & j) { - lean_assert(is_metavar(m)); return assign(mlocal_name(m), t, j); } - substitution assign(expr const & m, expr const & t) const { lean_assert(is_metavar(m)); return assign(mlocal_name(m), t); } bool is_assigned(level const & m) const { lean_assert(is_meta(m)); return is_level_assigned(meta_id(m)); } opt_level_jst get_assignment(level const & m) const { lean_assert(is_meta(m)); return get_level_assignment(meta_id(m)); } optional get_level(level const & m) const { lean_assert(is_meta(m)); return get_level(meta_id(m)); } - substitution assign(level const & m, level const & l, justification const & j) const { - lean_assert(is_meta(m)); return assign(meta_id(m), l, j); } - substitution assign(level const & m, level const & l) { lean_assert(is_meta(m)); return assign(meta_id(m), l); } - /** - \brief Instantiate metavariables in \c e assigned in this substitution. - */ - std::pair instantiate_metavars(expr const & e) const; - - /** - \brief Similar to the previous function, but it compress the substitution. - By compress, we mean, for any metavariable \c m reachable from \c e, - if s[m] = t, and t has asssigned metavariables, then s[m] <- instantiate_metavars(t, s). - The updated substitution is returned. - */ - std::tuple updt_instantiate_metavars(expr const & e) const; - - /** \brief Instantiate level metavariables in \c l. */ - std::pair instantiate_metavars(level const & l) const; - - /** - \brief Instantiate metavariables in \c e assigned in the substitution \c s, - but does not return a justification object for the new expression. - */ - expr instantiate_metavars_wo_jst(expr const & e) const; - expr instantiate(expr const & e) const { return instantiate_metavars_wo_jst(e); } - expr d_instantiate(expr const & e) { return d_instantiate_metavars_wo_jst(e); } - - std::pair updt_instantiate_metavars_wo_jst(expr const & e) const; - - /** \brief Instantiate level metavariables in \c l, but does not return justification object. */ - level instantiate_metavars_wo_jst(level const & l) const; + expr instantiate(expr const & e) { return instantiate_metavars_wo_jst(e); } /** \brief Return true iff the metavariable \c m occurrs (directly or indirectly) in \c e. */ bool occurs_expr(name const & m, expr const & e) const; diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 4011d2ee0..344fd6020 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -134,7 +134,7 @@ expr type_checker::ensure_sort_core(expr e, expr const & s) { expr r = mk_sort(mk_meta_univ(m_gen.next())); justification j = mk_justification(s, [=](formatter const & fmt, substitution const & subst) { - return pp_type_expected(fmt, subst.instantiate(s)); + return pp_type_expected(fmt, substitution(subst).instantiate(s)); }); add_cnstr(mk_eq_cnstr(e, r, j)); return r; @@ -153,7 +153,7 @@ expr type_checker::ensure_pi_core(expr e, expr const & s) { } else if (is_meta(e)) { expr r = mk_pi_for(m_gen, e); justification j = mk_justification(s, [=](formatter const & fmt, substitution const & subst) { - return pp_function_expected(fmt, subst.instantiate(s)); + return pp_function_expected(fmt, substitution(subst).instantiate(s)); }); add_cnstr(mk_eq_cnstr(e, r, j)); return r; @@ -182,7 +182,8 @@ app_delayed_justification::app_delayed_justification(expr const & e, expr const justification mk_app_justification(expr const & e, expr const & d_type, expr const & a_type) { auto pp_fn = [=](formatter const & fmt, substitution const & subst) { - return pp_app_type_mismatch(fmt, subst.instantiate(e), subst.instantiate(d_type), subst.instantiate(a_type)); + substitution s(subst); + return pp_app_type_mismatch(fmt, s.instantiate(e), s.instantiate(d_type), s.instantiate(a_type)); }; return mk_justification(e, pp_fn); } diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 337cee4c0..f93f5ba6d 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -1402,7 +1402,7 @@ static int mk_justification(lua_State * L) { environment env = to_environment(L, 2); expr e = to_expr(L, 3); justification j = mk_justification(some_expr(e), [=](formatter const & fmt, substitution const & subst) { - expr new_e = subst.instantiate(e); + expr new_e = substitution(subst).instantiate(e); format r; r += format(s.c_str()); r += pp_indent_expr(fmt, new_e); @@ -1605,28 +1605,29 @@ static int subst_assign(lua_State * L) { if (nargs == 3) { if (is_expr(L, 3)) { if (is_expr(L, 2)) - return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3))); + to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3)); else - return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3))); + to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3)); } else { if (is_level(L, 2)) - return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3))); + to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3)); else - return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3))); + to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3)); } } else { if (is_expr(L, 3)) { if (is_expr(L, 2)) - return push_substitution(L, to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3), to_justification(L, 4))); + to_substitution(L, 1).assign(to_expr(L, 2), to_expr(L, 3), to_justification(L, 4)); else - return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3), to_justification(L, 4))); + to_substitution(L, 1).assign(to_name_ext(L, 2), to_expr(L, 3), to_justification(L, 4)); } else { if (is_level(L, 2)) - return push_substitution(L, to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3), to_justification(L, 4))); + to_substitution(L, 1).assign(to_level(L, 2), to_level(L, 3), to_justification(L, 4)); else - return push_substitution(L, to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3), to_justification(L, 4))); + to_substitution(L, 1).assign(to_name_ext(L, 2), to_level(L, 3), to_justification(L, 4)); } } + return 0; } static int subst_is_assigned(lua_State * L) { if (is_expr(L, 2)) @@ -1711,8 +1712,13 @@ static int subst_for_each_level(lua_State * L) { return 0; } +static int subst_copy(lua_State * L) { + return push_substitution(L, substitution(to_substitution(L, 1))); +} + static const struct luaL_Reg substitution_m[] = { {"__gc", substitution_gc}, + {"copy", safe_function}, {"get_expr", safe_function}, {"get_level", safe_function}, {"assign", safe_function}, diff --git a/src/library/tactic/apply_tactic.cpp b/src/library/tactic/apply_tactic.cpp index abe10becd..826a7a4a4 100644 --- a/src/library/tactic/apply_tactic.cpp +++ b/src/library/tactic/apply_tactic.cpp @@ -123,10 +123,11 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, return map2(substs, [=](substitution const & subst) -> proof_state { name_generator new_ngen(ngen); type_checker tc(env, new_ngen.mk_child()); - expr new_e = subst.instantiate(e); + substitution new_subst = subst; + expr new_e = new_subst.instantiate(e); expr new_p = g.abstract(new_e); check_has_no_local(new_p, _e, "apply"); - substitution new_subst = subst.assign(g.get_name(), new_p); + new_subst.assign(g.get_name(), new_p); goals new_gs = tail_gs; if (add_subgoals) { buffer metas; @@ -138,7 +139,7 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, unsigned i = metas.size(); while (i > 0) { --i; - new_gs = cons(goal(metas[i], subst.instantiate(tc.infer(metas[i]))), new_gs); + new_gs = cons(goal(metas[i], new_subst.instantiate(tc.infer(metas[i]))), new_gs); } } return proof_state(new_gs, new_subst, new_ngen); @@ -184,9 +185,10 @@ expr refresh_univ_metavars(expr const & e, name_generator & ngen) { tactic apply_tactic(expr const & e, bool refresh_univ_mvars) { return tactic([=](environment const & env, io_state const & ios, proof_state const & s) { if (refresh_univ_mvars) { - name_generator ngen = s.get_ngen(); - expr new_e = refresh_univ_metavars(s.get_subst().instantiate(e), ngen); - proof_state new_s(s, ngen); + name_generator ngen = s.get_ngen(); + substitution new_subst = s.get_subst(); + expr new_e = refresh_univ_metavars(new_subst.instantiate(e), ngen); + proof_state new_s(s.get_goals(), new_subst, ngen); return apply_tactic_core(env, ios, new_s, new_e, true, true); } else { return apply_tactic_core(env, ios, s, e, true, true); diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index fb70d7a60..192fa3ec8 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -227,7 +227,7 @@ tactic assumption_tactic() { } } if (h) { - subst = subst.assign(g.get_mvar(), g.abstract(*h), justification()); + subst.assign(g.get_mvar(), g.abstract(*h), justification()); solved = true; return optional(); } else { @@ -253,8 +253,8 @@ tactic exact_tactic(expr const & _e) { if (tc.is_def_eq(e_t, t) && !tc.next_cnstr()) { expr new_p = g.abstract(e); check_has_no_local(new_p, _e, "exact"); - substitution new_subst = subst.assign(g.get_name(), new_p); - return some(proof_state(s, tail(gs), new_subst)); + subst.assign(g.get_name(), new_p); + return some(proof_state(s, tail(gs), subst)); } else { return none_proof_state(); } diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index 6601f6252..881a9badf 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -110,36 +110,36 @@ expr lambda_abstract_locals(expr const & e, buffer const & locals) { return v; } -static std::pair unify_simple_core(substitution const & s, expr const & lhs, expr const & rhs, - justification const & j) { +unify_status unify_simple_core(substitution & s, expr const & lhs, expr const & rhs, justification const & j) { lean_assert(is_meta(lhs)); buffer args; auto m = is_simple_meta(lhs, args); if (!m || is_meta(rhs)) { - return mk_pair(unify_status::Unsupported, s); + return unify_status::Unsupported; } else { switch (occurs_context_check(s, rhs, *m, args)) { - case l_false: return mk_pair(unify_status::Failed, s); - case l_undef: mk_pair(unify_status::Unsupported, s); + case l_false: return unify_status::Failed; + case l_undef: return unify_status::Unsupported; case l_true: { expr v = lambda_abstract_locals(rhs, args); - return mk_pair(unify_status::Solved, s.assign(mlocal_name(*m), v, j)); + s.assign(mlocal_name(*m), v, j); + return unify_status::Solved; }} } lean_unreachable(); // LCOV_EXCL_LINE } -std::pair unify_simple(substitution const & s, expr const & lhs, expr const & rhs, justification const & j) { +unify_status unify_simple(substitution & s, expr const & lhs, expr const & rhs, justification const & j) { if (lhs == rhs) - return mk_pair(unify_status::Solved, s); + return unify_status::Solved; else if (!has_metavar(lhs) && !has_metavar(rhs)) - return mk_pair(unify_status::Failed, s); + return unify_status::Failed; else if (is_meta(lhs)) return unify_simple_core(s, lhs, rhs, j); else if (is_meta(rhs)) return unify_simple_core(s, rhs, lhs, j); else - return mk_pair(unify_status::Unsupported, s); + return unify_status::Unsupported; } // Return true if m occurs in e @@ -158,23 +158,24 @@ bool occurs_meta(level const & m, level const & e) { return contains; } -std::pair unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) { +unify_status unify_simple_core(substitution & s, level const & lhs, level const & rhs, justification const & j) { lean_assert(is_meta(lhs)); bool contains = occurs_meta(lhs, rhs); if (contains) { if (is_succ(rhs)) - return mk_pair(unify_status::Failed, s); + return unify_status::Failed; else - return mk_pair(unify_status::Unsupported, s); + return unify_status::Unsupported; } - return mk_pair(unify_status::Solved, s.assign(meta_id(lhs), rhs, j)); + s.assign(meta_id(lhs), rhs, j); + return unify_status::Solved; } -std::pair unify_simple(substitution const & s, level const & lhs, level const & rhs, justification const & j) { +unify_status unify_simple(substitution & s, level const & lhs, level const & rhs, justification const & j) { if (lhs == rhs) - return mk_pair(unify_status::Solved, s); + return unify_status::Solved; else if (!has_meta(lhs) && !has_meta(rhs)) - return mk_pair(unify_status::Failed, s); + return unify_status::Failed; else if (is_meta(lhs)) return unify_simple_core(s, lhs, rhs, j); else if (is_meta(rhs)) @@ -182,16 +183,16 @@ std::pair unify_simple(substitution const & s, level else if (is_succ(lhs) && is_succ(rhs)) return unify_simple(s, succ_of(lhs), succ_of(rhs), j); else - return mk_pair(unify_status::Unsupported, s); + return unify_status::Unsupported; } -std::pair unify_simple(substitution const & s, constraint const & c) { +unify_status unify_simple(substitution & s, constraint const & c) { if (is_eq_cnstr(c)) return unify_simple(s, cnstr_lhs_expr(c), cnstr_rhs_expr(c), c.get_justification()); else if (is_level_eq_cnstr(c)) return unify_simple(s, cnstr_lhs_level(c), cnstr_rhs_level(c), c.get_justification()); else - return mk_pair(unify_status::Unsupported, s); + return unify_status::Unsupported; } static constraint g_dont_care_cnstr = mk_eq_cnstr(expr(), expr(), justification()); @@ -421,7 +422,7 @@ struct unifier_fn { */ bool assign(expr const & m, expr const & v, justification const & j) { lean_assert(is_metavar(m)); - m_subst.d_assign(m, v, j); + m_subst.assign(m, v, j); expr m_type = mlocal_type(m); expr v_type; try { @@ -433,7 +434,8 @@ struct unifier_fn { if (in_conflict()) return false; justification j1 = mk_justification(m, [=](formatter const & fmt, substitution const & subst) { - return pp_type_mismatch(fmt, subst.instantiate(m_type), subst.instantiate(v_type)); + substitution s(subst); + return pp_type_mismatch(fmt, s.instantiate(m_type), s.instantiate(v_type)); }); if (!is_def_eq(m_type, v_type, mk_composite1(j1, j))) return false; @@ -456,7 +458,7 @@ struct unifier_fn { */ bool assign(level const & m, level const & v, justification const & j) { lean_assert(is_meta(m)); - m_subst.d_assign(m, v, j); + m_subst.assign(m, v, j); return true; } @@ -507,8 +509,8 @@ struct unifier_fn { std::pair instantiate_metavars(constraint const & c) { if (is_eq_cnstr(c)) { - auto lhs_jst = m_subst.d_instantiate_metavars(cnstr_lhs_expr(c)); - auto rhs_jst = m_subst.d_instantiate_metavars(cnstr_rhs_expr(c)); + auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_expr(c)); + auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_expr(c)); expr lhs = lhs_jst.first; expr rhs = rhs_jst.first; if (lhs != cnstr_lhs_expr(c) || rhs != cnstr_rhs_expr(c)) { @@ -517,8 +519,8 @@ struct unifier_fn { true); } } else if (is_level_eq_cnstr(c)) { - auto lhs_jst = m_subst.d_instantiate_metavars(cnstr_lhs_level(c)); - auto rhs_jst = m_subst.d_instantiate_metavars(cnstr_rhs_level(c)); + auto lhs_jst = m_subst.instantiate_metavars(cnstr_lhs_level(c)); + auto rhs_jst = m_subst.instantiate_metavars(cnstr_rhs_level(c)); level lhs = lhs_jst.first; level rhs = rhs_jst.first; if (lhs != cnstr_lhs_level(c) || rhs != cnstr_rhs_level(c)) { @@ -839,7 +841,7 @@ struct unifier_fn { set_conflict(c.get_justification()); return false; } - auto m_type_jst = m_subst.d_instantiate_metavars(m_type); + auto m_type_jst = m_subst.instantiate_metavars(m_type); lazy_list alts = fn(m, m_type_jst.first, m_subst, m_ngen.mk_child()); return process_lazy_constraints(alts, mk_composite1(c.get_justification(), m_type_jst.second)); } @@ -1377,7 +1379,9 @@ struct unifier_fn { } lean_assert(!in_conflict()); lean_assert(m_cnstrs.empty()); - return optional(m_subst.forget_justifications()); + substitution s = m_subst; + s.forget_justifications(); + return optional(s); } }; @@ -1409,8 +1413,8 @@ lazy_list unify(environment const & env, unsigned num_cs, constra lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, substitution const & s, unsigned max_steps) { substitution new_s = s; - expr _lhs = new_s.d_instantiate(lhs); - expr _rhs = new_s.d_instantiate(rhs); + expr _lhs = new_s.instantiate(lhs); + expr _rhs = new_s.instantiate(rhs); auto u = std::make_shared(env, 0, nullptr, ngen, new_s, false, max_steps); if (!u->m_tc->is_def_eq(_lhs, _rhs)) return lazy_list(); @@ -1425,7 +1429,7 @@ lazy_list unify(environment const & env, expr const & lhs, expr co static int unify_simple(lua_State * L) { int nargs = lua_gettop(L); - std::pair r; + unify_status r; if (nargs == 2) r = unify_simple(to_substitution(L, 1), to_constraint(L, 2)); else if (nargs == 3 && is_expr(L, 2)) @@ -1436,9 +1440,7 @@ static int unify_simple(lua_State * L) { r = unify_simple(to_substitution(L, 1), to_expr(L, 2), to_expr(L, 3), to_justification(L, 4)); else r = unify_simple(to_substitution(L, 1), to_level(L, 2), to_level(L, 3), to_justification(L, 4)); - push_integer(L, static_cast(r.first)); - push_substitution(L, r.second); - return 2; + return push_integer(L, static_cast(r)); } typedef lazy_list substitution_seq; diff --git a/src/library/unifier.h b/src/library/unifier.h index 935e0b4ea..26e09f3e9 100644 --- a/src/library/unifier.h +++ b/src/library/unifier.h @@ -32,9 +32,9 @@ enum class unify_status { Solved, Failed, Unsupported }; This function assumes that all assigned metavariables have been substituted. */ -std::pair unify_simple(substitution const & s, expr const & lhs, expr const & rhs, justification const & j); -std::pair unify_simple(substitution const & s, level const & lhs, level const & rhs, justification const & j); -std::pair unify_simple(substitution const & s, constraint const & c); +unify_status unify_simple(substitution & s, expr const & lhs, expr const & rhs, justification const & j); +unify_status unify_simple(substitution & s, level const & lhs, level const & rhs, justification const & j); +unify_status unify_simple(substitution & s, constraint const & c); lazy_list unify(environment const & env, unsigned num_cs, constraint const * cs, name_generator const & ngen, bool use_exception = true, unsigned max_steps = LEAN_DEFAULT_UNIFIER_MAX_STEPS); diff --git a/src/tests/kernel/metavar.cpp b/src/tests/kernel/metavar.cpp index c39c6998e..9d92b409e 100644 --- a/src/tests/kernel/metavar.cpp +++ b/src/tests/kernel/metavar.cpp @@ -75,7 +75,7 @@ static void tst1() { lean_assert(m1 != m2); expr f = Const("f"); expr a = Const("a"); - subst = subst.assign(m1, f(a)); + subst.assign(m1, f(a)); lean_assert(subst.is_assigned(m1)); lean_assert(!subst.is_assigned(m2)); lean_assert(*subst.get_expr(m1) == f(a)); @@ -91,8 +91,8 @@ static void tst2() { expr f = Const("f"); expr g = Const("g"); expr a = Const("a"); - s = s.assign(m1, f(m2), mk_assumption_justification(1)); - s = s.assign(m2, g(a), mk_assumption_justification(2)); + s.assign(m1, f(m2), mk_assumption_justification(1)); + s.assign(m2, g(a), mk_assumption_justification(2)); lean_assert(check_assumptions(s.get_assignment(m1)->second, {1})); lean_assert(s.occurs(m1, f(m1))); lean_assert(s.occurs(m2, f(m1))); @@ -102,22 +102,8 @@ static void tst2() { std::cout << s << "\n"; auto p1 = s.instantiate_metavars(g(m1)); check_assumptions(p1.second, {1, 2}); - lean_assert(check_assumptions(s.get_assignment(m1)->second, {1})); - lean_assert(p1.first == g(f(g(a)))); - auto ts = s.updt_instantiate_metavars(g(m1, m3)); - s = std::get<2>(ts); - check_assumptions(std::get<1>(ts), {1, 2}); - std::cout << std::get<0>(ts) << "\n"; - std::cout << s << "\n"; lean_assert(check_assumptions(s.get_assignment(m1)->second, {1, 2})); - lean_assert(std::get<0>(ts) == g(f(g(a)), m3)); - s = s.assign(m3, f(m1, m2), mk_assumption_justification(3)); - auto p3 = s.instantiate_metavars(g(m1, m3)); - lean_assert(check_assumptions(p3.second, {1, 2, 3})); - std::cout << p3.first << "\n"; - std::cout << p3.second << "\n"; - std::cout << s << "\n"; - lean_assert_eq(p3.first, g(f(g(a)), f(f(g(a)), g(a)))); + lean_assert(p1.first == g(f(g(a)))); } static void tst3() { @@ -129,7 +115,7 @@ static void tst3() { expr b = Const("b"); expr x = Local("x", Prop); expr y = Local("y", Prop); - s = s.assign(m1, Fun({x, y}, f(y, x))); + s.assign(m1, Fun({x, y}, f(y, x))); lean_assert_eq(s.instantiate_metavars(m1(a, b, g(a))).first, f(b, a, g(a))); lean_assert_eq(s.instantiate_metavars(m1(a)).first, Fun(y, f(y, a))); lean_assert_eq(s.instantiate_metavars(m1(a, b)).first, f(b, a)); @@ -150,10 +136,10 @@ static void tst4() { expr T2 = mk_sort(u); expr t = f(T1, T2, m1, m2); lean_assert(s.instantiate_metavars(t).first == t); - s = s.assign(m1, a, justification()); - s = s.assign(m2, m3, justification()); + s.assign(m1, a, justification()); + s.assign(m2, m3, justification()); lean_assert(s.instantiate_metavars(t).first == f(T1, T2, a, m3)); - s = s.assign(l1, level(), justification()); + s.assign(l1, level(), justification()); lean_assert(s.instantiate_metavars(t).first == f(Prop, T2, a, m3)); } diff --git a/tests/lua/jst1.lua b/tests/lua/jst1.lua index 34de22934..271ee22e3 100644 --- a/tests/lua/jst1.lua +++ b/tests/lua/jst1.lua @@ -12,7 +12,7 @@ local m = mk_metavar("m", Prop) local j2 = justification("expresion must be a Proposition", env, g(m)) print(j2:pp()) local s = substitution() -s = s:assign(m, f(a)) +s:assign(m, f(a)) print(j2:pp(s)) local j3 = assumption_justification(1) assert(not j2:depends_on(1)) diff --git a/tests/lua/subst1.lua b/tests/lua/subst1.lua index 1826c6a26..b450ad372 100644 --- a/tests/lua/subst1.lua +++ b/tests/lua/subst1.lua @@ -7,22 +7,22 @@ local f = Const("f") local g = Const("g") local a = Const("a") local t = f(f(a)) -s = s:assign(m, t) +s:assign(m, t) assert(s:is_assigned(m)) assert(s:is_expr_assigned("m")) assert(not s:is_level_assigned("m")) assert(s:instantiate(g(m)) == g(t)) -s = s:assign("m", a) +s:assign("m", a) assert(s:instantiate(g(m)) == g(a)) local l = mk_level_one() local u = mk_meta_univ("u") -s = s:assign(u, l) +s:assign(u, l) assert(s:is_assigned(u)) assert(s:is_level_assigned("u")) assert(not s:is_expr_assigned("u")) assert(s:get_expr("m") == a) local m2 = mk_metavar("m2", Prop) -s = s:assign(m2, f(m)) +s:assign(m2, f(m)) print(s:get_expr("m2")) assert(s:occurs(m, f(m2))) assert(s:occurs_expr("m", f(m2))) diff --git a/tests/lua/unify1.lua b/tests/lua/unify1.lua index cad9e61fd..f202315ee 100644 --- a/tests/lua/unify1.lua +++ b/tests/lua/unify1.lua @@ -1,6 +1,7 @@ function test_unify_simple(lhs, rhs, expected) print(tostring(lhs) .. " =?= " .. tostring(rhs) .. ", expected: " .. tostring(expected)) - r, s = unify_simple(substitution(), lhs, rhs, justification()) + s = substitution() + r = unify_simple(s, lhs, rhs, justification()) if r == unify_status.Solved then s:for_each_expr(function(n, v, j) print(" " .. tostring(n) .. " := " .. tostring(v))