mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-22 23:42:06 +00:00
parent
8a223eb6ac
commit
e955f7b536
@ -34,6 +34,7 @@ from text_generation_server.layers import (
|
|||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
|
|||||||
class Idefics2ForConditionalGeneration(nn.Module):
|
class Idefics2ForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.vision_config.quantize = config.quantize
|
config.vision_config.quantize = None
|
||||||
config.vision_config.speculator = config.speculator
|
config.vision_config.speculator = config.speculator
|
||||||
config.text_config.quantize = config.quantize
|
config.text_config.quantize = config.quantize
|
||||||
config.text_config.speculator = config.speculator
|
config.text_config.speculator = config.speculator
|
||||||
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||||||
name="text_model",
|
name="text_model",
|
||||||
)
|
)
|
||||||
self.dtype = weights.dtype
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
# The vision and connector models are not quantized.
|
||||||
|
with weights.use_loader(DefaultWeightsLoader()):
|
||||||
self.vision_model = Idefics2VisionTransformer(
|
self.vision_model = Idefics2VisionTransformer(
|
||||||
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
prefix=(
|
||||||
|
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||||
|
),
|
||||||
config=vision_config,
|
config=vision_config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quantize = config.quantize
|
||||||
|
try:
|
||||||
|
config.quantize = None
|
||||||
self.connector = Idefics2Connector(
|
self.connector = Idefics2Connector(
|
||||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
config.quantize = quantize
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||||
self.image_token_id = config.image_token_id
|
self.image_token_id = config.image_token_id
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@ -306,6 +307,20 @@ class Weights:
|
|||||||
def get_weights_row(self, prefix: str):
|
def get_weights_row(self, prefix: str):
|
||||||
return self.weights_loader.get_weights_row(self, prefix)
|
return self.weights_loader.get_weights_row(self, prefix)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def use_loader(self, weights_loader: WeightsLoader):
|
||||||
|
"""
|
||||||
|
This method is a context manager that can be used to use `Weights` with
|
||||||
|
a different loader for the duration of the context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
old_loader = self.weights_loader
|
||||||
|
self.weights_loader = weights_loader
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self.weights_loader = old_loader
|
||||||
|
|
||||||
|
|
||||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user