-
Notifications
You must be signed in to change notification settings - Fork 238
[CK_TILE] Fixing Type Conversions in PassThroughPack8 #2769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
3c9798c
to
6d7c174
Compare
Most of the work in this branch went into fixing the type conversions in I did not modify the existing conversion from pkint4 to fp16 since it passes validation, but it is probably worthwhile to compare its performance to a lookup-table based one. |
2ded46c
to
d6b990c
Compare
I should mention that bf8 x pk_i4 was not part of the ticket but adding it was straightforward using the same approach as with fp8 x pk_i4. |
e904753
to
d960cdb
Compare
…_element_wise_operation.hpp
…up based converters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR fixes type conversion issues in the PassThroughPack8 implementation by correcting bit shifting operations, implementing constexpr lookup tables for more reliable data type conversions, and updating test files to properly handle boolean return values from the run_gemm_combinations function.
- Updates the return type of run_gemm_combinations from int to bool and fixes return value handling in test files
- Fixes incorrect bit shifting in PassThroughPack8::operator() for bf16x8_t conversion
- Implements constexpr lookup table alternatives for fp8, bf8, and bf16 conversions to improve reliability
Reviewed Changes
Copilot reviewed 16 out of 16 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc | Changes return type and logic for run_gemm_combinations |
test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc | Changes return type and logic for run_gemm_combinations |
test/ck_tile/gemm/*.cpp | Updates main functions to handle boolean return values properly |
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp | Fixes bit shift bug and adds constexpr lookup table implementations |
include/ck_tile/core/numeric/bfloat16.hpp | Replaces union-based bit casting with constexpr bit_cast |
test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp | Adds new type configuration specializations |
test/ck_tile/elementwise/CMakeLists.txt | Removes incorrect conditional compilation logic |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
{ | ||
y.lo = i4_to_bhalf4(bit_cast<int>(x)); | ||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16); | ||
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit shift should be 16, not 8. The original code had >> 16
which correctly extracts the upper 16 bits for the high half of the bf16x8_t. Shifting by 8 will cause incorrect data extraction.
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 8); | |
y.hi = i4_to_bhalf4(bit_cast<int>(x) >> 16); |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functionality of the new i4_to_bhalf4
was modeled after the existing i4_to_half4
function. These make use of the same layout and data extraction for pk_int4_t
, and therefore the new operator
for bf16x8_t
shifts by the same amount as the existing operator
for fp16x8_t
(see lines #355 to #359).
Validation for both functions was added in test_gemm_pipeline_universal_bf16.cpp
and test_gemm_pipeline_universal_fp16.cpp
.
bf16_lookup_table[(q >> 16) & 0xf], | ||
bf16_lookup_table[(q >> 4) & 0xf], | ||
bf16_lookup_table[(q >> 20) & 0xf]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.
bf16_lookup_table[(q >> 16) & 0xf], | |
bf16_lookup_table[(q >> 4) & 0xf], | |
bf16_lookup_table[(q >> 20) & 0xf]}; | |
bf16_lookup_table[(q >> 4) & 0xf], | |
bf16_lookup_table[(q >> 8) & 0xf], | |
bf16_lookup_table[(q >> 12) & 0xf]}; |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fp8_lookup_table[(q >> 16) & 0xf], | ||
fp8_lookup_table[(q >> 4) & 0xf], | ||
fp8_lookup_table[(q >> 20) & 0xf]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.
fp8_lookup_table[(q >> 16) & 0xf], | |
fp8_lookup_table[(q >> 4) & 0xf], | |
fp8_lookup_table[(q >> 20) & 0xf]}; | |
fp8_lookup_table[(q >> 4) & 0xf], | |
fp8_lookup_table[(q >> 8) & 0xf], | |
fp8_lookup_table[(q >> 12) & 0xf]}; |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The situation here is identical to the above.
bf8_lookup_table[(q >> 16) & 0xf], | ||
bf8_lookup_table[(q >> 4) & 0xf], | ||
bf8_lookup_table[(q >> 20) & 0xf]}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The bit extraction pattern is inconsistent with the expected packed int4 layout. The shifts should be 0, 4, 8, 12 to extract consecutive 4-bit values, not 0, 16, 4, 20. This will result in incorrect value extraction from the packed integer.
bf8_lookup_table[(q >> 16) & 0xf], | |
bf8_lookup_table[(q >> 4) & 0xf], | |
bf8_lookup_table[(q >> 20) & 0xf]}; | |
bf8_lookup_table[(q >> 4) & 0xf], | |
bf8_lookup_table[(q >> 8) & 0xf], | |
bf8_lookup_table[(q >> 12) & 0xf]}; |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another identical instance of the use of the existing layout for pk_int4_t
.
…nstexpr compliant
… lookup table for use in conversions from pk_int4 to bf16
deb46a6
to
28c69dc
Compare
Proposed changes
Please describe the motivation behind the pull request, whether it enables a new feature or fixes a bug. If there are associated pull requests or issues, please link them to the pull request.
Checklist
Please put an
x
into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-format
on all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered