fix(library/module): allow multiple calls to import_modules with the same modules

The idea is to store a set of already imported files.
This feature is useful when using the import_modules API directly (e.g.,
from javascript).
This commit is contained in:
Leonardo de Moura 2014-10-14 08:13:41 -07:00
parent d75a9c840c
commit 7231aa0d73
2 changed files with 32 additions and 8 deletions

View file

@ -44,6 +44,7 @@ struct module_ext : public environment_extension {
// directly imported files have changed
list<time_t> m_direct_imports_mod_time;
std::string m_base;
name_set m_imported;
};
struct module_ext_reg {
@ -252,10 +253,13 @@ struct import_modules_fn {
typedef std::shared_ptr<module_info> module_info_ptr;
name_map<module_info_ptr> m_module_info;
name_set m_visited;
name_set m_imported; // all imported files
import_modules_fn(environment const & env, unsigned num_threads, bool keep_proofs, io_state const & ios):
m_senv(env), m_num_threads(num_threads), m_keep_proofs(keep_proofs), m_ios(ios),
m_next_module_idx(1), m_import_counter(0), m_all_modules_imported(false) {
module_ext const & ext = get_extension(env);
m_imported = ext.m_imported;
if (m_num_threads == 0)
m_num_threads = 1;
#if !defined(LEAN_MULTI_THREAD)
@ -273,6 +277,8 @@ struct import_modules_fn {
auto it = m_module_info.find(fname);
if (it)
return *it;
if (m_imported.contains(fname)) // file was imported in previous call
return nullptr;
if (m_visited.contains(fname))
throw exception(sstream() << "circular dependency detected at '" << fname << "'");
m_visited.insert(fname);
@ -311,8 +317,8 @@ struct import_modules_fn {
std::string new_base = dirname(fname.c_str());
std::swap(r->m_obj_code, code);
for (auto i : imports) {
auto d = load_module_file(new_base, i);
d->m_dependents.push_back(r);
if (auto d = load_module_file(new_base, i))
d->m_dependents.push_back(r);
}
m_module_info.insert(fname, r);
r->m_module_idx = m_next_module_idx++;
@ -424,6 +430,7 @@ struct import_modules_fn {
add_import_module_task(d);
}
}
m_imported.insert(r->m_fname);
}
optional<asynch_update_fn> next_task() {
@ -506,12 +513,14 @@ struct import_modules_fn {
ext.m_base = base;
for (unsigned i = 0; i < num_modules; i++) {
module_name const & mname = modules[i];
ext.m_direct_imports = cons(mname, ext.m_direct_imports);
std::string fname = find_file(base, mname.get_k(), mname.get_name(), {".olean"});
struct stat st;
if (stat(fname.c_str(), &st) != 0)
throw exception(sstream() << "failed to access stats of file '" << fname << "'");
ext.m_direct_imports_mod_time = cons(st.st_mtime, ext.m_direct_imports_mod_time);
if (!m_imported.contains(fname)) {
ext.m_direct_imports = cons(mname, ext.m_direct_imports);
struct stat st;
if (stat(fname.c_str(), &st) != 0)
throw exception(sstream() << "failed to access stats of file '" << fname << "'");
ext.m_direct_imports_mod_time = cons(st.st_mtime, ext.m_direct_imports_mod_time);
}
}
return update(env, ext);
});
@ -522,7 +531,10 @@ struct import_modules_fn {
for (unsigned i = 0; i < num_modules; i++)
load_module_file(base, modules[i]);
process_asynch_tasks();
return process_delayed_tasks();
environment env = process_delayed_tasks();
module_ext ext = get_extension(env);
ext.m_imported = m_imported;
return update(env, ext);
}
};

View file

@ -55,6 +55,7 @@ using lean::expr;
using lean::options;
using lean::declaration_index;
using lean::keep_theorem_mode;
using lean::module_name;
enum class input_kind { Unspecified, Lean, Lua };
@ -222,6 +223,17 @@ static void export_as_cpp_file(std::string const & fname, char const * varname,
out << "}\n";
}
environment import_module(environment const & env, io_state const & ios, module_name const & mod, bool keep_proofs = true) {
std::string base = ".";
bool num_threads = 1;
return import_modules(env, base, 1, &mod, num_threads, keep_proofs, ios);
}
environment import_standard(environment const & env, io_state const & ios, bool keep_proofs = true) {
module_name std(lean::name("standard"));
return import_module(env, ios, std, keep_proofs);
}
int main(int argc, char ** argv) {
lean::initializer init;
bool export_objects = false;