-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix include_self for scatter_reduce #2090
base: main
Are you sure you want to change the base?
Conversation
❌ 19 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
if not include_self: | ||
if onnx_reduce == "max": | ||
value = onh.from_array( | ||
np.array([np.finfo(src.dtype.numpy()).min], dtype=src.dtype.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g. could you add a comment on why we needed to use np.finfo min/max for this reduction type, for future readers?
# ONNX has not include_self parameter and default is include_self=True mode | ||
matcher=lambda sample: sample.kwargs.get("include_self") is False, | ||
reason="ONNX does't support include_self=False option", | ||
) | ||
.xfail( | ||
variant_name="amax", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should set dtypes=(torch.float16,),
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tried
value = onh.from_array( | ||
np.array([np.finfo(src.dtype.numpy()).max], dtype=src.dtype.numpy()) | ||
) | ||
value = ir.tensor([np.finfo(src.dtype.numpy()).max], dtype=src.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work for bfloat16 or integer types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ml-dtypes is used, that should work as well. I can also switch to pytorch to find the minimum value for each type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tested it, it seems it doesn't work:
>>> import ml_dtypes
>>> import numpy as np
>>> np.finfo(ml_dtypes.bfloat16)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/numpy/_core/getlimits.py", line 525, in __new__
raise ValueError("data type %r not inexact" % (dtype))
ValueError: data type <class 'ml_dtypes.bfloat16'> not inexact
>>> np.finfo(np.int32)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/justinchu/anaconda3/envs/pytorch/lib/python3.12/site-packages/numpy/_core/getlimits.py", line 525, in __new__
raise ValueError("data type %r not inexact" % (dtype))
ValueError: data type <class 'numpy.int32'> not inexact
>>> np.finfo(np.float32)
finfo(resolution=1e-06, min=-3.4028235e+38, max=3.4028235e+38, dtype=float32)
Would it be possible to do this ir.tensor(np.inf, dtype=src.dtype)
? Seems like it would work:
>>> ir.tensor(np.inf, dtype=ir.DataType.BFLOAT16)
Tensor<BFLOAT16,[]>(array(inf, dtype=bfloat16), name=None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For int types I think we need some special handling. Maybe we should store max
and min
values to the ir.DataType
class?
# whether or not it takes part in it. | ||
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. | ||
# mean is not supported. | ||
dtype = src.dtype or cst.dtype |
Check failure
Code scanning / CodeQL
Potentially uninitialized local variable Error
# whether or not it takes part in it. | ||
# It is -inf if the reduction is max, inf for min, 0 for add, 1 for mul. | ||
# mean is not supported. | ||
dtype = src.dtype or cst.dtype |
Check failure
Code scanning / lintrunner
RUFF/F821 Error
See https://docs.astral.sh/ruff/rules/undefined-name
No description provided.