Add support for AWQ-quantized Idefics2 (#2233)

Fixes #2036.
This commit is contained in:
Daniël de Kok 2024-07-16 07:58:25 +02:00 committed by yuanwu
parent 8a223eb6ac
commit e955f7b536
2 changed files with 39 additions and 11 deletions

View File

@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
TensorParallelRowLinear, TensorParallelRowLinear,
) )
from text_generation_server.utils.weights import DefaultWeightsLoader
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class Idefics2ForConditionalGeneration(nn.Module): class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
config.vision_config.quantize = config.quantize config.vision_config.quantize = None
config.vision_config.speculator = config.speculator config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator config.text_config.speculator = config.speculator
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
name="text_model", name="text_model",
) )
self.dtype = weights.dtype self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
self.vision_model = Idefics2VisionTransformer( self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model", prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config, config=vision_config,
weights=weights, weights=weights,
) )
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector( self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector", prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config, config=config,
weights=weights, weights=weights,
) )
finally:
config.quantize = quantize
self.config = config self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id self.image_token_id = config.image_token_id

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from safetensors import safe_open from safetensors import safe_open
@ -306,6 +307,20 @@ class Weights:
def get_weights_row(self, prefix: str): def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix) return self.weights_loader.get_weights_row(self, prefix)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]: def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
""" """