Update 'orelse' and 'then' rewriter to take a list of rewriters

This commit is contained in:
Soonho Kong 2013-09-25 16:46:39 -07:00
parent a50f5f92b8
commit 1d8b7dc193
3 changed files with 62 additions and 30 deletions

View file

@ -100,27 +100,42 @@ pair<expr, expr> theorem_rewriter_cell::operator()(environment const &, context
// OrElse Rewriter // OrElse Rewriter
orelse_rewriter_cell::orelse_rewriter_cell(rewriter const & rw1, rewriter const & rw2) orelse_rewriter_cell::orelse_rewriter_cell(rewriter const & rw1, rewriter const & rw2)
:rewriter_cell(rewriter_kind::OrElse), m_rw1(rw1), m_rw2(rw2) { } :rewriter_cell(rewriter_kind::OrElse), m_rwlist({rw1, rw2}) { }
orelse_rewriter_cell::orelse_rewriter_cell(std::initializer_list<rewriter> const & l)
:rewriter_cell(rewriter_kind::OrElse), m_rwlist(l) {
lean_assert(l.size() >= 2);
}
orelse_rewriter_cell::~orelse_rewriter_cell() { } orelse_rewriter_cell::~orelse_rewriter_cell() { }
pair<expr, expr> orelse_rewriter_cell::operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception) { pair<expr, expr> orelse_rewriter_cell::operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception) {
for (rewriter const & rw : m_rwlist) {
try { try {
return m_rw1(env, ctx, v); return rw(env, ctx, v);
} catch (rewriter_exception & ) { } catch (rewriter_exception & ) {
return m_rw2(env, ctx, v); // Do nothing
} }
} }
// If the execution reaches here, it means every rewriter failed.
throw rewriter_exception();
}
// Then Rewriter // Then Rewriter
then_rewriter_cell::then_rewriter_cell(rewriter const & rw1, rewriter const & rw2) then_rewriter_cell::then_rewriter_cell(rewriter const & rw1, rewriter const & rw2)
:rewriter_cell(rewriter_kind::Then), m_rw1(rw1), m_rw2(rw2) { } :rewriter_cell(rewriter_kind::Then), m_rwlist({rw1, rw2}) { }
then_rewriter_cell::then_rewriter_cell(std::initializer_list<rewriter> const & l)
:rewriter_cell(rewriter_kind::Then), m_rwlist(l) {
lean_assert(l.size() >= 2);
}
then_rewriter_cell::~then_rewriter_cell() { } then_rewriter_cell::~then_rewriter_cell() { }
pair<expr, expr> then_rewriter_cell::operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception) { pair<expr, expr> then_rewriter_cell::operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception) {
pair<expr, expr> result1 = m_rw1(env, ctx, v); pair<expr, expr> result = car(m_rwlist)(env, ctx, v);
pair<expr, expr> result2 = m_rw2(env, ctx, result1.first); pair<expr, expr> new_result = result;
light_checker lc(env); for (rewriter const & rw : cdr(m_rwlist)) {
expr const & t = lc(v, ctx); new_result = rw(env, ctx, result.first);
return make_pair(result2.first, expr const & t = light_checker(env)(v, ctx);
Trans(t, v, result1.first, result2.first, result1.second, result2.second)); result = make_pair(new_result.first,
Trans(t, v, result.first, new_result.first, result.second, new_result.second));
}
return result;
} }
// App Rewriter // App Rewriter
@ -232,9 +247,15 @@ rewriter mk_theorem_rewriter(expr const & type, expr const & body) {
rewriter mk_then_rewriter(rewriter const & rw1, rewriter const & rw2) { rewriter mk_then_rewriter(rewriter const & rw1, rewriter const & rw2) {
return rewriter(new then_rewriter_cell(rw1, rw2)); return rewriter(new then_rewriter_cell(rw1, rw2));
} }
rewriter mk_then_rewriter(std::initializer_list<rewriter> const & l) {
return rewriter(new then_rewriter_cell(l));
}
rewriter mk_orelse_rewriter(rewriter const & rw1, rewriter const & rw2) { rewriter mk_orelse_rewriter(rewriter const & rw1, rewriter const & rw2) {
return rewriter(new orelse_rewriter_cell(rw1, rw2)); return rewriter(new orelse_rewriter_cell(rw1, rw2));
} }
rewriter mk_orelse_rewriter(std::initializer_list<rewriter> const & l) {
return rewriter(new orelse_rewriter_cell(l));
}
rewriter mk_app_rewriter(rewriter const & rw) { rewriter mk_app_rewriter(rewriter const & rw) {
return rewriter(new app_rewriter_cell(rw)); return rewriter(new app_rewriter_cell(rw));
} }

View file

@ -80,20 +80,20 @@ public:
class orelse_rewriter_cell : public rewriter_cell { class orelse_rewriter_cell : public rewriter_cell {
private: private:
rewriter m_rw1; list<rewriter> m_rwlist;
rewriter m_rw2;
public: public:
orelse_rewriter_cell(rewriter const & rw1, rewriter const & rw2); orelse_rewriter_cell(rewriter const & rw1, rewriter const & rw2);
orelse_rewriter_cell(std::initializer_list<rewriter> const & l);
~orelse_rewriter_cell(); ~orelse_rewriter_cell();
std::pair<expr, expr> operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception); std::pair<expr, expr> operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception);
}; };
class then_rewriter_cell : public rewriter_cell { class then_rewriter_cell : public rewriter_cell {
private: private:
rewriter m_rw1; list<rewriter> m_rwlist;
rewriter m_rw2;
public: public:
then_rewriter_cell(rewriter const & rw1, rewriter const & rw2); then_rewriter_cell(rewriter const & rw1, rewriter const & rw2);
then_rewriter_cell(std::initializer_list<rewriter> const & l);
~then_rewriter_cell(); ~then_rewriter_cell();
std::pair<expr, expr> operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception); std::pair<expr, expr> operator()(environment const & env, context & ctx, expr const & v) const throw(rewriter_exception);
}; };
@ -159,7 +159,9 @@ public:
rewriter mk_theorem_rewriter(expr const & type, expr const & body); rewriter mk_theorem_rewriter(expr const & type, expr const & body);
rewriter mk_then_rewriter(rewriter const & rw1, rewriter const & rw2); rewriter mk_then_rewriter(rewriter const & rw1, rewriter const & rw2);
rewriter mk_then_rewriter(std::initializer_list<rewriter> const & l);
rewriter mk_orelse_rewriter(rewriter const & rw1, rewriter const & rw2); rewriter mk_orelse_rewriter(rewriter const & rw1, rewriter const & rw2);
rewriter mk_orelse_rewriter(std::initializer_list<rewriter> const & l);
rewriter mk_app_rewriter(rewriter const & rw); rewriter mk_app_rewriter(rewriter const & rw);
rewriter mk_lambda_rewriter(rewriter const & rw); rewriter mk_lambda_rewriter(rewriter const & rw);
rewriter mk_pi_rewriter(rewriter const & rw); rewriter mk_pi_rewriter(rewriter const & rw);

View file

@ -51,7 +51,7 @@ static void theorem_rewriter1_tst() {
cout << " " << concl << " := " << proof << endl; cout << " " << concl << " := " << proof << endl;
lean_assert(concl == mk_eq(a_plus_b, b_plus_a)); lean_assert(concl == mk_eq(a_plus_b, b_plus_a));
lean_assert(proof == mk_app(mk_app(Const("ADD_COMM"), a), b)); lean_assert(proof == Const("ADD_COMM")(a, b));
env.add_theorem("New_theorem1", concl, proof); env.add_theorem("New_theorem1", concl, proof);
} }
@ -82,7 +82,7 @@ static void theorem_rewriter2_tst() {
cout << " " << concl << " := " << proof << endl; cout << " " << concl << " := " << proof << endl;
lean_assert(concl == mk_eq(a_plus_zero, a)); lean_assert(concl == mk_eq(a_plus_zero, a));
lean_assert(proof == mk_app(Const("ADD_ID"), a)); lean_assert(proof == Const("ADD_ID")(a));
env.add_theorem("New_theorem2", concl, proof); env.add_theorem("New_theorem2", concl, proof);
} }
@ -125,7 +125,7 @@ static void then_rewriter1_tst() {
lean_assert(concl == mk_eq(zero_plus_a, a)); lean_assert(concl == mk_eq(zero_plus_a, a));
lean_assert(proof == Trans(Nat, zero_plus_a, a_plus_zero, a, lean_assert(proof == Trans(Nat, zero_plus_a, a_plus_zero, a,
mk_app(mk_app(Const("ADD_COMM"), zero), a), mk_app(Const("ADD_ID"), a))); Const("ADD_COMM")(zero, a), Const("ADD_ID")(a)));
env.add_theorem("New_theorem3", concl, proof); env.add_theorem("New_theorem3", concl, proof);
} }
@ -171,8 +171,10 @@ static void then_rewriter2_tst() {
rewriter add_assoc_thm_rewriter = mk_theorem_rewriter(add_assoc_thm_type, add_assoc_thm_body); rewriter add_assoc_thm_rewriter = mk_theorem_rewriter(add_assoc_thm_type, add_assoc_thm_body);
rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body);
rewriter add_id_thm_rewriter = mk_theorem_rewriter(add_id_thm_type, add_id_thm_body); rewriter add_id_thm_rewriter = mk_theorem_rewriter(add_id_thm_type, add_id_thm_body);
rewriter then_rewriter2 = mk_then_rewriter(mk_then_rewriter(add_assoc_thm_rewriter, add_id_thm_rewriter), rewriter then_rewriter2 = mk_then_rewriter({add_assoc_thm_rewriter,
mk_then_rewriter(add_comm_thm_rewriter, add_id_thm_rewriter)); add_id_thm_rewriter,
add_comm_thm_rewriter,
add_id_thm_rewriter});
context ctx; context ctx;
pair<expr, expr> result = then_rewriter2(env, ctx, zero_plus_a_plus_zero); pair<expr, expr> result = then_rewriter2(env, ctx, zero_plus_a_plus_zero);
expr concl = mk_eq(zero_plus_a_plus_zero, result.first); expr concl = mk_eq(zero_plus_a_plus_zero, result.first);
@ -183,13 +185,12 @@ static void then_rewriter2_tst() {
cout << " " << concl << " := " << proof << endl; cout << " " << concl << " := " << proof << endl;
lean_assert(concl == mk_eq(zero_plus_a_plus_zero, a)); lean_assert(concl == mk_eq(zero_plus_a_plus_zero, a));
lean_assert(proof == Trans(Nat, zero_plus_a_plus_zero, zero_plus_a, a, lean_assert(proof == Trans(Nat, zero_plus_a_plus_zero, a_plus_zero, a,
Trans(Nat, zero_plus_a_plus_zero, zero_plus_a, a_plus_zero,
Trans(Nat, zero_plus_a_plus_zero, zero_plus_a_plus_zero_, zero_plus_a, Trans(Nat, zero_plus_a_plus_zero, zero_plus_a_plus_zero_, zero_plus_a,
mk_app(mk_app(mk_app(Const("ADD_ASSOC"), zero), a), zero), Const("ADD_ASSOC")(zero, a, zero), Const("ADD_ID")(zero_plus_a)),
mk_app(Const("ADD_ID"), zero_plus_a)), Const("ADD_COMM")(zero, a)),
Trans(Nat, zero_plus_a, a_plus_zero, a, Const("ADD_ID")(a)));
mk_app(mk_app(Const("ADD_COMM"), zero), a),
mk_app(Const("ADD_ID"), a))));
env.add_theorem("New_theorem4", concl, proof); env.add_theorem("New_theorem4", concl, proof);
} }
@ -203,6 +204,7 @@ static void orelse_rewriter1_tst() {
// Result : (b + a, ADD_COMM a b) // Result : (b + a, ADD_COMM a b)
expr a = Const("a"); // a : Nat expr a = Const("a"); // a : Nat
expr b = Const("b"); // b : Nat expr b = Const("b"); // b : Nat
expr zero = nVal(0); // zero : Nat
expr a_plus_b = nAdd(a, b); expr a_plus_b = nAdd(a, b);
expr b_plus_a = nAdd(b, a); expr b_plus_a = nAdd(b, a);
expr add_assoc_thm_type = Pi("x", Nat, expr add_assoc_thm_type = Pi("x", Nat,
@ -215,6 +217,9 @@ static void orelse_rewriter1_tst() {
Pi("y", Nat, Pi("y", Nat,
Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x"))))); Eq(nAdd(Const("x"), Const("y")), nAdd(Const("y"), Const("x")))));
expr add_comm_thm_body = Const("ADD_COMM"); expr add_comm_thm_body = Const("ADD_COMM");
expr add_id_thm_type = Pi("x", Nat,
Eq(nAdd(Const("x"), zero), Const("x")));
expr add_id_thm_body = Const("ADD_ID");
environment env = mk_toplevel(); environment env = mk_toplevel();
env.add_var("a", Nat); env.add_var("a", Nat);
@ -224,7 +229,10 @@ static void orelse_rewriter1_tst() {
// Rewriting // Rewriting
rewriter add_assoc_thm_rewriter = mk_theorem_rewriter(add_assoc_thm_type, add_assoc_thm_body); rewriter add_assoc_thm_rewriter = mk_theorem_rewriter(add_assoc_thm_type, add_assoc_thm_body);
rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body); rewriter add_comm_thm_rewriter = mk_theorem_rewriter(add_comm_thm_type, add_comm_thm_body);
rewriter add_assoc_or_comm_thm_rewriter = mk_orelse_rewriter(add_assoc_thm_rewriter, add_comm_thm_rewriter); rewriter add_id_thm_rewriter = mk_theorem_rewriter(add_id_thm_type, add_id_thm_body);
rewriter add_assoc_or_comm_thm_rewriter = mk_orelse_rewriter({add_assoc_thm_rewriter,
add_comm_thm_rewriter,
add_id_thm_rewriter});
context ctx; context ctx;
pair<expr, expr> result = add_assoc_or_comm_thm_rewriter(env, ctx, a_plus_b); pair<expr, expr> result = add_assoc_or_comm_thm_rewriter(env, ctx, a_plus_b);
expr concl = mk_eq(a_plus_b, result.first); expr concl = mk_eq(a_plus_b, result.first);
@ -232,10 +240,11 @@ static void orelse_rewriter1_tst() {
cout << "Theorem: " << add_assoc_thm_type << " := " << add_assoc_thm_body << endl; cout << "Theorem: " << add_assoc_thm_type << " := " << add_assoc_thm_body << endl;
cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl; cout << "Theorem: " << add_comm_thm_type << " := " << add_comm_thm_body << endl;
cout << "Theorem: " << add_id_thm_type << " := " << add_id_thm_body << endl;
cout << " " << concl << " := " << proof << endl; cout << " " << concl << " := " << proof << endl;
lean_assert(concl == mk_eq(a_plus_b, b_plus_a)); lean_assert(concl == mk_eq(a_plus_b, b_plus_a));
lean_assert(proof == mk_app(mk_app(Const("ADD_COMM"), a), b)); lean_assert(proof == Const("ADD_COMM")(a, b));
env.add_theorem("New_theorem5", concl, proof); env.add_theorem("New_theorem5", concl, proof);
} }