mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-21 14:52:20 +00:00
Improve vlm support (add idefics3 support) (#2437)
* feat: expand vlm support and add image token logic and tests * fix: avoid unused perceiver config * feat: integrate image tokens into inputs embeds * feat: add simple idefics3 test * feat: update docs, image token logic and weight names * fix: improve image processing * feat: improve prefix for idefics3 * fix: bump idefics3 tests and snapshots * fix: improve text model loading * feat: consolidate changes with existing vlms and add support and test for smolvlm * fix: create new idefic3 file, simplify logic and adjust llama weight loading * fix: lint with ruff * fix: clean up idefics 3 and improve prefix handling * fix: improve typing * fix: improve prompt_split_image with ref to original impl * fix: adjust ruff lints and small refactors * fix: adjust FlashLlamaModel prefix logic
This commit is contained in:
parent
a9c7d2e3b6
commit
da5ab46705
@ -5,6 +5,7 @@ Text Generation Inference enables serving optimized models. The following sectio
|
|||||||
|
|
||||||
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
- [Deepseek V2](https://huggingface.co/deepseek-ai/DeepSeek-V2)
|
||||||
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
- [Idefics 2](https://huggingface.co/HuggingFaceM4/idefics2-8b) (Multimodal)
|
||||||
|
- [Idefics 3](https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3) (Multimodal)
|
||||||
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
- [Llava Next (1.6)](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (Multimodal)
|
||||||
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
- [Llama](https://huggingface.co/collections/meta-llama/llama-31-669fc079a0c406a149a5738f)
|
||||||
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
- [Phi 3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct)
|
||||||
|
@ -354,6 +354,7 @@ def launcher(event_loop):
|
|||||||
kv_cache_dtype: Optional[str] = None,
|
kv_cache_dtype: Optional[str] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
max_input_length: Optional[int] = None,
|
max_input_length: Optional[int] = None,
|
||||||
|
max_input_tokens: Optional[int] = None,
|
||||||
max_batch_prefill_tokens: Optional[int] = None,
|
max_batch_prefill_tokens: Optional[int] = None,
|
||||||
max_total_tokens: Optional[int] = None,
|
max_total_tokens: Optional[int] = None,
|
||||||
lora_adapters: Optional[List[str]] = None,
|
lora_adapters: Optional[List[str]] = None,
|
||||||
@ -402,6 +403,9 @@ def launcher(event_loop):
|
|||||||
if max_input_length:
|
if max_input_length:
|
||||||
args.append("--max-input-length")
|
args.append("--max-input-length")
|
||||||
args.append(str(max_input_length))
|
args.append(str(max_input_length))
|
||||||
|
if max_input_tokens:
|
||||||
|
args.append("--max-input-tokens")
|
||||||
|
args.append(str(max_input_tokens))
|
||||||
if max_batch_prefill_tokens:
|
if max_batch_prefill_tokens:
|
||||||
args.append("--max-batch-prefill-tokens")
|
args.append("--max-batch-prefill-tokens")
|
||||||
args.append(str(max_batch_prefill_tokens))
|
args.append(str(max_batch_prefill_tokens))
|
||||||
|
@ -0,0 +1,67 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 9,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 2684,
|
||||||
|
"logprob": -0.24902344,
|
||||||
|
"special": false,
|
||||||
|
"text": " There"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 374,
|
||||||
|
"logprob": -0.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " is"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 264,
|
||||||
|
"logprob": -0.23535156,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 35372,
|
||||||
|
"logprob": -0.125,
|
||||||
|
"special": false,
|
||||||
|
"text": " statue"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -0.30273438,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 279,
|
||||||
|
"logprob": -0.20507812,
|
||||||
|
"special": false,
|
||||||
|
"text": " the"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2217,
|
||||||
|
"logprob": -0.076171875,
|
||||||
|
"special": false,
|
||||||
|
"text": " image"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 13,
|
||||||
|
"logprob": -0.053710938,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 128258,
|
||||||
|
"logprob": -0.011352539,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " There is a statue in the image."
|
||||||
|
}
|
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"details": {
|
||||||
|
"best_of_sequences": null,
|
||||||
|
"finish_reason": "eos_token",
|
||||||
|
"generated_tokens": 8,
|
||||||
|
"prefill": [],
|
||||||
|
"seed": null,
|
||||||
|
"tokens": [
|
||||||
|
{
|
||||||
|
"id": 330,
|
||||||
|
"logprob": -0.118652344,
|
||||||
|
"special": false,
|
||||||
|
"text": " A"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11426,
|
||||||
|
"logprob": -0.28320312,
|
||||||
|
"special": false,
|
||||||
|
"text": " bee"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 335,
|
||||||
|
"logprob": -0.95703125,
|
||||||
|
"special": false,
|
||||||
|
"text": " on"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 253,
|
||||||
|
"logprob": -0.06982422,
|
||||||
|
"special": false,
|
||||||
|
"text": " a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 11986,
|
||||||
|
"logprob": -0.49414062,
|
||||||
|
"special": false,
|
||||||
|
"text": " pink"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 8525,
|
||||||
|
"logprob": -0.07763672,
|
||||||
|
"special": false,
|
||||||
|
"text": " flower"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 30,
|
||||||
|
"logprob": -1.0703125,
|
||||||
|
"special": false,
|
||||||
|
"text": "."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 49154,
|
||||||
|
"logprob": -0.092285156,
|
||||||
|
"special": true,
|
||||||
|
"text": "<end_of_utterance>"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"top_tokens": null
|
||||||
|
},
|
||||||
|
"generated_text": " A bee on a pink flower."
|
||||||
|
}
|
31
integration-tests/models/test_idefics3.py
Normal file
31
integration-tests/models/test_idefics3.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_idefics3_next_handle(launcher):
|
||||||
|
with launcher("HuggingFaceM4/Idefics3-8B-Llama3") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_idefics3_next(flash_idefics3_next_handle):
|
||||||
|
await flash_idefics3_next_handle.health(300)
|
||||||
|
return flash_idefics3_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_idefics3_next_simple_url(flash_idefics3_next, response_snapshot):
|
||||||
|
ny_skyline = "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg"
|
||||||
|
query = "What is in this image?"
|
||||||
|
response = await flash_idefics3_next.generate(
|
||||||
|
f"<|begin_of_text|><|begin_of_text|>User:{query}<end_of_utterance>\nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
seed=1337,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " There is a statue in the image."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 9
|
||||||
|
assert response == response_snapshot
|
31
integration-tests/models/test_smolvlm.py
Normal file
31
integration-tests/models/test_smolvlm.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def flash_smolvlm_next_handle(launcher):
|
||||||
|
with launcher("HuggingFaceTB/SmolVLM-Instruct") as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def flash_smolvlm_next(flash_smolvlm_next_handle):
|
||||||
|
await flash_smolvlm_next_handle.health(300)
|
||||||
|
return flash_smolvlm_next_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.private
|
||||||
|
async def test_flash_smolvlm_next_simple_url(flash_smolvlm_next, response_snapshot):
|
||||||
|
ny_skyline = "https://huggingface.co/spaces/merve/chameleon-7b/resolve/main/bee.jpg"
|
||||||
|
query = "What is in this image?"
|
||||||
|
response = await flash_smolvlm_next.generate(
|
||||||
|
f"<|begin_of_text|><|begin_of_text|>User:{query}<end_of_utterance>\nAssistant:",
|
||||||
|
max_new_tokens=10,
|
||||||
|
seed=1337,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert (
|
||||||
|
response.generated_text == " A bee on a pink flower."
|
||||||
|
), f"{repr(response.generated_text)}"
|
||||||
|
assert response.details.generated_tokens == 8
|
||||||
|
assert response == response_snapshot
|
@ -110,6 +110,24 @@ pub struct ClipVisionModel {
|
|||||||
patch_size: usize,
|
patch_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub struct Idefics3 {}
|
||||||
|
|
||||||
|
impl Idefics3 {
|
||||||
|
pub fn get_max_longest_edge(&self) -> usize {
|
||||||
|
364
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_number_of_features(&self) -> usize {
|
||||||
|
169
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_max_longest_edge_for_image_resize(&self) -> usize {
|
||||||
|
1456
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub struct Idefics2 {}
|
pub struct Idefics2 {}
|
||||||
@ -178,6 +196,7 @@ pub enum Config {
|
|||||||
Idefics,
|
Idefics,
|
||||||
Mllama,
|
Mllama,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
|
Idefics3(Idefics3),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Granite,
|
Granite,
|
||||||
|
@ -170,6 +170,7 @@ impl TokenizerConfigToken {
|
|||||||
#[serde(tag = "processor_class")]
|
#[serde(tag = "processor_class")]
|
||||||
pub enum HubPreprocessorConfig {
|
pub enum HubPreprocessorConfig {
|
||||||
Idefics2Processor(Idefics2Preprocessor),
|
Idefics2Processor(Idefics2Preprocessor),
|
||||||
|
Idefics3Processor(Idefics2Preprocessor),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl HubPreprocessorConfig {
|
impl HubPreprocessorConfig {
|
||||||
|
@ -614,6 +614,73 @@ fn image_tokens(
|
|||||||
|
|
||||||
image_string
|
image_string
|
||||||
}
|
}
|
||||||
|
Idefics3(config) => {
|
||||||
|
const FAKE: &str = "<fake_token_around_image>";
|
||||||
|
const IMAGE: &str = "<image>";
|
||||||
|
const GLOBAL_IMG: &str = "<global-img>";
|
||||||
|
|
||||||
|
let max_longest_edge_for_image_resize = config.get_max_longest_edge_for_image_resize();
|
||||||
|
|
||||||
|
// resize image if it is larger than max_longest_edge_for_image_resize keeping aspect ratio
|
||||||
|
let (height, width) = if height > max_longest_edge_for_image_resize
|
||||||
|
|| width > max_longest_edge_for_image_resize
|
||||||
|
{
|
||||||
|
let aspect_ratio = height as f32 / width as f32;
|
||||||
|
if height > width {
|
||||||
|
(
|
||||||
|
max_longest_edge_for_image_resize,
|
||||||
|
(max_longest_edge_for_image_resize as f32 / aspect_ratio) as usize,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
(max_longest_edge_for_image_resize as f32 * aspect_ratio) as usize,
|
||||||
|
max_longest_edge_for_image_resize,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
(height, width)
|
||||||
|
};
|
||||||
|
|
||||||
|
let image_seq_len = config.get_number_of_features();
|
||||||
|
let max_edge = config.get_max_longest_edge();
|
||||||
|
|
||||||
|
let (image_rows, image_cols) = if height > max_edge || width > max_edge {
|
||||||
|
(
|
||||||
|
(height as f32 / max_edge as f32).ceil() as usize,
|
||||||
|
(width as f32 / max_edge as f32).ceil() as usize,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(0, 0)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut image_string = String::new();
|
||||||
|
|
||||||
|
if image_rows == 0 && image_cols == 0 {
|
||||||
|
// Single image case
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(GLOBAL_IMG);
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
} else {
|
||||||
|
// Split image case
|
||||||
|
for n_h in 0..image_rows {
|
||||||
|
for n_w in 0..image_cols {
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(&format!("<row_{}_col_{}>", n_h + 1, n_w + 1));
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
}
|
||||||
|
image_string.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
image_string.push('\n');
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
image_string.push_str(GLOBAL_IMG);
|
||||||
|
image_string.push_str(&IMAGE.repeat(image_seq_len));
|
||||||
|
image_string.push_str(FAKE);
|
||||||
|
}
|
||||||
|
|
||||||
|
image_string
|
||||||
|
}
|
||||||
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
Paligemma(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
LlavaNext(config) => "<image>".repeat(config.get_number_of_features(height, width)),
|
||||||
Qwen2Vl(config) => format!(
|
Qwen2Vl(config) => format!(
|
||||||
@ -647,7 +714,8 @@ fn prepare_input<T: TokenizerTrait>(
|
|||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(
|
Some(
|
||||||
config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_) | Qwen2Vl(_)),
|
config @ (Idefics | Mllama | Idefics2(_) | Idefics3(_) | Paligemma(_) | LlavaNext(_)
|
||||||
|
| Qwen2Vl(_)),
|
||||||
) => {
|
) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
|
@ -152,6 +152,9 @@ try:
|
|||||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||||
Idefics2ForConditionalGeneration,
|
Idefics2ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics3 import (
|
||||||
|
Idefics3ForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
from text_generation_server.models.custom_modeling.qwen2_vl import (
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
)
|
)
|
||||||
@ -188,6 +191,12 @@ class ModelType(enum.Enum):
|
|||||||
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
"url": "https://huggingface.co/HuggingFaceM4/idefics2-8b",
|
||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
IDEFICS3 = {
|
||||||
|
"type": "idefics3",
|
||||||
|
"name": "Idefics 3",
|
||||||
|
"url": "https://huggingface.co/HuggingFaceM4/Idefics3-8B-Llama3",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
LLAVA_NEXT = {
|
LLAVA_NEXT = {
|
||||||
"type": "llava_next",
|
"type": "llava_next",
|
||||||
"name": "Llava Next (1.6)",
|
"name": "Llava Next (1.6)",
|
||||||
@ -1253,6 +1262,24 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == IDEFICS3:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return VlmCausalLM(
|
||||||
|
model_id=model_id,
|
||||||
|
model_class=Idefics3ForConditionalGeneration,
|
||||||
|
revision=revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
default_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
lora_adapter_ids=lora_adapter_ids,
|
||||||
|
# XXX: Extremely important to cap resolution in order to limit
|
||||||
|
# VRAM usage.
|
||||||
|
processor_kwargs={"size": {"longest_edge": 1456}},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
if model_type == PALIGEMMA:
|
if model_type == PALIGEMMA:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
|
@ -515,9 +515,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=0,
|
index=0,
|
||||||
prefix=(
|
prefix=f"{prefix}.layers.0",
|
||||||
"model.layers.0" if not prefix else f"{prefix}.model.layers.0"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -533,11 +531,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaCrossLayer(
|
FlashLlamaCrossLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -546,11 +540,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=layer_id,
|
index=layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{layer_id}"),
|
||||||
f"model.layers.{layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
@ -561,18 +551,14 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
self.layers.append(
|
self.layers.append(
|
||||||
FlashLlamaLayer(
|
FlashLlamaLayer(
|
||||||
index=last_layer_id,
|
index=last_layer_id,
|
||||||
prefix=(
|
prefix=(f"{prefix}.layers.{last_layer_id}"),
|
||||||
f"model.layers.{last_layer_id}"
|
|
||||||
if not prefix
|
|
||||||
else f"{prefix}.model.layers.{last_layer_id}"
|
|
||||||
),
|
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = FastRMSNorm.load(
|
self.norm = FastRMSNorm.load(
|
||||||
prefix="model.norm" if not prefix else f"{prefix}.model.norm",
|
prefix=f"{prefix}.norm",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
eps=config.rms_norm_eps,
|
eps=config.rms_norm_eps,
|
||||||
)
|
)
|
||||||
@ -629,19 +615,24 @@ class FlashLlamaModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights):
|
def __init__(self, prefix: str, config, weights, name=None):
|
||||||
|
if name is None:
|
||||||
|
name = "model"
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.embed_tokens = TensorParallelEmbedding(
|
self.embed_tokens = TensorParallelEmbedding(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.embed_tokens"
|
f"{name}.embed_tokens"
|
||||||
if not prefix
|
if not prefix
|
||||||
else f"{prefix}.model.embed_tokens"
|
else f"{prefix}.{name}.embed_tokens"
|
||||||
),
|
),
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.model = FlashLlamaModel(prefix, config, weights)
|
self.model = FlashLlamaModel(
|
||||||
|
prefix=name if not prefix else f"{prefix}.{name}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
if config.tie_word_embeddings:
|
if config.tie_word_embeddings:
|
||||||
suffix = "model.embed_tokens"
|
suffix = "model.embed_tokens"
|
||||||
else:
|
else:
|
||||||
@ -652,11 +643,13 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||||||
if embedding_multiplier is not None:
|
if embedding_multiplier is not None:
|
||||||
self.embed_tokens.weight.data *= embedding_multiplier
|
self.embed_tokens.weight.data *= embedding_multiplier
|
||||||
|
|
||||||
|
prefix = "lm_head" if not prefix or name != "model" else f"{prefix}.{suffix}"
|
||||||
|
|
||||||
with no_fp8(weights):
|
with no_fp8(weights):
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config,
|
config,
|
||||||
prefix=suffix if not prefix else f"{prefix}.{suffix}",
|
prefix,
|
||||||
weights=weights,
|
weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Used in Granite
|
# Used in Granite
|
||||||
|
584
server/text_generation_server/models/custom_modeling/idefics3.py
Normal file
584
server/text_generation_server/models/custom_modeling/idefics3.py
Normal file
@ -0,0 +1,584 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" PyTorch Idefics3 model."""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
from text_generation_server.models.custom_modeling.vlm import (
|
||||||
|
load_text_model,
|
||||||
|
)
|
||||||
|
from text_generation_server.layers.attention import Seqlen
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||||
|
|
||||||
|
from text_generation_server.layers import (
|
||||||
|
TensorParallelColumnLinear,
|
||||||
|
TensorParallelEmbedding,
|
||||||
|
TensorParallelRowLinear,
|
||||||
|
)
|
||||||
|
from text_generation_server.utils.weights import DefaultWeightsLoader, UnquantizedWeight
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionEmbeddings(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||||
|
resolution.
|
||||||
|
|
||||||
|
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||||
|
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||||
|
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||||
|
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.image_size = config.image_size
|
||||||
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
|
self.patch_embedding = nn.Conv2d(
|
||||||
|
in_channels=config.num_channels,
|
||||||
|
out_channels=self.embed_dim,
|
||||||
|
kernel_size=self.patch_size,
|
||||||
|
stride=self.patch_size,
|
||||||
|
padding="valid",
|
||||||
|
)
|
||||||
|
self.patch_embedding.weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||||
|
)
|
||||||
|
self.patch_embedding.bias = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_patches_per_side = self.image_size // self.patch_size
|
||||||
|
self.num_patches = self.num_patches_per_side**2
|
||||||
|
self.num_positions = self.num_patches
|
||||||
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
|
prefix=f"{prefix}.position_embedding", weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||||
|
|
||||||
|
patch_embeds = self.patch_embedding(pixel_values)
|
||||||
|
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
max_nb_patches_h, max_nb_patches_w = (
|
||||||
|
max_im_h // self.patch_size,
|
||||||
|
max_im_w // self.patch_size,
|
||||||
|
)
|
||||||
|
boundaries = torch.arange(
|
||||||
|
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||||
|
)
|
||||||
|
position_ids = torch.full(
|
||||||
|
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||||
|
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||||
|
nb_patches_w = p_attn_mask[0].sum()
|
||||||
|
|
||||||
|
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||||
|
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||||
|
|
||||||
|
bucket_coords_h = torch.bucketize(
|
||||||
|
fractional_coords_h, boundaries, right=True
|
||||||
|
)
|
||||||
|
bucket_coords_w = torch.bucketize(
|
||||||
|
fractional_coords_w, boundaries, right=True
|
||||||
|
)
|
||||||
|
|
||||||
|
pos_ids = (
|
||||||
|
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||||
|
).flatten()
|
||||||
|
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||||
|
|
||||||
|
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||||
|
embeddings = embeddings + self.position_embedding(position_ids)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionAttention(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_size = self.embed_dim // self.num_heads
|
||||||
|
if self.head_size * self.num_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||||
|
f" {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.scale = self.head_size**-0.5
|
||||||
|
self.dropout = config.attention_dropout
|
||||||
|
|
||||||
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||||
|
config,
|
||||||
|
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||||
|
dim=0,
|
||||||
|
weights=weights,
|
||||||
|
bias=True,
|
||||||
|
)
|
||||||
|
self.out_proj = TensorParallelRowLinear.load(
|
||||||
|
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.is_causal = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
qkv = self.qkv(hidden_states)
|
||||||
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
[
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
self.head_size * self.num_heads,
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
batch_size, q_len, self.num_heads, self.head_size
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
k_v_seq_len = key_states.shape[-2]
|
||||||
|
attn_weights = (
|
||||||
|
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||||
|
f" {attn_weights.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||||
|
raise ValueError(
|
||||||
|
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights + attention_mask
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(
|
||||||
|
attn_weights, dim=-1, dtype=torch.float32
|
||||||
|
).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||||
|
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.activation_fn = ACT2FN[config.hidden_act]
|
||||||
|
self.fc1 = TensorParallelColumnLinear.load(
|
||||||
|
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
self.fc2 = TensorParallelRowLinear.load(
|
||||||
|
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
hidden_states = self.activation_fn(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3EncoderLayer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.self_attn = Idefics3VisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||||
|
)
|
||||||
|
self.mlp = Idefics3VisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3Encoder(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layers = nn.ModuleList(
|
||||||
|
[
|
||||||
|
Idefics3EncoderLayer(
|
||||||
|
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||||
|
)
|
||||||
|
for i in range(config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ignore copy
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs_embeds,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
):
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
for encoder_layer in self.layers:
|
||||||
|
hidden_states = encoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3VisionTransformer(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.embeddings = Idefics3VisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.encoder = Idefics3Encoder(
|
||||||
|
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||||
|
)
|
||||||
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values,
|
||||||
|
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
):
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
patch_size = self.config.patch_size
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // patch_size,
|
||||||
|
pixel_values.size(3) // patch_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
patch_attention_mask = patch_attention_mask.to(
|
||||||
|
dtype=torch.bool, device=pixel_values.device
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.embeddings(
|
||||||
|
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||||
|
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||||
|
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||||
|
if not torch.any(~patch_attention_mask):
|
||||||
|
patch_attention_mask = None
|
||||||
|
else:
|
||||||
|
patch_attention_mask = _prepare_4d_attention_mask(
|
||||||
|
patch_attention_mask, hidden_states.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_outputs = self.encoder(
|
||||||
|
inputs_embeds=hidden_states,
|
||||||
|
attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
last_hidden_state = encoder_outputs
|
||||||
|
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||||
|
|
||||||
|
return last_hidden_state
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3SimpleMLP(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
input_size = config.vision_config.hidden_size * (config.scale_factor**2)
|
||||||
|
output_size = config.text_config.hidden_size
|
||||||
|
proj = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.modality_projection.proj.weight"),
|
||||||
|
requires_grad=False,
|
||||||
|
).to(weights.dtype)
|
||||||
|
self.proj = nn.Linear(input_size, output_size, bias=False)
|
||||||
|
self.proj.weight = proj
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.proj(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3Connector(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
self.modality_projection = Idefics3SimpleMLP(prefix, config, weights)
|
||||||
|
self.scale_factor = config.scale_factor
|
||||||
|
|
||||||
|
def pixel_shuffle(self, x, scale_factor=2):
|
||||||
|
bsz, seq, embed_dim = x.size()
|
||||||
|
height = width = int(seq**0.5)
|
||||||
|
x = x.view(bsz, height, width, embed_dim)
|
||||||
|
x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(
|
||||||
|
bsz,
|
||||||
|
int(width / scale_factor),
|
||||||
|
int(height / scale_factor),
|
||||||
|
embed_dim * (scale_factor**2),
|
||||||
|
)
|
||||||
|
x = x.permute(0, 2, 1, 3)
|
||||||
|
x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, image_hidden_states):
|
||||||
|
image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
|
||||||
|
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||||
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Idefics3ForConditionalGeneration(nn.Module):
|
||||||
|
def __init__(self, prefix, config, weights):
|
||||||
|
super().__init__()
|
||||||
|
config.vision_config.quantize = None
|
||||||
|
config.vision_config.speculator = config.speculator
|
||||||
|
config.text_config.quantize = config.quantize
|
||||||
|
config.text_config.speculator = config.speculator
|
||||||
|
# set tie_word_embeddings to True to load `.embed_tokens.weight` instead of `.lm_head.weight`
|
||||||
|
# since Idefics3 uses the `embed_tokens` for the final prediction
|
||||||
|
# config.text_config.tie_word_embeddings = True
|
||||||
|
|
||||||
|
vision_config = config.vision_config
|
||||||
|
self.text_model = load_text_model(
|
||||||
|
prefix="model" if not prefix else f"{prefix}.model",
|
||||||
|
config=config.text_config,
|
||||||
|
weights=weights,
|
||||||
|
name="text_model",
|
||||||
|
)
|
||||||
|
self.dtype = weights.dtype
|
||||||
|
|
||||||
|
# The vision and connector models are not quantized.
|
||||||
|
with weights.use_loader(DefaultWeightsLoader(UnquantizedWeight)):
|
||||||
|
self.vision_model = Idefics3VisionTransformer(
|
||||||
|
prefix=(
|
||||||
|
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||||
|
),
|
||||||
|
config=vision_config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
config.quantize = None
|
||||||
|
self.connector = Idefics3Connector(
|
||||||
|
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.image_token_id = config.image_token_id
|
||||||
|
self.pad_token_id = (
|
||||||
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _merge_input_ids_with_image_features(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
inputs_embeds: torch.Tensor,
|
||||||
|
image_features: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||||
|
# mask = input_ids == self.config.image_token_index
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
# Let's pray we have enabled enough slots !
|
||||||
|
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||||
|
return inputs_embeds
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||||
|
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
slots: torch.Tensor,
|
||||||
|
seqlen: Seqlen,
|
||||||
|
max_s: int,
|
||||||
|
prefill_cache_indices: Optional[torch.Tensor],
|
||||||
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values: torch.FloatTensor = None,
|
||||||
|
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
# Unused here
|
||||||
|
image_sizes: Optional[torch.Tensor] = None,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
cross_attention_states: Optional[torch.Tensor] = None,
|
||||||
|
image_indices=None,
|
||||||
|
):
|
||||||
|
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||||
|
all_states = []
|
||||||
|
all_pixel_values = pixel_values
|
||||||
|
all_pixel_mask = pixel_attention_mask
|
||||||
|
for i in range(batch_size):
|
||||||
|
pixel_values = all_pixel_values.to(
|
||||||
|
dtype=self.dtype
|
||||||
|
) # fp16 compatibility
|
||||||
|
pixel_values = pixel_values[i : i + 1]
|
||||||
|
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||||
|
|
||||||
|
# Remove padding images - padding images are full 0.
|
||||||
|
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||||
|
real_images_inds = (pixel_values == 0.0).sum(
|
||||||
|
dim=(-1, -2, -3)
|
||||||
|
) != nb_values_per_image
|
||||||
|
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||||
|
# Handle the vision attention mask
|
||||||
|
if pixel_attention_mask is None:
|
||||||
|
pixel_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
pixel_values.size(0),
|
||||||
|
pixel_values.size(2),
|
||||||
|
pixel_values.size(3),
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Remove padding images from the mask/pP p
|
||||||
|
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||||
|
pixel_attention_mask = pixel_attention_mask.view(
|
||||||
|
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||||
|
)
|
||||||
|
pixel_attention_mask = pixel_attention_mask[
|
||||||
|
real_images_inds
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
|
patch_size = self.config.vision_config.patch_size
|
||||||
|
patches_subgrid = pixel_attention_mask.unfold(
|
||||||
|
dimension=1, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patches_subgrid = patches_subgrid.unfold(
|
||||||
|
dimension=2, size=patch_size, step=patch_size
|
||||||
|
)
|
||||||
|
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||||
|
|
||||||
|
# Get sequence from the vision encoder
|
||||||
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values,
|
||||||
|
patch_attention_mask=patch_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Modality projection & resampling
|
||||||
|
image_hidden_states = self.connector(
|
||||||
|
image_hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_states.append(image_hidden_states)
|
||||||
|
image_hidden_states = torch.stack(all_states, dim=0)
|
||||||
|
|
||||||
|
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||||
|
input_ids, inputs_embeds, image_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = self.text_model.model(
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
position_ids=position_ids,
|
||||||
|
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
block_tables=block_tables,
|
||||||
|
slots=slots,
|
||||||
|
seqlen=seqlen,
|
||||||
|
max_s=max_s,
|
||||||
|
true_max_s=max_s,
|
||||||
|
prefill_cache_indices=None,
|
||||||
|
adapter_data=adapter_data,
|
||||||
|
)
|
||||||
|
if lm_head_indices is not None:
|
||||||
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||||
|
return logits, speculative_logits
|
@ -4,7 +4,7 @@ def load_text_model(prefix, config, weights, name=None):
|
|||||||
FlashLlamaForCausalLM,
|
FlashLlamaForCausalLM,
|
||||||
)
|
)
|
||||||
|
|
||||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
return FlashLlamaForCausalLM(prefix, config, weights, name=name)
|
||||||
elif config.model_type == "mistral":
|
elif config.model_type == "mistral":
|
||||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||||
FlashMistralForCausalLM,
|
FlashMistralForCausalLM,
|
||||||
|
@ -1288,7 +1288,7 @@ class FlashCausalLM(Model):
|
|||||||
weights_loader=weights_loader,
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
prefix = ""
|
prefix = None
|
||||||
model = model_class(prefix, config, weights)
|
model = model_class(prefix, config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from text_generation_server.models.flash_causal_lm import (
|
|||||||
FlashCausalLM,
|
FlashCausalLM,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
from text_generation_server.models.globals import PREFIX_CACHING, ATTENTION
|
||||||
|
from loguru import logger
|
||||||
from text_generation_server.utils.log import log_master
|
from text_generation_server.utils.log import log_master
|
||||||
from transformers import AutoProcessor
|
from transformers import AutoProcessor
|
||||||
from text_generation_server.layers.attention import Seqlen
|
from text_generation_server.layers.attention import Seqlen
|
||||||
@ -23,6 +24,40 @@ tracer = trace.get_tracer(__name__)
|
|||||||
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
IDEFICS2_FAKE_TOKEN = "<fake_token_around_image>"
|
||||||
IDEFICS2_IMAGE_TOKEN = "<image>"
|
IDEFICS2_IMAGE_TOKEN = "<image>"
|
||||||
|
|
||||||
|
IDEFICS3_IMAGE_TOKEN = "<image>"
|
||||||
|
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
|
||||||
|
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"
|
||||||
|
|
||||||
|
|
||||||
|
# copied from: https://github.com/huggingface/transformers/blob/02ed609285c2448b3b54c31e362f2c389fa952ab/src/transformers/models/idefics3/processing_idefics3.py#L44-L60
|
||||||
|
def _prompt_split_image(
|
||||||
|
*,
|
||||||
|
image_seq_len: int,
|
||||||
|
image_rows: int,
|
||||||
|
image_cols: int,
|
||||||
|
fake_token_around_image: str,
|
||||||
|
image_token: str,
|
||||||
|
global_img_token: str,
|
||||||
|
):
|
||||||
|
"""Prompt with expanded image tokens for when the image is split into patches."""
|
||||||
|
text_split_images = ""
|
||||||
|
for n_h in range(image_rows):
|
||||||
|
for n_w in range(image_cols):
|
||||||
|
text_split_images += (
|
||||||
|
f"{fake_token_around_image}"
|
||||||
|
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
)
|
||||||
|
text_split_images += "\n"
|
||||||
|
|
||||||
|
text_split_images += (
|
||||||
|
f"\n{fake_token_around_image}"
|
||||||
|
+ f"{global_img_token}"
|
||||||
|
+ f"{image_token}" * image_seq_len
|
||||||
|
+ f"{fake_token_around_image}"
|
||||||
|
)
|
||||||
|
return text_split_images
|
||||||
|
|
||||||
|
|
||||||
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
||||||
"""
|
"""
|
||||||
@ -54,10 +89,26 @@ def image_text_replacement(processor, image_input, config, image_id: int) -> str
|
|||||||
if processor.image_processor.do_image_splitting:
|
if processor.image_processor.do_image_splitting:
|
||||||
image_str *= 5
|
image_str *= 5
|
||||||
return image_str
|
return image_str
|
||||||
|
if config.model_type == "idefics3":
|
||||||
|
# TODO: implement this in a more general way
|
||||||
|
n_rows = image_input["rows"][0][image_id]
|
||||||
|
n_cols = image_input["cols"][0][image_id]
|
||||||
|
image_seq_len = int(
|
||||||
|
((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
|
||||||
|
/ (config.scale_factor**2)
|
||||||
|
)
|
||||||
|
image_str = _prompt_split_image(
|
||||||
|
image_seq_len=image_seq_len,
|
||||||
|
image_rows=n_rows,
|
||||||
|
image_cols=n_cols,
|
||||||
|
fake_token_around_image=IDEFICS3_FAKE_IMAGE_TOKEN,
|
||||||
|
image_token=IDEFICS3_IMAGE_TOKEN,
|
||||||
|
global_img_token=IDEFICS3_GLOBAL_IMG_TOKEN,
|
||||||
|
)
|
||||||
|
return image_str
|
||||||
elif config.model_type == "llava_next":
|
elif config.model_type == "llava_next":
|
||||||
height, width = image_input["image_sizes"][image_id]
|
height, width = image_input["image_sizes"][image_id]
|
||||||
num_features = get_number_of_features(height, width, config)
|
num_features = get_number_of_features(height, width, config)
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
log_master(
|
log_master(
|
||||||
logger.info,
|
logger.info,
|
||||||
@ -194,12 +245,21 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
|
|
||||||
if images:
|
if images:
|
||||||
image_inputs = processor.image_processor(images, return_tensors="pt")
|
kwargs = {}
|
||||||
|
if (
|
||||||
|
hasattr(processor, "image_processor_class")
|
||||||
|
and processor.image_processor_class == "Idefics3ImageProcessor"
|
||||||
|
):
|
||||||
|
kwargs["return_row_col_info"] = True
|
||||||
|
|
||||||
|
image_inputs = processor.image_processor(
|
||||||
|
images, return_tensors="pt", **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
image_inputs = None
|
image_inputs = None
|
||||||
|
|
||||||
batch_inputs = []
|
batch_tokenized_inputs = []
|
||||||
max_truncation = 0
|
max_length = 0
|
||||||
image_id = 0
|
image_id = 0
|
||||||
for r in requests:
|
for r in requests:
|
||||||
full_text = ""
|
full_text = ""
|
||||||
@ -214,16 +274,14 @@ class VlmCausalLMBatch(FlashCausalLMBatch):
|
|||||||
image_id += 1
|
image_id += 1
|
||||||
|
|
||||||
full_text = image_text_replacement_fixup(config, full_text)
|
full_text = image_text_replacement_fixup(config, full_text)
|
||||||
|
input_ids = tokenizer(
|
||||||
batch_inputs.append(full_text)
|
full_text,
|
||||||
max_truncation = max(max_truncation, r.truncate)
|
truncation=True,
|
||||||
|
max_length=r.truncate,
|
||||||
batch_tokenized_inputs = tokenizer(
|
add_special_tokens=r.add_special_tokens,
|
||||||
batch_inputs,
|
)["input_ids"]
|
||||||
truncation=True,
|
max_length = max(max_length, len(input_ids))
|
||||||
max_length=max_truncation,
|
batch_tokenized_inputs.append(input_ids)
|
||||||
add_special_tokens=not config.model_type == "paligemma",
|
|
||||||
)["input_ids"]
|
|
||||||
|
|
||||||
return batch_tokenized_inputs, image_inputs
|
return batch_tokenized_inputs, image_inputs
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user