mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-04-24 16:32:12 +00:00
Working loading state.
This commit is contained in:
parent
7efcb5e0ed
commit
907906466a
@ -308,6 +308,12 @@ class ModelType(enum.Enum):
|
|||||||
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
"url": "https://huggingface.co/HuggingFaceM4/idefics-9b",
|
||||||
"multimodal": True,
|
"multimodal": True,
|
||||||
}
|
}
|
||||||
|
MLLAMA = {
|
||||||
|
"type": "mllama",
|
||||||
|
"name": "Mllama",
|
||||||
|
"url": "https://huggingface.co/xxx/xx",
|
||||||
|
"multimodal": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
__GLOBALS = locals()
|
__GLOBALS = locals()
|
||||||
@ -1095,6 +1101,18 @@ def get_model(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
if model_type == MLLAMA:
|
||||||
|
if FLASH_ATTENTION:
|
||||||
|
return IDEFICSSharded(
|
||||||
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
speculator=speculator,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Mllama"))
|
||||||
if model_type == IDEFICS2:
|
if model_type == IDEFICS2:
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return VlmCausalLM(
|
return VlmCausalLM(
|
||||||
|
1233
server/text_generation_server/models/custom_modeling/mllama.py
Normal file
1233
server/text_generation_server/models/custom_modeling/mllama.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -4,14 +4,13 @@ import torch.distributed
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
from transformers import AutoConfig, AutoProcessor, AutoTokenizer
|
||||||
from text_generation_server.models.custom_modeling.idefics_processing import (
|
|
||||||
IdeficsProcessor,
|
|
||||||
)
|
|
||||||
from transformers import LlamaTokenizerFast
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
IdeficsForVisionText2Text,
|
IdeficsForVisionText2Text,
|
||||||
)
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.mllama import (
|
||||||
|
MllamaForConditionalGeneration,
|
||||||
|
)
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLM
|
||||||
from text_generation_server.utils import (
|
from text_generation_server.utils import (
|
||||||
initialize_torch_distributed,
|
initialize_torch_distributed,
|
||||||
@ -53,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
dtype = torch.float32 if dtype is None else dtype
|
dtype = torch.float32 if dtype is None else dtype
|
||||||
self.device, self.dtype = device, dtype
|
self.device, self.dtype = device, dtype
|
||||||
|
|
||||||
config = IdeficsConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@ -62,14 +61,14 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
config.speculator = speculator
|
config.speculator = speculator
|
||||||
config.vision_config.quantize = quantize
|
config.vision_config.quantize = quantize
|
||||||
|
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
truncation_side="left",
|
truncation_side="left",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
self.processor = IdeficsProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
padding_side="left",
|
padding_side="left",
|
||||||
@ -90,7 +89,14 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||||||
weights_loader=weights_loader,
|
weights_loader=weights_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.model_type == "idefics":
|
||||||
model = IdeficsForVisionText2Text(config, weights)
|
model = IdeficsForVisionText2Text(config, weights)
|
||||||
|
elif config.model_type == "mllama":
|
||||||
|
model = MllamaForConditionalGeneration(
|
||||||
|
prefix="", config=config, weights=weights
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(IdeficsCausalLM, self).__init__(
|
super(IdeficsCausalLM, self).__init__(
|
||||||
|
Loading…
Reference in New Issue
Block a user