This commit is contained in:
OlivierDehaene 2022-11-18 17:11:10 +01:00
parent 1c5365ce85
commit a4782da22b

View File

@ -58,7 +58,7 @@ def _insert_split_marker(m: re.Match):
str - the text with the split token added str - the text with the split token added
""" """
start_token, _, sequence, end_token = m.groups() start_token, _, sequence, end_token = m.groups()
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL)
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}"
@ -75,13 +75,14 @@ def escape_custom_split_sequence(text):
""" """
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) return CUSTOM_SEQ_RE.sub(_insert_split_marker, text)
# END CREDIT # END CREDIT
class GalacticaCausalLMBatch(CausalLMBatch): class GalacticaCausalLMBatch(CausalLMBatch):
@classmethod @classmethod
def from_pb( def from_pb(
cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device
) -> "CausalLMBatch": ) -> "CausalLMBatch":
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
@ -149,9 +150,7 @@ class GalacticaSharded(Galactica):
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(model_name, tp_parallel=True)
model_name, tp_parallel=True
)
tokenizer.pad_token_id = config.pad_token_id tokenizer.pad_token_id = config.pad_token_id
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False # The flag below controls whether to allow TF32 on matmul. This flag defaults to False
@ -192,17 +191,17 @@ class GalacticaSharded(Galactica):
@staticmethod @staticmethod
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool, quantize: bool,
device: torch.device, device: torch.device,
rank: int, rank: int,
world_size: int, world_size: int,
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open( with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu" file, framework="pt", device=str(device) if not quantize else "cpu"
) as f: ) as f:
for name in f.keys(): for name in f.keys():
module_name, param_name = name.rsplit(".", 1) module_name, param_name = name.rsplit(".", 1)
@ -267,9 +266,9 @@ class GalacticaSharded(Galactica):
) )
if ( if (
type(module) type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear] in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight" and param_name == "weight"
): ):
tensor = Int8Params( tensor = Int8Params(
tensor.transpose(1, 0), tensor.transpose(1, 0),