Add AutoCausalLM

Currently `BLOOMSharded` is a subclass of `CausalLM`, while it skips `CausalLM`'s constructor. This is a supprising behavior that we might want to avoid.

This PR extract `CausalLM`'s constructor to `AutoCausalLM` to detect settings from `model_id`, so that we don't have to skip `CausalLM`'s constructor.
This commit is contained in:
Yang, Bo 2023-07-12 01:07:10 +00:00
parent b4024edd45
commit 9bb64c92a9
10 changed files with 72 additions and 70 deletions

View File

@ -5,12 +5,12 @@ from copy import copy
from transformers import AutoTokenizer
from text_generation_server.pb import generate_pb2
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
from text_generation_server.models.causal_lm import AutoCausalLM, CausalLMBatch
@pytest.fixture(scope="session")
def default_causal_lm():
return CausalLM("gpt2")
return AutoCausalLM("gpt2")
@pytest.fixture(scope="session")

View File

@ -7,7 +7,7 @@ from transformers.models.auto import modeling_auto
from typing import Optional
from text_generation_server.models.model import Model
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.causal_lm import CausalLM, AutoCausalLM
from text_generation_server.models.flash_causal_lm import FlashCausalLM
from text_generation_server.models.bloom import BLOOMSharded
from text_generation_server.models.mpt import MPTSharded
@ -33,6 +33,7 @@ __all__ = [
"Model",
"BLOOMSharded",
"CausalLM",
"AutoCausalLM",
"FlashCausalLM",
"GalacticaSharded",
"Seq2SeqLM",
@ -202,7 +203,7 @@ def get_model(
trust_remote_code=trust_remote_code,
)
else:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -222,7 +223,7 @@ def get_model(
elif sharded:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama"))
else:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -292,7 +293,7 @@ def get_model(
)
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,
@ -311,7 +312,7 @@ def get_model(
auto_map = config_dict.get("auto_map", None)
if trust_remote_code and auto_map is not None:
if "AutoModelForCausalLM" in auto_map.keys():
return CausalLM(
return AutoCausalLM(
model_id,
revision,
quantize=quantize,

View File

@ -80,7 +80,7 @@ class BLOOMSharded(CausalLM):
model = BloomForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -449,62 +449,6 @@ class CausalLMBatch(Batch):
class CausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property
def batch_type(self) -> Type[CausalLMBatch]:
return CausalLMBatch
@ -676,3 +620,60 @@ class CausalLM(Model):
batch.past_key_values = past
return generations, batch
class AutoCausalLM(CausalLM):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None,
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)

View File

@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM):
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -60,7 +60,7 @@ class GPTNeoxSharded(CausalLM):
model = GPTNeoxForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -83,7 +83,7 @@ class MPTSharded(CausalLM):
model = MPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,

View File

@ -58,7 +58,7 @@ class OPTSharded(CausalLM):
model = OPTForCausalLM(config, weights)
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -55,7 +55,7 @@ class RW(CausalLM):
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,

View File

@ -60,7 +60,7 @@ class SantaCoder(CausalLM):
trust_remote_code=trust_remote_code,
).to(device)
super(CausalLM, self).__init__(
super().__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,