Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompilerDevTools: add proof of concept for caching runtime calls #57193

Merged
merged 10 commits into from
Feb 28, 2025
50 changes: 39 additions & 11 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,74 @@
module CompilerDevTools

using Compiler
using Compiler: argextype, widenconst
using Core.IR
using Base: isexpr

mutable struct SplitCacheOwner end

struct SplitCacheOwner; end
struct SplitCacheInterp <: Compiler.AbstractInterpreter
world::UInt
owner::SplitCacheOwner
inf_params::Compiler.InferenceParams
opt_params::Compiler.OptimizationParams
inf_cache::Vector{Compiler.InferenceResult}
codegen_cache::IdDict{CodeInstance,CodeInfo}
function SplitCacheInterp(;
world::UInt = Base.get_world_counter(),
owner::SplitCacheOwner = SplitCacheOwner(),
inf_params::Compiler.InferenceParams = Compiler.InferenceParams(),
opt_params::Compiler.OptimizationParams = Compiler.OptimizationParams(),
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
new(world, inf_params, opt_params, inf_cache, IdDict{CodeInstance,CodeInfo}())
new(world, owner, inf_params, opt_params, inf_cache, IdDict{CodeInstance,CodeInfo}())
end
end

Compiler.InferenceParams(interp::SplitCacheInterp) = interp.inf_params
Compiler.OptimizationParams(interp::SplitCacheInterp) = interp.opt_params
Compiler.get_inference_world(interp::SplitCacheInterp) = interp.world
Compiler.get_inference_cache(interp::SplitCacheInterp) = interp.inf_cache
Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()
Compiler.cache_owner(interp::SplitCacheInterp) = interp.owner
Compiler.codegen_cache(interp::SplitCacheInterp) = interp.codegen_cache

import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext_toplevel, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)
@eval @noinline typeinf(owner::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext_toplevel, SplitCacheInterp(; world=Base.tls_world_age(), owner), mi, source_mode)

@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
@eval @noinline function typeinf_edge(owner::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
# TODO: This isn't quite right, we're just sketching things for now
interp = SplitCacheInterp(; world)
interp = SplitCacheInterp(; world, owner)
Compiler.typeinf_edge(interp, mi.def, mi.specTypes, Core.svec(), parent_frame, false, false)
end

function with_new_compiler(f, args...)
mi = @ccall jl_method_lookup(Any[f, args...]::Ptr{Any}, (1+length(args))::Csize_t, Base.tls_world_age()::Csize_t)::Ref{Core.MethodInstance}
world = Base.tls_world_age()
function lookup_method_instance(f, args...)
@ccall jl_method_lookup(Any[f, args...]::Ptr{Any}, (1+length(args))::Csize_t, Base.tls_world_age()::Csize_t)::Ref{Core.MethodInstance}
end

function Compiler.optimize(interp::SplitCacheInterp, opt::Compiler.OptimizationState, caller::Compiler.InferenceResult)
@invoke Compiler.optimize(interp::Compiler.AbstractInterpreter, opt::Compiler.OptimizationState, caller::Compiler.InferenceResult)
ir = opt.ir::Compiler.IRCode
override = GlobalRef(@__MODULE__(), :with_new_compiler)
for inst in ir.stmts
stmt = inst[:stmt]
isexpr(stmt, :call) || continue
f = stmt.args[1]
f === override && continue
if isa(f, GlobalRef)
T = widenconst(argextype(f, ir))
T <: Core.Builtin && continue
end
insert!(stmt.args, 1, override)
insert!(stmt.args, 3, interp.owner)
end
end

with_new_compiler(f, args...; owner::SplitCacheOwner = SplitCacheOwner()) = with_new_compiler(f, owner, args...)

function with_new_compiler(f, owner::SplitCacheOwner, args...)
mi = lookup_method_instance(f, args...)
new_compiler_ci = Core.OptimizedGenerics.CompilerPlugins.typeinf(
SplitCacheOwner(), mi, Compiler.SOURCE_MODE_ABI
owner, mi, Compiler.SOURCE_MODE_ABI
)
invoke(f, new_compiler_ci, args...)
end
Expand Down
20 changes: 20 additions & 0 deletions Compiler/extras/CompilerDevTools/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Test
using Compiler: code_cache
using Base: inferencebarrier
using CompilerDevTools
using CompilerDevTools: lookup_method_instance, SplitCacheInterp

@testset "CompilerDevTools" begin
do_work(x, y) = x + y
f1() = do_work(inferencebarrier(1), inferencebarrier(2))
interp = SplitCacheInterp()
cache = code_cache(interp)
mi = lookup_method_instance(f1)
@test !haskey(cache, mi)
@test with_new_compiler(f1, interp.owner) === 3
@test haskey(cache, mi)
# Here `do_work` is compiled at runtime, and so must have
# required extra work to be cached under the same cache owner.
mi = lookup_method_instance(do_work, 1, 2)
@test haskey(cache, mi)
end;
6 changes: 6 additions & 0 deletions Compiler/extras/CompilerDevTools/test/testpkg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
using Pkg

Pkg.activate(dirname(@__DIR__)) do
Pkg.instantiate()
include("runtests.jl")
end
6 changes: 5 additions & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const STDLIB_DIR = Sys.STDLIB
const STDLIBS = filter!(x -> isfile(joinpath(STDLIB_DIR, x, "src", "$(x).jl")), readdir(STDLIB_DIR))

const TESTNAMES = [
"subarray", "core", "compiler", "worlds", "atomics",
"subarray", "core", "compiler", "compiler_extras", "worlds", "atomics",
"keywordargs", "numbers", "subtype",
"char", "strings", "triplequote", "unicode", "intrinsics",
"dict", "hashing", "iobuffer", "staged", "offsetarray",
Expand Down Expand Up @@ -54,6 +54,9 @@ function test_path(test)
else
return joinpath(pkgdir, "test", "runtests")
end
elseif t[1] == "Compiler" && length(t) ≥ 3 && t[2] == "extras"
testpath = length(t) >= 4 ? t[4:end] : ("runtests",)
return joinpath(@__DIR__, "..", t[1], t[2], t[3], "test", testpath...)
elseif t[1] == "Compiler"
testpath = length(t) >= 2 ? t[2:end] : ("runtests",)
return joinpath(@__DIR__, "..", t[1], "test", testpath...)
Expand Down Expand Up @@ -172,6 +175,7 @@ function choosetests(choices = [])
# do subarray before sparse but after linalg
filtertests!(tests, "subarray")
filtertests!(tests, "compiler", ["Compiler"])
filtertests!(tests, "compiler_extras", ["Compiler/extras/CompilerDevTools/testpkg"])
filtertests!(tests, "stdlib", STDLIBS)
filtertests!(tests, "internet_required", INTERNET_REQUIRED_LIST)
# do ambiguous first to avoid failing if ambiguities are introduced by other tests
Expand Down