diff --git a/src/library/definitional/brec_on.cpp b/src/library/definitional/brec_on.cpp index 12136315f..f69039b68 100644 --- a/src/library/definitional/brec_on.cpp +++ b/src/library/definitional/brec_on.cpp @@ -47,15 +47,17 @@ static environment mk_below(environment const & env, name const & n, bool ibelow unsigned nminors = *inductive::get_num_minor_premises(env, n); unsigned ntypeformers = length(std::get<2>(decls)); level_param_names lps = rec_decl.get_univ_params(); - level lvl = mk_param_univ(head(lps)); // universe we are eliminating too + bool is_reflexive = is_reflexive_datatype(tc, n); + level lvl = mk_param_univ(head(lps)); // universe we are eliminating to levels lvls = param_names_to_levels(tail(lps)); - levels blvls; - level rlvl; + levels blvls; // universe level parameters of ibelow/below + level rlvl; // universe level of the resultant type name prod_name; - expr unit, outer_prod; + expr unit, outer_prod, inner_prod; // The arguments of below (ibelow) are the ones in the recursor - minor premises. // The universe we map to is also different (l+1 for below) and (0 fo ibelow). expr ref_type; + expr Type_result; if (ibelow) { // we are eliminating to Prop blvls = lvls; @@ -63,8 +65,10 @@ static environment mk_below(environment const & env, name const & n, bool ibelow unit = mk_constant("true"); prod_name = name("and"); outer_prod = mk_constant(prod_name); + inner_prod = outer_prod; ref_type = instantiate_univ_param(rec_decl.get_type(), param_id(lvl), mk_level_zero()); - } else { + Type_result = mk_sort(rlvl); + } else if (is_reflexive) { blvls = cons(lvl, lvls); rlvl = get_datatype_level(ind_decl.get_type()); // if rlvl is of the form (max 1 l), then rlvl <- l @@ -75,8 +79,18 @@ static environment mk_below(environment const & env, name const & n, bool ibelow prod_name = name("prod"); outer_prod = mk_constant(prod_name, {rlvl, rlvl}); ref_type = instantiate_univ_param(rec_decl.get_type(), param_id(lvl), mk_succ(lvl)); + Type_result = mk_sort(rlvl); + } else { + // we can simplify the universe levels for non-reflexive datatypes + blvls = cons(lvl, lvls); + rlvl = mk_max(mk_level_one(), lvl); + unit = mk_constant("unit", rlvl); + prod_name = name("prod"); + outer_prod = mk_constant(prod_name, {rlvl, rlvl}); + inner_prod = mk_constant(prod_name, {lvl, rlvl}); + ref_type = rec_decl.get_type(); + Type_result = mk_sort(rlvl); } - expr Type_result = mk_sort(rlvl); buffer ref_args; to_telescope(ngen, ref_type, ref_args); if (ref_args.size() != nparams + ntypeformers + nminors + nindices + 1) @@ -121,10 +135,8 @@ static environment mk_below(environment const & env, name const & n, bool ibelow expr r = minor_arg; expr fst = mlocal_type(minor_arg); expr snd = Pi(minor_arg_args, mk_app(r, minor_arg_args)); - expr inner_prod; - if (ibelow) { - inner_prod = outer_prod; // and - } else { + if (!ibelow && is_reflexive) { + // inner product is not constant level fst_lvl = sort_level(tc.ensure_type(fst).first); inner_prod = mk_constant(prod_name, {fst_lvl, rlvl}); } diff --git a/src/library/definitional/util.cpp b/src/library/definitional/util.cpp index c17c30e14..0cd8466c4 100644 --- a/src/library/definitional/util.cpp +++ b/src/library/definitional/util.cpp @@ -64,6 +64,28 @@ bool is_recursive_datatype(environment const & env, name const & n) { return false; } +bool is_reflexive_datatype(type_checker & tc, name const & n) { + environment const & env = tc.env(); + name_generator ngen = tc.mk_ngen(); + optional decls = inductive::is_inductive_decl(env, n); + if (!decls) + return false; + for (inductive::inductive_decl const & decl : std::get<2>(*decls)) { + for (inductive::intro_rule const & intro : inductive::inductive_decl_intros(decl)) { + expr type = inductive::intro_rule_type(intro); + while (is_pi(type)) { + expr arg = tc.whnf(binding_domain(type)).first; + if (is_pi(arg) && find(arg, [&](expr const & e, unsigned) { return is_constant(e) && const_name(e) == n; })) { + return true; + } + expr local = mk_local(ngen.next(), binding_domain(type)); + type = instantiate(binding_body(type), local); + } + } + } + return false; +} + level get_datatype_level(expr ind_type) { while (is_pi(ind_type)) ind_type = binding_body(ind_type); diff --git a/src/library/definitional/util.h b/src/library/definitional/util.h index f54e4ac40..ccc3b7a9c 100644 --- a/src/library/definitional/util.h +++ b/src/library/definitional/util.h @@ -22,6 +22,13 @@ bool has_prod_decls(environment const & env); */ bool is_recursive_datatype(environment const & env, name const & n); +/** \brief Return true if \c n is a recursive *and* reflexive datatype. + + We say an inductive type T is reflexive if it contains at least one constructor that + takes as an argument a function returning T. +*/ +bool is_reflexive_datatype(type_checker & tc, name const & n); + /** \brief Return true iff \c n is an inductive predicate, i.e., an inductive datatype that is in Prop. \remark If \c env does not have Prop (i.e., Type.{0} is not impredicative), then this method always return false. diff --git a/tests/lean/run/vector.lean b/tests/lean/run/vector.lean index b15308780..019729dc6 100644 --- a/tests/lean/run/vector.lean +++ b/tests/lean/run/vector.lean @@ -6,7 +6,7 @@ vnil {} : vector A zero, vcons : Π {n : nat}, A → vector A n → vector A (succ n) namespace vector - print definition no_confusion + -- print definition no_confusion infixr `::` := vcons theorem vcons.inj₁ {A : Type} {n : nat} (a₁ a₂ : A) (v₁ v₂ : vector A n) : vcons a₁ v₁ = vcons a₂ v₂ → a₁ = a₂ := @@ -19,10 +19,13 @@ namespace vector intro h, apply heq.to_eq, apply (no_confusion h), intros, eassumption, end + set_option pp.universes true + check @below + section universe variables l₁ l₂ variable {A : Type.{l₁}} - variable {C : Π (n : nat), vector A n → Type.{l₂+1}} + variable {C : Π (n : nat), vector A n → Type.{l₂}} definition brec_on {n : nat} (v : vector A n) (H : Π (n : nat) (v : vector A n), @below A C n v → C n v) : C n v := have general : C n v × @below A C n v, from rec_on v @@ -36,7 +39,7 @@ namespace vector pr₁ general end - check brec_on + -- check brec_on definition bw := @below @@ -94,7 +97,7 @@ namespace vector example : add (1 :: 2 :: vnil) (3 :: 5 :: vnil) = 4 :: 7 :: vnil := rfl - definition map {A B C : Type'} {n : nat} (f : A → B → C) (w : vector A n) (v : vector B n) : vector C n := + definition map {A B C : Type} {n : nat} (f : A → B → C) (w : vector A n) (v : vector B n) : vector C n := let P := λ (n : nat) (v : vector A n), vector B n → vector C n in @brec_on A P n w (λ (n : nat) (w : vector A n), @@ -111,6 +114,13 @@ namespace vector end end) v + theorem map_nil_nil {A B C : Type} (f : A → B → C) : map f vnil vnil = vnil := + rfl + + theorem map_cons_cons {A B C : Type} (f : A → B → C) (a : A) (b : B) {n : nat} (va : vector A n) (vb : vector B n) : + map f (a :: va) (b :: vb) = f a b :: map f va vb := + rfl + example : map nat.add (1 :: 2 :: vnil) (3 :: 5 :: vnil) = 4 :: 7 :: vnil := rfl