Working loading state.

This commit is contained in:
Nicolas Patry 2024-09-18 17:01:36 +02:00
parent 7efcb5e0ed
commit 907906466a
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
3 changed files with 1266 additions and 9 deletions

View File

@ -308,6 +308,12 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b", "url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
"multimodal": True, "multimodal": True,
} }
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/xxx/xx",
"multimodal": True,
}
__GLOBALS = locals() __GLOBALS = locals()
@ -1095,6 +1101,18 @@ def get_model(
) )
else: else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
if model_type == MLLAMA:
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
if model_type == IDEFICS2: if model_type == IDEFICS2:
if FLASH_ATTENTION: if FLASH_ATTENTION:
return VlmCausalLM( return VlmCausalLM(

File diff suppressed because it is too large Load Diff

View File

@ -4,14 +4,13 @@ import torch.distributed
from typing import Optional from typing import Optional
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig from transformers import AutoConfig, AutoProcessor, AutoTokenizer
from text_generation_server.models.custom_modeling.idefics_processing import (
IdeficsProcessor,
)
from transformers import LlamaTokenizerFast
from text_generation_server.models.custom_modeling.idefics_modeling import ( from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text, IdeficsForVisionText2Text,
) )
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
@ -53,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
dtype = torch.float32 if dtype is None else dtype dtype = torch.float32 if dtype is None else dtype
self.device, self.dtype = device, dtype self.device, self.dtype = device, dtype
config = IdeficsConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -62,14 +61,14 @@ class IDEFICSSharded(IdeficsCausalLM):
config.speculator = speculator config.speculator = speculator
config.vision_config.quantize = quantize config.vision_config.quantize = quantize
tokenizer = LlamaTokenizerFast.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
truncation_side="left", truncation_side="left",
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
self.processor = IdeficsProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
padding_side="left", padding_side="left",
@ -90,7 +89,14 @@ class IDEFICSSharded(IdeficsCausalLM):
weights_loader=weights_loader, weights_loader=weights_loader,
) )
if config.model_type == "idefics":
model = IdeficsForVisionText2Text(config, weights) model = IdeficsForVisionText2Text(config, weights)
elif config.model_type == "mllama":
model = MllamaForConditionalGeneration(
prefix="", config=config, weights=weights
)
else:
raise RuntimeError(f"Unsupported model type {config.model_type}")
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__( super(IdeficsCausalLM, self).__init__(