2014-06-17 20:11:06 +00:00
/*
Copyright ( c ) 2014 Microsoft Corporation . All rights reserved .
Released under Apache 2.0 license as described in the file LICENSE .
Author : Leonardo de Moura
*/
# include <string>
# include <utility>
2014-06-18 00:15:38 +00:00
# include <algorithm>
# include <vector>
2014-06-17 20:11:06 +00:00
# include "util/optional.h"
# include "util/name.h"
# include "util/rb_map.h"
# include "util/buffer.h"
2014-06-18 00:15:38 +00:00
# include "util/interrupt.h"
2014-06-17 20:11:06 +00:00
# include "kernel/environment.h"
# include "library/module.h"
2014-06-18 00:15:38 +00:00
# include "library/choice.h"
# include "library/placeholder.h"
2014-06-25 15:30:09 +00:00
# include "library/explicit.h"
2014-08-23 23:44:06 +00:00
# include "library/scoped_ext.h"
2014-10-31 01:33:47 +00:00
# include "library/annotation.h"
2014-11-29 23:35:09 +00:00
# include "library/typed_expr.h"
2014-10-31 01:33:47 +00:00
# include "library/sorry.h"
# include "library/tactic/expr_to_tactic.h"
2014-06-17 20:35:31 +00:00
# include "frontends/lean/parser.h"
2014-06-18 17:36:21 +00:00
# include "frontends/lean/util.h"
2014-09-23 00:30:29 +00:00
# include "frontends/lean/tokens.h"
2014-10-31 01:33:47 +00:00
# include "frontends/lean/begin_end_ext.h"
2014-06-17 20:11:06 +00:00
namespace lean {
2014-08-23 23:44:06 +00:00
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
static name const & get_fn_const ( expr const & e , char const * msg ) {
expr const & fn = get_app_fn ( e ) ;
if ( ! is_constant ( fn ) )
throw exception ( msg ) ;
return const_name ( fn ) ;
2014-06-17 20:11:06 +00:00
}
2014-10-31 06:24:09 +00:00
static pair < expr , unsigned > extract_arg_types_core ( environment const & env , name const & f , buffer < expr > & arg_types ) {
declaration d = env . get ( f ) ;
expr f_type = d . get_type ( ) ;
2014-06-17 20:11:06 +00:00
while ( is_pi ( f_type ) ) {
arg_types . push_back ( binding_domain ( f_type ) ) ;
f_type = binding_body ( f_type ) ;
}
2015-01-29 01:22:18 +00:00
return mk_pair ( f_type , d . get_num_univ_params ( ) ) ;
2014-10-31 06:24:09 +00:00
}
static expr extract_arg_types ( environment const & env , name const & f , buffer < expr > & arg_types ) {
return extract_arg_types_core ( env , f , arg_types ) . first ;
2014-06-17 20:11:06 +00:00
}
2014-10-31 06:24:09 +00:00
enum class calc_cmd { Subst , Trans , Refl , Symm } ;
2014-08-23 23:44:06 +00:00
struct calc_entry {
calc_cmd m_cmd ;
name m_name ;
calc_entry ( ) { }
calc_entry ( calc_cmd c , name const & n ) : m_cmd ( c ) , m_name ( n ) { }
} ;
2014-06-17 20:11:06 +00:00
2014-08-23 23:44:06 +00:00
struct calc_state {
2014-10-31 06:56:38 +00:00
typedef name_map < std : : tuple < name , unsigned , unsigned > > refl_table ;
typedef name_map < std : : tuple < name , unsigned , unsigned > > subst_table ;
2014-10-31 06:24:09 +00:00
typedef name_map < std : : tuple < name , unsigned , unsigned > > symm_table ;
2014-08-23 23:44:06 +00:00
typedef rb_map < name_pair , std : : tuple < name , name , unsigned > , name_pair_quick_cmp > trans_table ;
trans_table m_trans_table ;
refl_table m_refl_table ;
subst_table m_subst_table ;
2014-10-31 06:24:09 +00:00
symm_table m_symm_table ;
2014-08-23 23:44:06 +00:00
calc_state ( ) { }
2014-06-17 20:11:06 +00:00
2014-08-23 23:44:06 +00:00
void add_calc_subst ( environment const & env , name const & subst ) {
buffer < expr > arg_types ;
2014-10-31 06:56:38 +00:00
auto p = extract_arg_types_core ( env , subst , arg_types ) ;
expr r_type = p . first ;
unsigned nunivs = p . second ;
unsigned nargs = arg_types . size ( ) ;
2014-08-23 23:44:06 +00:00
if ( nargs < 2 )
throw exception ( " invalid calc substitution theorem, it must have at least 2 arguments " ) ;
name const & rop = get_fn_const ( arg_types [ nargs - 2 ] , " invalid calc substitution theorem, argument penultimate argument must be an operator application " ) ;
2014-10-31 06:56:38 +00:00
m_subst_table . insert ( rop , std : : make_tuple ( subst , nargs , nunivs ) ) ;
2014-08-23 23:44:06 +00:00
}
2014-06-17 20:11:06 +00:00
2014-08-23 23:44:06 +00:00
void add_calc_refl ( environment const & env , name const & refl ) {
buffer < expr > arg_types ;
2014-10-31 06:56:38 +00:00
auto p = extract_arg_types_core ( env , refl , arg_types ) ;
expr r_type = p . first ;
unsigned nunivs = p . second ;
unsigned nargs = arg_types . size ( ) ;
2014-08-23 23:44:06 +00:00
if ( nargs < 1 )
throw exception ( " invalid calc reflexivity rule, it must have at least 1 argument " ) ;
name const & rop = get_fn_const ( r_type , " invalid calc reflexivity rule, result type must be an operator application " ) ;
2014-10-31 06:56:38 +00:00
m_refl_table . insert ( rop , std : : make_tuple ( refl , nargs , nunivs ) ) ;
2014-08-23 23:44:06 +00:00
}
2014-06-17 20:11:06 +00:00
2014-08-23 23:44:06 +00:00
void add_calc_trans ( environment const & env , name const & trans ) {
buffer < expr > arg_types ;
expr r_type = extract_arg_types ( env , trans , arg_types ) ;
unsigned nargs = arg_types . size ( ) ;
if ( nargs < 5 )
throw exception ( " invalid calc transitivity rule, it must have at least 5 arguments " ) ;
name const & rop = get_fn_const ( r_type , " invalid calc transitivity rule, result type must be an operator application " ) ;
name const & op1 = get_fn_const ( arg_types [ nargs - 2 ] , " invalid calc transitivity rule, penultimate argument must be an operator application " ) ;
name const & op2 = get_fn_const ( arg_types [ nargs - 1 ] , " invalid calc transitivity rule, last argument must be an operator application " ) ;
m_trans_table . insert ( name_pair ( op1 , op2 ) , std : : make_tuple ( trans , rop , nargs ) ) ;
}
2014-10-31 06:24:09 +00:00
void add_calc_symm ( environment const & env , name const & symm ) {
buffer < expr > arg_types ;
auto p = extract_arg_types_core ( env , symm , arg_types ) ;
expr r_type = p . first ;
unsigned nunivs = p . second ;
unsigned nargs = arg_types . size ( ) ;
2014-11-01 14:30:04 +00:00
if ( nargs < 1 )
throw exception ( " invalid calc symmetry rule, it must have at least 1 argument " ) ;
2014-10-31 06:24:09 +00:00
name const & rop = get_fn_const ( r_type , " invalid calc symmetry rule, result type must be an operator application " ) ;
m_symm_table . insert ( rop , std : : make_tuple ( symm , nargs , nunivs ) ) ;
}
2014-08-23 23:44:06 +00:00
} ;
2014-06-17 20:11:06 +00:00
2014-09-23 00:30:29 +00:00
static name * g_calc_name = nullptr ;
static std : : string * g_key = nullptr ;
2014-08-23 23:44:06 +00:00
struct calc_config {
typedef calc_state state ;
typedef calc_entry entry ;
static void add_entry ( environment const & env , io_state const & , state & s , entry const & e ) {
switch ( e . m_cmd ) {
case calc_cmd : : Refl : s . add_calc_refl ( env , e . m_name ) ; break ;
case calc_cmd : : Subst : s . add_calc_subst ( env , e . m_name ) ; break ;
case calc_cmd : : Trans : s . add_calc_trans ( env , e . m_name ) ; break ;
2014-10-31 06:24:09 +00:00
case calc_cmd : : Symm : s . add_calc_symm ( env , e . m_name ) ; break ;
2014-08-23 23:44:06 +00:00
}
}
static name const & get_class_name ( ) {
2014-09-23 00:30:29 +00:00
return * g_calc_name ;
2014-08-23 23:44:06 +00:00
}
static std : : string const & get_serialization_key ( ) {
2014-09-23 00:30:29 +00:00
return * g_key ;
2014-08-23 23:44:06 +00:00
}
static void write_entry ( serializer & s , entry const & e ) {
s < < static_cast < char > ( e . m_cmd ) < < e . m_name ;
}
static entry read_entry ( deserializer & d ) {
entry e ;
char cmd ;
d > > cmd > > e . m_name ;
e . m_cmd = static_cast < calc_cmd > ( cmd ) ;
return e ;
}
2014-09-30 01:26:53 +00:00
static optional < unsigned > get_fingerprint ( entry const & ) {
return optional < unsigned > ( ) ;
}
2014-08-23 23:44:06 +00:00
} ;
2014-06-17 20:11:06 +00:00
2014-08-23 23:44:06 +00:00
template class scoped_ext < calc_config > ;
typedef scoped_ext < calc_config > calc_ext ;
2014-06-17 20:35:31 +00:00
environment calc_subst_cmd ( parser & p ) {
2014-07-15 00:19:47 +00:00
name id = p . check_constant_next ( " invalid 'calc_subst' command, constant expected " ) ;
2014-08-23 23:44:06 +00:00
return calc_ext : : add_entry ( p . env ( ) , get_dummy_ios ( ) , calc_entry ( calc_cmd : : Subst , id ) ) ;
2014-06-17 20:35:31 +00:00
}
environment calc_refl_cmd ( parser & p ) {
2014-07-15 00:19:47 +00:00
name id = p . check_constant_next ( " invalid 'calc_refl' command, constant expected " ) ;
2014-08-23 23:44:06 +00:00
return calc_ext : : add_entry ( p . env ( ) , get_dummy_ios ( ) , calc_entry ( calc_cmd : : Refl , id ) ) ;
2014-06-17 20:35:31 +00:00
}
environment calc_trans_cmd ( parser & p ) {
2014-07-15 00:19:47 +00:00
name id = p . check_constant_next ( " invalid 'calc_trans' command, constant expected " ) ;
2014-08-23 23:44:06 +00:00
return calc_ext : : add_entry ( p . env ( ) , get_dummy_ios ( ) , calc_entry ( calc_cmd : : Trans , id ) ) ;
2014-06-17 20:35:31 +00:00
}
2014-10-31 06:24:09 +00:00
environment calc_symm_cmd ( parser & p ) {
name id = p . check_constant_next ( " invalid 'calc_symm' command, constant expected " ) ;
return calc_ext : : add_entry ( p . env ( ) , get_dummy_ios ( ) , calc_entry ( calc_cmd : : Symm , id ) ) ;
}
2014-06-17 20:35:31 +00:00
void register_calc_cmds ( cmd_table & r ) {
add_cmd ( r , cmd_info ( " calc_subst " , " set the substitution rule that is used by the calculational proof '{...}' notation " , calc_subst_cmd ) ) ;
add_cmd ( r , cmd_info ( " calc_refl " , " set the reflexivity rule for an operator, this command is relevant for the calculational proof '{...}' notation " , calc_refl_cmd ) ) ;
add_cmd ( r , cmd_info ( " calc_trans " , " set the transitivity rule for a pair of operators, this command is relevant for the calculational proof '{...}' notation " , calc_trans_cmd ) ) ;
2014-10-31 06:24:09 +00:00
add_cmd ( r , cmd_info ( " calc_symm " , " set the symmetry rule for an operator, this command is relevant for the calculational proof '{...}' notation " , calc_symm_cmd ) ) ;
}
2014-10-31 06:56:38 +00:00
static optional < std : : tuple < name , unsigned , unsigned > > get_info ( name_map < std : : tuple < name , unsigned , unsigned > > const & table , name const & op ) {
if ( auto it = table . find ( op ) ) {
2014-10-31 06:24:09 +00:00
return optional < std : : tuple < name , unsigned , unsigned > > ( * it ) ;
} else {
return optional < std : : tuple < name , unsigned , unsigned > > ( ) ;
}
2014-06-17 20:35:31 +00:00
}
2014-06-18 00:15:38 +00:00
2014-10-31 06:56:38 +00:00
optional < std : : tuple < name , unsigned , unsigned > > get_calc_refl_info ( environment const & env , name const & op ) {
return get_info ( calc_ext : : get_state ( env ) . m_refl_table , op ) ;
}
optional < std : : tuple < name , unsigned , unsigned > > get_calc_subst_info ( environment const & env , name const & op ) {
return get_info ( calc_ext : : get_state ( env ) . m_subst_table , op ) ;
}
optional < std : : tuple < name , unsigned , unsigned > > get_calc_symm_info ( environment const & env , name const & op ) {
return get_info ( calc_ext : : get_state ( env ) . m_symm_table , op ) ;
}
2014-10-31 05:22:04 +00:00
static expr mk_calc_annotation_core ( expr const & e ) { return mk_annotation ( * g_calc_name , e ) ; }
static expr mk_calc_annotation ( expr const & pr ) {
if ( is_by ( pr ) | | is_begin_end_annotation ( pr ) | | is_sorry ( pr ) ) {
return pr ;
} else {
return mk_calc_annotation_core ( pr ) ;
}
}
bool is_calc_annotation ( expr const & e ) { return is_annotation ( e , * g_calc_name ) ; }
2014-06-18 00:15:38 +00:00
typedef std : : tuple < name , expr , expr > calc_pred ;
2014-08-19 23:28:58 +00:00
typedef pair < calc_pred , expr > calc_step ;
2014-06-18 00:15:38 +00:00
inline name const & pred_op ( calc_pred const & p ) { return std : : get < 0 > ( p ) ; }
inline expr const & pred_lhs ( calc_pred const & p ) { return std : : get < 1 > ( p ) ; }
inline expr const & pred_rhs ( calc_pred const & p ) { return std : : get < 2 > ( p ) ; }
inline calc_pred const & step_pred ( calc_step const & s ) { return s . first ; }
inline expr const & step_proof ( calc_step const & s ) { return s . second ; }
static void decode_expr_core ( expr const & e , buffer < calc_pred > & preds ) {
buffer < expr > args ;
expr const & fn = get_app_args ( e , args ) ;
if ( ! is_constant ( fn ) )
return ;
unsigned nargs = args . size ( ) ;
if ( nargs < 2 )
return ;
preds . emplace_back ( const_name ( fn ) , args [ nargs - 2 ] , args [ nargs - 1 ] ) ;
}
// Check whether e is of the form (f ...) where f is a constant. If it is return f.
static void decode_expr ( expr const & e , buffer < calc_pred > & preds , pos_info const & pos ) {
preds . clear ( ) ;
if ( is_choice ( e ) ) {
for ( unsigned i = 0 ; i < get_num_choices ( e ) ; i + + )
decode_expr_core ( get_choice ( e , i ) , preds ) ;
} else {
decode_expr_core ( e , preds ) ;
}
if ( preds . empty ( ) )
throw parser_error ( " invalid 'calc' expression, expression must be a function application 'f a_1 ... a_k' "
" where f is a constant, and k >= 2 " , pos ) ;
}
// Create (op _ _ ... _)
static expr mk_op_fn ( parser & p , name const & op , unsigned num_placeholders , pos_info const & pos ) {
2014-06-25 15:30:09 +00:00
expr r = p . save_pos ( mk_explicit ( mk_constant ( op ) ) , pos ) ;
2014-06-18 00:15:38 +00:00
while ( num_placeholders > 0 ) {
num_placeholders - - ;
r = p . mk_app ( r , p . save_pos ( mk_expr_placeholder ( ) , pos ) , pos ) ;
}
return r ;
}
static void parse_calc_proof ( parser & p , buffer < calc_pred > const & preds , std : : vector < calc_step > & steps ) {
steps . clear ( ) ;
auto pos = p . pos ( ) ;
2014-09-23 00:30:29 +00:00
p . check_token_next ( get_colon_tk ( ) , " invalid 'calc' expression, ':' expected " ) ;
if ( p . curr_is_token ( get_lcurly_tk ( ) ) ) {
2014-06-18 00:15:38 +00:00
p . next ( ) ;
expr pr = p . parse_expr ( ) ;
2014-09-23 00:30:29 +00:00
p . check_token_next ( get_rcurly_tk ( ) , " invalid 'calc' expression, '}' expected " ) ;
2014-08-23 23:44:06 +00:00
calc_state const & state = calc_ext : : get_state ( p . env ( ) ) ;
2014-06-18 00:15:38 +00:00
for ( auto const & pred : preds ) {
2014-08-23 23:44:06 +00:00
if ( auto refl_it = state . m_refl_table . find ( pred_op ( pred ) ) ) {
if ( auto subst_it = state . m_subst_table . find ( pred_op ( pred ) ) ) {
2014-10-31 06:56:38 +00:00
expr refl = mk_op_fn ( p , std : : get < 0 > ( * refl_it ) , std : : get < 1 > ( * refl_it ) - 1 , pos ) ;
2014-08-23 23:44:06 +00:00
expr refl_pr = p . mk_app ( refl , pred_lhs ( pred ) , pos ) ;
2014-10-31 06:56:38 +00:00
expr subst = mk_op_fn ( p , std : : get < 0 > ( * subst_it ) , std : : get < 1 > ( * subst_it ) - 2 , pos ) ;
2014-08-23 23:44:06 +00:00
expr subst_pr = p . mk_app ( { subst , pr , refl_pr } , pos ) ;
steps . emplace_back ( pred , subst_pr ) ;
}
2014-06-18 00:15:38 +00:00
}
}
if ( steps . empty ( ) )
2014-08-23 23:44:06 +00:00
throw parser_error ( " invalid 'calc' expression, reflexivity and/or substitution rule is not defined for operator " , pos ) ;
2014-06-18 00:15:38 +00:00
} else {
expr pr = p . parse_expr ( ) ;
for ( auto const & pred : preds )
2014-10-31 01:33:47 +00:00
steps . emplace_back ( pred , mk_calc_annotation ( pr ) ) ;
2014-06-18 00:15:38 +00:00
}
}
/** \brief Collect distinct rhs's */
static void collect_rhss ( std : : vector < calc_step > const & steps , buffer < expr > & rhss ) {
rhss . clear ( ) ;
for ( auto const & step : steps ) {
calc_pred const & pred = step_pred ( step ) ;
expr const & rhs = pred_rhs ( pred ) ;
if ( std : : find ( rhss . begin ( ) , rhss . end ( ) , rhs ) = = rhss . end ( ) )
rhss . push_back ( rhs ) ;
}
lean_assert ( ! rhss . empty ( ) ) ;
}
2014-11-29 23:35:09 +00:00
static void join ( parser & p , std : : vector < calc_step > const & steps1 , std : : vector < calc_step > const & steps2 ,
std : : vector < calc_step > & res_steps , pos_info const & pos ) {
2014-06-18 00:15:38 +00:00
res_steps . clear ( ) ;
2014-08-23 23:44:06 +00:00
calc_state const & state = calc_ext : : get_state ( p . env ( ) ) ;
2014-06-18 00:15:38 +00:00
for ( calc_step const & s1 : steps1 ) {
check_interrupted ( ) ;
calc_pred const & pred1 = step_pred ( s1 ) ;
expr const & pr1 = step_proof ( s1 ) ;
for ( calc_step const & s2 : steps2 ) {
calc_pred const & pred2 = step_pred ( s2 ) ;
expr const & pr2 = step_proof ( s2 ) ;
if ( ! is_eqp ( pred_rhs ( pred1 ) , pred_lhs ( pred2 ) ) )
continue ;
2014-08-23 23:44:06 +00:00
auto trans_it = state . m_trans_table . find ( name_pair ( pred_op ( pred1 ) , pred_op ( pred2 ) ) ) ;
2014-06-18 00:15:38 +00:00
if ( ! trans_it )
continue ;
expr trans = mk_op_fn ( p , std : : get < 0 > ( * trans_it ) , std : : get < 2 > ( * trans_it ) - 5 , pos ) ;
expr trans_pr = p . mk_app ( { trans , pred_lhs ( pred1 ) , pred_rhs ( pred1 ) , pred_rhs ( pred2 ) , pr1 , pr2 } , pos ) ;
res_steps . emplace_back ( calc_pred ( std : : get < 1 > ( * trans_it ) , pred_lhs ( pred1 ) , pred_rhs ( pred2 ) ) , trans_pr ) ;
}
}
}
expr parse_calc ( parser & p ) {
buffer < calc_pred > preds , new_preds ;
buffer < expr > rhss ;
std : : vector < calc_step > steps , new_steps , next_steps ;
auto pos = p . pos ( ) ;
2014-11-29 23:35:09 +00:00
expr first_pred = p . parse_expr ( ) ;
decode_expr ( first_pred , preds , pos ) ;
2014-06-18 00:15:38 +00:00
parse_calc_proof ( p , preds , steps ) ;
2014-11-29 23:35:09 +00:00
bool single = true ; // true if calc has only one step
2014-06-18 00:15:38 +00:00
expr dummy = mk_expr_placeholder ( ) ;
2014-09-23 00:30:29 +00:00
while ( p . curr_is_token ( get_ellipsis_tk ( ) ) ) {
2014-11-29 23:35:09 +00:00
single = false ;
pos = p . pos ( ) ;
2014-06-18 00:15:38 +00:00
p . next ( ) ;
decode_expr ( p . parse_led ( dummy ) , preds , pos ) ;
collect_rhss ( steps , rhss ) ;
new_steps . clear ( ) ;
for ( auto const & pred : preds ) {
if ( is_eqp ( pred_lhs ( pred ) , dummy ) ) {
for ( expr const & rhs : rhss )
new_preds . emplace_back ( pred_op ( pred ) , rhs , pred_rhs ( pred ) ) ;
}
}
if ( new_preds . empty ( ) )
throw parser_error ( " invalid 'calc' expression, invalid expression " , pos ) ;
parse_calc_proof ( p , new_preds , new_steps ) ;
join ( p , steps , new_steps , next_steps , pos ) ;
if ( next_steps . empty ( ) )
throw parser_error ( " invalid 'calc' expression, transitivity rule is not defined for current step " , pos ) ;
steps . swap ( next_steps ) ;
}
buffer < expr > choices ;
2014-11-29 23:35:09 +00:00
for ( auto const & s : steps ) {
if ( single ) {
expr new_s = p . save_pos ( mk_typed_expr ( first_pred , step_proof ( s ) ) , pos ) ;
choices . push_back ( new_s ) ;
} else {
choices . push_back ( step_proof ( s ) ) ;
}
}
2014-06-26 15:50:44 +00:00
return p . save_pos ( mk_choice ( choices . size ( ) , choices . data ( ) ) , pos ) ;
2014-06-18 00:15:38 +00:00
}
2014-09-23 00:30:29 +00:00
void initialize_calc ( ) {
g_calc_name = new name ( " calc " ) ;
g_key = new std : : string ( " calc " ) ;
calc_ext : : initialize ( ) ;
2014-10-31 05:22:04 +00:00
register_annotation ( * g_calc_name ) ;
2014-09-23 00:30:29 +00:00
}
void finalize_calc ( ) {
calc_ext : : finalize ( ) ;
delete g_key ;
delete g_calc_name ;
}
2014-06-17 20:11:06 +00:00
}