diff --git a/server/text_generation/models/galactica.py b/server/text_generation/models/galactica.py index 07163121..81aac649 100644 --- a/server/text_generation/models/galactica.py +++ b/server/text_generation/models/galactica.py @@ -58,7 +58,7 @@ def _insert_split_marker(m: re.Match): str - the text with the split token added """ start_token, _, sequence, end_token = m.groups() - sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) + sequence = re.sub(r"(.)", rf"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" @@ -75,13 +75,14 @@ def escape_custom_split_sequence(text): """ return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) + # END CREDIT class GalacticaCausalLMBatch(CausalLMBatch): @classmethod def from_pb( - cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device + cls, pb: generate_pb2.Batch, tokenizer: AutoTokenizer, device: torch.device ) -> "CausalLMBatch": inputs = [] next_token_choosers = [] @@ -149,9 +150,7 @@ class GalacticaSharded(Galactica): tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") - config = AutoConfig.from_pretrained( - model_name, tp_parallel=True - ) + config = AutoConfig.from_pretrained(model_name, tp_parallel=True) tokenizer.pad_token_id = config.pad_token_id # The flag below controls whether to allow TF32 on matmul. This flag defaults to False @@ -192,17 +191,17 @@ class GalacticaSharded(Galactica): @staticmethod def load_weights( - model, - filenames: List[str], - quantize: bool, - device: torch.device, - rank: int, - world_size: int, + model, + filenames: List[str], + quantize: bool, + device: torch.device, + rank: int, + world_size: int, ): parameters = dict(model.named_parameters()) for file in filenames: with safe_open( - file, framework="pt", device=str(device) if not quantize else "cpu" + file, framework="pt", device=str(device) if not quantize else "cpu" ) as f: for name in f.keys(): module_name, param_name = name.rsplit(".", 1) @@ -267,9 +266,9 @@ class GalacticaSharded(Galactica): ) if ( - type(module) - in [TensorParallelRowLinear, TensorParallelColumnLinear] - and param_name == "weight" + type(module) + in [TensorParallelRowLinear, TensorParallelColumnLinear] + and param_name == "weight" ): tensor = Int8Params( tensor.transpose(1, 0),