mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-09-12 04:44:52 +00:00
fix: revert non snapshot changes
This commit is contained in:
parent
68854d11ef
commit
b5f61e92b5
@ -26,6 +26,3 @@ install-selective-scan: install-causal-conv1d build-selective-scan
|
||||
cd mamba && pip install .
|
||||
|
||||
build-all: build-causal-conv1d build-selective-scan
|
||||
|
||||
install-ssm: install-causal-conv1d install-selective-scan
|
||||
@echo "Selective scan model installed"
|
||||
|
@ -609,16 +609,15 @@ class CausalLM(Model):
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
device_map = (
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
@ -629,11 +628,6 @@ class CausalLM(Model):
|
||||
):
|
||||
model = model.cuda()
|
||||
|
||||
# if device_map is "auto", it's unclear which device the model is on
|
||||
# therefore, we need to get the device the model is on after loading
|
||||
if device_map is not None:
|
||||
device = next(model.parameters()).device
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
|
@ -639,28 +639,21 @@ class Seq2SeqLM(Model):
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
device_map = (
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
)
|
||||
|
||||
model = AutoModelForSeq2SeqLM.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=device_map,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
# if device_map is "auto", it's unclear which device the model is on
|
||||
# therefore, we need to get the device the model is on after loading
|
||||
if device_map is not None:
|
||||
device = next(model.parameters()).device
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
|
Loading…
Reference in New Issue
Block a user