diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 23038c087..992882a07 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -30,7 +30,7 @@ Author: Leonardo de Moura namespace lean { class elaborator { typedef list context; - typedef std::vector constraints; + typedef std::vector constraint_vect; typedef name_map tactic_hints; typedef name_map mvar2meta; @@ -43,7 +43,7 @@ class elaborator { context m_ctx; pos_info_provider * m_pos_provider; justification m_accumulated; // accumulate justification of eagerly used substitutions - constraints m_constraints; + constraint_vect m_constraints; tactic_hints m_tactic_hints; mvar2meta m_mvar2meta; name_set m_displayed_errors; // set of metavariables that we already reported unsolved/unassigned @@ -84,15 +84,16 @@ class elaborator { struct choice_elaborator { elaborator & m_elab; + expr m_mvar; expr m_choice; context m_ctx; substitution m_subst; unsigned m_idx; - choice_elaborator(elaborator & elab, expr const & c, context const & ctx, substitution const & s): - m_elab(elab), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) { + choice_elaborator(elaborator & elab, expr const & mvar, expr const & c, context const & ctx, substitution const & s): + m_elab(elab), m_mvar(mvar), m_choice(c), m_ctx(ctx), m_subst(s), m_idx(0) { } - optional next() { + optional next() { while (m_idx < get_num_choices(m_choice)) { expr const & c = get_choice(m_choice, m_idx); m_idx++; @@ -102,20 +103,21 @@ class elaborator { justification j = m_elab.m_accumulated; m_elab.consume_tc_cnstrs(); list cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end()); - return optional(r, j, cs); + cs = cons(mk_eq_cnstr(m_mvar, r, j), cs); + return optional(cs); } catch (exception &) {} } - return optional(); + return optional(); } }; - lazy_list choose(std::shared_ptr c) { - return mk_lazy_list([=]() { + lazy_list choose(std::shared_ptr c) { + return mk_lazy_list([=]() { auto s = c->next(); if (s) return some(mk_pair(*s, choose(c))); else - return lazy_list::maybe_pair(); + return lazy_list::maybe_pair(); }); } @@ -326,11 +328,11 @@ public: // Possible optimization: try to lookahead and discard some of the alternatives. expr m = mk_meta(t, e.get_tag()); context ctx = m_ctx; - auto choice_fn = [=](expr const & /* t */, substitution const & s, name_generator const & /* ngen */) { - return choose(std::make_shared(*this, e, ctx, s)); + auto fn = [=](expr const & mvar, expr const & /* type */, substitution const & s, name_generator const & /* ngen */) { + return choose(std::make_shared(*this, mvar, e, ctx, s)); }; justification j = mk_justification("none of the overloads is applicable", some_expr(e)); - add_cnstr(mk_choice_cnstr(m, choice_fn, false, j)); + add_cnstr(mk_choice_cnstr(m, fn, false, j)); return m; } @@ -404,10 +406,9 @@ public: expr mk_delayed_coercion(expr const & e, expr const & d_type, expr const & a_type) { expr a = app_arg(e); expr m = mk_meta(some_expr(d_type), a.get_tag()); - auto choice_fn = [=](expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) { + auto choice_fn = [=](expr const & mvar, expr const & new_d_type, substitution const & /* s */, name_generator const & /* ngen */) { expr r = apply_coercion(a, a_type, new_d_type); - a_choice c(r, justification(), list()); - return lazy_list(c); + return lazy_list(constraints(mk_eq_cnstr(mvar, r, justification()))); }; justification j = mk_app_justification(m_env, e, d_type, a_type); add_cnstr(mk_choice_cnstr(m, choice_fn, false, j)); diff --git a/src/kernel/constraint.h b/src/kernel/constraint.h index b731599f6..1e103bf71 100644 --- a/src/kernel/constraint.h +++ b/src/kernel/constraint.h @@ -34,22 +34,18 @@ namespace lean { enum class constraint_kind { Eq, LevelEq, Choice }; class constraint; typedef list constraints; -typedef std::tuple a_choice; // a choice produced by the choice_fn /** \brief A choice_fn is used to enumerate the possible solutions for a metavariable. The input arguments are: - - an inferred type + - metavariable that should be inferred + - the metavariable type - substitution map (metavar -> value) - name generator - - The result is a lazy_list of choices, i.e., tuples containing: - - an expression representing one of the possible solutions - - a justification for it (this is used to accumulate the justification for the substitutions used). - - a list of new constraints (that is, the solution is only valid if the additional constraints can be solved) + The result is a lazy_list of constraints One application of choice constraints is overloaded notation. */ -typedef std::function(expr const &, substitution const &, name_generator const &)> choice_fn; +typedef std::function(expr const &, expr const &, substitution const &, name_generator const &)> choice_fn; struct constraint_cell; class constraint { @@ -76,6 +72,9 @@ public: constraint_cell * raw() const { return m_ptr; } }; +inline bool operator==(constraint const & c1, constraint const & c2) { return c1.raw() == c2.raw(); } +inline bool operator!=(constraint const & c1, constraint const & c2) { return !(c1 == c2); } + constraint mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j); constraint mk_level_eq_cnstr(level const & lhs, level const & rhs, justification const & j); constraint mk_choice_cnstr(expr const & m, choice_fn const & fn, bool delayed, justification const & j); diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 26ce2fd2c..16a514871 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -1431,6 +1431,7 @@ static void open_justification(lua_State * L) { // Constraint DECL_UDATA(constraint) +DEFINE_LUA_LIST(constraint, push_constraint, to_constraint) int push_optional_constraint(lua_State * L, optional const & c) { return c ? push_constraint(L, *c) : push_nil(L); } #define CNSTR_PRED(P) static int constraint_ ## P(lua_State * L) { check_num_args(L, 1); return push_boolean(L, P(to_constraint(L, 1))); } CNSTR_PRED(is_eq_cnstr) @@ -1475,14 +1476,15 @@ static int mk_level_eq_cnstr(lua_State * L) { static choice_fn to_choice_fn(lua_State * L, int idx) { luaL_checktype(L, idx, LUA_TFUNCTION); // user-fun luaref f(L, idx); - return choice_fn([=](expr const & e, substitution const & s, name_generator const & ngen) { + return choice_fn([=](expr const & mvar, expr const & mvar_type, substitution const & s, name_generator const & ngen) { lua_State * L = f.get_state(); f.push(); - push_expr(L, e); + push_expr(L, mvar); + push_expr(L, mvar_type); push_substitution(L, s); push_name_generator(L, ngen); - pcall(L, 3, 1, 0); - buffer r; + pcall(L, 4, 1, 0); + buffer r; if (lua_isnil(L, -1)) { // do nothing } else if (lua_istable(L, -1)) { @@ -1490,35 +1492,12 @@ static choice_fn to_choice_fn(lua_State * L, int idx) { // each entry is an alternative for (int i = 1; i <= num; i++) { lua_rawgeti(L, -1, i); - if (is_expr(L, -1)) { - r.push_back(a_choice(to_expr(L, -1), justification(), constraints())); - } else if (lua_istable(L, -1) && objlen(L, -1) == 3) { - lua_rawgeti(L, -1, 1); - expr c = to_expr(L, -1); - lua_pop(L, 1); - lua_rawgeti(L, -1, 2); - justification j = to_justification(L, -1); - lua_pop(L, 1); - lua_rawgeti(L, -1, 3); - buffer cs; - if (lua_isnil(L, -1)) { - // do nothing - } else if (lua_istable(L, -1)) { - int num_cs = objlen(L, -1); - for (int i = 1; i <= num_cs; i++) { - lua_rawgeti(L, -1, i); - cs.push_back(to_constraint(L, -1)); - lua_pop(L, 1); - } - } else { - throw exception("invalid choice function, result must be an array of triples, " - "where the third element of each triple is an array of constraints"); - } - lua_pop(L, 1); - r.push_back(a_choice(c, j, to_list(cs.begin(), cs.end()))); - } else { - throw exception("invalid choice function, result must be an array of triples"); - } + if (is_constraint(L, -1)) + r.push_back(constraints(to_constraint(L, -1))); + else if (is_expr(L, -1)) + r.push_back(constraints(mk_eq_cnstr(mvar, to_expr(L, -1), justification()))); + else + r.push_back(to_list_constraint_ext(L, -1)); lua_pop(L, 1); } } else { @@ -1960,6 +1939,7 @@ void open_kernel_module(lua_State * L) { open_io_state(L); open_justification(L); open_constraint(L); + open_list_constraint(L); open_substitution(L); open_type_checker(L); open_inductive(L); diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index 496fb3e29..0c8f2fbda 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -288,19 +288,10 @@ struct unifier_fn { }; typedef std::vector> case_split_stack; - struct plugin_case_split : public case_split { + struct lazy_constraints_case_split : public case_split { lazy_list m_tail; - plugin_case_split(unifier_fn & u, lazy_list const & tail):case_split(u), m_tail(tail) {} - virtual bool next(unifier_fn & u) { return u.next_plugin_case_split(*this); } - }; - - struct choice_case_split : public case_split { - expr m_expr; - justification m_jst; - lazy_list m_tail; - choice_case_split(unifier_fn & u, expr const & expr, justification const & j, lazy_list const & tail): - case_split(u), m_expr(expr), m_jst(j), m_tail(tail) {} - virtual bool next(unifier_fn & u) { return u.next_choice_case_split(*this); } + lazy_constraints_case_split(unifier_fn & u, lazy_list const & tail):case_split(u), m_tail(tail) {} + virtual bool next(unifier_fn & u) { return u.next_lazy_constraints_case_split(*this); } }; struct ho_case_split : public case_split { @@ -775,21 +766,13 @@ struct unifier_fn { return !in_conflict(); } - bool process_choice_result(expr const & m, a_choice const & r, justification j) { - j = mk_composite1(j, std::get<1>(r)); - return - process_constraint(mk_eq_cnstr(m, std::get<0>(r), j)) && - process_constraints(std::get<2>(r), j); - } - - bool next_choice_case_split(choice_case_split & cs) { + bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) { auto r = cs.m_tail.pull(); if (r) { cs.restore_state(*this); lean_assert(!in_conflict()); cs.m_tail = r->second; - justification a = mk_assumption_justification(cs.m_assumption_idx); - return process_choice_result(cs.m_expr, r->first, mk_composite1(cs.m_jst, a)); + return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx)); } else { // update conflict update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications)); @@ -797,22 +780,16 @@ struct unifier_fn { } } - bool process_choice_constraint(constraint const & c) { - lean_assert(is_choice_cnstr(c)); - expr const & m = cnstr_expr(c); - choice_fn const & fn = cnstr_choice_fn(c); - auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr); - auto rlist = fn(m_type_jst.first, m_subst, m_ngen.mk_child()); - auto r = rlist.pull(); - justification j = mk_composite1(c.get_justification(), m_type_jst.second); + bool process_lazy_constraints(lazy_list const & l, justification const & j) { + auto r = l.pull(); if (r) { if (r->second.is_nil()) { // there is only one alternative - return process_choice_result(m, r->first, j); + return process_constraints(r->first, j); } else { justification a = mk_assumption_justification(m_next_assumption_idx); - add_case_split(std::unique_ptr(new choice_case_split(*this, m, m_type_jst.second, r->second))); - return process_choice_result(m, r->first, mk_composite1(j, a)); + add_case_split(std::unique_ptr(new lazy_constraints_case_split(*this, r->second))); + return process_constraints(r->first, mk_composite1(j, a)); } } else { set_conflict(j); @@ -820,6 +797,21 @@ struct unifier_fn { } } + bool process_plugin_constraint(constraint const & c) { + lean_assert(!is_choice_cnstr(c)); + lazy_list alts = m_plugin(c, m_ngen.mk_child()); + return process_lazy_constraints(alts, c.get_justification()); + } + + bool process_choice_constraint(constraint const & c) { + lean_assert(is_choice_cnstr(c)); + expr const & m = cnstr_expr(c); + choice_fn const & fn = cnstr_choice_fn(c); + auto m_type_jst = m_subst.instantiate_metavars(m_tc.infer(m), nullptr, nullptr); + lazy_list alts = fn(m, m_type_jst.first, m_subst, m_ngen.mk_child()); + return process_lazy_constraints(alts, mk_composite1(c.get_justification(), m_type_jst.second)); + } + /** \brief Return true iff \c e is of the form (elim ... (?m ...)) */ bool is_elim_meta_app(expr const & e) { if (!is_app(e)) @@ -921,35 +913,6 @@ struct unifier_fn { return process_elim_meta_core(rhs, lhs, j); } - bool next_plugin_case_split(plugin_case_split & cs) { - auto r = cs.m_tail.pull(); - if (r) { - cs.restore_state(*this); - lean_assert(!in_conflict()); - cs.m_tail = r->second; - return process_constraints(r->first, mk_assumption_justification(cs.m_assumption_idx)); - } else { - // update conflict - update_conflict(mk_composite1(*m_conflict, cs.m_failed_justifications)); - return false; - } - } - - bool process_plugin_constraint(constraint const & c) { - lean_assert(!is_choice_cnstr(c)); - lazy_list alts = m_plugin(c, m_ngen.mk_child()); - auto r = alts.pull(); - if (!r) { - set_conflict(c.get_justification()); - return false; - } else { - // create a backtracking point - justification a = mk_assumption_justification(m_next_assumption_idx); - add_case_split(std::unique_ptr(new plugin_case_split(*this, r->second))); - return process_constraints(r->first, a); - } - } - bool next_ho_case_split(ho_case_split & cs) { if (!is_nil(cs.m_tail)) { cs.restore_state(*this); diff --git a/tests/lua/unify5.lua b/tests/lua/unify5.lua index ed8c1c7b2..04d061f04 100644 --- a/tests/lua/unify5.lua +++ b/tests/lua/unify5.lua @@ -29,9 +29,9 @@ function display_solutions(m, ss) end cs = { mk_eq_cnstr(m1, f(m2, f(m3, m4))), - mk_choice_cnstr(m2, function(e, s, ngen) return {{a, justification(), {}}, {f(a, a), justification(), {}}} end), - mk_choice_cnstr(m3, function(e, s, ngen) return {{b, justification(), {}}, {f(b, b), justification(), {}}} end), - mk_choice_cnstr(m4, function(e, s, ngen) return {a, b} end) + mk_choice_cnstr(m2, function(m, e, s, ngen) return {{mk_eq_cnstr(m, a)}, {mk_eq_cnstr(m, f(a, a))}} end), + mk_choice_cnstr(m3, function(m, e, s, ngen) return {mk_eq_cnstr(m, b), mk_eq_cnstr(m, f(b, b))} end), + mk_choice_cnstr(m4, function(m, e, s, ngen) return {a, b} end) } display_solutions(m1, unify(env, cs, name_generator(), o))