This commit is contained in:
OlivierDehaene 2022-11-18 17:11:10 +01:00
parent 1c5365ce85
commit a4782da22b

View File

@ -58,7 +58,7 @@ def _insert_split_marker(m: re.Match):
str - the text with the split token added
"""
start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
@ -75,13 +75,14 @@ def escape_custom_split_sequence(text):
"""
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
# END CREDIT
class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod
def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch":
inputs = []
next_token_choosers = []
@ -149,9 +150,7 @@ class GalacticaSharded(Galactica):
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained(
model_name, tp_parallel=True
)
config = AutoConfig.from_pretrained(model_name, tp_parallel=True)
tokenizer.pad_token_id = config.pad_token_id
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
@ -192,17 +191,17 @@ class GalacticaSharded(Galactica):
@staticmethod
def load_weights(
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
model,
filenames: List[str],
quantize: bool,
device: torch.device,
rank: int,
world_size: int,
):
parameters = dict(model.named_parameters())
for file in filenames:
with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys():
module_name, param_name = name.rsplit(".", 1)
@ -267,9 +266,9 @@ class GalacticaSharded(Galactica):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor.transpose(1, 0),