feat(library/unification_hint): basic handling of user-supplied unification hints
This commit is contained in:
parent
d8fb6f5082
commit
bb4b8da582
19 changed files with 585 additions and 28 deletions
|
@ -51,7 +51,7 @@ otherkeywords={
|
|||
[class], [parsing-only], [coercion], [unfold_full], [constructor],
|
||||
[reducible], [irreducible], [semireducible], [quasireducible], [wf],
|
||||
[whnf], [multiple_instances], [none], [decl], [declaration],
|
||||
[relation], [symm], [subst], [refl], [trans], [simp], [congr],
|
||||
[relation], [symm], [subst], [refl], [trans], [simp], [congr], [unify],
|
||||
[backward], [forward], [no_pattern], [begin_end], [tactic], [abbreviation],
|
||||
[reducible], [unfold], [alias], [eqv], [intro], [intro!], [elim], [grinder],
|
||||
[localrefinfo], [recursor]
|
||||
|
@ -274,4 +274,3 @@ stringstyle=\ttfamily,
|
|||
% commentstyle={\ttfamily\footnotesize },
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@
|
|||
"whnf" "multiple_instances" "none" "decl" "declaration"
|
||||
"relation" "symm" "subst" "refl" "trans" "simp" "congr"
|
||||
"backward" "forward" "no_pattern" "begin_end" "tactic" "abbreviation"
|
||||
"reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder"
|
||||
"reducible" "unfold" "alias" "eqv" "intro" "intro!" "elim" "grinder" "unify"
|
||||
"localrefinfo" "recursor"))
|
||||
"lean modifiers")
|
||||
(defconst lean-modifiers-regexp
|
||||
|
|
|
@ -779,6 +779,23 @@ static environment normalizer_cmd(parser & p) {
|
|||
return env;
|
||||
}
|
||||
|
||||
static environment unify_cmd(parser & p) {
|
||||
environment const & env = p.env();
|
||||
expr e1; level_param_names ls1;
|
||||
std::tie(e1, ls1) = parse_local_expr(p);
|
||||
p.check_token_next(get_comma_tk(), "invalid #unify command, proper usage \"#unify e1, e2\"");
|
||||
expr e2; level_param_names ls2;
|
||||
std::tie(e2, ls2) = parse_local_expr(p);
|
||||
default_type_context ctx(env, p.get_options());
|
||||
bool success = ctx.is_def_eq(e1, e2);
|
||||
flycheck_information info(p.regular_stream());
|
||||
if (info.enabled()) {
|
||||
p.display_information_pos(p.cmd_pos());
|
||||
}
|
||||
p.regular_stream() << (success ? "success" : "fail") << endl;
|
||||
return env;
|
||||
}
|
||||
|
||||
static environment abstract_expr_cmd(parser & p) {
|
||||
unsigned o = p.parse_small_nat();
|
||||
default_type_context ctx(p.env(), p.get_options());
|
||||
|
@ -841,6 +858,7 @@ void init_cmd_table(cmd_table & r) {
|
|||
add_cmd(r, cmd_info("#congr_simp", "(for debugging purposes)", congr_simp_cmd));
|
||||
add_cmd(r, cmd_info("#congr_rel", "(for debugging purposes)", congr_rel_cmd));
|
||||
add_cmd(r, cmd_info("#normalizer", "(for debugging purposes)", normalizer_cmd));
|
||||
add_cmd(r, cmd_info("#unify", "(for debugging purposes)", unify_cmd));
|
||||
add_cmd(r, cmd_info("#accessible", "(for debugging purposes) display number of accessible declarations for blast tactic", accessible_cmd));
|
||||
add_cmd(r, cmd_info("#simplify", "(for debugging purposes) simplify given expression", simplify_cmd));
|
||||
add_cmd(r, cmd_info("#abstract_expr", "(for debugging purposes) call abstract expr methods", abstract_expr_cmd));
|
||||
|
|
|
@ -26,6 +26,7 @@ Author: Leonardo de Moura
|
|||
#include "library/user_recursors.h"
|
||||
#include "library/relation_manager.h"
|
||||
#include "library/noncomputable.h"
|
||||
#include "library/unification_hint.h"
|
||||
#include "library/definitional/projection.h"
|
||||
#include "library/blast/blast.h"
|
||||
#include "library/blast/simplifier/simplifier.h"
|
||||
|
@ -506,6 +507,23 @@ static void print_reducible_info(parser & p, reducible_status s1) {
|
|||
out << n << "\n";
|
||||
}
|
||||
|
||||
static void print_unification_hints(parser & p) {
|
||||
io_state_stream out = p.regular_stream();
|
||||
unification_hints hints;
|
||||
name ns;
|
||||
if (p.curr_is_identifier()) {
|
||||
ns = p.get_name_val();
|
||||
p.next();
|
||||
hints = get_unification_hints(p.env(), ns);
|
||||
} else {
|
||||
hints = get_unification_hints(p.env());
|
||||
}
|
||||
format header;
|
||||
if (!ns.is_anonymous())
|
||||
header = format(" at namespace '") + format(ns) + format("'");
|
||||
out << pp_unification_hints(hints, out.get_formatter(), header);
|
||||
}
|
||||
|
||||
static void print_simp_rules(parser & p) {
|
||||
io_state_stream out = p.regular_stream();
|
||||
blast::scope_debug scope(p.env(), p.ios());
|
||||
|
@ -699,6 +717,9 @@ environment print_cmd(parser & p) {
|
|||
p.next();
|
||||
p.check_token_next(get_rbracket_tk(), "invalid 'print [recursor]', ']' expected");
|
||||
print_recursor_info(p);
|
||||
} else if (p.curr_is_token(get_unify_attr_tk())) {
|
||||
p.next();
|
||||
print_unification_hints(p);
|
||||
} else if (p.curr_is_token(get_simp_attr_tk())) {
|
||||
p.next();
|
||||
print_simp_rules(p);
|
||||
|
|
|
@ -120,7 +120,7 @@ void init_token_table(token_table & t) {
|
|||
"multiple_instances", "find_decl", "attribute", "persistent",
|
||||
"include", "omit", "migrate", "init_quotient", "init_hits", "#erase_cache", "#projections", "#telescope_eq",
|
||||
"#compile", "#accessible", "#decl_stats", "#relevant_thms", "#simplify", "#app_builder", "#refl", "#symm",
|
||||
"#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", nullptr};
|
||||
"#trans", "#congr", "#hcongr", "#congr_simp", "#congr_rel", "#normalizer", "#abstract_expr", "#unify", nullptr};
|
||||
|
||||
pair<char const *, char const *> aliases[] =
|
||||
{{g_lambda_unicode, "fun"}, {"forall", "Pi"}, {g_forall_unicode, "Pi"}, {g_pi_unicode, "Pi"},
|
||||
|
|
|
@ -119,6 +119,7 @@ static name const * g_intro_attr_tk = nullptr;
|
|||
static name const * g_intro_bang_attr_tk = nullptr;
|
||||
static name const * g_elim_attr_tk = nullptr;
|
||||
static name const * g_recursor_tk = nullptr;
|
||||
static name const * g_unify_attr_tk = nullptr;
|
||||
static name const * g_attribute_tk = nullptr;
|
||||
static name const * g_with_tk = nullptr;
|
||||
static name const * g_class_tk = nullptr;
|
||||
|
@ -269,6 +270,7 @@ void initialize_tokens() {
|
|||
g_intro_bang_attr_tk = new name{"[intro!]"};
|
||||
g_elim_attr_tk = new name{"[elim]"};
|
||||
g_recursor_tk = new name{"[recursor"};
|
||||
g_unify_attr_tk = new name{"[unify]"};
|
||||
g_attribute_tk = new name{"attribute"};
|
||||
g_with_tk = new name{"with"};
|
||||
g_class_tk = new name{"[class]"};
|
||||
|
@ -420,6 +422,7 @@ void finalize_tokens() {
|
|||
delete g_intro_bang_attr_tk;
|
||||
delete g_elim_attr_tk;
|
||||
delete g_recursor_tk;
|
||||
delete g_unify_attr_tk;
|
||||
delete g_attribute_tk;
|
||||
delete g_with_tk;
|
||||
delete g_class_tk;
|
||||
|
@ -570,6 +573,7 @@ name const & get_intro_attr_tk() { return *g_intro_attr_tk; }
|
|||
name const & get_intro_bang_attr_tk() { return *g_intro_bang_attr_tk; }
|
||||
name const & get_elim_attr_tk() { return *g_elim_attr_tk; }
|
||||
name const & get_recursor_tk() { return *g_recursor_tk; }
|
||||
name const & get_unify_attr_tk() { return *g_unify_attr_tk; }
|
||||
name const & get_attribute_tk() { return *g_attribute_tk; }
|
||||
name const & get_with_tk() { return *g_with_tk; }
|
||||
name const & get_class_tk() { return *g_class_tk; }
|
||||
|
|
|
@ -121,6 +121,7 @@ name const & get_intro_attr_tk();
|
|||
name const & get_intro_bang_attr_tk();
|
||||
name const & get_elim_attr_tk();
|
||||
name const & get_recursor_tk();
|
||||
name const & get_unify_attr_tk();
|
||||
name const & get_attribute_tk();
|
||||
name const & get_with_tk();
|
||||
name const & get_class_tk();
|
||||
|
|
|
@ -114,6 +114,7 @@ intro_attr [intro]
|
|||
intro_bang_attr [intro!]
|
||||
elim_attr [elim]
|
||||
recursor [recursor
|
||||
unify_attr [unify]
|
||||
attribute attribute
|
||||
with with
|
||||
class [class]
|
||||
|
|
|
@ -18,4 +18,4 @@ add_library(library OBJECT deep_copy.cpp expr_lt.cpp io_state.cpp
|
|||
aux_recursors.cpp norm_num.cpp norm_num.cpp class_instance_resolution.cpp type_context.cpp
|
||||
tmp_type_context.cpp fun_info_manager.cpp congr_lemma_manager.cpp
|
||||
abstract_expr_manager.cpp light_lt_manager.cpp trace.cpp
|
||||
attribute_manager.cpp error_handling.cpp)
|
||||
attribute_manager.cpp error_handling.cpp unification_hint.cpp)
|
||||
|
|
|
@ -95,6 +95,8 @@ name const * g_lift_down = nullptr;
|
|||
name const * g_lift_up = nullptr;
|
||||
name const * g_linear_ordered_ring = nullptr;
|
||||
name const * g_linear_ordered_semiring = nullptr;
|
||||
name const * g_list_nil = nullptr;
|
||||
name const * g_list_cons = nullptr;
|
||||
name const * g_monoid = nullptr;
|
||||
name const * g_mul = nullptr;
|
||||
name const * g_mul_one = nullptr;
|
||||
|
@ -264,6 +266,10 @@ name const * g_trans_rel_left = nullptr;
|
|||
name const * g_trans_rel_right = nullptr;
|
||||
name const * g_true = nullptr;
|
||||
name const * g_true_intro = nullptr;
|
||||
name const * g_unification_hint = nullptr;
|
||||
name const * g_unification_hint_mk = nullptr;
|
||||
name const * g_unification_constraint = nullptr;
|
||||
name const * g_unification_constraint_mk = nullptr;
|
||||
name const * g_weak_order = nullptr;
|
||||
name const * g_well_founded = nullptr;
|
||||
name const * g_zero = nullptr;
|
||||
|
@ -363,6 +369,8 @@ void initialize_constants() {
|
|||
g_lift_up = new name{"lift", "up"};
|
||||
g_linear_ordered_ring = new name{"linear_ordered_ring"};
|
||||
g_linear_ordered_semiring = new name{"linear_ordered_semiring"};
|
||||
g_list_nil = new name{"list", "nil"};
|
||||
g_list_cons = new name{"list", "cons"};
|
||||
g_monoid = new name{"monoid"};
|
||||
g_mul = new name{"mul"};
|
||||
g_mul_one = new name{"mul_one"};
|
||||
|
@ -532,6 +540,10 @@ void initialize_constants() {
|
|||
g_trans_rel_right = new name{"trans_rel_right"};
|
||||
g_true = new name{"true"};
|
||||
g_true_intro = new name{"true", "intro"};
|
||||
g_unification_hint = new name{"unification_hint"};
|
||||
g_unification_hint_mk = new name{"unification_hint", "mk"};
|
||||
g_unification_constraint = new name{"unification_constraint"};
|
||||
g_unification_constraint_mk = new name{"unification_constraint", "mk"};
|
||||
g_weak_order = new name{"weak_order"};
|
||||
g_well_founded = new name{"well_founded"};
|
||||
g_zero = new name{"zero"};
|
||||
|
@ -632,6 +644,8 @@ void finalize_constants() {
|
|||
delete g_lift_up;
|
||||
delete g_linear_ordered_ring;
|
||||
delete g_linear_ordered_semiring;
|
||||
delete g_list_nil;
|
||||
delete g_list_cons;
|
||||
delete g_monoid;
|
||||
delete g_mul;
|
||||
delete g_mul_one;
|
||||
|
@ -801,6 +815,10 @@ void finalize_constants() {
|
|||
delete g_trans_rel_right;
|
||||
delete g_true;
|
||||
delete g_true_intro;
|
||||
delete g_unification_hint;
|
||||
delete g_unification_hint_mk;
|
||||
delete g_unification_constraint;
|
||||
delete g_unification_constraint_mk;
|
||||
delete g_weak_order;
|
||||
delete g_well_founded;
|
||||
delete g_zero;
|
||||
|
@ -900,6 +918,8 @@ name const & get_lift_down_name() { return *g_lift_down; }
|
|||
name const & get_lift_up_name() { return *g_lift_up; }
|
||||
name const & get_linear_ordered_ring_name() { return *g_linear_ordered_ring; }
|
||||
name const & get_linear_ordered_semiring_name() { return *g_linear_ordered_semiring; }
|
||||
name const & get_list_nil_name() { return *g_list_nil; }
|
||||
name const & get_list_cons_name() { return *g_list_cons; }
|
||||
name const & get_monoid_name() { return *g_monoid; }
|
||||
name const & get_mul_name() { return *g_mul; }
|
||||
name const & get_mul_one_name() { return *g_mul_one; }
|
||||
|
@ -1069,6 +1089,10 @@ name const & get_trans_rel_left_name() { return *g_trans_rel_left; }
|
|||
name const & get_trans_rel_right_name() { return *g_trans_rel_right; }
|
||||
name const & get_true_name() { return *g_true; }
|
||||
name const & get_true_intro_name() { return *g_true_intro; }
|
||||
name const & get_unification_hint_name() { return *g_unification_hint; }
|
||||
name const & get_unification_hint_mk_name() { return *g_unification_hint_mk; }
|
||||
name const & get_unification_constraint_name() { return *g_unification_constraint; }
|
||||
name const & get_unification_constraint_mk_name() { return *g_unification_constraint_mk; }
|
||||
name const & get_weak_order_name() { return *g_weak_order; }
|
||||
name const & get_well_founded_name() { return *g_well_founded; }
|
||||
name const & get_zero_name() { return *g_zero; }
|
||||
|
|
|
@ -97,6 +97,8 @@ name const & get_lift_down_name();
|
|||
name const & get_lift_up_name();
|
||||
name const & get_linear_ordered_ring_name();
|
||||
name const & get_linear_ordered_semiring_name();
|
||||
name const & get_list_nil_name();
|
||||
name const & get_list_cons_name();
|
||||
name const & get_monoid_name();
|
||||
name const & get_mul_name();
|
||||
name const & get_mul_one_name();
|
||||
|
@ -266,6 +268,10 @@ name const & get_trans_rel_left_name();
|
|||
name const & get_trans_rel_right_name();
|
||||
name const & get_true_name();
|
||||
name const & get_true_intro_name();
|
||||
name const & get_unification_hint_name();
|
||||
name const & get_unification_hint_mk_name();
|
||||
name const & get_unification_constraint_name();
|
||||
name const & get_unification_constraint_mk_name();
|
||||
name const & get_weak_order_name();
|
||||
name const & get_well_founded_name();
|
||||
name const & get_zero_name();
|
||||
|
|
|
@ -90,6 +90,8 @@ lift.down
|
|||
lift.up
|
||||
linear_ordered_ring
|
||||
linear_ordered_semiring
|
||||
list.nil
|
||||
list.cons
|
||||
monoid
|
||||
mul
|
||||
mul_one
|
||||
|
@ -259,6 +261,10 @@ trans_rel_left
|
|||
trans_rel_right
|
||||
true
|
||||
true.intro
|
||||
unification_hint
|
||||
unification_hint.mk
|
||||
unification_constraint
|
||||
unification_constraint.mk
|
||||
weak_order
|
||||
well_founded
|
||||
zero
|
||||
|
|
|
@ -47,6 +47,7 @@ Author: Leonardo de Moura
|
|||
#include "library/app_builder.h"
|
||||
#include "library/attribute_manager.h"
|
||||
#include "library/fun_info_manager.h"
|
||||
#include "library/unification_hint.h"
|
||||
|
||||
namespace lean {
|
||||
void initialize_library_module() {
|
||||
|
@ -93,9 +94,11 @@ void initialize_library_module() {
|
|||
initialize_congr_lemma_manager();
|
||||
initialize_app_builder();
|
||||
initialize_fun_info_manager();
|
||||
initialize_unification_hint();
|
||||
}
|
||||
|
||||
void finalize_library_module() {
|
||||
finalize_unification_hint();
|
||||
finalize_fun_info_manager();
|
||||
finalize_app_builder();
|
||||
finalize_congr_lemma_manager();
|
||||
|
|
|
@ -11,6 +11,7 @@ Author: Leonardo de Moura
|
|||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
#include "kernel/replace_fn.h"
|
||||
#include "kernel/inductive/inductive.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/util.h"
|
||||
|
@ -23,6 +24,7 @@ Author: Leonardo de Moura
|
|||
#include "library/generic_exception.h"
|
||||
#include "library/class.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/unification_hint.h"
|
||||
|
||||
#ifndef LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH
|
||||
#define LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH 32
|
||||
|
@ -1013,10 +1015,7 @@ bool type_context::is_def_eq_core(expr const & t, expr const & s) {
|
|||
if (is_def_eq_proof_irrel(t_n, s_n))
|
||||
return true;
|
||||
|
||||
if (on_is_def_eq_failure(t_n, s_n))
|
||||
return is_def_eq_core(t_n, s_n);
|
||||
else
|
||||
return false;
|
||||
return on_is_def_eq_failure(t_n, s_n);
|
||||
}
|
||||
|
||||
bool type_context::process_postponed(unsigned old_sz) {
|
||||
|
@ -1405,19 +1404,112 @@ optional<pair<expr, expr>> type_context::find_unsynth_metavar(expr const & e) {
|
|||
}
|
||||
}
|
||||
|
||||
bool type_context::on_is_def_eq_failure(expr & e1, expr & e2) {
|
||||
bool type_context::on_is_def_eq_failure(expr const & e1, expr const & e2) {
|
||||
if (is_app(e1)) {
|
||||
if (auto p1 = find_unsynth_metavar(e1)) {
|
||||
if (mk_nested_instance(p1->first, p1->second)) {
|
||||
e1 = instantiate_uvars_mvars(e1);
|
||||
return true;
|
||||
return is_def_eq_core(instantiate_uvars_mvars(e1), e2);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_app(e2)) {
|
||||
if (auto p2 = find_unsynth_metavar(e2)) {
|
||||
if (mk_nested_instance(p2->first, p2->second)) {
|
||||
e2 = instantiate_uvars_mvars(e2);
|
||||
return is_def_eq_core(e1, instantiate_uvars_mvars(e2));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (try_unification_hints(e1, e2)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
struct type_context::unification_hint_fn {
|
||||
type_context & m_owner;
|
||||
unification_hint m_hint;
|
||||
buffer<optional<expr> > m_assignment;
|
||||
|
||||
unification_hint_fn(type_context & o, unification_hint const & hint):
|
||||
m_owner(o), m_hint(hint) { m_assignment.resize(m_hint.get_num_vars()); }
|
||||
|
||||
bool syntactic_match(expr const & pattern, expr const & e) {
|
||||
unsigned idx;
|
||||
switch (pattern.kind()) {
|
||||
case expr_kind::Var:
|
||||
idx = var_idx(pattern);
|
||||
if (!m_assignment[idx]) {
|
||||
m_assignment[idx] = some_expr(e);
|
||||
return true;
|
||||
} else {
|
||||
return m_owner.is_def_eq(*m_assignment[idx], e);
|
||||
}
|
||||
case expr_kind::Constant:
|
||||
return is_constant(e) && const_name(pattern) == const_name(e)
|
||||
&& m_owner.is_def_eq(const_levels(pattern), const_levels(e));
|
||||
case expr_kind::Sort:
|
||||
return is_sort(e) && m_owner.is_def_eq(sort_level(pattern), sort_level(e));
|
||||
case expr_kind::Pi: case expr_kind::Lambda: case expr_kind::Macro:
|
||||
// Remark: we do not traverse inside of binders.
|
||||
return pattern == e;
|
||||
case expr_kind::App:
|
||||
return is_app(e) && syntactic_match(app_fn(pattern), app_fn(e)) && syntactic_match(app_arg(pattern), app_arg(e));
|
||||
case expr_kind::Local: case expr_kind::Meta:
|
||||
break;
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
bool operator()(expr const & lhs, expr const & rhs) {
|
||||
if (!syntactic_match(m_hint.get_lhs(), lhs)) {
|
||||
lean_trace(name({"type_context", "unification_hint"}), tout() << "LHS does not match\n";);
|
||||
return false;
|
||||
} else if (!syntactic_match(m_hint.get_rhs(), rhs)) {
|
||||
lean_trace(name({"type_context", "unification_hint"}), tout() << "RHS does not match\n";);
|
||||
return false;
|
||||
} else {
|
||||
auto instantiate_assignment_fn = [&](expr const & e, unsigned offset) {
|
||||
if (is_var(e)) {
|
||||
unsigned idx = var_idx(e) + offset;
|
||||
if (idx < m_assignment.size()) {
|
||||
lean_assert(m_assignment[idx]);
|
||||
return m_assignment[idx];
|
||||
}
|
||||
}
|
||||
return none_expr();
|
||||
};
|
||||
buffer<expr_pair> constraints;
|
||||
to_buffer(m_hint.get_constraints(), constraints);
|
||||
for (expr_pair const & p : constraints) {
|
||||
expr new_lhs = replace(p.first, instantiate_assignment_fn);
|
||||
expr new_rhs = replace(p.second, instantiate_assignment_fn);
|
||||
expr new_lhs_inst = m_owner.instantiate_uvars_mvars(new_lhs);
|
||||
expr new_rhs_inst = m_owner.instantiate_uvars_mvars(new_rhs);
|
||||
bool success = m_owner.is_def_eq(new_lhs, new_rhs);
|
||||
lean_trace(name({"type_context", "unification_hint"}),
|
||||
tout() << new_lhs_inst << " =?= " << new_rhs_inst << "..."
|
||||
<< (success ? "success" : "failed") << "\n";);
|
||||
if (!success) return false;
|
||||
}
|
||||
lean_trace(name({"type_context", "unification_hint"}), tout() << "hint successfully applied\n";);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool type_context::try_unification_hints(expr const & e1, expr const & e2) {
|
||||
expr e1_fn = get_app_fn(e1);
|
||||
expr e2_fn = get_app_fn(e2);
|
||||
if (is_constant(e1_fn) && is_constant(e2_fn)) {
|
||||
buffer<unification_hint> hints;
|
||||
get_unification_hints(m_env, const_name(e1_fn), const_name(e2_fn), hints);
|
||||
for (unification_hint const & hint : hints) {
|
||||
scope s(*this);
|
||||
lean_trace(name({"type_context", "unification_hint"}),
|
||||
tout() << e1 << " =?= " << e2
|
||||
<< ", pattern: " << hint.get_lhs() << " =?= " << hint.get_rhs() << "\n";);
|
||||
if (unification_hint_fn(*this, hint)(e1, e2)) {
|
||||
s.commit();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -2042,6 +2134,7 @@ void initialize_type_context() {
|
|||
g_tmp_prefix = new name(name::mk_internal_unique_name());
|
||||
g_internal_prefix = new name(name::mk_internal_unique_name());
|
||||
register_trace_class("class_instances");
|
||||
register_trace_class(name({"type_context", "unification_hint"}));
|
||||
g_class_instance_max_depth = new name{"class", "instance_max_depth"};
|
||||
g_class_trans_instances = new name{"class", "trans_instances"};
|
||||
register_unsigned_option(*g_class_instance_max_depth, LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH,
|
||||
|
|
|
@ -413,7 +413,11 @@ public:
|
|||
|
||||
The default implementation tries to invoke type class resolution to
|
||||
assign unassigned metavariables in the given terms. */
|
||||
virtual bool on_is_def_eq_failure(expr &, expr &);
|
||||
virtual bool on_is_def_eq_failure(expr const &, expr const &);
|
||||
|
||||
bool try_unification_hints(expr const &, expr const &);
|
||||
struct unification_hint_fn;
|
||||
friend struct unification_hint_fn;
|
||||
|
||||
bool is_assigned(level const & u) const { return static_cast<bool>(get_assignment(u)); }
|
||||
bool is_assigned(expr const & m) const { return static_cast<bool>(get_assignment(m)); }
|
||||
|
|
239
src/library/unification_hint.cpp
Normal file
239
src/library/unification_hint.cpp
Normal file
|
@ -0,0 +1,239 @@
|
|||
/*
|
||||
Copyright (c) 2015 Daniel Selsam. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Daniel Selsam
|
||||
*/
|
||||
#include <string>
|
||||
#include "util/sexpr/format.h"
|
||||
#include "kernel/expr.h"
|
||||
#include "kernel/error_msgs.h"
|
||||
#include "library/attribute_manager.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/unification_hint.h"
|
||||
#include "library/util.h"
|
||||
#include "library/expr_lt.h"
|
||||
#include "library/scoped_ext.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
/* Unification hints */
|
||||
|
||||
unification_hint::unification_hint(expr const & lhs, expr const & rhs, list<expr_pair> const & constraints, unsigned num_vars):
|
||||
m_lhs(lhs), m_rhs(rhs), m_constraints(constraints), m_num_vars(num_vars) {}
|
||||
|
||||
int unification_hint_cmp::operator()(unification_hint const & uh1, unification_hint const & uh2) const {
|
||||
if (uh1.get_lhs() != uh2.get_lhs()) {
|
||||
return expr_quick_cmp()(uh1.get_lhs(), uh2.get_lhs());
|
||||
} else if (uh1.get_rhs() != uh2.get_rhs()) {
|
||||
return expr_quick_cmp()(uh1.get_rhs(), uh2.get_rhs());
|
||||
} else {
|
||||
auto it1 = uh1.get_constraints().begin();
|
||||
auto it2 = uh2.get_constraints().begin();
|
||||
auto end1 = uh1.get_constraints().end();
|
||||
auto end2 = uh2.get_constraints().end();
|
||||
for (; it1 != end1 && it2 != end2; ++it1, ++it2) {
|
||||
if (unsigned cmp = expr_pair_quick_cmp()(*it1, *it2)) return cmp;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Environment extension */
|
||||
|
||||
static name * g_class_name = nullptr;
|
||||
static std::string * g_key = nullptr;
|
||||
|
||||
struct unification_hint_state {
|
||||
unification_hints m_hints;
|
||||
name_map<unsigned> m_decl_names_to_prio; // Note: redundant but convenient
|
||||
|
||||
void validate_type(expr const & decl_type) {
|
||||
expr type = decl_type;
|
||||
while (is_pi(type)) type = binding_body(type);
|
||||
if (!is_app_of(type, get_unification_hint_name(), 0)) {
|
||||
throw exception("invalid unification hint, must return element of type `unification hint`");
|
||||
}
|
||||
}
|
||||
|
||||
void register_hint(name const & decl_name, expr const & value, unsigned priority) {
|
||||
m_decl_names_to_prio.insert(decl_name, priority);
|
||||
|
||||
expr e_hint = value;
|
||||
unsigned num_vars = 0;
|
||||
while (is_lambda(e_hint)) {
|
||||
e_hint = binding_body(e_hint);
|
||||
num_vars++;
|
||||
}
|
||||
|
||||
if (!is_app_of(e_hint, get_unification_hint_mk_name(), 2)) {
|
||||
throw exception("invalid unification hint, body must be application of 'unification_hint.mk' to two arguments");
|
||||
}
|
||||
|
||||
// e_hint := unification_hint.mk pattern constraints
|
||||
expr e_pattern = app_arg(app_fn(e_hint));
|
||||
expr e_constraints = app_arg(e_hint);
|
||||
|
||||
// pattern := unification_constraint.mk _ lhs rhs
|
||||
expr e_pattern_lhs = app_arg(app_fn(e_pattern));
|
||||
expr e_pattern_rhs = app_arg(e_pattern);
|
||||
|
||||
expr e_pattern_lhs_fn = get_app_fn(e_pattern_lhs);
|
||||
expr e_pattern_rhs_fn = get_app_fn(e_pattern_rhs);
|
||||
|
||||
if (!is_constant(e_pattern_lhs_fn) || !is_constant(e_pattern_rhs_fn)) {
|
||||
throw exception("invalid unification hint, the heads of both sides of pattern must be constants");
|
||||
}
|
||||
|
||||
name_pair key = mk_pair(const_name(e_pattern_lhs_fn), const_name(e_pattern_rhs_fn));
|
||||
|
||||
buffer<expr_pair> constraints;
|
||||
while (is_app_of(e_constraints, get_list_cons_name(), 3)) {
|
||||
// e_constraints := cons _ constraint rest
|
||||
expr e_constraint = app_arg(app_fn(e_constraints));
|
||||
expr e_constraint_lhs = app_arg(app_fn(e_constraint));
|
||||
expr e_constraint_rhs = app_arg(e_constraint);
|
||||
constraints.push_back(mk_pair(e_constraint_lhs, e_constraint_rhs));
|
||||
e_constraints = app_arg(e_constraints);
|
||||
}
|
||||
|
||||
if (!is_app_of(e_constraints, get_list_nil_name(), 1)) {
|
||||
throw exception("invalid unification hint, must provide list of constraints explicitly");
|
||||
}
|
||||
|
||||
unification_hint hint(e_pattern_lhs, e_pattern_rhs, to_list(constraints), num_vars);
|
||||
unification_hint_queue q;
|
||||
if (auto const & q_ptr = m_hints.find(key)) q = *q_ptr;
|
||||
q.insert(hint, priority);
|
||||
m_hints.insert(key, q);
|
||||
}
|
||||
};
|
||||
|
||||
struct unification_hint_entry {
|
||||
name m_decl_name;
|
||||
unsigned m_priority;
|
||||
unification_hint_entry(name const & decl_name, unsigned priority):
|
||||
m_decl_name(decl_name), m_priority(priority) {}
|
||||
};
|
||||
|
||||
struct unification_hint_config {
|
||||
typedef unification_hint_entry entry;
|
||||
typedef unification_hint_state state;
|
||||
|
||||
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
|
||||
declaration decl = env.get(e.m_decl_name);
|
||||
s.validate_type(decl.get_type());
|
||||
// Note: only definitions should be tagged as [unify], so if it is not a definition,
|
||||
// there must have been an error when processing the definition. We return immediately
|
||||
// so as not to hide the original error.
|
||||
// TODO(dhs): the downside to this approach is that a [unify] tag on an actual axiom will be silently ignored.
|
||||
if (decl.is_definition()) s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority);
|
||||
}
|
||||
static name const & get_class_name() {
|
||||
return *g_class_name;
|
||||
}
|
||||
static std::string const & get_serialization_key() {
|
||||
return *g_key;
|
||||
}
|
||||
static void write_entry(serializer & s, entry const & e) {
|
||||
s << e.m_decl_name << e.m_priority;
|
||||
}
|
||||
static entry read_entry(deserializer & d) {
|
||||
name decl_name; unsigned prio;
|
||||
d >> decl_name >> prio;
|
||||
return entry(decl_name, prio);
|
||||
}
|
||||
static optional<unsigned> get_fingerprint(entry const & e) {
|
||||
return some(hash(e.m_decl_name.hash(), e.m_priority));
|
||||
}
|
||||
};
|
||||
|
||||
typedef scoped_ext<unification_hint_config> unification_hint_ext;
|
||||
|
||||
environment add_unification_hint(environment const & env, io_state const & ios, name const & decl_name, unsigned prio, name const & ns, bool persistent) {
|
||||
return unification_hint_ext::add_entry(env, ios, unification_hint_entry(decl_name, prio), ns, persistent);
|
||||
}
|
||||
|
||||
bool is_unification_hint(environment const & env, name const & decl_name) {
|
||||
return unification_hint_ext::get_state(env).m_decl_names_to_prio.contains(decl_name);
|
||||
}
|
||||
|
||||
unification_hints get_unification_hints(environment const & env) {
|
||||
return unification_hint_ext::get_state(env).m_hints;
|
||||
}
|
||||
|
||||
unification_hints get_unification_hints(environment const & env, name const & ns) {
|
||||
list<unification_hint_entry> const * entries = unification_hint_ext::get_entries(env, ns);
|
||||
unification_hint_state s;
|
||||
if (entries) {
|
||||
for (auto const & e : *entries) {
|
||||
declaration decl = env.get(e.m_decl_name);
|
||||
s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority);
|
||||
}
|
||||
}
|
||||
return s.m_hints;
|
||||
}
|
||||
|
||||
void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer<unification_hint> & uhints) {
|
||||
unification_hints hints = unification_hint_ext::get_state(env).m_hints;
|
||||
if (auto const & q_ptr = hints.find(mk_pair(n1, n2))) {
|
||||
q_ptr->to_buffer(uhints);
|
||||
}
|
||||
if (auto const & q_ptr = hints.find(mk_pair(n2, n1))) {
|
||||
q_ptr->to_buffer(uhints);
|
||||
}
|
||||
}
|
||||
|
||||
/* Pretty-printing */
|
||||
|
||||
// TODO(dhs): I may not be using all the formatting functions correctly.
|
||||
format unification_hint::pp(unsigned prio, formatter const & fmt) const {
|
||||
format r;
|
||||
if (prio != LEAN_DEFAULT_PRIORITY)
|
||||
r += paren(format(prio)) + space();
|
||||
format r1 = fmt(get_lhs()) + space() + format("=?=") + pp_indent_expr(fmt, get_rhs());
|
||||
r1 += space() + lcurly();
|
||||
r += group(r1);
|
||||
for_each(m_constraints, [&](expr_pair p) {
|
||||
r += fmt(p.first) + space() + format("=?=");
|
||||
r += space() + fmt(p.second) + comma() + space();
|
||||
});
|
||||
r += rcurly();
|
||||
return r;
|
||||
}
|
||||
|
||||
format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header) {
|
||||
format r;
|
||||
r += format("unification hints");
|
||||
r += header + colon() + line();
|
||||
hints.for_each([&](name_pair const & names, unification_hint_queue const & q) {
|
||||
q.for_each([&](unification_hint const & hint) {
|
||||
r += lp() + format(names.first) + comma() + space() + format(names.second) + rp() + space();
|
||||
r += hint.pp(*q.get_prio(hint), fmt) + line();
|
||||
});
|
||||
});
|
||||
return r;
|
||||
}
|
||||
|
||||
void initialize_unification_hint() {
|
||||
g_class_name = new name("unification_hint");
|
||||
g_key = new std::string("UNIFICATION_HINT");
|
||||
|
||||
unification_hint_ext::initialize();
|
||||
|
||||
register_prio_attribute("unify", "unification hint",
|
||||
add_unification_hint,
|
||||
is_unification_hint,
|
||||
[](environment const & env, name const & decl_name) {
|
||||
if (auto p = unification_hint_ext::get_state(env).m_decl_names_to_prio.find(decl_name))
|
||||
return *p;
|
||||
else
|
||||
return LEAN_DEFAULT_PRIORITY;
|
||||
});
|
||||
}
|
||||
|
||||
void finalize_unification_hint() {
|
||||
unification_hint_ext::finalize();
|
||||
delete g_key;
|
||||
delete g_class_name;
|
||||
}
|
||||
}
|
69
src/library/unification_hint.h
Normal file
69
src/library/unification_hint.h
Normal file
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
Copyright (c) 2015 Daniel Selsam. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Author: Daniel Selsam
|
||||
*/
|
||||
#pragma once
|
||||
#include "kernel/environment.h"
|
||||
#include "library/expr_pair.h"
|
||||
#include "library/io_state.h"
|
||||
#include "library/head_map.h"
|
||||
#include "util/priority_queue.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
/*
|
||||
Users can declare unification hints using the following structures:
|
||||
|
||||
structure unification_constraint := {A : Type} (lhs : A) (rhs : A)
|
||||
structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint)
|
||||
|
||||
Example:
|
||||
|
||||
definition both_zero_of_add_eq_zero [unify] (n₁ n₂ : ℕ) (s₁ : has_add ℕ) (s₂ : has_zero ℕ) : unification_hint :=
|
||||
unification_hint.mk (unification_constraint.mk (@add ℕ s₁ n₁ n₂) (@zero ℕ s₂))
|
||||
[unification_constraint.mk n₁ (@zero ℕ s₂),
|
||||
unification_constraint.mk n₂ (@zero ℕ s₂)]
|
||||
|
||||
creates the following unification hint:
|
||||
m_lhs: add nat #1 #3 #2
|
||||
m_rhs: zero nat #0
|
||||
m_constraints: [(#3, zero nat #0), (#2, zero nat #0)]
|
||||
m_num_vars: #4
|
||||
|
||||
Note that once we have an assignment to all variables from matching, we must substitute the assignments in the constraints.
|
||||
*/
|
||||
|
||||
class unification_hint {
|
||||
expr m_lhs;
|
||||
expr m_rhs;
|
||||
|
||||
list<expr_pair> m_constraints;
|
||||
unsigned m_num_vars;
|
||||
public:
|
||||
expr get_lhs() const { return m_lhs; }
|
||||
expr get_rhs() const { return m_rhs; }
|
||||
list<expr_pair> get_constraints() const { return m_constraints; }
|
||||
unsigned get_num_vars() const { return m_num_vars; }
|
||||
|
||||
unification_hint() {}
|
||||
unification_hint(expr const & lhs, expr const & rhs, list<expr_pair> const & constraints, unsigned num_vars);
|
||||
format pp(unsigned priority, formatter const & fmt) const;
|
||||
};
|
||||
|
||||
struct unification_hint_cmp {
|
||||
int operator()(unification_hint const & uh1, unification_hint const & uh2) const;
|
||||
};
|
||||
|
||||
typedef priority_queue<unification_hint, unification_hint_cmp> unification_hint_queue;
|
||||
typedef rb_map<name_pair, unification_hint_queue, name_pair_quick_cmp> unification_hints;
|
||||
|
||||
unification_hints get_unification_hints(environment const & env);
|
||||
unification_hints get_unification_hints(environment const & env, name const & ns);
|
||||
void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer<unification_hint> & hints);
|
||||
|
||||
format pp_unification_hints(unification_hints const & hints, formatter const & fmt, format const & header);
|
||||
|
||||
void initialize_unification_hint();
|
||||
void finalize_unification_hint();
|
||||
}
|
54
tests/lean/unification_hints1.lean
Normal file
54
tests/lean/unification_hints1.lean
Normal file
|
@ -0,0 +1,54 @@
|
|||
import data.list data.nat
|
||||
open list nat
|
||||
|
||||
structure unification_constraint := {A : Type} (lhs : A) (rhs : A)
|
||||
structure unification_hint := (pattern : unification_constraint) (constraints : list unification_constraint)
|
||||
|
||||
namespace toy
|
||||
constants (A : Type.{1}) (f h : A → A) (x y z : A)
|
||||
definition g [irreducible] (x y : A) : A := f z
|
||||
|
||||
#unify (g x y), (f z)
|
||||
|
||||
definition toy_hint [unify] (x y : A) : unification_hint :=
|
||||
unification_hint.mk (unification_constraint.mk (g x y) (f z)) []
|
||||
|
||||
#unify (g x y), (f z)
|
||||
print [unify]
|
||||
|
||||
end toy
|
||||
|
||||
namespace add
|
||||
constants (n : ℕ)
|
||||
|
||||
#unify (n + 1), succ n
|
||||
|
||||
definition add_zero_hint [unify] (m n : ℕ) [has_add ℕ] [has_one ℕ] [has_zero ℕ] : unification_hint :=
|
||||
unification_hint.mk (unification_constraint.mk (m + 1) (succ n)) [unification_constraint.mk m n]
|
||||
|
||||
#unify (n + 1), (succ n)
|
||||
print [unify]
|
||||
|
||||
end add
|
||||
|
||||
namespace canonical
|
||||
structure Canonical := (carrier : Type) (op : carrier → carrier)
|
||||
attribute Canonical.carrier [irreducible]
|
||||
|
||||
constants (A : Type.{1}) (f : A → A) (x : A)
|
||||
definition A_canonical : Canonical := Canonical.mk A f
|
||||
|
||||
#unify (Canonical.carrier A_canonical), A
|
||||
|
||||
definition Canonical_hint [unify] (C : Canonical) : unification_hint :=
|
||||
unification_hint.mk (unification_constraint.mk (Canonical.carrier C) A) [unification_constraint.mk C A_canonical]
|
||||
|
||||
-- TODO(dhs): we mark carrier as irreducible and prove A_canonical explicitly to work around the fact that
|
||||
-- the default_type_context does not recognize the elaborator metavariables as metavariables,
|
||||
-- and so cannot perform the assignment.
|
||||
#unify (Canonical.carrier A_canonical), A
|
||||
print [unify]
|
||||
|
||||
end canonical
|
||||
|
||||
print [unify] canonical
|
15
tests/lean/unification_hints1.lean.expected.out
Normal file
15
tests/lean/unification_hints1.lean.expected.out
Normal file
|
@ -0,0 +1,15 @@
|
|||
fail
|
||||
success
|
||||
unification hints:
|
||||
(toy.g, toy.f) g #1 #0 =?= f z {}
|
||||
fail
|
||||
success
|
||||
unification hints:
|
||||
(add, nat.succ) #4 + 1 =?= succ #3 {#4 =?= #3, }
|
||||
fail
|
||||
success
|
||||
unification hints:
|
||||
(canonical.Canonical.carrier, canonical.A) Canonical.carrier #0 =?= A {#0 =?= A_canonical, }
|
||||
unification hints at namespace 'canonical':
|
||||
(canonical.Canonical.carrier, canonical.A) canonical.Canonical.carrier #0 =?=
|
||||
canonical.A {#0 =?= canonical.A_canonical, }
|
Loading…
Reference in a new issue