-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp8/fp6/fp4 types #158
Conversation
thanks! leaving it open for one week for comments |
Summarizing the current discussion. The main sticky pt as of now is how do we deal with sub-byte data types, and/or sub-byte datatype whose multiple is not aligned. Two situations:
Then main question is when we have subbyte data type while lane=1, what is the underlying data storage. There are two options
For performance and memory saving reasons, likely A1 is more desirable. However, frameworks like numpy may not be designed to support sub-byte level, mainly likely due to strides specification is in unit of bytes instead of in unit of element(which DLPack adopts). So for certain implementations the array may indeed be A0. For example current ml_dtypes would bring F4 numpy array as padded as @hawkinsp suggested During discussions @seberg also suggested that we could potentially use the
Alternatively, we can have a byte padding flag in the DLManagedTensor to indicate byte-level padding.
The main rationale that I think we should take are two folds
Given the importance of F8 and F4, I think we should expedite enabling them. And most motivating cases are for A1. I feel we could first go with A1, then A3(or A2, both are compatible with A1). We can also phase the stages so F6 comes in a separate PR after F4, given there might be a bit more discussion there(but they can also land very quickly once we converged on padding spec). Let me know if that would be a reasonable approach |
So, I am OK with either, Re-using One other thing I am not sure about common use: For sub-byte dtypes you also need to store a EDIT: Although... Adding a flag for bit-sized/unpadded is easy enough that we could also just do it. Needs bumping the minor version, but that is has almost zero cost. So e.g. if we want to (ab)use the |
Thanks all. In commit 46122b1 I made some updates to ensure we unblock the A1 use cases (packed layouts), while leaving room for padded layouts in the future. @hawkinsp @seberg my question would be: Even if DLPack allows both packed and padded layouts in the future, would NumPy's dtype registration mechanism allow ml_dtypes to register this info for NumPy to pick up during exchange through DLPack? IIUC NumPy's current DLPack implementation does not query these information from the dtype object (likewise, it does not query from the memory allocator the memory type). It seems like another gap to me (in addition to NumPy not supporting subtypes) that we need to consider sooner than later? |
FWIW @tqchen and I discussed offline. If everyone is OK with this PR, let's merge it and cut a v1.1 release. |
approving, let us wait for another three days to give everyone a chance to chime in |
The problem is representing arbitrary sub-byte offsets and strides in the array object mainly. Having padded vs. non-padded is in principle fully supported in NumPy. |
I am OK with this, my only concern is always possible future regret. There might a regret in avoiding having two new flags rather than just one (i.e. a bit-sized indicator and a padding indicator). |
How about defining |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Right, but I feel that if we add a flag, it should likely also indicate that there is a new Similarly for a "padded" flag: Not sure that "padded to next byte" is the only thing one would need. So unless we are sure about indicating e.g. EDIT: To be clear, I do not know there is an actual use-case for bit-offsets. I am not aware of any library that would ever slice bit-sized dtypes! |
Thinking about it more I am actually wondering if we really need |
@hawkinsp brought up a good point about pack ordering, which is related to endianess. So explicitly state it here to bring awareness
Note that such order is consist with little endian (bit order/byte order) in general, when you have a big packed data set D and shift/masking operation defined for D.
can be consistently defined for D equals one byte and D equals multiple bytes. |
I think it is reasonable to go with |
kDLFloat8_e4m3 = 8U, | ||
kDLFloat8_e4m3b11fnuz = 9U, | ||
kDLFloat8_e4m3fn = 10U, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One quick question: what's the spec difference between e4m3
and e4m3fn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/jax-ml/ml_dtypes (linking there may make sense)
Done in commit a0a1597. |
thanks @leofang ,i think all outstanding comments are addressed, going to merge before monday if nobody objetcts |
Thanks everyone, this change is now merged |
Close #156.
As discussed in the Python Array API meeting last week (Feb 6), I am adding additional data type enumerators based on user requests from, e.g., #156, pytorch/pytorch#146414, and NVIDIA internal teams.
cc @tqchen @rgommers @seberg @oleksandr-pavlyk @hawkinsp @jakevdp @kmaehashi @kkraus14 @ptrblck @nouiz