mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
black
This commit is contained in:
parent
1c5365ce85
commit
a4782da22b
@ -58,7 +58,7 @@ def _insert_split_marker(m: re.Match):
|
|||||||
str - the text with the split token added
|
str - the text with the split token added
|
||||||
"""
|
"""
|
||||||
start_token, _, sequence, end_token = m.groups()
|
start_token, _, sequence, end_token = m.groups()
|
||||||
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
||||||
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
||||||
|
|
||||||
|
|
||||||
@ -75,13 +75,14 @@ def escape_custom_split_sequence(text):
|
|||||||
"""
|
"""
|
||||||
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
||||||
|
|
||||||
|
|
||||||
# END CREDIT
|
# END CREDIT
|
||||||
|
|
||||||
|
|
||||||
class GalacticaCausalLMBatch(CausalLMBatch):
|
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pb(
|
def from_pb(
|
||||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||||
) -> "CausalLMBatch":
|
) -> "CausalLMBatch":
|
||||||
inputs = []
|
inputs = []
|
||||||
next_token_choosers = []
|
next_token_choosers = []
|
||||||
@ -149,9 +150,7 @@ class GalacticaSharded(Galactica):
|
|||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(
|
config = AutoConfig.from_pretrained(model_name, tp_parallel=True)
|
||||||
model_name, tp_parallel=True
|
|
||||||
)
|
|
||||||
tokenizer.pad_token_id = config.pad_token_id
|
tokenizer.pad_token_id = config.pad_token_id
|
||||||
|
|
||||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||||
@ -192,17 +191,17 @@ class GalacticaSharded(Galactica):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
quantize: bool,
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
):
|
):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(
|
with safe_open(
|
||||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
) as f:
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
@ -267,9 +266,9 @@ class GalacticaSharded(Galactica):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
type(module)
|
type(module)
|
||||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||||
and param_name == "weight"
|
and param_name == "weight"
|
||||||
):
|
):
|
||||||
tensor = Int8Params(
|
tensor = Int8Params(
|
||||||
tensor.transpose(1, 0),
|
tensor.transpose(1, 0),
|
||||||
|
Loading…
Reference in New Issue
Block a user