Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected output type when multiplying reshaped ConcreteRArrays #778

Open
alexiscltrn opened this issue Feb 20, 2025 · 2 comments
Open

Comments

@alexiscltrn
Copy link

Description

When reshaping a ConcreteRArray and performing pairwise multiplication, the result returns a standard Array rather than retaining the ConcreteRArray type. This behavior seems to deviate from the expected functionality, as it would be ideal for the result of such operations to preserve the ConcreteRArray type to ensure consistency in further operations. This issue becomes problematic when used within Lux neural networks. Specifically, when performing matrix operations on reshaped ConcreteRArray objects, the type mismatch between ConcreteRArray and Array triggers an ArgumentError due to objects being on different devices (ReactantDevice and CPUDevice). This disrupts the execution of the model and leads to errors in subsequent operations since Lux.Utils.make_abstract_matrix is involved internally.

Steps to reproduce

Reactant.set_default_backend("cpu")
const xdev = reactant_device()
x = rand(Float32, 2, 10, 3) |> xdev #ConcreteRArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}
reshape(x, 2, :) .* reshape(x, 2, :) #Matrix{Float32}
model = Chain(Parallel(.*, Dense(2, 16, tanh), Dense(2, 16, tanh)), Dense(16, 1))
ps, st = Lux.setup(Random.default_rng(), model) |> xdev
model(x, ps, st)

This triggers an error:

ERROR: ArgumentError: Objects are on devices with different types: ReactantDevice and CPUDevice.
Stacktrace:
  [1] combine_devices(T1::Type{ReactantDevice}, T2::Type{CPUDevice})
    @ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/uhCbD/src/internal.jl:125
  [2] macro expansion
    @ ~/.julia/packages/MLDataDevices/uhCbD/src/internal.jl:205 [inlined]
  [3] unrolled_mapreduce
    @ ~/.julia/packages/MLDataDevices/uhCbD/src/internal.jl:192 [inlined]
  [4] unrolled_mapreduce(f::typeof(get_device_type), op::typeof(MLDataDevices.Internal.combine_devices), itr::Tuple{ConcreteRArray{…}, Matrix{…}, ConcreteRArray{…}})
    @ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/uhCbD/src/internal.jl:183
  [5] get_device_type(x::Tuple{ConcreteRArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{…}}, Matrix{Float32}, ConcreteRArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}}})
    @ MLDataDevices.Internal ~/.julia/packages/MLDataDevices/uhCbD/src/internal.jl:160
  [6] get_device_type(x::Tuple{ConcreteRArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{…}}, Matrix{Float32}, ConcreteRArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}}})
    @ MLDataDevices ~/.julia/packages/MLDataDevices/uhCbD/src/public.jl:370
  [7] internal_operation_mode(xs::Tuple{ConcreteRArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{…}}, Matrix{Float32}, ConcreteRArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}}})
    @ LuxLib ~/.julia/packages/LuxLib/kH9PB/src/traits.jl:217
  [8] select_fastest_activation(::typeof(identity), ::ConcreteRArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{…}}, ::Matrix{Float32}, ::ConcreteRArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}})
    @ LuxLib.Impl ~/.julia/packages/LuxLib/kH9PB/src/impl/activation.jl:129
  [9] fused_dense_bias_activation::typeof(identity), weight::ConcreteRArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{…}}, x::Matrix{Float32}, b::ConcreteRArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{…}})
    @ LuxLib.API ~/.julia/packages/LuxLib/kH9PB/src/api/dense.jl:35
 [10] (::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True})(x::Array{Float32, 3}, ps::@NamedTuple{weight::ConcreteRArray{…}, bias::ConcreteRArray{…}}, st::@NamedTuple{})
    @ Lux ~/.julia/packages/Lux/TbS8R/src/layers/basic.jl:343
 [11] apply
    @ ~/.julia/packages/LuxCore/8mVob/src/LuxCore.jl:155 [inlined]
 [12] macro expansion
    @ ~/.julia/packages/Lux/TbS8R/src/layers/containers.jl:0 [inlined]
 [13] applychain(layers::@NamedTuple{}, x::ConcreteRArray{…}, ps::@NamedTuple{}, st::@NamedTuple{})
    @ Lux ~/.julia/packages/Lux/TbS8R/src/layers/containers.jl:482
 [14] (::Chain{@NamedTuple{…}, Nothing})(x::ConcreteRArray{Float32, 3, 1, Reactant.Sharding.ShardInfo{…}}, ps::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, st::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}})
    @ Lux ~/.julia/packages/Lux/TbS8R/src/layers/containers.jl:480
 [15] top-level scope
    @ REPL[47]:1
Some type information was truncated. Use `show(err)` to see complete types.
  [b2108857] Lux v1.7.0
  [3c362404] Reactant v0.2.31
@avik-pal
Copy link
Collaborator

model(x, ps, st) is mostly going to error or lead to slow performance. You need to do @jit model(x, ps, st) to run it with proper compilation. See https://lux.csail.mit.edu/stable/manual/compiling_lux_models

There are a few other issues here (though above is the solution to your actual problem)

  1. Make sure broadcasting generates a ConcreteRArray (outside of @compile)
  2. Relax the combine_devices in MLDataDevices Relax MLDataDevices.combine_devices for ReactantDevice and CPUDevice LuxDL/Lux.jl#1244

@wsmoses
Copy link
Member

wsmoses commented Feb 20, 2025

that also said we can probably define reshape(ConcreteRArray) -> another concreterarray that shares the same xla buffer

Though ++ all @avik-pal 's points

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants