Skip to content

Commit 225c71e

Browse files
committed
WIP: fix #17997, don't load packages in Main
[ci skip]
1 parent 75ec2b9 commit 225c71e

File tree

6 files changed

+283
-208
lines changed

6 files changed

+283
-208
lines changed

base/loading.jl

+81-21
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ function reload(name::AbstractString)
288288
error("use `include` instead of `reload` to load source files")
289289
else
290290
# reload("Package") is ok
291+
unreference_module(Symbol(name))
291292
require(Symbol(name))
292293
end
293294
end
@@ -315,25 +316,84 @@ all platforms, including those with case-insensitive filesystems like macOS and
315316
Windows.
316317
"""
317318
function require(mod::Symbol)
318-
_require(mod)
319+
existed = root_module_exists(mod)
320+
M = _require(mod)
319321
# After successfully loading, notify downstream consumers
320322
if toplevel_load[] && myid() == 1 && nprocs() > 1
321323
# broadcast top-level import/using from node 1 (only)
322324
@sync for p in procs()
323325
p == 1 && continue
324326
@async remotecall_wait(p) do
325-
if !isbindingresolved(Main, mod) || !isdefined(Main, mod)
326-
_require(mod)
327-
end
327+
_require(mod)
328+
nothing
328329
end
329330
end
330331
end
331-
for callback in package_callbacks
332-
invokelatest(callback, mod)
332+
if !existed
333+
for callback in package_callbacks
334+
invokelatest(callback, mod)
335+
end
336+
end
337+
return M
338+
end
339+
340+
const loaded_modules = ObjectIdDict()
341+
const module_keys = ObjectIdDict()
342+
343+
function register_root_module(key, m::Module)
344+
if haskey(loaded_modules, key)
345+
oldm = loaded_modules[key]
346+
if oldm !== m
347+
name = module_name(oldm)
348+
warn("replacing module $name.")
349+
end
350+
end
351+
loaded_modules[key] = m
352+
module_keys[m] = key
353+
nothing
354+
end
355+
356+
register_root_module(:Core, Core)
357+
register_root_module(:Base, Base)
358+
register_root_module(:Main, Main)
359+
360+
is_root_module(m::Module) = haskey(module_keys, m)
361+
362+
root_module_key(m::Module) = module_keys[m]
363+
364+
# This is used as the current module when loading top-level modules.
365+
# It has the special behavior that modules evaluated in it get added
366+
# to the loaded_modules table instead of getting bindings.
367+
baremodule __toplevel__
368+
using Base
369+
end
370+
371+
# get a top-level Module from the given key
372+
# for now keys can only be Symbols, but that will change
373+
root_module(key::Symbol) = loaded_modules[key]
374+
375+
root_module_exists(key::Symbol) = haskey(loaded_modules, key)
376+
377+
loaded_modules_array() = collect(values(loaded_modules))
378+
379+
function unreference_module(key)
380+
if haskey(loaded_modules, key)
381+
m = pop!(loaded_modules, key)
382+
# need to ensure all modules are GC rooted; will still be referenced
383+
# in module_keys
384+
end
385+
end
386+
387+
function register_all(a)
388+
for m in a
389+
register_root_module(module_name(m), m)
333390
end
334391
end
335392

336393
function _require(mod::Symbol)
394+
if root_module_exists(mod)
395+
return root_module(mod)
396+
end
337397
# dependency-tracking is only used for one top-level include(path),
338398
# and is not applied recursively to imported modules:
339399
old_track_dependencies = _track_dependencies[]
@@ -345,7 +405,7 @@ function _require(mod::Symbol)
345405
if loading !== false
346406
# load already in progress for this module
347407
wait(loading)
348-
return
408+
return root_module(mod)
349409
end
350410
package_locks[mod] = Condition()
351411

@@ -364,7 +424,8 @@ function _require(mod::Symbol)
364424
if JLOptions().use_compilecache != 0
365425
doneprecompile = _require_search_from_serialized(mod, path)
366426
if !isa(doneprecompile, Bool)
367-
return # success
427+
register_all(doneprecompile)
428+
return root_module(mod) # success
368429
end
369430
end
370431

@@ -391,14 +452,17 @@ function _require(mod::Symbol)
391452
warn(m, prefix="WARNING: ")
392453
# fall-through, TODO: disable __precompile__(true) error so that the normal include will succeed
393454
else
394-
return # success
455+
register_all(m)
456+
return root_module(mod) # success
395457
end
396458
end
397459

398460
# just load the file normally via include
399461
# for unknown dependencies
462+
local M
400463
try
401-
Base.include_relative(Main, path)
464+
Base.include_relative(__toplevel__, path)
465+
return root_module(mod)
402466
catch ex
403467
if doneprecompile === true || JLOptions().use_compilecache == 0 || !precompilableerror(ex, true)
404468
rethrow() # rethrow non-precompilable=true errors
@@ -411,6 +475,8 @@ function _require(mod::Symbol)
411475
# TODO: disable __precompile__(true) error and do normal include instead of error
412476
error("Module $mod declares __precompile__(true) but require failed to create a usable precompiled cache file.")
413477
end
478+
register_all(m)
479+
return root_module(mod)
414480
end
415481
finally
416482
toplevel_load[] = last
@@ -532,7 +598,7 @@ function create_expr_cache(input::String, output::String, concrete_deps::Vector{
532598
task_local_storage()[:SOURCE_PATH] = $(source)
533599
end)
534600
end
535-
serialize(in, :(Base.include(Main, $(abspath(input)))))
601+
serialize(in, :(Base.include(Base.__toplevel__, $(abspath(input)))))
536602
if source !== nothing
537603
serialize(in, :(delete!(task_local_storage(), :SOURCE_PATH)))
538604
end
@@ -570,15 +636,9 @@ function compilecache(name::String)
570636
cachefile::String = abspath(cachepath, name*".ji")
571637
# build up the list of modules that we want the precompile process to preserve
572638
concrete_deps = copy(_concrete_dependencies)
573-
for existing in names(Main)
574-
if isdefined(Main, existing)
575-
mod = getfield(Main, existing)
576-
if isa(mod, Module) && !(mod === Main || mod === Core || mod === Base)
577-
mod = mod::Module
578-
if module_parent(mod) === Main && module_name(mod) === existing
579-
push!(concrete_deps, (existing, module_uuid(mod)))
580-
end
581-
end
639+
for (key,mod) in loaded_modules
640+
if !(mod === Main || mod === Core || mod === Base)
641+
push!(concrete_deps, (key, module_uuid(mod)))
582642
end
583643
end
584644
# run the expression and cache the result
@@ -675,7 +735,7 @@ function stale_cachefile(modpath::String, cachefile::String)
675735
if mod == :Main || mod == :Core || mod == :Base
676736
continue
677737
# Module is already loaded
678-
elseif isbindingresolved(Main, mod)
738+
elseif root_module_exists(mod)
679739
continue
680740
end
681741
name = string(mod)

base/serialize.jl

+17-12
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,10 @@ function serialize(s::AbstractSerializer, d::Dict)
343343
end
344344

345345
function serialize_mod_names(s::AbstractSerializer, m::Module)
346-
p = module_parent(m)
347-
if m !== p
348-
serialize_mod_names(s, p)
346+
if Base.is_root_module(m)
347+
serialize(s, Base.root_module_key(m))
348+
else
349+
serialize_mod_names(s, module_parent(m))
349350
serialize(s, module_name(m))
350351
end
351352
end
@@ -772,21 +773,25 @@ function deserialize_svec(s::AbstractSerializer)
772773
end
773774

774775
function deserialize_module(s::AbstractSerializer)
775-
path = deserialize(s)
776-
m = Main
777-
if isa(path,Tuple) && path !== ()
778-
# old version
779-
for mname in path
780-
m = getfield(m,mname)::Module
776+
mkey = deserialize(s)
777+
if isa(mkey, Tuple)
778+
# old version, TODO: remove
779+
if mkey === ()
780+
return Main
781+
end
782+
m = Base.root_module(mkey[1])
783+
for i = 2:length(mkey)
784+
m = getfield(m, mkey[i])::Module
781785
end
782786
else
783-
mname = path
787+
m = Base.root_module(mkey)
788+
mname = deserialize(s)
784789
while mname !== ()
785-
m = getfield(m,mname)::Module
790+
m = getfield(m, mname)::Module
786791
mname = deserialize(s)
787792
end
788793
end
789-
m
794+
return m
790795
end
791796

792797
function deserialize(s::AbstractSerializer, ::Type{Method})

base/show.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ function show(io::IO, p::Pair)
380380
end
381381

382382
function show(io::IO, m::Module)
383-
if m === Main
384-
print(io, "Main")
383+
if is_root_module(m)
384+
print(io, module_name(m))
385385
else
386386
print(io, join(fullname(m),"."))
387387
end

0 commit comments

Comments
 (0)