diff --git a/src/kernel/metavar.cpp b/src/kernel/metavar.cpp index d3b1ad904..99bca100d 100644 --- a/src/kernel/metavar.cpp +++ b/src/kernel/metavar.cpp @@ -14,6 +14,11 @@ Author: Leonardo de Moura #include "kernel/level.h" namespace lean { +substitution::substitution(expr_map const & em, level_map const & lm): + m_expr_subst(em), m_level_subst(lm) {} + +substitution::substitution() {} + bool substitution::is_expr_assigned(name const & m) const { return m_expr_subst.contains(m); } @@ -54,23 +59,40 @@ optional substitution::get_level(name const & m) const { return none_level(); } -void substitution::assign(name const & m, expr const & t, justification const & j) { +void substitution::d_assign(name const & m, expr const & t, justification const & j) { lean_assert(closed(t)); m_expr_subst.insert(m, mk_pair(t, j)); } -void substitution::assign(name const & m, expr const & t) { +void substitution::d_assign(name const & m, expr const & t) { assign(m, t, justification()); } -void substitution::assign(name const & m, level const & l, justification const & j) { +void substitution::d_assign(name const & m, level const & l, justification const & j) { m_level_subst.insert(m, mk_pair(l, j)); } -void substitution::assign(name const & m, level const & l) { +void substitution::d_assign(name const & m, level const & l) { assign(m, l, justification()); } +substitution substitution::assign(name const & m, expr const & t, justification const & j) const { + lean_assert(closed(t)); + return substitution(insert(m_expr_subst, m, mk_pair(t, j)), m_level_subst); +} + +substitution substitution::assign(name const & m, expr const & t) const { + return assign(m, t, justification()); +} + +substitution substitution::assign(name const & m, level const & l, justification const & j) const { + return substitution(m_expr_subst, insert(m_level_subst, m, mk_pair(l, j))); +} + +substitution substitution::assign(name const & m, level const & l) const { + return assign(m, l, justification()); +} + void substitution::for_each(std::function const & fn) const { m_expr_subst.for_each([=](name const & n, std::pair const & a) { fn(n, a.first, a.second); @@ -83,7 +105,7 @@ void substitution::for_each(std::function instantiate_metavars(level const & l, substitution & s, bool use_jst, bool updt) { +std::pair substitution::d_instantiate_metavars(level const & l, bool use_jst, bool updt) { if (!has_param(l)) return mk_pair(l, justification()); justification j; @@ -92,16 +114,16 @@ std::pair instantiate_metavars(level const & l, substituti if (!has_meta(l)) { return some_level(l); } else if (is_meta(l)) { - auto p1 = s.get_assignment(l); + auto p1 = get_assignment(l); if (p1) { - auto p2 = instantiate_metavars(p1->first, s, use_jst, updt); + auto p2 = d_instantiate_metavars(p1->first, use_jst, updt); if (use_jst) { justification new_jst = mk_composite1(p1->second, p2.second); if (updt) - s.assign(l, p2.first, new_jst); + d_assign(meta_id(l), p2.first, new_jst); save_jst(new_jst); } else if (updt) { - s.assign(l, p2.first); + d_assign(meta_id(l), p2.first); } return some_level(p2.first); } @@ -121,7 +143,7 @@ protected: void save_jst(justification const & j) { m_jst = mk_composite1(m_jst, j); } level visit_level(level const & l) { - auto p1 = instantiate_metavars(l, m_subst, m_use_jst, m_update); + auto p1 = m_subst.d_instantiate_metavars(l, m_use_jst, m_update); if (m_use_jst) save_jst(p1.second); return p1.first; @@ -151,7 +173,7 @@ protected: if (m_update) { auto p2 = m_subst.d_instantiate_metavars(p1->first); justification new_jst = mk_composite1(p1->second, p2.second); - m_subst.assign(m_name, p2.first, new_jst); + m_subst.d_assign(m_name, p2.first, new_jst); save_jst(new_jst); return p2.first; } else { @@ -162,7 +184,7 @@ protected: } else { if (m_update) { expr r = m_subst.d_instantiate_metavars_wo_jst(p1->first); - m_subst.assign(m_name, r); + m_subst.d_assign(m_name, r); return r; } else { return m_subst.instantiate_metavars_wo_jst(p1->first); @@ -205,7 +227,8 @@ public: }; std::pair substitution::instantiate_metavars(expr const & e) const { - instantiate_metavars_fn fn(const_cast(*this), true, false); + substitution s(*this); + instantiate_metavars_fn fn(s, true, false); expr r = fn(e); return mk_pair(r, fn.get_justification()); } @@ -216,20 +239,35 @@ std::pair substitution::d_instantiate_metavars(expr const & return mk_pair(r, fn.get_justification()); } +std::tuple substitution::updt_instantiate_metavars(expr const & e) const { + substitution s(*this); + instantiate_metavars_fn fn(s, true, true); + expr r = fn(e); + return std::make_tuple(r, fn.get_justification(), s); +} + std::pair substitution::instantiate_metavars(level const & l) const { - return lean::instantiate_metavars(l, const_cast(*this), true, false); + substitution s(*this); + return s.d_instantiate_metavars(l, true, false); } expr substitution::instantiate_metavars_wo_jst(expr const & e) const { - return instantiate_metavars_fn(const_cast(*this), false, false)(e); + substitution s(*this); + return instantiate_metavars_fn(s, false, false)(e); } expr substitution::d_instantiate_metavars_wo_jst(expr const & e) { return instantiate_metavars_fn(*this, false, true)(e); } +std::pair substitution::updt_instantiate_metavars_wo_jst(expr const & e) const { + substitution s(*this); + return mk_pair(instantiate_metavars_fn(s, false, true)(e), s); +} + level substitution::instantiate_metavars_wo_jst(level const & l) const { - return lean::instantiate_metavars(l, const_cast(*this), false, false).first; + substitution s(*this); + return s.d_instantiate_metavars(l, false, false).first; } bool substitution::occurs_expr(name const & m, expr const & e) const { diff --git a/src/kernel/metavar.h b/src/kernel/metavar.h index f00487fc9..dfabb356e 100644 --- a/src/kernel/metavar.h +++ b/src/kernel/metavar.h @@ -13,9 +13,22 @@ Author: Leonardo de Moura namespace lean { class substitution { - rb_map, name_quick_cmp> m_expr_subst; - rb_map, name_quick_cmp> m_level_subst; + typedef rb_map, name_quick_cmp> expr_map; + typedef rb_map, name_quick_cmp> level_map; + expr_map m_expr_subst; + level_map m_level_subst; + + substitution(expr_map const & em, level_map const & lm); + void d_assign(name const & m, expr const & t, justification const & j); + void d_assign(name const & m, expr const & t); + void d_assign(name const & m, level const & t, justification const & j); + void d_assign(name const & m, level const & t); + std::pair d_instantiate_metavars(expr const & e); + expr d_instantiate_metavars_wo_jst(expr const & e); + std::pair d_instantiate_metavars(level const & l, bool use_jst, bool updt); + friend class instantiate_metavars_fn; public: + substitution(); typedef optional> opt_expr_jst; typedef optional> opt_level_jst; @@ -28,11 +41,11 @@ public: optional get_expr(name const & m) const; optional get_level(name const & m) const; - void assign(name const & m, expr const & t, justification const & j); - void assign(name const & m, expr const & t); + substitution assign(name const & m, expr const & t, justification const & j) const; + substitution assign(name const & m, expr const & t) const; - void assign(name const & m, level const & t, justification const & j); - void assign(name const & m, level const & t); + substitution assign(name const & m, level const & t, justification const & j) const; + substitution assign(name const & m, level const & t) const; void for_each(std::function const & fn) const; void for_each(std::function const & fn) const; @@ -40,14 +53,14 @@ public: bool is_assigned(expr const & m) const { lean_assert(is_metavar(m)); return is_expr_assigned(mlocal_name(m)); } opt_expr_jst get_assignment(expr const & m) const { lean_assert(is_metavar(m)); return get_expr_assignment(mlocal_name(m)); } optional get_expr(expr const & m) const { lean_assert(is_metavar(m)); return get_expr(mlocal_name(m)); } - void assign(expr const & m, expr const & t, justification const & j) { lean_assert(is_metavar(m)); assign(mlocal_name(m), t, j); } - void assign(expr const & m, expr const & t) { lean_assert(is_metavar(m)); return assign(mlocal_name(m), t); } + substitution assign(expr const & m, expr const & t, justification const & j) { lean_assert(is_metavar(m)); return assign(mlocal_name(m), t, j); } + substitution assign(expr const & m, expr const & t) const { lean_assert(is_metavar(m)); return assign(mlocal_name(m), t); } bool is_assigned(level const & m) const { lean_assert(is_meta(m)); return is_level_assigned(meta_id(m)); } opt_level_jst get_assignment(level const & m) const { lean_assert(is_meta(m)); return get_level_assignment(meta_id(m)); } optional get_level(level const & m) const { lean_assert(is_meta(m)); return get_level(meta_id(m)); } - void assign(level const & m, level const & l, justification const & j) { lean_assert(is_meta(m)); assign(meta_id(m), l, j); } - void assign(level const & m, level const & l) { lean_assert(is_meta(m)); return assign(meta_id(m), l); } + substitution assign(level const & m, level const & l, justification const & j) const { lean_assert(is_meta(m)); return assign(meta_id(m), l, j); } + substitution assign(level const & m, level const & l) { lean_assert(is_meta(m)); return assign(meta_id(m), l); } /** \brief Instantiate metavariables in \c e assigned in this substitution. */ std::pair instantiate_metavars(expr const & e) const; @@ -56,8 +69,9 @@ public: \brief Similar to the previous function, but it compress the substitution. By compress, we mean, for any metavariable \c m reachable from \c e, if s[m] = t, and t has asssigned metavariables, then s[m] <- instantiate_metavars(t, s). + The updated substitution is returned. */ - std::pair d_instantiate_metavars(expr const & e); + std::tuple updt_instantiate_metavars(expr const & e) const; /** \brief Instantiate level metavariables in \c l. */ std::pair instantiate_metavars(level const & l) const; @@ -68,7 +82,7 @@ public: */ expr instantiate_metavars_wo_jst(expr const & e) const; - expr d_instantiate_metavars_wo_jst(expr const & e); + std::pair updt_instantiate_metavars_wo_jst(expr const & e) const; /** \brief Instantiate level metavariables in \c l, but does not return justification object. */ level instantiate_metavars_wo_jst(level const & l) const; diff --git a/src/tests/kernel/metavar.cpp b/src/tests/kernel/metavar.cpp index 59f2bb7d7..f82ea7a48 100644 --- a/src/tests/kernel/metavar.cpp +++ b/src/tests/kernel/metavar.cpp @@ -75,7 +75,7 @@ static void tst1() { lean_assert(m1 != m2); expr f = Const("f"); expr a = Const("a"); - subst.assign(m1, f(a)); + subst = subst.assign(m1, f(a)); lean_assert(subst.is_assigned(m1)); lean_assert(!subst.is_assigned(m2)); lean_assert(*subst.get_expr(m1) == f(a)); @@ -91,8 +91,8 @@ static void tst2() { expr f = Const("f"); expr g = Const("g"); expr a = Const("a"); - s.assign(m1, f(m2), mk_assumption_justification(1)); - s.assign(m2, g(a), mk_assumption_justification(2)); + s = s.assign(m1, f(m2), mk_assumption_justification(1)); + s = s.assign(m2, g(a), mk_assumption_justification(2)); lean_assert(check_assumptions(s.get_assignment(m1)->second, {1})); lean_assert(s.occurs(m1, f(m1))); lean_assert(s.occurs(m2, f(m1))); @@ -104,13 +104,14 @@ static void tst2() { check_assumptions(p1.second, {1, 2}); lean_assert(check_assumptions(s.get_assignment(m1)->second, {1})); lean_assert(p1.first == g(f(g(a)))); - auto p2 = s.d_instantiate_metavars(g(m1, m3)); - check_assumptions(p2.second, {1, 2}); - std::cout << p2.first << "\n"; + auto ts = s.updt_instantiate_metavars(g(m1, m3)); + s = std::get<2>(ts); + check_assumptions(std::get<1>(ts), {1, 2}); + std::cout << std::get<0>(ts) << "\n"; std::cout << s << "\n"; lean_assert(check_assumptions(s.get_assignment(m1)->second, {1, 2})); - lean_assert(p2.first == g(f(g(a)), m3)); - s.assign(m3, f(m1, m2), mk_assumption_justification(3)); + lean_assert(std::get<0>(ts) == g(f(g(a)), m3)); + s = s.assign(m3, f(m1, m2), mk_assumption_justification(3)); auto p3 = s.instantiate_metavars(g(m1, m3)); lean_assert(check_assumptions(p3.second, {1, 2, 3})); std::cout << p3.first << "\n"; @@ -128,7 +129,7 @@ static void tst3() { expr y = Const("y"); expr a = Const("a"); expr b = Const("b"); - s.assign(m1, Fun({{x, Bool}, {y, Bool}}, f(y, x))); + s = s.assign(m1, Fun({{x, Bool}, {y, Bool}}, f(y, x))); lean_assert_eq(s.instantiate_metavars(m1(a, b, g(a))).first, f(b, a, g(a))); lean_assert_eq(s.instantiate_metavars(m1(a)).first, Fun({y, Bool}, f(y, a))); lean_assert_eq(s.instantiate_metavars(m1(a, b)).first, f(b, a));