From 26fc232afb6fe99c5bf184eb4926d02ecdf22605 Mon Sep 17 00:00:00 2001 From: OlivierDehaene <23298448+OlivierDehaene@users.noreply.github.com> Date: Thu, 6 Apr 2023 17:27:32 +0200 Subject: [PATCH] fix tp --- .../custom_modeling/flash_llama_modeling.py | 8 +++++ .../custom_modeling/flash_neox_modeling.py | 11 ++++--- .../models/flash_llama.py | 13 +++----- .../models/flash_neox.py | 5 +-- server/text_generation_server/server.py | 33 ++++++++++++++----- 5 files changed, 47 insertions(+), 23 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 9c4a487b..0eb260c9 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -575,6 +575,14 @@ class FlashLlamaForCausalLM(torch.nn.Module): def __init__(self, config, process_group=None): super().__init__() + self.process_group = process_group + if self.process_group is not None: + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + else: + self.world_size = 1 + self.rank = 0 + self.model = FlashLlamaModel(config, process_group) if self.model.tp_embeddings: diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 5fb453b6..efbfa70b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -624,13 +624,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config): + def __init__(self, config, process_group=None): super().__init__(config) - if config.tp_parallel: - process_group = torch.distributed.distributed_c10d._get_default_group() + self.process_group = process_group + if self.process_group is not None: + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() else: - process_group = None + self.world_size = 1 + self.rank = 0 self.gpt_neox = FlashGPTNeoXModel(config, process_group) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 962eaa52..85039cdf 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -43,7 +43,8 @@ class FlashLlama(FlashCausalLM): ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, ) # We do not use from_pretrained as we modified the model internal module layout @@ -57,12 +58,7 @@ class FlashLlama(FlashCausalLM): with init_empty_weights(): model = FlashLlamaForCausalLM(config) - self.load_weights( - model, - filenames, - device, - dtype - ) + self.load_weights(model, filenames, device, dtype) self.model = model.eval() super(FlashCausalLM, self).__init__( @@ -163,7 +159,8 @@ class FlashLlamaSharded(FlashLlama): ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, ) torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index ed09eb65..ecf68442 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -49,14 +49,15 @@ class FlashNeoXSharded(FlashNeoX): ) config = AutoConfig.from_pretrained( - model_id, revision=revision, tp_parallel=True + model_id, + revision=revision, ) torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") with init_empty_weights(): - model = FlashGPTNeoXForCausalLM(config) + model = FlashGPTNeoXForCausalLM(config, self.process_group) torch.distributed.barrier(group=self.process_group) self.load_weights( diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 3e3789bf..929481cf 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -39,11 +39,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.ClearCacheResponse() async def Prefill(self, request, context): - batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.device - ) + from torch.profiler import profile, ProfilerActivity + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] + ) as prefill_prof: + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.device + ) + + generations, next_batch = self.model.generate_token(batch) + prefill_prof.export_chrome_trace("prefill.json") - generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.PrefillResponse( @@ -62,12 +69,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") batches.append(batch) - if len(batches) > 1: - batch = self.model.batch_type.concatenate(batches) - else: - batch = batches[0] + from torch.profiler import profile, ProfilerActivity + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA] + ) as decode_prof: + + if len(batches) > 1: + batch = self.model.batch_type.concatenate(batches) + else: + batch = batches[0] + + generations, next_batch = self.model.generate_token(batch) + decode_prof.export_chrome_trace("decode.json") - generations, next_batch = self.model.generate_token(batch) self.cache.set(next_batch) return generate_pb2.DecodeResponse(