Working loading state.

This commit is contained in:
Nicolas Patry 2024-09-18 17:01:36 +02:00
parent 7efcb5e0ed
commit 907906466a
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 1266 additions and 9 deletions

View File

@ -308,6 +308,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True,
}
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/xxx/xx",
"multimodal": True,
}
__GLOBALS = locals()
@ -1095,6 +1101,18 @@ def get_model(
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2:
if FLASH_ATTENTION:
return VlmCausalLM(

File diff suppressed because it is too large Load Diff

View File

@ -4,14 +4,13 @@ import torch.distributed
from typing import Optional
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_processing import (
IdeficsProcessor,
)
from transformers import LlamaTokenizerFast
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import (
initialize_torch_distributed,
@ -53,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype
config = IdeficsConfig.from_pretrained(
config = AutoConfig.from_pretrained(
model_id,
revision=revision,
trust_remote_code=trust_remote_code,
@ -62,14 +61,14 @@ class IDEFICSSharded(IdeficsCausalLM):
config.speculator = speculator
config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = IdeficsProcessor.from_pretrained(
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
@ -90,7 +89,14 @@ class IDEFICSSharded(IdeficsCausalLM):
weights_loader=weights_loader,
)
model = IdeficsForVisionText2Text(config, weights)
if config.model_type == "idefics":
model = IdeficsForVisionText2Text(config, weights)
elif config.model_type == "mllama":
model = MllamaForConditionalGeneration(
prefix="", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__(