diff --git a/extras/latex/lstlean.tex b/extras/latex/lstlean.tex index 7453ff988..d28462d63 100644 --- a/extras/latex/lstlean.tex +++ b/extras/latex/lstlean.tex @@ -5,7 +5,7 @@ \lstdefinelanguage{lean} { % Anything betweeen $ becomes LaTeX math mode -mathescape=true, +mathescape=true, % Comments may or not include Latex commands texcl=false, @@ -24,7 +24,7 @@ using, namespace, section, fields, find_decl, attribute, local, set_option, extends, include, omit, classes, instances, coercions, metaclasses, raw, migrate, replacing, calc, have, obtains, show, suffices, by, by+, in, at, let, forall, Pi, fun, -exists, if, dif, then, else, assume, assert, take, +exists, if, dif, then, else, assume, assert, take, obtain, from, aliases }, @@ -36,22 +36,22 @@ morekeywords=[3]{ Cond, or_else, then, try, when, assumption, eassumption, rapply, apply, fapply, eapply, rename, intro, intros, all_goals, fold, focus, focus_at, generalize, generalizes, clear, clears, revert, reverts, back, beta, done, exact, rexact, -refine, repeat, whnf, rotate, rotate_left, rotate_right, inversion, cases, rewrite, -xrewrite, krewrite, blast, simp, esimp, unfold, change, check_expr, contradiction, -exfalso, split, existsi, constructor, fconstructor, left, right, injection, congruence, reflexivity, -symmetry, transitivity, state, induction, induction_using, fail, append, +refine, repeat, whnf, rotate, rotate_left, rotate_right, inversion, cases, rewrite, +xrewrite, krewrite, blast, simp, esimp, unfold, change, check_expr, contradiction, +exfalso, split, existsi, constructor, fconstructor, left, right, injection, congruence, reflexivity, +symmetry, transitivity, state, induction, induction_using, fail, append, substvars, now, with_options, with_attributes, with_attrs, note }, -% modifiers, taken from lean-syntax.el +% modifiers, taken from lean-syntax.el % note: 'otherkeywords' is needed because these use a different symbol. % this command doesn't allow us to specify a number -- they are put with [1] otherkeywords={ [persistent], [notation], [visible], [instance], [trans_instance], -[class], [parsing-only], [coercion], [unfold_full], [constructor], +[class], [parsing-only], [coercion], [unfold_full], [constructor], [reducible], [irreducible], [semireducible], [quasireducible], [wf], -[whnf], [multiple_instances], [none], [decl], [declaration], -[relation], [symm], [subst], [refl], [trans], [simp], [congr], +[whnf], [multiple_instances], [none], [decl], [declaration], +[relation], [symm], [subst], [refl], [trans], [simp], [congr], [unify], [backward], [forward], [no_pattern], [begin_end], [tactic], [abbreviation], [reducible], [unfold], [alias], [eqv], [intro], [intro!], [elim], [grinder], [localrefinfo], [recursor] @@ -228,13 +228,13 @@ morestring=[b]", morestring=[d]’, % Size of tabulations -tabsize=3, +tabsize=3, % Enables ASCII chars 128 to 255 extendedchars=false, % Case sensitivity -sensitive=true, +sensitive=true, % Automatic breaking of long lines breaklines=true, @@ -243,9 +243,9 @@ breaklines=true, basicstyle=\ttfamily, % Position of captions is bottom -captionpos=b, +captionpos=b, -% Full flexible columns +% Full flexible columns columns=[l]fullflexible, @@ -258,7 +258,7 @@ identifierstyle={\ttfamily\color{black}}, % Style for declaration keywords keywordstyle=[1]{\ttfamily\color{keywordcolor}}, -% Style for sorts +% Style for sorts keywordstyle=[2]{\ttfamily\color{sortcolor}}, % Style for tactics keywords @@ -274,4 +274,3 @@ stringstyle=\ttfamily, % commentstyle={\ttfamily\footnotesize }, } - diff --git a/src/emacs/lean-syntax.el b/src/emacs/lean-syntax.el index a0e0e4038..623305a9e 100644 --- a/src/emacs/lean-syntax.el +++ b/src/emacs/lean-syntax.el @@ -57,7 +57,7 @@ "whnf" "multiple_instances" "none" "decl" "declaration" "relation" "symm" "subst" "refl" "trans" "simp" "congr" "backward" "forward" "no_pattern" "begin_end" "tactic" "abbreviation" - "reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder" + "reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder" "unify" "localrefinfo" "recursor")) "lean modifiers") (defconst lean-modifiers-regexp diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index 63d14c470..6168f7826 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -779,6 +779,23 @@ static environment normalizer_cmd(parser & p) { return env; } +static environment unify_cmd(parser & p) { + environment const & env = p.env(); + expr e1; level_param_names ls1; + std::tie(e1, ls1) = parse_local_expr(p); + p.check_token_next(get_comma_tk(), "invalid #unify command, proper usage \"#unify e1, e2\""); + expr e2; level_param_names ls2; + std::tie(e2, ls2) = parse_local_expr(p); + default_type_context ctx(env, p.get_options()); + bool success = ctx.is_def_eq(e1, e2); + flycheck_information info(p.regular_stream()); + if (info.enabled()) { + p.display_information_pos(p.cmd_pos()); + } + p.regular_stream() << (success ? "success" : "fail") << endl; + return env; +} + static environment abstract_expr_cmd(parser & p) { unsigned o = p.parse_small_nat(); default_type_context ctx(p.env(), p.get_options()); @@ -841,6 +858,7 @@ void init_cmd_table(cmd_table & r) { add_cmd(r, cmd_info("#congr_simp", "(for debugging purposes)", congr_simp_cmd)); add_cmd(r, cmd_info("#congr_rel", "(for debugging purposes)", congr_rel_cmd)); add_cmd(r, cmd_info("#normalizer", "(for debugging purposes)", normalizer_cmd)); + add_cmd(r, cmd_info("#unify", "(for debugging purposes)", unify_cmd)); add_cmd(r, cmd_info("#accessible", "(for debugging purposes) display number of accessible declarations for blast tactic", accessible_cmd)); add_cmd(r, cmd_info("#simplify", "(for debugging purposes) simplify given expression", simplify_cmd)); add_cmd(r, cmd_info("#abstract_expr", "(for debugging purposes) call abstract expr methods", abstract_expr_cmd)); diff --git a/src/frontends/lean/print_cmd.cpp b/src/frontends/lean/print_cmd.cpp index 5d8c45afd..dc9433bdd 100644 --- a/src/frontends/lean/print_cmd.cpp +++ b/src/frontends/lean/print_cmd.cpp @@ -26,6 +26,7 @@ Author: Leonardo de Moura #include "library/user_recursors.h" #include "library/relation_manager.h" #include "library/noncomputable.h" +#include "library/unification_hint.h" #include "library/definitional/projection.h" #include "library/blast/blast.h" #include "library/blast/simplifier/simplifier.h" @@ -506,6 +507,23 @@ static void print_reducible_info(parser & p, reducible_status s1) { out << n << "\n"; } +static void print_unification_hints(parser & p) { + io_state_stream out = p.regular_stream(); + unification_hints hints; + name ns; + if (p.curr_is_identifier()) { + ns = p.get_name_val(); + p.next(); + hints = get_unification_hints(p.env(), ns); + } else { + hints = get_unification_hints(p.env()); + } + format header; + if (!ns.is_anonymous()) + header = format(" at namespace '") + format(ns) + format("'"); + out << pp_unification_hints(hints, out.get_formatter(), header); +} + static void print_simp_rules(parser & p) { io_state_stream out = p.regular_stream(); blast::scope_debug scope(p.env(), p.ios()); @@ -699,6 +717,9 @@ environment print_cmd(parser & p) { p.next(); p.check_token_next(get_rbracket_tk(), "invalid 'print [recursor]', ']' expected"); print_recursor_info(p); + } else if (p.curr_is_token(get_unify_attr_tk())) { + p.next(); + print_unification_hints(p); } else if (p.curr_is_token(get_simp_attr_tk())) { p.next(); print_simp_rules(p); diff --git a/src/frontends/lean/token_table.cpp b/src/frontends/lean/token_table.cpp index cea1da1e6..bd8616364 100644 --- a/src/frontends/lean/token_table.cpp +++ b/src/frontends/lean/token_table.cpp @@ -120,7 +120,7 @@ void init_token_table(token_table & t) { "multiple_instances", "find_decl", "attribute", "persistent", "include", "omit", "migrate", "init_quotient", "init_hits", "#erase_cache", "#projections", "#telescope_eq", "#compile", "#accessible", "#decl_stats", "#relevant_thms", "#simplify", "#app_builder", "#refl", "#symm", - "#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", nullptr}; + "#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", "#unify", nullptr}; pair aliases[] = {{g_lambda_unicode, "fun"}, {"forall", "Pi"}, {g_forall_unicode, "Pi"}, {g_pi_unicode, "Pi"}, diff --git a/src/frontends/lean/tokens.cpp b/src/frontends/lean/tokens.cpp index b6ae85cc3..800a9d737 100644 --- a/src/frontends/lean/tokens.cpp +++ b/src/frontends/lean/tokens.cpp @@ -119,6 +119,7 @@ static name const * g_intro_attr_tk = nullptr; static name const * g_intro_bang_attr_tk = nullptr; static name const * g_elim_attr_tk = nullptr; static name const * g_recursor_tk = nullptr; +static name const * g_unify_attr_tk = nullptr; static name const * g_attribute_tk = nullptr; static name const * g_with_tk = nullptr; static name const * g_class_tk = nullptr; @@ -269,6 +270,7 @@ void initialize_tokens() { g_intro_bang_attr_tk = new name{"[intro!]"}; g_elim_attr_tk = new name{"[elim]"}; g_recursor_tk = new name{"[recursor"}; + g_unify_attr_tk = new name{"[unify]"}; g_attribute_tk = new name{"attribute"}; g_with_tk = new name{"with"}; g_class_tk = new name{"[class]"}; @@ -420,6 +422,7 @@ void finalize_tokens() { delete g_intro_bang_attr_tk; delete g_elim_attr_tk; delete g_recursor_tk; + delete g_unify_attr_tk; delete g_attribute_tk; delete g_with_tk; delete g_class_tk; @@ -570,6 +573,7 @@ name const & get_intro_attr_tk() { return *g_intro_attr_tk; } name const & get_intro_bang_attr_tk() { return *g_intro_bang_attr_tk; } name const & get_elim_attr_tk() { return *g_elim_attr_tk; } name const & get_recursor_tk() { return *g_recursor_tk; } +name const & get_unify_attr_tk() { return *g_unify_attr_tk; } name const & get_attribute_tk() { return *g_attribute_tk; } name const & get_with_tk() { return *g_with_tk; } name const & get_class_tk() { return *g_class_tk; } diff --git a/src/frontends/lean/tokens.h b/src/frontends/lean/tokens.h index fbcafd71e..791de64e9 100644 --- a/src/frontends/lean/tokens.h +++ b/src/frontends/lean/tokens.h @@ -121,6 +121,7 @@ name const & get_intro_attr_tk(); name const & get_intro_bang_attr_tk(); name const & get_elim_attr_tk(); name const & get_recursor_tk(); +name const & get_unify_attr_tk(); name const & get_attribute_tk(); name const & get_with_tk(); name const & get_class_tk(); diff --git a/src/frontends/lean/tokens.txt b/src/frontends/lean/tokens.txt index c202502db..316214ae5 100644 --- a/src/frontends/lean/tokens.txt +++ b/src/frontends/lean/tokens.txt @@ -114,6 +114,7 @@ intro_attr [intro] intro_bang_attr [intro!] elim_attr [elim] recursor [recursor +unify_attr [unify] attribute attribute with with class [class] diff --git a/src/library/CMakeLists.txt b/src/library/CMakeLists.txt index f252f72d0..eddf9ea9b 100644 --- a/src/library/CMakeLists.txt +++ b/src/library/CMakeLists.txt @@ -18,4 +18,4 @@ add_library(library OBJECT deep_copy.cpp expr_lt.cpp io_state.cpp aux_recursors.cpp norm_num.cpp norm_num.cpp class_instance_resolution.cpp type_context.cpp tmp_type_context.cpp fun_info_manager.cpp congr_lemma_manager.cpp abstract_expr_manager.cpp light_lt_manager.cpp trace.cpp - attribute_manager.cpp error_handling.cpp) + attribute_manager.cpp error_handling.cpp unification_hint.cpp) diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 7a234e40c..84346f194 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -95,6 +95,8 @@ name const * g_lift_down = nullptr; name const * g_lift_up = nullptr; name const * g_linear_ordered_ring = nullptr; name const * g_linear_ordered_semiring = nullptr; +name const * g_list_nil = nullptr; +name const * g_list_cons = nullptr; name const * g_monoid = nullptr; name const * g_mul = nullptr; name const * g_mul_one = nullptr; @@ -264,6 +266,10 @@ name const * g_trans_rel_left = nullptr; name const * g_trans_rel_right = nullptr; name const * g_true = nullptr; name const * g_true_intro = nullptr; +name const * g_unification_hint = nullptr; +name const * g_unification_hint_mk = nullptr; +name const * g_unification_constraint = nullptr; +name const * g_unification_constraint_mk = nullptr; name const * g_weak_order = nullptr; name const * g_well_founded = nullptr; name const * g_zero = nullptr; @@ -363,6 +369,8 @@ void initialize_constants() { g_lift_up = new name{"lift", "up"}; g_linear_ordered_ring = new name{"linear_ordered_ring"}; g_linear_ordered_semiring = new name{"linear_ordered_semiring"}; + g_list_nil = new name{"list", "nil"}; + g_list_cons = new name{"list", "cons"}; g_monoid = new name{"monoid"}; g_mul = new name{"mul"}; g_mul_one = new name{"mul_one"}; @@ -532,6 +540,10 @@ void initialize_constants() { g_trans_rel_right = new name{"trans_rel_right"}; g_true = new name{"true"}; g_true_intro = new name{"true", "intro"}; + g_unification_hint = new name{"unification_hint"}; + g_unification_hint_mk = new name{"unification_hint", "mk"}; + g_unification_constraint = new name{"unification_constraint"}; + g_unification_constraint_mk = new name{"unification_constraint", "mk"}; g_weak_order = new name{"weak_order"}; g_well_founded = new name{"well_founded"}; g_zero = new name{"zero"}; @@ -632,6 +644,8 @@ void finalize_constants() { delete g_lift_up; delete g_linear_ordered_ring; delete g_linear_ordered_semiring; + delete g_list_nil; + delete g_list_cons; delete g_monoid; delete g_mul; delete g_mul_one; @@ -801,6 +815,10 @@ void finalize_constants() { delete g_trans_rel_right; delete g_true; delete g_true_intro; + delete g_unification_hint; + delete g_unification_hint_mk; + delete g_unification_constraint; + delete g_unification_constraint_mk; delete g_weak_order; delete g_well_founded; delete g_zero; @@ -900,6 +918,8 @@ name const & get_lift_down_name() { return *g_lift_down; } name const & get_lift_up_name() { return *g_lift_up; } name const & get_linear_ordered_ring_name() { return *g_linear_ordered_ring; } name const & get_linear_ordered_semiring_name() { return *g_linear_ordered_semiring; } +name const & get_list_nil_name() { return *g_list_nil; } +name const & get_list_cons_name() { return *g_list_cons; } name const & get_monoid_name() { return *g_monoid; } name const & get_mul_name() { return *g_mul; } name const & get_mul_one_name() { return *g_mul_one; } @@ -1069,6 +1089,10 @@ name const & get_trans_rel_left_name() { return *g_trans_rel_left; } name const & get_trans_rel_right_name() { return *g_trans_rel_right; } name const & get_true_name() { return *g_true; } name const & get_true_intro_name() { return *g_true_intro; } +name const & get_unification_hint_name() { return *g_unification_hint; } +name const & get_unification_hint_mk_name() { return *g_unification_hint_mk; } +name const & get_unification_constraint_name() { return *g_unification_constraint; } +name const & get_unification_constraint_mk_name() { return *g_unification_constraint_mk; } name const & get_weak_order_name() { return *g_weak_order; } name const & get_well_founded_name() { return *g_well_founded; } name const & get_zero_name() { return *g_zero; } diff --git a/src/library/constants.h b/src/library/constants.h index 0872a185e..9fccbe000 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -97,6 +97,8 @@ name const & get_lift_down_name(); name const & get_lift_up_name(); name const & get_linear_ordered_ring_name(); name const & get_linear_ordered_semiring_name(); +name const & get_list_nil_name(); +name const & get_list_cons_name(); name const & get_monoid_name(); name const & get_mul_name(); name const & get_mul_one_name(); @@ -266,6 +268,10 @@ name const & get_trans_rel_left_name(); name const & get_trans_rel_right_name(); name const & get_true_name(); name const & get_true_intro_name(); +name const & get_unification_hint_name(); +name const & get_unification_hint_mk_name(); +name const & get_unification_constraint_name(); +name const & get_unification_constraint_mk_name(); name const & get_weak_order_name(); name const & get_well_founded_name(); name const & get_zero_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index d71ac62ec..de299eed7 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -90,6 +90,8 @@ lift.down lift.up linear_ordered_ring linear_ordered_semiring +list.nil +list.cons monoid mul mul_one @@ -259,6 +261,10 @@ trans_rel_left trans_rel_right true true.intro +unification_hint +unification_hint.mk +unification_constraint +unification_constraint.mk weak_order well_founded zero diff --git a/src/library/init_module.cpp b/src/library/init_module.cpp index 6a50c2434..5318016b1 100644 --- a/src/library/init_module.cpp +++ b/src/library/init_module.cpp @@ -47,6 +47,7 @@ Author: Leonardo de Moura #include "library/app_builder.h" #include "library/attribute_manager.h" #include "library/fun_info_manager.h" +#include "library/unification_hint.h" namespace lean { void initialize_library_module() { @@ -93,9 +94,11 @@ void initialize_library_module() { initialize_congr_lemma_manager(); initialize_app_builder(); initialize_fun_info_manager(); + initialize_unification_hint(); } void finalize_library_module() { + finalize_unification_hint(); finalize_fun_info_manager(); finalize_app_builder(); finalize_congr_lemma_manager(); diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 91aa09b02..ec487551c 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/for_each_fn.h" +#include "kernel/replace_fn.h" #include "kernel/inductive/inductive.h" #include "library/trace.h" #include "library/util.h" @@ -23,6 +24,7 @@ Author: Leonardo de Moura #include "library/generic_exception.h" #include "library/class.h" #include "library/constants.h" +#include "library/unification_hint.h" #ifndef LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH #define LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH 32 @@ -1013,10 +1015,7 @@ bool type_context::is_def_eq_core(expr const & t, expr const & s) { if (is_def_eq_proof_irrel(t_n, s_n)) return true; - if (on_is_def_eq_failure(t_n, s_n)) - return is_def_eq_core(t_n, s_n); - else - return false; + return on_is_def_eq_failure(t_n, s_n); } bool type_context::process_postponed(unsigned old_sz) { @@ -1405,19 +1404,112 @@ optional> type_context::find_unsynth_metavar(expr const & e) { } } -bool type_context::on_is_def_eq_failure(expr & e1, expr & e2) { +bool type_context::on_is_def_eq_failure(expr const & e1, expr const & e2) { if (is_app(e1)) { if (auto p1 = find_unsynth_metavar(e1)) { if (mk_nested_instance(p1->first, p1->second)) { - e1 = instantiate_uvars_mvars(e1); - return true; + return is_def_eq_core(instantiate_uvars_mvars(e1), e2); } } } if (is_app(e2)) { if (auto p2 = find_unsynth_metavar(e2)) { if (mk_nested_instance(p2->first, p2->second)) { - e2 = instantiate_uvars_mvars(e2); + return is_def_eq_core(e1, instantiate_uvars_mvars(e2)); + } + } + } + if (try_unification_hints(e1, e2)) { + return true; + } + return false; +} + +struct type_context::unification_hint_fn { + type_context & m_owner; + unification_hint m_hint; + buffer > m_assignment; + + unification_hint_fn(type_context & o, unification_hint const & hint): + m_owner(o), m_hint(hint) { m_assignment.resize(m_hint.get_num_vars()); } + + bool syntactic_match(expr const & pattern, expr const & e) { + unsigned idx; + switch (pattern.kind()) { + case expr_kind::Var: + idx = var_idx(pattern); + if (!m_assignment[idx]) { + m_assignment[idx] = some_expr(e); + return true; + } else { + return m_owner.is_def_eq(*m_assignment[idx], e); + } + case expr_kind::Constant: + return is_constant(e) && const_name(pattern) == const_name(e) + && m_owner.is_def_eq(const_levels(pattern), const_levels(e)); + case expr_kind::Sort: + return is_sort(e) && m_owner.is_def_eq(sort_level(pattern), sort_level(e)); + case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Macro: + // Remark: we do not traverse inside of binders. + return pattern == e; + case expr_kind::App: + return is_app(e) && syntactic_match(app_fn(pattern), app_fn(e)) && syntactic_match(app_arg(pattern), app_arg(e)); + case expr_kind::Local: case expr_kind::Meta: + break; + } + lean_unreachable(); + } + + bool operator()(expr const & lhs, expr const & rhs) { + if (!syntactic_match(m_hint.get_lhs(), lhs)) { + lean_trace(name({"type_context", "unification_hint"}), tout() << "LHS does not match\n";); + return false; + } else if (!syntactic_match(m_hint.get_rhs(), rhs)) { + lean_trace(name({"type_context", "unification_hint"}), tout() << "RHS does not match\n";); + return false; + } else { + auto instantiate_assignment_fn = [&](expr const & e, unsigned offset) { + if (is_var(e)) { + unsigned idx = var_idx(e) + offset; + if (idx < m_assignment.size()) { + lean_assert(m_assignment[idx]); + return m_assignment[idx]; + } + } + return none_expr(); + }; + buffer constraints; + to_buffer(m_hint.get_constraints(), constraints); + for (expr_pair const & p : constraints) { + expr new_lhs = replace(p.first, instantiate_assignment_fn); + expr new_rhs = replace(p.second, instantiate_assignment_fn); + expr new_lhs_inst = m_owner.instantiate_uvars_mvars(new_lhs); + expr new_rhs_inst = m_owner.instantiate_uvars_mvars(new_rhs); + bool success = m_owner.is_def_eq(new_lhs, new_rhs); + lean_trace(name({"type_context", "unification_hint"}), + tout() << new_lhs_inst << " =?= " << new_rhs_inst << "..." + << (success ? "success" : "failed") << "\n";); + if (!success) return false; + } + lean_trace(name({"type_context", "unification_hint"}), tout() << "hint successfully applied\n";); + return true; + } + } +}; + +bool type_context::try_unification_hints(expr const & e1, expr const & e2) { + expr e1_fn = get_app_fn(e1); + expr e2_fn = get_app_fn(e2); + if (is_constant(e1_fn) && is_constant(e2_fn)) { + buffer hints; + get_unification_hints(m_env, const_name(e1_fn), const_name(e2_fn), hints); + for (unification_hint const & hint : hints) { + scope s(*this); + lean_trace(name({"type_context", "unification_hint"}), + tout() << e1 << " =?= " << e2 + << ", pattern: " << hint.get_lhs() << " =?= " << hint.get_rhs() << "\n";); + if (unification_hint_fn(*this, hint)(e1, e2)) { + s.commit(); return true; } } @@ -2042,6 +2134,7 @@ void initialize_type_context() { g_tmp_prefix = new name(name::mk_internal_unique_name()); g_internal_prefix = new name(name::mk_internal_unique_name()); register_trace_class("class_instances"); + register_trace_class(name({"type_context", "unification_hint"})); g_class_instance_max_depth = new name{"class", "instance_max_depth"}; g_class_trans_instances = new name{"class", "trans_instances"}; register_unsigned_option(*g_class_instance_max_depth, LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH, diff --git a/src/library/type_context.h b/src/library/type_context.h index f78700ae7..e5cfa4f3a 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -413,7 +413,11 @@ public: The default implementation tries to invoke type class resolution to assign unassigned metavariables in the given terms. */ - virtual bool on_is_def_eq_failure(expr &, expr &); + virtual bool on_is_def_eq_failure(expr const &, expr const &); + + bool try_unification_hints(expr const &, expr const &); + struct unification_hint_fn; + friend struct unification_hint_fn; bool is_assigned(level const & u) const { return static_cast(get_assignment(u)); } bool is_assigned(expr const & m) const { return static_cast(get_assignment(m)); } diff --git a/src/library/unification_hint.cpp b/src/library/unification_hint.cpp new file mode 100644 index 000000000..c39c28647 --- /dev/null +++ b/src/library/unification_hint.cpp @@ -0,0 +1,239 @@ +/* +Copyright (c) 2015 Daniel Selsam. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Daniel Selsam +*/ +#include +#include "util/sexpr/format.h" +#include "kernel/expr.h" +#include "kernel/error_msgs.h" +#include "library/attribute_manager.h" +#include "library/constants.h" +#include "library/unification_hint.h" +#include "library/util.h" +#include "library/expr_lt.h" +#include "library/scoped_ext.h" + +namespace lean { + +/* Unification hints */ + +unification_hint::unification_hint(expr const & lhs, expr const & rhs, list const & constraints, unsigned num_vars): + m_lhs(lhs), m_rhs(rhs), m_constraints(constraints), m_num_vars(num_vars) {} + +int unification_hint_cmp::operator()(unification_hint const & uh1, unification_hint const & uh2) const { + if (uh1.get_lhs() != uh2.get_lhs()) { + return expr_quick_cmp()(uh1.get_lhs(), uh2.get_lhs()); + } else if (uh1.get_rhs() != uh2.get_rhs()) { + return expr_quick_cmp()(uh1.get_rhs(), uh2.get_rhs()); + } else { + auto it1 = uh1.get_constraints().begin(); + auto it2 = uh2.get_constraints().begin(); + auto end1 = uh1.get_constraints().end(); + auto end2 = uh2.get_constraints().end(); + for (; it1 != end1 && it2 != end2; ++it1, ++it2) { + if (unsigned cmp = expr_pair_quick_cmp()(*it1, *it2)) return cmp; + } + return 0; + } +} + +/* Environment extension */ + +static name * g_class_name = nullptr; +static std::string * g_key = nullptr; + +struct unification_hint_state { + unification_hints m_hints; + name_map m_decl_names_to_prio; // Note: redundant but convenient + + void validate_type(expr const & decl_type) { + expr type = decl_type; + while (is_pi(type)) type = binding_body(type); + if (!is_app_of(type, get_unification_hint_name(), 0)) { + throw exception("invalid unification hint, must return element of type `unification hint`"); + } + } + + void register_hint(name const & decl_name, expr const & value, unsigned priority) { + m_decl_names_to_prio.insert(decl_name, priority); + + expr e_hint = value; + unsigned num_vars = 0; + while (is_lambda(e_hint)) { + e_hint = binding_body(e_hint); + num_vars++; + } + + if (!is_app_of(e_hint, get_unification_hint_mk_name(), 2)) { + throw exception("invalid unification hint, body must be application of 'unification_hint.mk' to two arguments"); + } + + // e_hint := unification_hint.mk pattern constraints + expr e_pattern = app_arg(app_fn(e_hint)); + expr e_constraints = app_arg(e_hint); + + // pattern := unification_constraint.mk _ lhs rhs + expr e_pattern_lhs = app_arg(app_fn(e_pattern)); + expr e_pattern_rhs = app_arg(e_pattern); + + expr e_pattern_lhs_fn = get_app_fn(e_pattern_lhs); + expr e_pattern_rhs_fn = get_app_fn(e_pattern_rhs); + + if (!is_constant(e_pattern_lhs_fn) || !is_constant(e_pattern_rhs_fn)) { + throw exception("invalid unification hint, the heads of both sides of pattern must be constants"); + } + + name_pair key = mk_pair(const_name(e_pattern_lhs_fn), const_name(e_pattern_rhs_fn)); + + buffer constraints; + while (is_app_of(e_constraints, get_list_cons_name(), 3)) { + // e_constraints := cons _ constraint rest + expr e_constraint = app_arg(app_fn(e_constraints)); + expr e_constraint_lhs = app_arg(app_fn(e_constraint)); + expr e_constraint_rhs = app_arg(e_constraint); + constraints.push_back(mk_pair(e_constraint_lhs, e_constraint_rhs)); + e_constraints = app_arg(e_constraints); + } + + if (!is_app_of(e_constraints, get_list_nil_name(), 1)) { + throw exception("invalid unification hint, must provide list of constraints explicitly"); + } + + unification_hint hint(e_pattern_lhs, e_pattern_rhs, to_list(constraints), num_vars); + unification_hint_queue q; + if (auto const & q_ptr = m_hints.find(key)) q = *q_ptr; + q.insert(hint, priority); + m_hints.insert(key, q); + } +}; + +struct unification_hint_entry { + name m_decl_name; + unsigned m_priority; + unification_hint_entry(name const & decl_name, unsigned priority): + m_decl_name(decl_name), m_priority(priority) {} +}; + +struct unification_hint_config { + typedef unification_hint_entry entry; + typedef unification_hint_state state; + + static void add_entry(environment const & env, io_state const &, state & s, entry const & e) { + declaration decl = env.get(e.m_decl_name); + s.validate_type(decl.get_type()); + // Note: only definitions should be tagged as [unify], so if it is not a definition, + // there must have been an error when processing the definition. We return immediately + // so as not to hide the original error. + // TODO(dhs): the downside to this approach is that a [unify] tag on an actual axiom will be silently ignored. + if (decl.is_definition()) s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority); + } + static name const & get_class_name() { + return *g_class_name; + } + static std::string const & get_serialization_key() { + return *g_key; + } + static void write_entry(serializer & s, entry const & e) { + s << e.m_decl_name << e.m_priority; + } + static entry read_entry(deserializer & d) { + name decl_name; unsigned prio; + d >> decl_name >> prio; + return entry(decl_name, prio); + } + static optional get_fingerprint(entry const & e) { + return some(hash(e.m_decl_name.hash(), e.m_priority)); + } +}; + +typedef scoped_ext unification_hint_ext; + +environment add_unification_hint(environment const & env, io_state const & ios, name const & decl_name, unsigned prio, name const & ns, bool persistent) { + return unification_hint_ext::add_entry(env, ios, unification_hint_entry(decl_name, prio), ns, persistent); +} + +bool is_unification_hint(environment const & env, name const & decl_name) { + return unification_hint_ext::get_state(env).m_decl_names_to_prio.contains(decl_name); +} + +unification_hints get_unification_hints(environment const & env) { + return unification_hint_ext::get_state(env).m_hints; +} + +unification_hints get_unification_hints(environment const & env, name const & ns) { + list const * entries = unification_hint_ext::get_entries(env, ns); + unification_hint_state s; + if (entries) { + for (auto const & e : *entries) { + declaration decl = env.get(e.m_decl_name); + s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority); + } + } + return s.m_hints; +} + +void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer & uhints) { + unification_hints hints = unification_hint_ext::get_state(env).m_hints; + if (auto const & q_ptr = hints.find(mk_pair(n1, n2))) { + q_ptr->to_buffer(uhints); + } + if (auto const & q_ptr = hints.find(mk_pair(n2, n1))) { + q_ptr->to_buffer(uhints); + } +} + +/* Pretty-printing */ + +// TODO(dhs): I may not be using all the formatting functions correctly. +format unification_hint::pp(unsigned prio, formatter const & fmt) const { + format r; + if (prio != LEAN_DEFAULT_PRIORITY) + r += paren(format(prio)) + space(); + format r1 = fmt(get_lhs()) + space() + format("=?=") + pp_indent_expr(fmt, get_rhs()); + r1 += space() + lcurly(); + r += group(r1); + for_each(m_constraints, [&](expr_pair p) { + r += fmt(p.first) + space() + format("=?="); + r += space() + fmt(p.second) + comma() + space(); + }); + r += rcurly(); + return r; +} + +format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header) { + format r; + r += format("unification hints"); + r += header + colon() + line(); + hints.for_each([&](name_pair const & names, unification_hint_queue const & q) { + q.for_each([&](unification_hint const & hint) { + r += lp() + format(names.first) + comma() + space() + format(names.second) + rp() + space(); + r += hint.pp(*q.get_prio(hint), fmt) + line(); + }); + }); + return r; +} + +void initialize_unification_hint() { + g_class_name = new name("unification_hint"); + g_key = new std::string("UNIFICATION_HINT"); + + unification_hint_ext::initialize(); + + register_prio_attribute("unify", "unification hint", + add_unification_hint, + is_unification_hint, + [](environment const & env, name const & decl_name) { + if (auto p = unification_hint_ext::get_state(env).m_decl_names_to_prio.find(decl_name)) + return *p; + else + return LEAN_DEFAULT_PRIORITY; + }); +} + +void finalize_unification_hint() { + unification_hint_ext::finalize(); + delete g_key; + delete g_class_name; +} +} diff --git a/src/library/unification_hint.h b/src/library/unification_hint.h new file mode 100644 index 000000000..7f531a989 --- /dev/null +++ b/src/library/unification_hint.h @@ -0,0 +1,69 @@ +/* +Copyright (c) 2015 Daniel Selsam. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Author: Daniel Selsam +*/ +#pragma once +#include "kernel/environment.h" +#include "library/expr_pair.h" +#include "library/io_state.h" +#include "library/head_map.h" +#include "util/priority_queue.h" + +namespace lean { + +/* +Users can declare unification hints using the following structures: + +structure unification_constraint := {A : Type} (lhs : A) (rhs : A) +structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint) + +Example: + +definition both_zero_of_add_eq_zero [unify] (n₁ n₂ : ℕ) (s₁ : has_add ℕ) (s₂ : has_zero ℕ) : unification_hint := + unification_hint.mk (unification_constraint.mk (@add ℕ s₁ n₁ n₂) (@zero ℕ s₂)) + [unification_constraint.mk n₁ (@zero ℕ s₂), + unification_constraint.mk n₂ (@zero ℕ s₂)] + +creates the following unification hint: +m_lhs: add nat #1 #3 #2 +m_rhs: zero nat #0 +m_constraints: [(#3, zero nat #0), (#2, zero nat #0)] +m_num_vars: #4 + +Note that once we have an assignment to all variables from matching, we must substitute the assignments in the constraints. +*/ + +class unification_hint { + expr m_lhs; + expr m_rhs; + + list m_constraints; + unsigned m_num_vars; +public: + expr get_lhs() const { return m_lhs; } + expr get_rhs() const { return m_rhs; } + list get_constraints() const { return m_constraints; } + unsigned get_num_vars() const { return m_num_vars; } + + unification_hint() {} + unification_hint(expr const & lhs, expr const & rhs, list const & constraints, unsigned num_vars); + format pp(unsigned priority, formatter const & fmt) const; +}; + +struct unification_hint_cmp { + int operator()(unification_hint const & uh1, unification_hint const & uh2) const; +}; + +typedef priority_queue unification_hint_queue; +typedef rb_map unification_hints; + +unification_hints get_unification_hints(environment const & env); +unification_hints get_unification_hints(environment const & env, name const & ns); +void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer & hints); + +format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header); + +void initialize_unification_hint(); +void finalize_unification_hint(); +} diff --git a/tests/lean/unification_hints1.lean b/tests/lean/unification_hints1.lean new file mode 100644 index 000000000..6e4ce71ba --- /dev/null +++ b/tests/lean/unification_hints1.lean @@ -0,0 +1,54 @@ +import data.list data.nat +open list nat + +structure unification_constraint := {A : Type} (lhs : A) (rhs : A) +structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint) + +namespace toy +constants (A : Type.{1}) (f h : A → A) (x y z : A) +definition g [irreducible] (x y : A) : A := f z + +#unify (g x y), (f z) + +definition toy_hint [unify] (x y : A) : unification_hint := + unification_hint.mk (unification_constraint.mk (g x y) (f z)) [] + +#unify (g x y), (f z) +print [unify] + +end toy + +namespace add +constants (n : ℕ) + +#unify (n + 1), succ n + +definition add_zero_hint [unify] (m n : ℕ) [has_add ℕ] [has_one ℕ] [has_zero ℕ] : unification_hint := + unification_hint.mk (unification_constraint.mk (m + 1) (succ n)) [unification_constraint.mk m n] + +#unify (n + 1), (succ n) +print [unify] + +end add + +namespace canonical +structure Canonical := (carrier : Type) (op : carrier → carrier) +attribute Canonical.carrier [irreducible] + +constants (A : Type.{1}) (f : A → A) (x : A) +definition A_canonical : Canonical := Canonical.mk A f + +#unify (Canonical.carrier A_canonical), A + +definition Canonical_hint [unify] (C : Canonical) : unification_hint := + unification_hint.mk (unification_constraint.mk (Canonical.carrier C) A) [unification_constraint.mk C A_canonical] + +-- TODO(dhs): we mark carrier as irreducible and prove A_canonical explicitly to work around the fact that +-- the default_type_context does not recognize the elaborator metavariables as metavariables, +-- and so cannot perform the assignment. +#unify (Canonical.carrier A_canonical), A +print [unify] + +end canonical + +print [unify] canonical diff --git a/tests/lean/unification_hints1.lean.expected.out b/tests/lean/unification_hints1.lean.expected.out new file mode 100644 index 000000000..cf3972b12 --- /dev/null +++ b/tests/lean/unification_hints1.lean.expected.out @@ -0,0 +1,15 @@ +fail +success +unification hints: +(toy.g, toy.f) g #1 #0 =?= f z {} +fail +success +unification hints: +(add, nat.succ) #4 + 1 =?= succ #3 {#4 =?= #3, } +fail +success +unification hints: +(canonical.Canonical.carrier, canonical.A) Canonical.carrier #0 =?= A {#0 =?= A_canonical, } +unification hints at namespace 'canonical': +(canonical.Canonical.carrier, canonical.A) canonical.Canonical.carrier #0 =?= + canonical.A {#0 =?= canonical.A_canonical, }