-
Notifications
You must be signed in to change notification settings - Fork 827
Metal backend: Add Metal int4 quantization support to Parakeet #17235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
39db621
0ed7c5c
b4310cc
94c823c
31b6f45
c68cc6b
bd7192f
bcc8bda
f166c50
0834659
ed4dcee
a058197
7146282
d3501af
fe5be37
a0e3469
fcfa832
2e50286
0145613
2e3254a
c5a3c1a
457428b
fec15bc
40ec415
c16dc59
8ee7d60
9966d37
646b4b3
3483dbf
310b1b6
6ad4556
7e422e2
1ae26f5
086e05c
9cede1e
4149007
ade165f
11da547
5ba588f
0bfe7a5
099bfd3
7ee1d30
3655f63
a3a8aca
f4203c8
c96a67f
e81b589
0f2cddd
8ff273f
4316164
401af46
957ba1f
87f1529
9ea88a9
cf89a2b
56f91d6
4962722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ def quantize_model_( # noqa: C901 | |
|
|
||
| Args: | ||
| module: The PyTorch module to quantize. | ||
| qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w"). | ||
| qlinear_config: Quantization config for linear layers ("4w", "8w", "8da4w", "8da8w", "fpa4w"). | ||
| qlinear_group_size: Group size for linear quantization (default: 32). | ||
| qlinear_packing_format: Packing format for linear layers (e.g., "tile_packed_to_4d"). | ||
| qembedding_config: Quantization config for embedding layers ("4w", "8w"). | ||
|
|
@@ -26,12 +26,41 @@ def quantize_model_( # noqa: C901 | |
| if not qlinear_config and not qembedding_config: | ||
| return | ||
|
|
||
| from torchao.quantization.quant_api import quantize_ | ||
|
|
||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # Metal (MPS) quantization uses different API | ||
| if qlinear_config == "fpa4w": | ||
| # Load MPS ops | ||
| import torchao.experimental.ops.mps # noqa: F401 | ||
| from torchao.experimental.quant_api import UIntxWeightOnlyConfig | ||
|
|
||
| config = UIntxWeightOnlyConfig( | ||
| group_size=qlinear_group_size, | ||
| bitwidth=4, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update the pin past pytorch/ao#3829, and set
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be done in a follow-up PR too
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, that's my plan, to do in follow-up PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here #17258 |
||
| ) | ||
|
|
||
| def linear_filter(m, fqn): | ||
| if isinstance(m, torch.nn.Linear): | ||
| if m.weight.shape[1] % qlinear_group_size != 0: | ||
| raise ValueError( | ||
| f"Metal int4 quantization requires weight dimension (K) to be multiple of group_size. " | ||
| f"Layer {fqn} has weight shape {m.weight.shape} (K={m.weight.shape[1]}, group_size={qlinear_group_size})" # noqa: E501 | ||
| ) | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return True | ||
| return False | ||
|
|
||
| print( | ||
| f" Applying {qlinear_config} linear quantization " | ||
| f"(group_size={qlinear_group_size})..." | ||
| ) | ||
| quantize_(module, config, filter_fn=linear_filter) | ||
| return | ||
|
|
||
| from torchao.quantization.granularity import PerAxis, PerGroup | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| from torchao.quantization.quant_api import ( | ||
| Int4WeightOnlyConfig, | ||
| Int8DynamicActivationIntxWeightConfig, | ||
| IntxWeightOnlyConfig, | ||
| quantize_, | ||
| ) | ||
|
|
||
| # Quantize embedding layers first | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.