Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp8/fp6/fp4 types #158

Merged
merged 4 commits into from
Mar 10, 2025
Merged

Add fp8/fp6/fp4 types #158

merged 4 commits into from
Mar 10, 2025

Conversation

leofang
Copy link
Collaborator

@leofang leofang commented Feb 13, 2025

Close #156.

As discussed in the Python Array API meeting last week (Feb 6), I am adding additional data type enumerators based on user requests from, e.g., #156, pytorch/pytorch#146414, and NVIDIA internal teams.

cc @tqchen @rgommers @seberg @oleksandr-pavlyk @hawkinsp @jakevdp @kmaehashi @kkraus14 @ptrblck @nouiz

@tqchen
Copy link
Member

tqchen commented Feb 13, 2025

thanks! leaving it open for one week for comments

@tqchen
Copy link
Member

tqchen commented Feb 15, 2025

Summarizing the current discussion. The main sticky pt as of now is how do we deal with sub-byte data types, and/or sub-byte datatype whose multiple is not aligned. Two situations:

  • For F6, it is less for performance gain more more designed for memory saving, it is sub-byte, and its multiple do not aligns to a word.
  • For F4, its multiple aligns to a byte, and there is performance gains leveraged by accelerated instructions from latest GPU.

Then main question is when we have subbyte data type while lane=1, what is the underlying data storage. There are two options

  • A0: Padded, each element should be padded to byte-level
  • A1: Packed, each element should be packed continuously in-memory without padding.

For performance and memory saving reasons, likely A1 is more desirable. However, frameworks like numpy may not be designed to support sub-byte level, mainly likely due to strides specification is in unit of bytes instead of in unit of element(which DLPack adopts). So for certain implementations the array may indeed be A0. For example current ml_dtypes would bring F4 numpy array as padded as @hawkinsp suggested

During discussions @seberg also suggested that we could potentially use the bits field to indicate padding

  • A2: When bits=8 and we are looking at F4 for example, it indicated padding to bytes. but as @seberg suggested it is a bit patch on the protocol

Alternatively, we can have a byte padding flag in the DLManagedTensor to indicate byte-level padding.

  • A3: Having a byte_padding flag in DLManagedTensor to indicate that we would like to pad, and defaults to False(mainly because most use-cases are A1 for acceleration)

The main rationale that I think we should take are two folds

  • Enable the latest high performance accelerated use-cases in array API
  • Design minimal and future compatible specs(e.g. if we specify in one and and end up the future end up wanting another way it would be bad, but we can always incrementally add more as long as future spec do not contradict current one)

Given the importance of F8 and F4, I think we should expedite enabling them. And most motivating cases are for A1. I feel we could first go with A1, then A3(or A2, both are compatible with A1). We can also phase the stages so F6 comes in a separate PR after F4, given there might be a bit more discussion there(but they can also land very quickly once we converged on padding spec).

Let me know if that would be a reasonable approach

@seberg
Copy link
Collaborator

seberg commented Feb 20, 2025

So, I am OK with either, but it seems like not padding is the more useful thing for now. +So maybe we should just do that and then follow-up indeed.

Re-using .bits here does lead to two spellings for a packed int6 (not a biggy but true), since you would need a new enum entry for a padded version!
I think right now, we also never said what an int with bits=2 is (i.e. is it padded to byte size or not?).

One other thing I am not sure about common use: For sub-byte dtypes you also need to store a bit_offset (or re-interpret the byte_offset as bits).
If I think of NumPy supporting bit-sized, slicing + view semantics mean that you need a bit-offset, but likely nobody needs it for these (or even for bit-fields e.g. in dataframe world).

EDIT: Although... Adding a flag for bit-sized/unpadded is easy enough that we could also just do it. Needs bumping the minor version, but that is has almost zero cost. So e.g. if we want to (ab)use the byte_offset field as a bit_offset, then we can do so.

@leofang
Copy link
Collaborator Author

leofang commented Feb 22, 2025

Thanks all. In commit 46122b1 I made some updates to ensure we unblock the A1 use cases (packed layouts), while leaving room for padded layouts in the future.

@hawkinsp @seberg my question would be: Even if DLPack allows both packed and padded layouts in the future, would NumPy's dtype registration mechanism allow ml_dtypes to register this info for NumPy to pick up during exchange through DLPack? IIUC NumPy's current DLPack implementation does not query these information from the dtype object (likewise, it does not query from the memory allocator the memory type). It seems like another gap to me (in addition to NumPy not supporting subtypes) that we need to consider sooner than later?

@leofang
Copy link
Collaborator Author

leofang commented Feb 22, 2025

FWIW @tqchen and I discussed offline. If everyone is OK with this PR, let's merge it and cut a v1.1 release.

@tqchen
Copy link
Member

tqchen commented Feb 22, 2025

approving, let us wait for another three days to give everyone a chance to chime in

@seberg
Copy link
Collaborator

seberg commented Feb 23, 2025

Even if DLPack allows both packed and padded layouts in the future, would NumPy's dtype registration mechanism allow ml_dtypes to register this info for NumPy to pick up during exchange through DLPack?

The problem is representing arbitrary sub-byte offsets and strides in the array object mainly. Having padded vs. non-padded is in principle fully supported in NumPy.

@tqchen
Copy link
Member

tqchen commented Feb 24, 2025

@seberg @hawkinsp @szha @isVoid please see if you can take another look

@seberg
Copy link
Collaborator

seberg commented Feb 24, 2025

I am OK with this, my only concern is always possible future regret. There might a regret in avoiding having two new flags rather than just one (i.e. a bit-sized indicator and a padding indicator).
But, since padding might make sense more broadly, I half suspect we need that either way if it becomes a serious use-case.

@leofang
Copy link
Collaborator Author

leofang commented Feb 24, 2025

How about defining DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (default it to false) and requiring consumers to check the flag when accessing fp6/fp4/int4/...?

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@seberg
Copy link
Collaborator

seberg commented Feb 24, 2025

How about defining DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (default it to false)

Right, but I feel that if we add a flag, it should likely also indicate that there is a new bitoffset field (or byte-offset has a bit-offset meaning), so that one can represent a arr_f4[1:5] which has a ptr + 4bits initial point.

Similarly for a "padded" flag: Not sure that "padded to next byte" is the only thing one would need.


So unless we are sure about indicating e.g. byteoffset -> bitoffset flag, it seems just as well to put this in as is?

EDIT: To be clear, I do not know there is an actual use-case for bit-offsets. I am not aware of any library that would ever slice bit-sized dtypes!
I.e. a bit-sized NumPy might, but that doesn't exist. And I doubt e.g. pyarrow ever create bit-offset views.

@leofang
Copy link
Collaborator Author

leofang commented Feb 24, 2025

Thinking about it more I am actually wondering if we really need bit_offset/byte_offset. I would assume that the only reasonable padding is to pad them to 1 byte, in which case the consumer should know very well how to compute the offset based on the bit width?

@tqchen
Copy link
Member

tqchen commented Feb 26, 2025

@hawkinsp brought up a good point about pack ordering, which is related to endianess. So explicitly state it here to bring awareness

for a byte B,  

((B >> (i * 4)) && bit_mask) stores the i-th element

Note that such order is consist with little endian (bit order/byte order) in general, when you have a big packed data set D and shift/masking operation defined for D.

((D >> (i * bits)) && bit_mask) stores the i-th element

can be consistently defined for D equals one byte and D equals multiple bytes.

@tqchen
Copy link
Member

tqchen commented Feb 26, 2025

I think it is reasonable to go with DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED and default to false. We can either also do it in a followup

Comment on lines +162 to +164
kDLFloat8_e4m3 = 8U,
kDLFloat8_e4m3b11fnuz = 9U,
kDLFloat8_e4m3fn = 10U,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One quick question: what's the spec difference between e4m3 and e4m3fn?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/jax-ml/ml_dtypes (linking there may make sense)

@leofang
Copy link
Collaborator Author

leofang commented Mar 7, 2025

I think it is reasonable to go with DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED and default to false. We can either also do it in a followup

Done in commit a0a1597.

@tqchen
Copy link
Member

tqchen commented Mar 7, 2025

thanks @leofang ,i think all outstanding comments are addressed, going to merge before monday if nobody objetcts

@tqchen tqchen merged commit 5b474f9 into dmlc:main Mar 10, 2025
3 checks passed
@tqchen
Copy link
Member

tqchen commented Mar 10, 2025

Thanks everyone, this change is now merged

@leofang leofang deleted the exotic branch March 10, 2025 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature Request] FP8 Support
7 participants