From 42dba5cc981de5edb99d9c4ffafcc66af452ba7f Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 30 Oct 2014 23:56:38 -0700 Subject: [PATCH] feat(frontends/lean/calc): expose get_calc_subst_info and get_calc_refl_info APIs --- src/frontends/lean/calc.cpp | 39 ++++++++++++++++++++++++------------- src/frontends/lean/calc.h | 2 ++ 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/frontends/lean/calc.cpp b/src/frontends/lean/calc.cpp index d617a06e6..7aaee4d15 100644 --- a/src/frontends/lean/calc.cpp +++ b/src/frontends/lean/calc.cpp @@ -60,8 +60,8 @@ struct calc_entry { }; struct calc_state { - typedef name_map> refl_table; - typedef name_map> subst_table; + typedef name_map> refl_table; + typedef name_map> subst_table; typedef name_map> symm_table; typedef rb_map, name_pair_quick_cmp> trans_table; trans_table m_trans_table; @@ -72,22 +72,26 @@ struct calc_state { void add_calc_subst(environment const & env, name const & subst) { buffer arg_types; - expr r_type = extract_arg_types(env, subst, arg_types); - unsigned nargs = arg_types.size(); + auto p = extract_arg_types_core(env, subst, arg_types); + expr r_type = p.first; + unsigned nunivs = p.second; + unsigned nargs = arg_types.size(); if (nargs < 2) throw exception("invalid calc substitution theorem, it must have at least 2 arguments"); name const & rop = get_fn_const(arg_types[nargs-2], "invalid calc substitution theorem, argument penultimate argument must be an operator application"); - m_subst_table.insert(rop, mk_pair(subst, nargs)); + m_subst_table.insert(rop, std::make_tuple(subst, nargs, nunivs)); } void add_calc_refl(environment const & env, name const & refl) { buffer arg_types; - expr r_type = extract_arg_types(env, refl, arg_types); - unsigned nargs = arg_types.size(); + auto p = extract_arg_types_core(env, refl, arg_types); + expr r_type = p.first; + unsigned nunivs = p.second; + unsigned nargs = arg_types.size(); if (nargs < 1) throw exception("invalid calc reflexivity rule, it must have at least 1 argument"); name const & rop = get_fn_const(r_type, "invalid calc reflexivity rule, result type must be an operator application"); - m_refl_table.insert(rop, mk_pair(refl, nargs)); + m_refl_table.insert(rop, std::make_tuple(refl, nargs, nunivs)); } void add_calc_trans(environment const & env, name const & trans) { @@ -180,15 +184,24 @@ void register_calc_cmds(cmd_table & r) { add_cmd(r, cmd_info("calc_symm", "set the symmetry rule for an operator, this command is relevant for the calculational proof '{...}' notation", calc_symm_cmd)); } -optional> get_calc_symm_info(environment const & env, name const & rop) { - auto const & s = calc_ext::get_state(env); - if (auto it = s.m_symm_table.find(rop)) { +static optional> get_info(name_map> const & table, name const & op) { + if (auto it = table.find(op)) { return optional>(*it); } else { return optional>(); } } +optional> get_calc_refl_info(environment const & env, name const & op) { + return get_info(calc_ext::get_state(env).m_refl_table, op); +} +optional> get_calc_subst_info(environment const & env, name const & op) { + return get_info(calc_ext::get_state(env).m_subst_table, op); +} +optional> get_calc_symm_info(environment const & env, name const & op) { + return get_info(calc_ext::get_state(env).m_symm_table, op); +} + static expr mk_calc_annotation_core(expr const & e) { return mk_annotation(*g_calc_name, e); } static expr mk_calc_annotation(expr const & pr) { if (is_by(pr) || is_begin_end_annotation(pr) || is_sorry(pr)) { @@ -254,9 +267,9 @@ static void parse_calc_proof(parser & p, buffer const & preds, std::v for (auto const & pred : preds) { if (auto refl_it = state.m_refl_table.find(pred_op(pred))) { if (auto subst_it = state.m_subst_table.find(pred_op(pred))) { - expr refl = mk_op_fn(p, refl_it->first, refl_it->second-1, pos); + expr refl = mk_op_fn(p, std::get<0>(*refl_it), std::get<1>(*refl_it)-1, pos); expr refl_pr = p.mk_app(refl, pred_lhs(pred), pos); - expr subst = mk_op_fn(p, subst_it->first, subst_it->second-2, pos); + expr subst = mk_op_fn(p, std::get<0>(*subst_it), std::get<1>(*subst_it)-2, pos); expr subst_pr = p.mk_app({subst, pr, refl_pr}, pos); steps.emplace_back(pred, subst_pr); } diff --git a/src/frontends/lean/calc.h b/src/frontends/lean/calc.h index 073a8125f..4df3c0d50 100644 --- a/src/frontends/lean/calc.h +++ b/src/frontends/lean/calc.h @@ -15,6 +15,8 @@ bool is_calc_annotation(expr const & e); Return none if the operator does not have a symmetry rule associated with it. */ optional> get_calc_symm_info(environment const & env, name const & op); +optional> get_calc_refl_info(environment const & env, name const & op); +optional> get_calc_subst_info(environment const & env, name const & op); void initialize_calc(); void finalize_calc(); }