From 150c7a1d70cd40d1dc73d463ee641e199732e968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Belmant?= Date: Tue, 25 Feb 2025 18:57:45 +0100 Subject: [PATCH] Use `_return_type` from Base, not from `Core.Compiler` (#819) --- Project.toml | 2 +- src/rulesets/Base/broadcast.jl | 4 ++-- src/rulesets/Base/mapreduce.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index f1ac45605..9d47625a1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.72.2" +version = "1.72.3" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index 4fb83c4e7..d1610ce24 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -48,7 +48,7 @@ end # Path 2: This is roughly what `derivatives_given_output` is designed for, should be fast. function may_bc_derivatives(::Type{T}, f::F, args::Vararg{Any,N}) where {T,F,N} - TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) + TΔ = Core.Compiler.return_type(derivatives_given_output, Tuple{T, F, map(_eltype, args)...}) return isconcretetype(TΔ) end @@ -98,7 +98,7 @@ function may_bc_forwards(cfg::C, f::F, arg) where {C,F} TA = _eltype(arg) TA <: Real || return false cfg isa RuleConfig{>:HasForwardsMode} && return true # allows frule_via_ad - TF = Core.Compiler._return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) + TF = Core.Compiler.return_type(frule, Tuple{C, Tuple{NoTangent, TA}, F, TA}) return isconcretetype(TF) && TF <: Tuple end diff --git a/src/rulesets/Base/mapreduce.jl b/src/rulesets/Base/mapreduce.jl index f56ffa607..999cfbb73 100644 --- a/src/rulesets/Base/mapreduce.jl +++ b/src/rulesets/Base/mapreduce.jl @@ -139,7 +139,7 @@ Works by seeing if the result of `derivatives_given_output(nothing, f, x)` can b The method of `derivatives_given_output` usually comes from `@scalar_rule`. """ function _uses_input_only(f::F, ::Type{xT}) where {F,xT} - gT = Core.Compiler._return_type(derivatives_given_output, Tuple{Nothing, F, xT}) + gT = Core.Compiler.return_type(derivatives_given_output, Tuple{Nothing, F, xT}) # Here we must check `<: Number`, to avoid this, the one rule which can return the `nothing`: # ChainRules.derivatives_given_output("anything", exp, 1) == (("anything",),) return isconcretetype(gT) && gT <: Tuple{Tuple{Number}}