Skip to content

Commit

Permalink
Adjust to recent compiler changes (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno authored Nov 7, 2024
1 parent 4045157 commit f792e6f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 48 deletions.
28 changes: 14 additions & 14 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ function dae_result_for_inst(interp, inst::CC.Instruction)
info = inst[:info]
stmt = inst[:stmt]
mi = stmt.args[1]
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing)
info = info.info
end
if isa(info, CC.ConstCallInfo)
if length(info.results) != 1
# TODO: When does this happen? Union split?
Expand Down
48 changes: 18 additions & 30 deletions src/analysis/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using Core: CodeInfo, MethodInstance, CodeInstance, SimpleVector, MethodMatch, MethodTable
using .CC: AbstractInterpreter, NativeInterpreter, InferenceParams, OptimizationParams,
InferenceResult, InferenceState, OptimizationState, WorldRange, WorldView, ArgInfo,
StmtInfo, MethodCallResult, ConstCallResults, ConstPropResult, MethodTableView,
StmtInfo, MethodCallResult, ConstCallResult, ConstPropResult, MethodTableView,
CachedMethodTable, InternalMethodTable, OverlayMethodTable, CallMeta, CallInfo,
IRCode, LazyDomtree, IRInterpretationState, set_inlineable!, block_for_inst,
BitSetBoundedMinPrioritySet, AbsIntState, Future
Expand Down Expand Up @@ -117,7 +117,7 @@ struct DAEInterpreter <: AbstractInterpreter
ipo_analysis_mode::Bool = false,
in_analysis::Bool = false)
if code_cache === nothing
code_cache = get_code_cache(method_table, ipo_analysis_mode)
code_cache = get_code_cache(world, method_table, ipo_analysis_mode)
end
if method_table !== nothing
method_table = CachedMethodTable(OverlayMethodTable(world, method_table))
Expand Down Expand Up @@ -315,13 +315,10 @@ end
return Future{MethodCallResult}(mret, interp, sv) do ret, interp, sv
edge = ret.edge
if edge !== nothing
cache = CC.get(CC.code_cache(interp), edge, nothing)
if cache !== nothing
src = @atomic :monotonic cache.inferred
if isa(src, DAECache)
info = src.info
merge_daeinfo!(interp, sv.result, info)
end
src = @atomic :monotonic edge.inferred
if isa(src, DAECache)
info = src.info
merge_daeinfo!(interp, sv.result, info)
end
end
return ret
Expand All @@ -330,11 +327,11 @@ end

@override function CC.const_prop_call(interp::DAEInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo,
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults})
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult})
ret = @invoke CC.const_prop_call(interp::AbstractInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo,
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResults})
if isa(ret, ConstCallResults)
sv::InferenceState, concrete_eval_result::Union{Nothing,ConstCallResult})
if isa(ret, ConstCallResult)
const_result = ret.const_result::ConstPropResult
info = interp.dae_cache[const_result.result]
merge_daeinfo!(interp, sv.result, info)
Expand All @@ -353,26 +350,20 @@ struct DAECache
new(inferred, ir, info)
end

@override CC.transform_result_for_cache(interp::DAEInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool) =
_transform_result_for_cache(interp, mi, valid_worlds, result, cond)

function _transform_result_for_cache(interp::DAEInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool=false)
function CC.transform_result_for_cache(interp::DAEInterpreter, result::InferenceResult)
src = result.src
if isa(src, DAECache)
return src
end
inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter,
mi::MethodInstance, valid_worlds::WorldRange, result::InferenceResult, cond::Bool)
inferred = @invoke CC.transform_result_for_cache(interp::AbstractInterpreter, result)
return DAECache(inferred, nothing, interp.dae_cache[result])
end

# inlining
# --------

function dae_inlining_policy(@nospecialize(src), @nospecialize(info::CallInfo), raise::Bool=true)
if isa(info, Diffractor.FRuleCallInfo)
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt !== Const(nothing)
return nothing
end
osrc = src
Expand Down Expand Up @@ -502,13 +493,10 @@ end
result::MethodCallResult, si::StmtInfo, sv::InferenceState, force::Bool)
edge = result.edge
if edge !== nothing
cache = CC.get(CC.code_cache(interp), edge, nothing)
if cache !== nothing
src = @atomic :monotonic cache.inferred
if isa(src, DAECache)
src.info.has_dae_intrinsics && return true
src.info.has_scoperead && return true
end
src = @atomic :monotonic edge.inferred
if isa(src, DAECache)
src.info.has_dae_intrinsics && return true
src.info.has_scoperead && return true
end
end
return @invoke CC.const_prop_rettype_heuristic(interp::AbstractInterpreter,
Expand Down Expand Up @@ -1052,7 +1040,7 @@ end
using Cthulhu

function Cthulhu.get_optimized_codeinst(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor)
interp.code_cache.cache[curs.mi]
CC.getindex(CC.code_cache(interp), curs.mi)
end

function Cthulhu.lookup(interp::DAEInterpreter, curs::Cthulhu.CthulhuCursor, optimize::Bool)
Expand Down Expand Up @@ -1109,7 +1097,7 @@ function lookup_optimized(interp::DAEInterpreter, mi::MethodInstance, allow_no_s
end

Cthulhu.can_descend(interp::DAEInterpreter, @nospecialize(key), optimize::Bool) =
haskey(optimize ? interp.code_cache.cache : interp.unopt, key)
optimize ? CC.haskey(CC.code_cache(interp), key) : haskey(interp.unopt, key)

# TODO: Why does Cthulhu have this separately from the lookup logic, which already
# returns effects
Expand Down
10 changes: 6 additions & 4 deletions src/transform/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,16 @@ end
function remap_info(remap_ir!, info)
# TODO: This is pretty aweful, but it works for now.
# It'll go away when we switch to IPO.
if isa(info, Diffractor.FRuleCallInfo) && info.frule_call.rt === Const(nothing)
info = info.info
end
isa(info, CC.ConstCallInfo) || return info
results = map(info.results) do result
result === nothing && return result
if isa(result, CC.SemiConcreteResult)
let ir = copy(result.ir)
remap_ir!(ir)
CC.SemiConcreteResult(result.mi, ir, result.effects, result.spec_info)
CC.SemiConcreteResult(result.edge, ir, result.effects, result.spec_info)
end
elseif isa(result, CC.ConstPropResult)
if isa(result.result.src, DAECache)
Expand All @@ -76,10 +79,9 @@ function widen_extra_info!(ir)
for i = 1:length(ir.stmts)
info = ir.stmts[i][:info]
if isa(info, Diffractor.FRuleCallInfo)
ir.stmts[i][:info] = info.info
else
ir.stmts[i][:info] = remap_info(widen_extra_info!, info)
info = info.info
end
ir.stmts[i][:info] = remap_info(widen_extra_info!, info)
inst = ir.stmts[i][:inst]
if isa(inst, PiNode)
ir.stmts[i][:inst] = PiNode(inst.val, widenconst(inst.typ))
Expand Down
1 change: 1 addition & 0 deletions test/ipo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module ipo
using Test
using DAECompiler
using DAECompiler.Intrinsics
using DAECompiler.Intrinsics: state_ddt
using SciMLBase, OrdinaryDiffEq, Sundials

include(joinpath(Base.pkgdir(DAECompiler), "test", "testutils.jl"))
Expand Down

0 comments on commit f792e6f

Please sign in to comment.