feat(library/unification_hint): basic handling of user-supplied unification hints

This commit is contained in:
Daniel Selsam 2016-02-12 09:03:06 -08:00 committed by Leonardo de Moura
parent d8fb6f5082
commit bb4b8da582
19 changed files with 585 additions and 28 deletions

View file

@ -5,7 +5,7 @@
\lstdefinelanguage{lean} {
% Anything betweeen $ becomes LaTeX math mode
mathescape=true,
mathescape=true,
% Comments may or not include Latex commands
texcl=false,
@ -24,7 +24,7 @@ using, namespace, section, fields, find_decl,
attribute, local, set_option, extends, include, omit, classes,
instances, coercions, metaclasses, raw, migrate, replacing,
calc, have, obtains, show, suffices, by, by+, in, at, let, forall, Pi, fun,
exists, if, dif, then, else, assume, assert, take,
exists, if, dif, then, else, assume, assert, take,
obtain, from, aliases
},
@ -36,22 +36,22 @@ morekeywords=[3]{
Cond, or_else, then, try, when, assumption, eassumption, rapply,
apply, fapply, eapply, rename, intro, intros, all_goals, fold, focus, focus_at,
generalize, generalizes, clear, clears, revert, reverts, back, beta, done, exact, rexact,
refine, repeat, whnf, rotate, rotate_left, rotate_right, inversion, cases, rewrite,
xrewrite, krewrite, blast, simp, esimp, unfold, change, check_expr, contradiction,
exfalso, split, existsi, constructor, fconstructor, left, right, injection, congruence, reflexivity,
symmetry, transitivity, state, induction, induction_using, fail, append,
refine, repeat, whnf, rotate, rotate_left, rotate_right, inversion, cases, rewrite,
xrewrite, krewrite, blast, simp, esimp, unfold, change, check_expr, contradiction,
exfalso, split, existsi, constructor, fconstructor, left, right, injection, congruence, reflexivity,
symmetry, transitivity, state, induction, induction_using, fail, append,
substvars, now, with_options, with_attributes, with_attrs, note
},
% modifiers, taken from lean-syntax.el
% modifiers, taken from lean-syntax.el
% note: 'otherkeywords' is needed because these use a different symbol.
% this command doesn't allow us to specify a number -- they are put with [1]
otherkeywords={
[persistent], [notation], [visible], [instance], [trans_instance],
[class], [parsing-only], [coercion], [unfold_full], [constructor],
[class], [parsing-only], [coercion], [unfold_full], [constructor],
[reducible], [irreducible], [semireducible], [quasireducible], [wf],
[whnf], [multiple_instances], [none], [decl], [declaration],
[relation], [symm], [subst], [refl], [trans], [simp], [congr],
[whnf], [multiple_instances], [none], [decl], [declaration],
[relation], [symm], [subst], [refl], [trans], [simp], [congr], [unify],
[backward], [forward], [no_pattern], [begin_end], [tactic], [abbreviation],
[reducible], [unfold], [alias], [eqv], [intro], [intro!], [elim], [grinder],
[localrefinfo], [recursor]
@ -228,13 +228,13 @@ morestring=[b]",
morestring=[d]’,
% Size of tabulations
tabsize=3,
tabsize=3,
% Enables ASCII chars 128 to 255
extendedchars=false,
% Case sensitivity
sensitive=true,
sensitive=true,
% Automatic breaking of long lines
breaklines=true,
@ -243,9 +243,9 @@ breaklines=true,
basicstyle=\ttfamily,
% Position of captions is bottom
captionpos=b,
captionpos=b,
% Full flexible columns
% Full flexible columns
columns=[l]fullflexible,
@ -258,7 +258,7 @@ identifierstyle={\ttfamily\color{black}},
% Style for declaration keywords
keywordstyle=[1]{\ttfamily\color{keywordcolor}},
% Style for sorts
% Style for sorts
keywordstyle=[2]{\ttfamily\color{sortcolor}},
% Style for tactics keywords
@ -274,4 +274,3 @@ stringstyle=\ttfamily,
% commentstyle={\ttfamily\footnotesize },
}

View file

@ -57,7 +57,7 @@
"whnf" "multiple_instances" "none" "decl" "declaration"
"relation" "symm" "subst" "refl" "trans" "simp" "congr"
"backward" "forward" "no_pattern" "begin_end" "tactic" "abbreviation"
"reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder"
"reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder" "unify"
"localrefinfo" "recursor"))
"lean modifiers")
(defconst lean-modifiers-regexp

View file

@ -779,6 +779,23 @@ static environment normalizer_cmd(parser & p) {
return env;
}
static environment unify_cmd(parser & p) {
environment const & env = p.env();
expr e1; level_param_names ls1;
std::tie(e1, ls1) = parse_local_expr(p);
p.check_token_next(get_comma_tk(), "invalid #unify command, proper usage \"#unify e1, e2\"");
expr e2; level_param_names ls2;
std::tie(e2, ls2) = parse_local_expr(p);
default_type_context ctx(env, p.get_options());
bool success = ctx.is_def_eq(e1, e2);
flycheck_information info(p.regular_stream());
if (info.enabled()) {
p.display_information_pos(p.cmd_pos());
}
p.regular_stream() << (success ? "success" : "fail") << endl;
return env;
}
static environment abstract_expr_cmd(parser & p) {
unsigned o = p.parse_small_nat();
default_type_context ctx(p.env(), p.get_options());
@ -841,6 +858,7 @@ void init_cmd_table(cmd_table & r) {
add_cmd(r, cmd_info("#congr_simp", "(for debugging purposes)", congr_simp_cmd));
add_cmd(r, cmd_info("#congr_rel", "(for debugging purposes)", congr_rel_cmd));
add_cmd(r, cmd_info("#normalizer", "(for debugging purposes)", normalizer_cmd));
add_cmd(r, cmd_info("#unify", "(for debugging purposes)", unify_cmd));
add_cmd(r, cmd_info("#accessible", "(for debugging purposes) display number of accessible declarations for blast tactic", accessible_cmd));
add_cmd(r, cmd_info("#simplify", "(for debugging purposes) simplify given expression", simplify_cmd));
add_cmd(r, cmd_info("#abstract_expr", "(for debugging purposes) call abstract expr methods", abstract_expr_cmd));

View file

@ -26,6 +26,7 @@ Author: Leonardo de Moura
#include "library/user_recursors.h"
#include "library/relation_manager.h"
#include "library/noncomputable.h"
#include "library/unification_hint.h"
#include "library/definitional/projection.h"
#include "library/blast/blast.h"
#include "library/blast/simplifier/simplifier.h"
@ -506,6 +507,23 @@ static void print_reducible_info(parser & p, reducible_status s1) {
out << n << "\n";
}
static void print_unification_hints(parser & p) {
io_state_stream out = p.regular_stream();
unification_hints hints;
name ns;
if (p.curr_is_identifier()) {
ns = p.get_name_val();
p.next();
hints = get_unification_hints(p.env(), ns);
} else {
hints = get_unification_hints(p.env());
}
format header;
if (!ns.is_anonymous())
header = format(" at namespace '") + format(ns) + format("'");
out << pp_unification_hints(hints, out.get_formatter(), header);
}
static void print_simp_rules(parser & p) {
io_state_stream out = p.regular_stream();
blast::scope_debug scope(p.env(), p.ios());
@ -699,6 +717,9 @@ environment print_cmd(parser & p) {
p.next();
p.check_token_next(get_rbracket_tk(), "invalid 'print [recursor]', ']' expected");
print_recursor_info(p);
} else if (p.curr_is_token(get_unify_attr_tk())) {
p.next();
print_unification_hints(p);
} else if (p.curr_is_token(get_simp_attr_tk())) {
p.next();
print_simp_rules(p);

View file

@ -120,7 +120,7 @@ void init_token_table(token_table & t) {
"multiple_instances", "find_decl", "attribute", "persistent",
"include", "omit", "migrate", "init_quotient", "init_hits", "#erase_cache", "#projections", "#telescope_eq",
"#compile", "#accessible", "#decl_stats", "#relevant_thms", "#simplify", "#app_builder", "#refl", "#symm",
"#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", nullptr};
"#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", "#unify", nullptr};
pair<char const *, char const *> aliases[] =
{{g_lambda_unicode, "fun"}, {"forall", "Pi"}, {g_forall_unicode, "Pi"}, {g_pi_unicode, "Pi"},

View file

@ -119,6 +119,7 @@ static name const * g_intro_attr_tk = nullptr;
static name const * g_intro_bang_attr_tk = nullptr;
static name const * g_elim_attr_tk = nullptr;
static name const * g_recursor_tk = nullptr;
static name const * g_unify_attr_tk = nullptr;
static name const * g_attribute_tk = nullptr;
static name const * g_with_tk = nullptr;
static name const * g_class_tk = nullptr;
@ -269,6 +270,7 @@ void initialize_tokens() {
g_intro_bang_attr_tk = new name{"[intro!]"};
g_elim_attr_tk = new name{"[elim]"};
g_recursor_tk = new name{"[recursor"};
g_unify_attr_tk = new name{"[unify]"};
g_attribute_tk = new name{"attribute"};
g_with_tk = new name{"with"};
g_class_tk = new name{"[class]"};
@ -420,6 +422,7 @@ void finalize_tokens() {
delete g_intro_bang_attr_tk;
delete g_elim_attr_tk;
delete g_recursor_tk;
delete g_unify_attr_tk;
delete g_attribute_tk;
delete g_with_tk;
delete g_class_tk;
@ -570,6 +573,7 @@ name const & get_intro_attr_tk() { return *g_intro_attr_tk; }
name const & get_intro_bang_attr_tk() { return *g_intro_bang_attr_tk; }
name const & get_elim_attr_tk() { return *g_elim_attr_tk; }
name const & get_recursor_tk() { return *g_recursor_tk; }
name const & get_unify_attr_tk() { return *g_unify_attr_tk; }
name const & get_attribute_tk() { return *g_attribute_tk; }
name const & get_with_tk() { return *g_with_tk; }
name const & get_class_tk() { return *g_class_tk; }

View file

@ -121,6 +121,7 @@ name const & get_intro_attr_tk();
name const & get_intro_bang_attr_tk();
name const & get_elim_attr_tk();
name const & get_recursor_tk();
name const & get_unify_attr_tk();
name const & get_attribute_tk();
name const & get_with_tk();
name const & get_class_tk();

View file

@ -114,6 +114,7 @@ intro_attr [intro]
intro_bang_attr [intro!]
elim_attr [elim]
recursor [recursor
unify_attr [unify]
attribute attribute
with with
class [class]

View file

@ -18,4 +18,4 @@ add_library(library OBJECT deep_copy.cpp expr_lt.cpp io_state.cpp
aux_recursors.cpp norm_num.cpp norm_num.cpp class_instance_resolution.cpp type_context.cpp
tmp_type_context.cpp fun_info_manager.cpp congr_lemma_manager.cpp
abstract_expr_manager.cpp light_lt_manager.cpp trace.cpp
attribute_manager.cpp error_handling.cpp)
attribute_manager.cpp error_handling.cpp unification_hint.cpp)

View file

@ -95,6 +95,8 @@ name const * g_lift_down = nullptr;
name const * g_lift_up = nullptr;
name const * g_linear_ordered_ring = nullptr;
name const * g_linear_ordered_semiring = nullptr;
name const * g_list_nil = nullptr;
name const * g_list_cons = nullptr;
name const * g_monoid = nullptr;
name const * g_mul = nullptr;
name const * g_mul_one = nullptr;
@ -264,6 +266,10 @@ name const * g_trans_rel_left = nullptr;
name const * g_trans_rel_right = nullptr;
name const * g_true = nullptr;
name const * g_true_intro = nullptr;
name const * g_unification_hint = nullptr;
name const * g_unification_hint_mk = nullptr;
name const * g_unification_constraint = nullptr;
name const * g_unification_constraint_mk = nullptr;
name const * g_weak_order = nullptr;
name const * g_well_founded = nullptr;
name const * g_zero = nullptr;
@ -363,6 +369,8 @@ void initialize_constants() {
g_lift_up = new name{"lift", "up"};
g_linear_ordered_ring = new name{"linear_ordered_ring"};
g_linear_ordered_semiring = new name{"linear_ordered_semiring"};
g_list_nil = new name{"list", "nil"};
g_list_cons = new name{"list", "cons"};
g_monoid = new name{"monoid"};
g_mul = new name{"mul"};
g_mul_one = new name{"mul_one"};
@ -532,6 +540,10 @@ void initialize_constants() {
g_trans_rel_right = new name{"trans_rel_right"};
g_true = new name{"true"};
g_true_intro = new name{"true", "intro"};
g_unification_hint = new name{"unification_hint"};
g_unification_hint_mk = new name{"unification_hint", "mk"};
g_unification_constraint = new name{"unification_constraint"};
g_unification_constraint_mk = new name{"unification_constraint", "mk"};
g_weak_order = new name{"weak_order"};
g_well_founded = new name{"well_founded"};
g_zero = new name{"zero"};
@ -632,6 +644,8 @@ void finalize_constants() {
delete g_lift_up;
delete g_linear_ordered_ring;
delete g_linear_ordered_semiring;
delete g_list_nil;
delete g_list_cons;
delete g_monoid;
delete g_mul;
delete g_mul_one;
@ -801,6 +815,10 @@ void finalize_constants() {
delete g_trans_rel_right;
delete g_true;
delete g_true_intro;
delete g_unification_hint;
delete g_unification_hint_mk;
delete g_unification_constraint;
delete g_unification_constraint_mk;
delete g_weak_order;
delete g_well_founded;
delete g_zero;
@ -900,6 +918,8 @@ name const & get_lift_down_name() { return *g_lift_down; }
name const & get_lift_up_name() { return *g_lift_up; }
name const & get_linear_ordered_ring_name() { return *g_linear_ordered_ring; }
name const & get_linear_ordered_semiring_name() { return *g_linear_ordered_semiring; }
name const & get_list_nil_name() { return *g_list_nil; }
name const & get_list_cons_name() { return *g_list_cons; }
name const & get_monoid_name() { return *g_monoid; }
name const & get_mul_name() { return *g_mul; }
name const & get_mul_one_name() { return *g_mul_one; }
@ -1069,6 +1089,10 @@ name const & get_trans_rel_left_name() { return *g_trans_rel_left; }
name const & get_trans_rel_right_name() { return *g_trans_rel_right; }
name const & get_true_name() { return *g_true; }
name const & get_true_intro_name() { return *g_true_intro; }
name const & get_unification_hint_name() { return *g_unification_hint; }
name const & get_unification_hint_mk_name() { return *g_unification_hint_mk; }
name const & get_unification_constraint_name() { return *g_unification_constraint; }
name const & get_unification_constraint_mk_name() { return *g_unification_constraint_mk; }
name const & get_weak_order_name() { return *g_weak_order; }
name const & get_well_founded_name() { return *g_well_founded; }
name const & get_zero_name() { return *g_zero; }

View file

@ -97,6 +97,8 @@ name const & get_lift_down_name();
name const & get_lift_up_name();
name const & get_linear_ordered_ring_name();
name const & get_linear_ordered_semiring_name();
name const & get_list_nil_name();
name const & get_list_cons_name();
name const & get_monoid_name();
name const & get_mul_name();
name const & get_mul_one_name();
@ -266,6 +268,10 @@ name const & get_trans_rel_left_name();
name const & get_trans_rel_right_name();
name const & get_true_name();
name const & get_true_intro_name();
name const & get_unification_hint_name();
name const & get_unification_hint_mk_name();
name const & get_unification_constraint_name();
name const & get_unification_constraint_mk_name();
name const & get_weak_order_name();
name const & get_well_founded_name();
name const & get_zero_name();

View file

@ -90,6 +90,8 @@ lift.down
lift.up
linear_ordered_ring
linear_ordered_semiring
list.nil
list.cons
monoid
mul
mul_one
@ -259,6 +261,10 @@ trans_rel_left
trans_rel_right
true
true.intro
unification_hint
unification_hint.mk
unification_constraint
unification_constraint.mk
weak_order
well_founded
zero

View file

@ -47,6 +47,7 @@ Author: Leonardo de Moura
#include "library/app_builder.h"
#include "library/attribute_manager.h"
#include "library/fun_info_manager.h"
#include "library/unification_hint.h"
namespace lean {
void initialize_library_module() {
@ -93,9 +94,11 @@ void initialize_library_module() {
initialize_congr_lemma_manager();
initialize_app_builder();
initialize_fun_info_manager();
initialize_unification_hint();
}
void finalize_library_module() {
finalize_unification_hint();
finalize_fun_info_manager();
finalize_app_builder();
finalize_congr_lemma_manager();

View file

@ -11,6 +11,7 @@ Author: Leonardo de Moura
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "kernel/for_each_fn.h"
#include "kernel/replace_fn.h"
#include "kernel/inductive/inductive.h"
#include "library/trace.h"
#include "library/util.h"
@ -23,6 +24,7 @@ Author: Leonardo de Moura
#include "library/generic_exception.h"
#include "library/class.h"
#include "library/constants.h"
#include "library/unification_hint.h"
#ifndef LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH
#define LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH 32
@ -1013,10 +1015,7 @@ bool type_context::is_def_eq_core(expr const & t, expr const & s) {
if (is_def_eq_proof_irrel(t_n, s_n))
return true;
if (on_is_def_eq_failure(t_n, s_n))
return is_def_eq_core(t_n, s_n);
else
return false;
return on_is_def_eq_failure(t_n, s_n);
}
bool type_context::process_postponed(unsigned old_sz) {
@ -1405,19 +1404,112 @@ optional<pair<expr, expr>> type_context::find_unsynth_metavar(expr const & e) {
}
}
bool type_context::on_is_def_eq_failure(expr & e1, expr & e2) {
bool type_context::on_is_def_eq_failure(expr const & e1, expr const & e2) {
if (is_app(e1)) {
if (auto p1 = find_unsynth_metavar(e1)) {
if (mk_nested_instance(p1->first, p1->second)) {
e1 = instantiate_uvars_mvars(e1);
return true;
return is_def_eq_core(instantiate_uvars_mvars(e1), e2);
}
}
}
if (is_app(e2)) {
if (auto p2 = find_unsynth_metavar(e2)) {
if (mk_nested_instance(p2->first, p2->second)) {
e2 = instantiate_uvars_mvars(e2);
return is_def_eq_core(e1, instantiate_uvars_mvars(e2));
}
}
}
if (try_unification_hints(e1, e2)) {
return true;
}
return false;
}
struct type_context::unification_hint_fn {
type_context & m_owner;
unification_hint m_hint;
buffer<optional<expr> > m_assignment;
unification_hint_fn(type_context & o, unification_hint const & hint):
m_owner(o), m_hint(hint) { m_assignment.resize(m_hint.get_num_vars()); }
bool syntactic_match(expr const & pattern, expr const & e) {
unsigned idx;
switch (pattern.kind()) {
case expr_kind::Var:
idx = var_idx(pattern);
if (!m_assignment[idx]) {
m_assignment[idx] = some_expr(e);
return true;
} else {
return m_owner.is_def_eq(*m_assignment[idx], e);
}
case expr_kind::Constant:
return is_constant(e) && const_name(pattern) == const_name(e)
&& m_owner.is_def_eq(const_levels(pattern), const_levels(e));
case expr_kind::Sort:
return is_sort(e) && m_owner.is_def_eq(sort_level(pattern), sort_level(e));
case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Macro:
// Remark: we do not traverse inside of binders.
return pattern == e;
case expr_kind::App:
return is_app(e) && syntactic_match(app_fn(pattern), app_fn(e)) && syntactic_match(app_arg(pattern), app_arg(e));
case expr_kind::Local: case expr_kind::Meta:
break;
}
lean_unreachable();
}
bool operator()(expr const & lhs, expr const & rhs) {
if (!syntactic_match(m_hint.get_lhs(), lhs)) {
lean_trace(name({"type_context", "unification_hint"}), tout() << "LHS does not match\n";);
return false;
} else if (!syntactic_match(m_hint.get_rhs(), rhs)) {
lean_trace(name({"type_context", "unification_hint"}), tout() << "RHS does not match\n";);
return false;
} else {
auto instantiate_assignment_fn = [&](expr const & e, unsigned offset) {
if (is_var(e)) {
unsigned idx = var_idx(e) + offset;
if (idx < m_assignment.size()) {
lean_assert(m_assignment[idx]);
return m_assignment[idx];
}
}
return none_expr();
};
buffer<expr_pair> constraints;
to_buffer(m_hint.get_constraints(), constraints);
for (expr_pair const & p : constraints) {
expr new_lhs = replace(p.first, instantiate_assignment_fn);
expr new_rhs = replace(p.second, instantiate_assignment_fn);
expr new_lhs_inst = m_owner.instantiate_uvars_mvars(new_lhs);
expr new_rhs_inst = m_owner.instantiate_uvars_mvars(new_rhs);
bool success = m_owner.is_def_eq(new_lhs, new_rhs);
lean_trace(name({"type_context", "unification_hint"}),
tout() << new_lhs_inst << " =?= " << new_rhs_inst << "..."
<< (success ? "success" : "failed") << "\n";);
if (!success) return false;
}
lean_trace(name({"type_context", "unification_hint"}), tout() << "hint successfully applied\n";);
return true;
}
}
};
bool type_context::try_unification_hints(expr const & e1, expr const & e2) {
expr e1_fn = get_app_fn(e1);
expr e2_fn = get_app_fn(e2);
if (is_constant(e1_fn) && is_constant(e2_fn)) {
buffer<unification_hint> hints;
get_unification_hints(m_env, const_name(e1_fn), const_name(e2_fn), hints);
for (unification_hint const & hint : hints) {
scope s(*this);
lean_trace(name({"type_context", "unification_hint"}),
tout() << e1 << " =?= " << e2
<< ", pattern: " << hint.get_lhs() << " =?= " << hint.get_rhs() << "\n";);
if (unification_hint_fn(*this, hint)(e1, e2)) {
s.commit();
return true;
}
}
@ -2042,6 +2134,7 @@ void initialize_type_context() {
g_tmp_prefix = new name(name::mk_internal_unique_name());
g_internal_prefix = new name(name::mk_internal_unique_name());
register_trace_class("class_instances");
register_trace_class(name({"type_context", "unification_hint"}));
g_class_instance_max_depth = new name{"class", "instance_max_depth"};
g_class_trans_instances = new name{"class", "trans_instances"};
register_unsigned_option(*g_class_instance_max_depth, LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH,

View file

@ -413,7 +413,11 @@ public:
The default implementation tries to invoke type class resolution to
assign unassigned metavariables in the given terms. */
virtual bool on_is_def_eq_failure(expr &, expr &);
virtual bool on_is_def_eq_failure(expr const &, expr const &);
bool try_unification_hints(expr const &, expr const &);
struct unification_hint_fn;
friend struct unification_hint_fn;
bool is_assigned(level const & u) const { return static_cast<bool>(get_assignment(u)); }
bool is_assigned(expr const & m) const { return static_cast<bool>(get_assignment(m)); }

View file

@ -0,0 +1,239 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include <string>
#include "util/sexpr/format.h"
#include "kernel/expr.h"
#include "kernel/error_msgs.h"
#include "library/attribute_manager.h"
#include "library/constants.h"
#include "library/unification_hint.h"
#include "library/util.h"
#include "library/expr_lt.h"
#include "library/scoped_ext.h"
namespace lean {
/* Unification hints */
unification_hint::unification_hint(expr const & lhs, expr const & rhs, list<expr_pair> const & constraints, unsigned num_vars):
m_lhs(lhs), m_rhs(rhs), m_constraints(constraints), m_num_vars(num_vars) {}
int unification_hint_cmp::operator()(unification_hint const & uh1, unification_hint const & uh2) const {
if (uh1.get_lhs() != uh2.get_lhs()) {
return expr_quick_cmp()(uh1.get_lhs(), uh2.get_lhs());
} else if (uh1.get_rhs() != uh2.get_rhs()) {
return expr_quick_cmp()(uh1.get_rhs(), uh2.get_rhs());
} else {
auto it1 = uh1.get_constraints().begin();
auto it2 = uh2.get_constraints().begin();
auto end1 = uh1.get_constraints().end();
auto end2 = uh2.get_constraints().end();
for (; it1 != end1 && it2 != end2; ++it1, ++it2) {
if (unsigned cmp = expr_pair_quick_cmp()(*it1, *it2)) return cmp;
}
return 0;
}
}
/* Environment extension */
static name * g_class_name = nullptr;
static std::string * g_key = nullptr;
struct unification_hint_state {
unification_hints m_hints;
name_map<unsigned> m_decl_names_to_prio; // Note: redundant but convenient
void validate_type(expr const & decl_type) {
expr type = decl_type;
while (is_pi(type)) type = binding_body(type);
if (!is_app_of(type, get_unification_hint_name(), 0)) {
throw exception("invalid unification hint, must return element of type `unification hint`");
}
}
void register_hint(name const & decl_name, expr const & value, unsigned priority) {
m_decl_names_to_prio.insert(decl_name, priority);
expr e_hint = value;
unsigned num_vars = 0;
while (is_lambda(e_hint)) {
e_hint = binding_body(e_hint);
num_vars++;
}
if (!is_app_of(e_hint, get_unification_hint_mk_name(), 2)) {
throw exception("invalid unification hint, body must be application of 'unification_hint.mk' to two arguments");
}
// e_hint := unification_hint.mk pattern constraints
expr e_pattern = app_arg(app_fn(e_hint));
expr e_constraints = app_arg(e_hint);
// pattern := unification_constraint.mk _ lhs rhs
expr e_pattern_lhs = app_arg(app_fn(e_pattern));
expr e_pattern_rhs = app_arg(e_pattern);
expr e_pattern_lhs_fn = get_app_fn(e_pattern_lhs);
expr e_pattern_rhs_fn = get_app_fn(e_pattern_rhs);
if (!is_constant(e_pattern_lhs_fn) || !is_constant(e_pattern_rhs_fn)) {
throw exception("invalid unification hint, the heads of both sides of pattern must be constants");
}
name_pair key = mk_pair(const_name(e_pattern_lhs_fn), const_name(e_pattern_rhs_fn));
buffer<expr_pair> constraints;
while (is_app_of(e_constraints, get_list_cons_name(), 3)) {
// e_constraints := cons _ constraint rest
expr e_constraint = app_arg(app_fn(e_constraints));
expr e_constraint_lhs = app_arg(app_fn(e_constraint));
expr e_constraint_rhs = app_arg(e_constraint);
constraints.push_back(mk_pair(e_constraint_lhs, e_constraint_rhs));
e_constraints = app_arg(e_constraints);
}
if (!is_app_of(e_constraints, get_list_nil_name(), 1)) {
throw exception("invalid unification hint, must provide list of constraints explicitly");
}
unification_hint hint(e_pattern_lhs, e_pattern_rhs, to_list(constraints), num_vars);
unification_hint_queue q;
if (auto const & q_ptr = m_hints.find(key)) q = *q_ptr;
q.insert(hint, priority);
m_hints.insert(key, q);
}
};
struct unification_hint_entry {
name m_decl_name;
unsigned m_priority;
unification_hint_entry(name const & decl_name, unsigned priority):
m_decl_name(decl_name), m_priority(priority) {}
};
struct unification_hint_config {
typedef unification_hint_entry entry;
typedef unification_hint_state state;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
declaration decl = env.get(e.m_decl_name);
s.validate_type(decl.get_type());
// Note: only definitions should be tagged as [unify], so if it is not a definition,
// there must have been an error when processing the definition. We return immediately
// so as not to hide the original error.
// TODO(dhs): the downside to this approach is that a [unify] tag on an actual axiom will be silently ignored.
if (decl.is_definition()) s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority);
}
static name const & get_class_name() {
return *g_class_name;
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_decl_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
name decl_name; unsigned prio;
d >> decl_name >> prio;
return entry(decl_name, prio);
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(hash(e.m_decl_name.hash(), e.m_priority));
}
};
typedef scoped_ext<unification_hint_config> unification_hint_ext;
environment add_unification_hint(environment const & env, io_state const & ios, name const & decl_name, unsigned prio, name const & ns, bool persistent) {
return unification_hint_ext::add_entry(env, ios, unification_hint_entry(decl_name, prio), ns, persistent);
}
bool is_unification_hint(environment const & env, name const & decl_name) {
return unification_hint_ext::get_state(env).m_decl_names_to_prio.contains(decl_name);
}
unification_hints get_unification_hints(environment const & env) {
return unification_hint_ext::get_state(env).m_hints;
}
unification_hints get_unification_hints(environment const & env, name const & ns) {
list<unification_hint_entry> const * entries = unification_hint_ext::get_entries(env, ns);
unification_hint_state s;
if (entries) {
for (auto const & e : *entries) {
declaration decl = env.get(e.m_decl_name);
s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority);
}
}
return s.m_hints;
}
void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer<unification_hint> & uhints) {
unification_hints hints = unification_hint_ext::get_state(env).m_hints;
if (auto const & q_ptr = hints.find(mk_pair(n1, n2))) {
q_ptr->to_buffer(uhints);
}
if (auto const & q_ptr = hints.find(mk_pair(n2, n1))) {
q_ptr->to_buffer(uhints);
}
}
/* Pretty-printing */
// TODO(dhs): I may not be using all the formatting functions correctly.
format unification_hint::pp(unsigned prio, formatter const & fmt) const {
format r;
if (prio != LEAN_DEFAULT_PRIORITY)
r += paren(format(prio)) + space();
format r1 = fmt(get_lhs()) + space() + format("=?=") + pp_indent_expr(fmt, get_rhs());
r1 += space() + lcurly();
r += group(r1);
for_each(m_constraints, [&](expr_pair p) {
r += fmt(p.first) + space() + format("=?=");
r += space() + fmt(p.second) + comma() + space();
});
r += rcurly();
return r;
}
format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header) {
format r;
r += format("unification hints");
r += header + colon() + line();
hints.for_each([&](name_pair const & names, unification_hint_queue const & q) {
q.for_each([&](unification_hint const & hint) {
r += lp() + format(names.first) + comma() + space() + format(names.second) + rp() + space();
r += hint.pp(*q.get_prio(hint), fmt) + line();
});
});
return r;
}
void initialize_unification_hint() {
g_class_name = new name("unification_hint");
g_key = new std::string("UNIFICATION_HINT");
unification_hint_ext::initialize();
register_prio_attribute("unify", "unification hint",
add_unification_hint,
is_unification_hint,
[](environment const & env, name const & decl_name) {
if (auto p = unification_hint_ext::get_state(env).m_decl_names_to_prio.find(decl_name))
return *p;
else
return LEAN_DEFAULT_PRIORITY;
});
}
void finalize_unification_hint() {
unification_hint_ext::finalize();
delete g_key;
delete g_class_name;
}
}

View file

@ -0,0 +1,69 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#pragma once
#include "kernel/environment.h"
#include "library/expr_pair.h"
#include "library/io_state.h"
#include "library/head_map.h"
#include "util/priority_queue.h"
namespace lean {
/*
Users can declare unification hints using the following structures:
structure unification_constraint := {A : Type} (lhs : A) (rhs : A)
structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint)
Example:
definition both_zero_of_add_eq_zero [unify] (n n : ) (s : has_add ) (s : has_zero ) : unification_hint :=
unification_hint.mk (unification_constraint.mk (@add s n n) (@zero s))
[unification_constraint.mk n (@zero s),
unification_constraint.mk n (@zero s)]
creates the following unification hint:
m_lhs: add nat #1 #3 #2
m_rhs: zero nat #0
m_constraints: [(#3, zero nat #0), (#2, zero nat #0)]
m_num_vars: #4
Note that once we have an assignment to all variables from matching, we must substitute the assignments in the constraints.
*/
class unification_hint {
expr m_lhs;
expr m_rhs;
list<expr_pair> m_constraints;
unsigned m_num_vars;
public:
expr get_lhs() const { return m_lhs; }
expr get_rhs() const { return m_rhs; }
list<expr_pair> get_constraints() const { return m_constraints; }
unsigned get_num_vars() const { return m_num_vars; }
unification_hint() {}
unification_hint(expr const & lhs, expr const & rhs, list<expr_pair> const & constraints, unsigned num_vars);
format pp(unsigned priority, formatter const & fmt) const;
};
struct unification_hint_cmp {
int operator()(unification_hint const & uh1, unification_hint const & uh2) const;
};
typedef priority_queue<unification_hint, unification_hint_cmp> unification_hint_queue;
typedef rb_map<name_pair, unification_hint_queue, name_pair_quick_cmp> unification_hints;
unification_hints get_unification_hints(environment const & env);
unification_hints get_unification_hints(environment const & env, name const & ns);
void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer<unification_hint> & hints);
format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header);
void initialize_unification_hint();
void finalize_unification_hint();
}

View file

@ -0,0 +1,54 @@
import data.list data.nat
open list nat
structure unification_constraint := {A : Type} (lhs : A) (rhs : A)
structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint)
namespace toy
constants (A : Type.{1}) (f h : A → A) (x y z : A)
definition g [irreducible] (x y : A) : A := f z
#unify (g x y), (f z)
definition toy_hint [unify] (x y : A) : unification_hint :=
unification_hint.mk (unification_constraint.mk (g x y) (f z)) []
#unify (g x y), (f z)
print [unify]
end toy
namespace add
constants (n : )
#unify (n + 1), succ n
definition add_zero_hint [unify] (m n : ) [has_add ] [has_one ] [has_zero ] : unification_hint :=
unification_hint.mk (unification_constraint.mk (m + 1) (succ n)) [unification_constraint.mk m n]
#unify (n + 1), (succ n)
print [unify]
end add
namespace canonical
structure Canonical := (carrier : Type) (op : carrier → carrier)
attribute Canonical.carrier [irreducible]
constants (A : Type.{1}) (f : A → A) (x : A)
definition A_canonical : Canonical := Canonical.mk A f
#unify (Canonical.carrier A_canonical), A
definition Canonical_hint [unify] (C : Canonical) : unification_hint :=
unification_hint.mk (unification_constraint.mk (Canonical.carrier C) A) [unification_constraint.mk C A_canonical]
-- TODO(dhs): we mark carrier as irreducible and prove A_canonical explicitly to work around the fact that
-- the default_type_context does not recognize the elaborator metavariables as metavariables,
-- and so cannot perform the assignment.
#unify (Canonical.carrier A_canonical), A
print [unify]
end canonical
print [unify] canonical

View file

@ -0,0 +1,15 @@
fail
success
unification hints:
(toy.g, toy.f) g #1 #0 =?= f z {}
fail
success
unification hints:
(add, nat.succ) #4 + 1 =?= succ #3 {#4 =?= #3, }
fail
success
unification hints:
(canonical.Canonical.carrier, canonical.A) Canonical.carrier #0 =?= A {#0 =?= A_canonical, }
unification hints at namespace 'canonical':
(canonical.Canonical.carrier, canonical.A) canonical.Canonical.carrier #0 =?=
canonical.A {#0 =?= canonical.A_canonical, }