Skip to content

Commit

Permalink
Adjust to compiler excision (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno authored Nov 13, 2024
1 parent f792e6f commit 24b782b
Show file tree
Hide file tree
Showing 22 changed files with 162 additions and 168 deletions.
18 changes: 13 additions & 5 deletions Manifest.toml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
Cthulhu = "f68482b8-f384-11e8-15f7-abe071a5a75f"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand All @@ -33,17 +34,17 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[sources]
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}

[weakdeps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

[extensions]
DAECompilerModelingToolkitExt = "ModelingToolkit"

Expand All @@ -52,6 +53,7 @@ Accessors = "0.1.36"
CentralizedCaches = "1.1.0"
ChainRules = "1.50"
ChainRulesCore = "1.20"
Compiler = "0.0.1"
Cthulhu = "2.10.1"
DiffEqBase = "6.149.2"
Diffractor = "0.2.7"
Expand Down
3 changes: 2 additions & 1 deletion src/DAECompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function reconstruct_sensitivities(args...)
error("This method requires SciMLSensitivity")
end

const CC = Core.Compiler
import Compiler
const CC = Compiler
import .CC: get_inference_world
using Base: get_world_counter

Expand Down
36 changes: 18 additions & 18 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ForwardDiff
using Base.Meta
using Graphs
using Core.IR
using Core.Compiler: InferenceState, bbidxiter, dominates, tmerge, typeinf_lattice
using .CC: InferenceState, bbidxiter, dominates, tmerge, typeinf_lattice

@breadcrumb "ir_levels" function run_dae_passes(
interp::DAEInterpreter, ir::IRCode, debug_config::DebugConfig = DebugConfig())
Expand Down Expand Up @@ -317,7 +317,7 @@ has_any_genscope(sc::Scope) = isdefined(sc, :parent) && has_any_genscope(sc.pare
has_any_genscope(sc::PartialScope) = false
has_any_genscope(sc::PartialStruct) = false # TODO

function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
function _make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
if isa(argt, Const)
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
return argt
Expand All @@ -331,7 +331,7 @@ function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_v
inc = Incidence(add_variable!(which))
return argt === Float64 ? inc : Incidence(argt, inc.row, inc.eps)
elseif isa(argt, PartialStruct)
return PartialStruct(argt.typ, Any[make_argument_lattice_elem(which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
return PartialStruct(𝕃, argt.typ, Any[make_argument_lattice_elem(𝕃, which, f, add_variable!, add_equation!, add_scope!) for f in argt.fields])
elseif isabstracttype(argt) || ismutabletype(argt) || !isa(argt, DataType)
return nothing
else
Expand All @@ -344,20 +344,20 @@ function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_v
for i = 1:length(fieldtypes(argt))
# TODO: Can we make this lazy?
ft = fieldtype(argt, i)
mft = _make_argument_lattice_elem(which, ft, add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(𝕃, which, ft, add_variable!, add_equation!, add_scope!)
if mft === nothing
push!(fields, Incidence(ft))
else
any = true
push!(fields, mft)
end
end
return any ? PartialStruct(argt, fields) : nothing
return any ? PartialStruct(𝕃, argt, fields) : nothing
end
end

function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(which, argt, add_variable!, add_equation!, add_scope!)
function make_argument_lattice_elem(𝕃, which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
mft = _make_argument_lattice_elem(𝕃, which, argt, add_variable!, add_equation!, add_scope!)
mft === nothing ? Incidence(argt) : mft
end

Expand Down Expand Up @@ -532,7 +532,7 @@ end
nexternalvars = 0 # number of variables that we expect to come in
nexternaleqs = 0 # number of equation references that we expect to come in
if caller !== nothing
argtypes = Any[make_argument_lattice_elem(Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
argtypes = Any[make_argument_lattice_elem(CC.typeinf_lattice(interp), Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
nexternalvars = length(var_to_diff)
nexternaleqs = length(eqssas)
else
Expand Down Expand Up @@ -571,7 +571,7 @@ end
end
end

cur_scope_lattice = PartialStruct(Base.ScopedValues.Scope,
cur_scope_lattice = PartialStruct(CC.typeinf_lattice(interp), Base.ScopedValues.Scope,
Any[PartialKeyValue(Incidence(Base.PersistentDict{Base.ScopedValues.ScopedValue, Any}))])

# Scan the IR, computing equations, variables, diffgraph, etc.
Expand Down Expand Up @@ -1017,7 +1017,7 @@ end
for eq = 1:length(result.eq_kind)
mapped_eq = mapping.eqs[eq]
mapped_eq == 0 && continue
mapped_inc = apply_linear_incidence(result.total_incidence[eq], result, var_to_diff, var_kind, eq_kind, mapping)
mapped_inc = apply_linear_incidence(CC.typeinf_lattice(interp), result.total_incidence[eq], result, var_to_diff, var_kind, eq_kind, mapping)
if isassigned(total_incidence, mapped_eq)
total_incidence[mapped_eq] = tfunc(Val(Core.Intrinsics.add_float),
total_incidence[mapped_eq],
Expand All @@ -1033,7 +1033,7 @@ end

for (ieq, inc) in enumerate(result.total_incidence[(result.nexternaleqs+1):end])
mapping.eqs[ieq] == 0 || continue
push!(total_incidence, apply_linear_incidence(inc, result, var_to_diff, var_kind, eq_kind, mapping))
push!(total_incidence, apply_linear_incidence(CC.typeinf_lattice(interp), inc, result, var_to_diff, var_kind, eq_kind, mapping))
push!(eq_callee_mapping, [SSAValue(i)=>ieq])
push!(eq_kind, CalleeInternal)
mapping.eqs[ieq] = length(total_incidence)
Expand Down Expand Up @@ -1115,7 +1115,7 @@ end

nimplicitoutpairs = 0
if caller !== nothing
ultimate_rt, nimplicitoutpairs = process_ipo_return!(ultimate_rt, eq_kind, var_kind,
ultimate_rt, nimplicitoutpairs = process_ipo_return!(CC.typeinf_lattice(interp), ultimate_rt, eq_kind, var_kind,
var_to_diff, total_incidence, eq_callee_mapping)
end

Expand All @@ -1135,7 +1135,7 @@ end
Dict{TornCacheKey, CodeInstance}())
end

function process_ipo_return!(ultimate_rt::Incidence, eq_kind, var_kind, var_to_diff, total_incidence, eq_callee_mapping)
function process_ipo_return!(𝕃, ultimate_rt::Incidence, eq_kind, var_kind, var_to_diff, total_incidence, eq_callee_mapping)
nonlinrepl = nothing
nimplicitoutpairs = 0
function get_nonlinrepl()
Expand Down Expand Up @@ -1179,20 +1179,20 @@ function process_ipo_return!(ultimate_rt::Incidence, eq_kind, var_kind, var_to_d
return ultimate_rt, nimplicitoutpairs
end

function process_ipo_return!(ultimate_rt::Eq, eq_kind, args...)
function process_ipo_return!(𝕃, ultimate_rt::Eq, eq_kind, args...)
eq_kind[ultimate_rt.id] = External
return ultimate_rt, 0
end
process_ipo_return!(ultimate_rt::Union{Type, PartialScope, PartialKeyValue, Const}, args...) = ultimate_rt, 0
function process_ipo_return!(ultimate_rt::PartialStruct, args...)
process_ipo_return!(𝕃, ultimate_rt::Union{Type, PartialScope, PartialKeyValue, Const}, args...) = ultimate_rt, 0
function process_ipo_return!(𝕃, ultimate_rt::PartialStruct, args...)
nimplicitoutpairs = 0
fields = Any[]
for f in ultimate_rt.fields
(rt, n) = process_ipo_return!(f, args...)
(rt, n) = process_ipo_return!(𝕃, f, args...)
nimplicitoutpairs += n
push!(fields, rt)
end
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
return PartialStruct(𝕃, ultimate_rt.typ, fields), nimplicitoutpairs
end

function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
Expand Down
23 changes: 1 addition & 22 deletions src/analysis/compiler_reexports.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@
using Core.IR
using Core.Compiler: IRCode, Instruction, InstructionStream, IncrementalCompact,
using .CC: IRCode, Instruction, InstructionStream, IncrementalCompact,
NewInstruction, DomTree, BBIdxIter, AnySSAValue, UseRef, UseRefIterator,
block_for_inst, cfg_simplify!, is_known_call, argextype, getfield_tfunc, finish,
singleton_type, widenconst, dominates_ssa, , userefs

# TODO: This really needs to go into a uniform compiler stdlib.
Base.iterate(compact::IncrementalCompact, state) = Core.Compiler.iterate(compact, state)
Base.iterate(compact::IncrementalCompact) = Core.Compiler.iterate(compact)
Base.iterate(abu::CC.AbsIntStackUnwind, state...) = CC.iterate(abu, state...)

Base.setindex!(compact::IncrementalCompact, @nospecialize(v), idx::SSAValue) = Core.Compiler.setindex!(compact,v,idx)
Base.setindex!(ir::IRCode, @nospecialize(v), idx::SSAValue) = Core.Compiler.setindex!(ir,v,idx)
Base.setindex!(inst::Instruction, @nospecialize(v), sym::Symbol) = Core.Compiler.setindex!(inst,v,sym)
Base.getindex(compact::IncrementalCompact, idx::AnySSAValue) = Core.Compiler.getindex(compact,idx)

Base.setindex!(urs::InstructionStream, @nospecialize args...) = Core.Compiler.setindex!(urs, args...)
Base.setindex!(ir::IRCode, @nospecialize args...) = Core.Compiler.setindex!(ir, args...)
Base.getindex(ir::IRCode, @nospecialize args...) = Core.Compiler.getindex(ir, args...)

Base.IteratorSize(::Type{CC.AbsIntStackUnwind}) = Base.SizeUnknown()

# TODO: Move this to Core.Compiler
CC.block_for_inst(ir::IRCode, s::SSAValue) = block_for_inst(ir, s.id)
function CC.dominates_ssa(ir::IRCode, domtree::DomTree, x::SSAValue, y::SSAValue; dominates_after=false)
xb = block_for_inst(ir, x)
yb = block_for_inst(ir, y)
Expand Down Expand Up @@ -82,6 +64,3 @@ function replace_argument!(compact::IncrementalCompact, idx::Int, argn::Argument
compact[ssa] = urs[]
end

Base.copy(phi::PhiNode) = Core.PhiNode(copy(phi.edges), copy(phi.values))
Base.push!(bs::CC.BitSet, i::Int) = CC.push!(bs, i)
Base.push!(bs::CC.BitSetBoundedMinPrioritySet, i::Int) = CC.push!(bs, i)
Loading

0 comments on commit 24b782b

Please sign in to comment.