mirror of
https://github.com/huggingface/text-generation-inference.git
synced 2025-06-19 15:52:08 +00:00
black
This commit is contained in:
parent
1c5365ce85
commit
a4782da22b
@ -58,7 +58,7 @@ def _insert_split_marker(m: re.Match):
|
||||
str - the text with the split token added
|
||||
"""
|
||||
start_token, _, sequence, end_token = m.groups()
|
||||
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
||||
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
|
||||
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
|
||||
|
||||
|
||||
@ -75,13 +75,14 @@ def escape_custom_split_sequence(text):
|
||||
"""
|
||||
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
|
||||
|
||||
|
||||
# END CREDIT
|
||||
|
||||
|
||||
class GalacticaCausalLMBatch(CausalLMBatch):
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
|
||||
) -> "CausalLMBatch":
|
||||
inputs = []
|
||||
next_token_choosers = []
|
||||
@ -149,9 +150,7 @@ class GalacticaSharded(Galactica):
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name, tp_parallel=True
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_name, tp_parallel=True)
|
||||
tokenizer.pad_token_id = config.pad_token_id
|
||||
|
||||
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
|
||||
@ -192,17 +191,17 @@ class GalacticaSharded(Galactica):
|
||||
|
||||
@staticmethod
|
||||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
@ -267,9 +266,9 @@ class GalacticaSharded(Galactica):
|
||||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor.transpose(1, 0),
|
||||
|
Loading…
Reference in New Issue
Block a user