Re-implement rewrite module using rewrite_cell

This commit is contained in:
Soonho Kong 2013-09-24 11:00:35 -07:00
parent ba0528c298
commit 57e9e2c658
4 changed files with 200 additions and 108 deletions

View file

@ -54,13 +54,13 @@ bool fo_match::match_var(expr const & p, expr const & t, unsigned o, subst_map &
}
}
bool fo_match::match_constant(expr const & p, expr const & t, unsigned o, subst_map & s) {
lean_trace("fo_match", tout << "match_constant : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;);
bool fo_match::match_constant(expr const & p, expr const & t, unsigned, subst_map &) {
lean_trace("fo_match", tout << "match_constant : (" << p << ", " << t << ")" << endl;);
return p == t;
}
bool fo_match::match_value(expr const & p, expr const & t, unsigned o, subst_map & s) {
lean_trace("fo_match", tout << "match_value : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;);
bool fo_match::match_value(expr const & p, expr const & t, unsigned, subst_map &) {
lean_trace("fo_match", tout << "match_value : (" << p << ", " << t << ")" << endl;);
return p == t;
}
@ -119,8 +119,8 @@ bool fo_match::match_pi(expr const & p, expr const & t, unsigned o, subst_map &
}
}
bool fo_match::match_type(expr const & p, expr const & t, unsigned o, subst_map & s) {
lean_trace("fo_match", tout << "match_type : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;);
bool fo_match::match_type(expr const & p, expr const & t, unsigned, subst_map &) {
lean_trace("fo_match", tout << "match_type : (" << p << ", " << t << ")" << endl;);
return p == t;
}
@ -154,8 +154,8 @@ bool fo_match::match_let(expr const & p, expr const & t, unsigned o, subst_map &
return match(p_body, t_body, o + 1, s);
}
}
bool fo_match::match_metavar(expr const & p, expr const & t, unsigned o, subst_map & s) {
lean_trace("fo_match", tout << "match_meta : (" << p << ", " << t << ", " << o << ", " << s << ")" << endl;);
bool fo_match::match_metavar(expr const & p, expr const & t, unsigned, subst_map &) {
lean_trace("fo_match", tout << "match_meta : (" << p << ", " << t << ")" << endl;);
return p == t;
}

View file

@ -23,61 +23,74 @@ using std::make_pair;
namespace lean {
theorem_rewrite::theorem_rewrite(expr const & type, expr const & body)
: thm_type(type), thm_body(body), num_args(0) {
lean_trace("rewrite", tout << "Type = " << thm_type << endl;);
lean_trace("rewrite", tout << "Body = " << thm_body << endl;);
void rewrite_cell::dealloc() {
delete this;
}
rewrite_cell::rewrite_cell(rewrite_kind k):m_kind(k), m_rc(1) { }
rewrite_cell::~rewrite_cell() {
}
// Theorem Rewrite
theorem_rewrite_cell::theorem_rewrite_cell(expr const & type, expr const & body)
: rewrite_cell(rewrite_kind::Theorem), m_type(type), m_body(body), m_num_args(0) {
lean_trace("rewrite", tout << "Type = " << m_type << endl;);
lean_trace("rewrite", tout << "Body = " << m_body << endl;);
// We expect the theorem is in the form of
// Pi (x_1 : t_1 ... x_n : t_n), pattern = rhs
pattern = type;
while (is_pi(pattern)) {
pattern = abst_body(pattern);
num_args++;
m_pattern = m_type;
while (is_pi(m_pattern)) {
m_pattern = abst_body(m_pattern);
m_num_args++;
}
if (!is_eq(pattern)) {
lean_trace("rewrite", tout << "Theorem " << thm_type << " is not in the form of "
if (!is_eq(m_pattern)) {
lean_trace("rewrite", tout << "Theorem " << m_type << " is not in the form of "
<< "Pi (x_1 : t_1 ... x_n : t_n), pattern = rhs" << endl;);
}
rhs = eq_rhs(pattern);
pattern = eq_lhs(pattern);
m_rhs = eq_rhs(m_pattern);
m_pattern = eq_lhs(m_pattern);
lean_trace("rewrite", tout << "Number of Arg = " << num_args << endl;);
lean_trace("rewrite", tout << "Number of Arg = " << m_num_args << endl;);
}
pair<expr, expr> theorem_rewrite::operator()(context & ctx, expr const & v, environment const & ) const throw(rewrite_exception) {
lean_trace("rewrite", tout << "Context = " << ctx << endl;);
theorem_rewrite_cell::~theorem_rewrite_cell() { }
pair<expr, expr> theorem_rewrite_cell::operator()(context &, expr const & v, environment const & ) const throw(rewrite_exception) {
// lean_trace("rewrite", tout << "Context = " << ctx << endl;);
lean_trace("rewrite", tout << "Term = " << v << endl;);
lean_trace("rewrite", tout << "Pattern = " << pattern << endl;);
lean_trace("rewrite", tout << "Num Args = " << num_args << endl;);
lean_trace("rewrite", tout << "Pattern = " << m_pattern << endl;);
lean_trace("rewrite", tout << "Num Args = " << m_num_args << endl;);
fo_match fm;
subst_map s;
if (!fm.match(pattern, v, 0, s)) {
if (!fm.match(m_pattern, v, 0, s)) {
throw rewrite_exception();
}
// apply s to rhs
auto f = [&s](expr const & e, unsigned offset) -> expr {
if (is_var(e)) {
lean_trace("rewrite", tout << "Inside of apply : offset = " << offset
<< ", e = " << e
<< ", idx = " << var_idx(e) << endl;);
unsigned idx = var_idx(e);
auto it = s.find(idx);
if (it != s.end()) {
lean_trace("rewrite", tout << "Inside of apply : s[" << idx << "] = " << s[idx] << endl;);
return s[idx];
}
if (!is_var(e)) {
return e;
}
unsigned idx = var_idx(e);
if (idx < offset) {
return e;
}
lean_trace("rewrite", tout << "Inside of apply : offset = " << offset
<< ", e = " << e
<< ", idx = " << var_idx(e) << endl;);
auto it = s.find(idx);
if (it != s.end()) {
lean_trace("rewrite", tout << "Inside of apply : s[" << idx << "] = " << s[idx] << endl;);
return s[idx];
}
return e;
};
expr new_rhs = replace_fn<decltype(f)>(f)(rhs);
expr new_rhs = replace_fn<decltype(f)>(f)(m_rhs);
lean_trace("rewrite", tout << "New RHS = " << new_rhs << endl;);
expr proof = thm_body;
for (int i = num_args -1 ; i >= 0; i--) {
expr proof = m_body;
for (int i = m_num_args -1 ; i >= 0; i--) {
proof = mk_app(proof, s[i]);
lean_trace("rewrite", tout << "proof: " << i << "\t" << s[i] << "\t" << proof << endl;);
}
@ -85,63 +98,88 @@ pair<expr, expr> theorem_rewrite::operator()(context & ctx, expr const & v, envi
return make_pair(new_rhs, proof);
}
pair<expr, expr> orelse_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
// OrElse Rewrite
orelse_rewrite_cell::orelse_rewrite_cell(rewrite const & rw1, rewrite const & rw2)
:rewrite_cell(rewrite_kind::OrElse), m_rw1(rw1), m_rw2(rw2) { }
orelse_rewrite_cell::~orelse_rewrite_cell() { }
pair<expr, expr> orelse_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
try {
return rw1(ctx, v, env);
return m_rw1(ctx, v, env);
} catch (rewrite_exception & ) {
return rw2(ctx, v, env);
return m_rw2(ctx, v, env);
}
}
pair<expr, expr> then_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
pair<expr, expr> result1 = rw1(ctx, v, env);
pair<expr, expr> result2 = rw2(ctx, result1.first, env);
expr const & t = light_checker(env)(v, ctx);
// Then Rewrite
then_rewrite_cell::then_rewrite_cell(rewrite const & rw1, rewrite const & rw2)
:rewrite_cell(rewrite_kind::Then), m_rw1(rw1), m_rw2(rw2) { }
then_rewrite_cell::~then_rewrite_cell() { }
pair<expr, expr> then_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
pair<expr, expr> result1 = m_rw1(ctx, v, env);
pair<expr, expr> result2 = m_rw2(ctx, result1.first, env);
light_checker lc(env);
expr const & t = lc(v, ctx);
return make_pair(result2.first,
Trans(t, v, result1.first, result2.first, result1.second, result2.second));
}
pair<expr, expr> app_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
// App Rewrite
app_rewrite_cell::app_rewrite_cell(rewrite const & rw)
:rewrite_cell(rewrite_kind::App), m_rw(rw) { }
app_rewrite_cell::~app_rewrite_cell() { }
pair<expr, expr> app_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
if (!is_app(v))
throw rewrite_exception();
unsigned n = num_args(v);
for (unsigned i = 0; i < n; i++) {
auto result = rw(ctx, arg(v, i), env);
auto result = m_rw(ctx, arg(v, i), env);
}
// TODO(soonhok)
throw rewrite_exception();
}
pair<expr, expr> lambda_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
// Lambda Rewrite
lambda_rewrite_cell::lambda_rewrite_cell(rewrite const & rw)
:rewrite_cell(rewrite_kind::Lambda), m_rw(rw) { }
lambda_rewrite_cell::~lambda_rewrite_cell() { }
pair<expr, expr> lambda_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
if (!is_lambda(v))
throw rewrite_exception();
expr const & domain = abst_domain(v);
expr const & body = abst_body(v);
auto result_domain = rw(ctx, domain, env);
auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context!
auto result_domain = m_rw(ctx, domain, env);
auto result_body = m_rw(ctx, body, env); // TODO(soonhok): add to context!
// TODO(soonhok)
throw rewrite_exception();
}
pair<expr, expr> pi_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
pi_rewrite_cell::pi_rewrite_cell(rewrite const & rw)
:rewrite_cell(rewrite_kind::Pi), m_rw(rw) { }
pi_rewrite_cell::~pi_rewrite_cell() { }
pair<expr, expr> pi_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
if (!is_pi(v))
throw rewrite_exception();
expr const & domain = abst_domain(v);
expr const & body = abst_body(v);
auto result_domain = rw(ctx, domain, env);
auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context!
auto result_domain = m_rw(ctx, domain, env);
auto result_body = m_rw(ctx, body, env); // TODO(soonhok): add to context!
// TODO(soonhok)
throw rewrite_exception();
}
pair<expr, expr> let_rewrite::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
let_rewrite_cell::let_rewrite_cell(rewrite const & rw)
:rewrite_cell(rewrite_kind::Let), m_rw(rw) { }
let_rewrite_cell::~let_rewrite_cell() { }
pair<expr, expr> let_rewrite_cell::operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) {
if (!is_let(v))
throw rewrite_exception();
@ -149,11 +187,22 @@ pair<expr, expr> let_rewrite::operator()(context & ctx, expr const & v, environm
expr const & value = let_value(v);
expr const & body = let_body(v);
auto result_ty = rw(ctx, ty, env);
auto result_value = rw(ctx, value, env);
auto result_body = rw(ctx, body, env); // TODO(soonhok): add to context!
auto result_ty = m_rw(ctx, ty, env);
auto result_value = m_rw(ctx, value, env);
auto result_body = m_rw(ctx, body, env); // TODO(soonhok): add to context!
// TODO(soonhok)
throw rewrite_exception();
}
rewrite mk_theorem_rewrite(expr const & type, expr const & body) {
return rewrite(new theorem_rewrite_cell(type, body));
}
rewrite mk_then_rewrite(rewrite const & rw1, rewrite const & rw2) {
return rewrite(new then_rewrite_cell(rw1, rw2));
}
rewrite mk_orelse_rewrite(rewrite const & rw1, rewrite const & rw2) {
return rewrite(new orelse_rewrite_cell(rw1, rw2));
}
}

View file

@ -25,85 +25,128 @@ namespace lean {
class rewrite_exception : public exception {
};
enum class rewrite_kind { Theorem, OrElse, Then, App, Lambda, Pi, Let };
class rewrite;
class rewrite_cell {
protected:
rewrite_kind m_kind;
MK_LEAN_RC();
void dealloc();
public:
rewrite_cell(rewrite_kind k);
virtual ~rewrite_cell();
rewrite_kind kind() const { return m_kind; }
// unsigned hash() const { return m_hash; }
virtual std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception) = 0;
};
class rewrite {
private:
rewrite_cell * m_ptr;
public:
virtual std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const = 0;
explicit rewrite(rewrite_cell * ptr):m_ptr(ptr) {}
rewrite():m_ptr(nullptr) {}
rewrite(rewrite const & r):m_ptr(r.m_ptr) {
if (m_ptr) m_ptr->inc_ref();
}
rewrite(rewrite && r):m_ptr(r.m_ptr) { r.m_ptr = nullptr; }
~rewrite() { if (m_ptr) m_ptr->dec_ref(); }
void release() { if (m_ptr) m_ptr->dec_ref(); m_ptr = nullptr; }
friend void swap(rewrite & a, rewrite & b) { std::swap(a.m_ptr, b.m_ptr); }
rewrite_kind kind() const { return m_ptr->kind(); }
rewrite & operator=(rewrite const & s) { LEAN_COPY_REF(rewrite, s); }
rewrite & operator=(rewrite && s) { LEAN_MOVE_REF(rewrite, s); }
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const {
return (*m_ptr)(ctx, v, env);
}
};
class theorem_rewrite : public rewrite {
class theorem_rewrite_cell : public rewrite_cell {
private:
expr const & thm_type;
expr const & thm_body;
expr pattern;
expr rhs;
unsigned num_args;
expr const & m_type;
expr const & m_body;
expr m_pattern;
expr m_rhs;
unsigned m_num_args;
public:
theorem_rewrite(expr const & type, expr const & body);
theorem_rewrite_cell(expr const & type, expr const & body);
~theorem_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class orelse_rewrite : public rewrite {
class orelse_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw1;
rewrite const & rw2;
rewrite m_rw1;
rewrite m_rw2;
public:
orelse_rewrite(rewrite const & rw_1, rewrite const & rw_2) :
rw1(rw_1), rw2(rw_2) { }
orelse_rewrite_cell(rewrite const & rw1, rewrite const & rw2);
~orelse_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class then_rewrite : public rewrite {
class then_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw1;
rewrite const & rw2;
rewrite m_rw1;
rewrite m_rw2;
public:
then_rewrite(rewrite const & rw_1, rewrite const & rw_2) :
rw1(rw_1), rw2(rw_2) { }
then_rewrite_cell(rewrite const & rw1, rewrite const & rw2);
~then_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class app_rewrite : public rewrite {
class app_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw;
rewrite m_rw;
public:
app_rewrite(rewrite const & rw_) :
rw(rw_) { }
app_rewrite_cell(rewrite const & rw);
~app_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class lambda_rewrite : public rewrite {
class lambda_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw;
rewrite m_rw;
public:
lambda_rewrite(rewrite const & rw_) :
rw(rw_) { }
lambda_rewrite_cell(rewrite const & rw);
~lambda_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class pi_rewrite : public rewrite {
class pi_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw;
rewrite m_rw;
public:
pi_rewrite(rewrite const & rw_) :
rw(rw_) { }
pi_rewrite_cell(rewrite const & rw);
~pi_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class let_rewrite : public rewrite {
class let_rewrite_cell : public rewrite_cell {
private:
rewrite const & rw;
rewrite m_rw;
public:
let_rewrite(rewrite const & rw_) :
rw(rw_) { }
let_rewrite_cell(rewrite const & rw);
~let_rewrite_cell();
std::pair<expr, expr> operator()(context & ctx, expr const & v, environment const & env) const throw(rewrite_exception);
};
class fail_rewrite : public rewrite {
class fail_rewrite_cell : public rewrite_cell {
public:
fail_rewrite_cell(rewrite const & rw1, rewrite const & rw2);
std::pair<expr, expr> operator()(context &, expr const &) const throw(rewrite_exception) {
throw rewrite_exception();
}
};
rewrite mk_theorem_rewrite(expr const & type, expr const & body);
rewrite mk_then_rewrite(rewrite const & rw1, rewrite const & rw2);
rewrite mk_orelse_rewrite(rewrite const & rw1, rewrite const & rw2);
rewrite mk_app_rewrite(rewrite const & rw);
rewrite mk_lambda_rewrite(rewrite const & rw);
rewrite mk_pi_rewrite(rewrite const & rw);
rewrite mk_let_rewrite(rewrite const & rw);
rewrite mk_fail_rewrite();
}

View file

@ -41,7 +41,7 @@ static void theorem_rewrite1_tst() {
env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z
// Rewriting
theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body);
rewrite add_comm_thm_rewriter = mk_theorem_rewrite(add_comm_thm_type, add_comm_thm_body);
context ctx;
pair<expr, expr> result = add_comm_thm_rewriter(ctx, a_plus_b, env);
expr concl = mk_eq(a_plus_b, result.first);
@ -72,7 +72,7 @@ static void theorem_rewrite2_tst() {
env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0
// Rewriting
theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body);
rewrite add_id_thm_rewriter = mk_theorem_rewrite(add_id_thm_type, add_id_thm_body);
context ctx;
pair<expr, expr> result = add_id_thm_rewriter(ctx, a_plus_zero, env);
expr concl = mk_eq(a_plus_zero, result.first);
@ -111,9 +111,9 @@ static void then_rewrite1_tst() {
env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0
// Rewriting
theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body);
theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body);
then_rewrite then_rewriter1(add_comm_thm_rewriter, add_id_thm_rewriter);
rewrite add_comm_thm_rewriter = mk_theorem_rewrite(add_comm_thm_type, add_comm_thm_body);
rewrite add_id_thm_rewriter = mk_theorem_rewrite(add_id_thm_type, add_id_thm_body);
rewrite then_rewriter1 = mk_then_rewrite(add_comm_thm_rewriter, add_id_thm_rewriter);
context ctx;
pair<expr, expr> result = then_rewriter1(ctx, zero_plus_a, env);
expr concl = mk_eq(zero_plus_a, result.first);
@ -168,11 +168,11 @@ static void then_rewrite2_tst() {
env.add_axiom("ADD_ID", add_id_thm_type); // ADD_ID : Pi (x : N), x = x + 0
// Rewriting
theorem_rewrite add_assoc_thm_rewriter(add_assoc_thm_type, add_assoc_thm_body);
theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body);
theorem_rewrite add_id_thm_rewriter(add_id_thm_type, add_id_thm_body);
then_rewrite then_rewriter2(then_rewrite(add_assoc_thm_rewriter, add_id_thm_rewriter),
then_rewrite(add_comm_thm_rewriter, add_id_thm_rewriter));
rewrite add_assoc_thm_rewriter = mk_theorem_rewrite(add_assoc_thm_type, add_assoc_thm_body);
rewrite add_comm_thm_rewriter = mk_theorem_rewrite(add_comm_thm_type, add_comm_thm_body);
rewrite add_id_thm_rewriter = mk_theorem_rewrite(add_id_thm_type, add_id_thm_body);
rewrite then_rewriter2 = mk_then_rewrite(mk_then_rewrite(add_assoc_thm_rewriter, add_id_thm_rewriter),
mk_then_rewrite(add_comm_thm_rewriter, add_id_thm_rewriter));
context ctx;
pair<expr, expr> result = then_rewriter2(ctx, zero_plus_a_plus_zero, env);
expr concl = mk_eq(zero_plus_a_plus_zero, result.first);
@ -222,9 +222,9 @@ static void orelse_rewrite1_tst() {
env.add_axiom("ADD_COMM", add_comm_thm_type); // ADD_COMM : Pi (x, y: N), x + y = y + z
// Rewriting
theorem_rewrite add_assoc_thm_rewriter(add_assoc_thm_type, add_assoc_thm_body);
theorem_rewrite add_comm_thm_rewriter(add_comm_thm_type, add_comm_thm_body);
orelse_rewrite add_assoc_or_comm_thm_rewriter(add_assoc_thm_rewriter, add_comm_thm_rewriter);
rewrite add_assoc_thm_rewriter = mk_theorem_rewrite(add_assoc_thm_type, add_assoc_thm_body);
rewrite add_comm_thm_rewriter = mk_theorem_rewrite(add_comm_thm_type, add_comm_thm_body);
rewrite add_assoc_or_comm_thm_rewriter = mk_orelse_rewrite(add_assoc_thm_rewriter, add_comm_thm_rewriter);
context ctx;
pair<expr, expr> result = add_assoc_or_comm_thm_rewriter(ctx, a_plus_b, env);
expr concl = mk_eq(a_plus_b, result.first);