This commit is contained in:
OlivierDehaene 2023-04-06 17:27:32 +02:00
parent 7816a47697
commit 26fc232afb
5 changed files with 47 additions and 23 deletions

View File

@ -575,6 +575,14 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, config, process_group=None):
super().__init__()
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else:
self.world_size = 1
self.rank = 0
self.model = FlashLlamaModel(config, process_group)
if self.model.tp_embeddings:

View File

@ -624,13 +624,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
def __init__(self, config):
def __init__(self, config, process_group=None):
super().__init__(config)
if config.tp_parallel:
process_group = torch.distributed.distributed_c10d._get_default_group()
self.process_group = process_group
if self.process_group is not None:
self.world_size = self.process_group.size()
self.rank = self.process_group.rank()
else:
process_group = None
self.world_size = 1
self.rank = 0
self.gpt_neox = FlashGPTNeoXModel(config, process_group)

View File

@ -43,7 +43,8 @@ class FlashLlama(FlashCausalLM):
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
)
# We do not use from_pretrained as we modified the model internal module layout
@ -57,12 +58,7 @@ class FlashLlama(FlashCausalLM):
with init_empty_weights():
model = FlashLlamaForCausalLM(config)
self.load_weights(
model,
filenames,
device,
dtype
)
self.load_weights(model, filenames, device, dtype)
self.model = model.eval()
super(FlashCausalLM, self).__init__(
@ -163,7 +159,8 @@ class FlashLlamaSharded(FlashLlama):
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
)
torch.distributed.barrier(group=self.process_group)

View File

@ -49,14 +49,15 @@ class FlashNeoXSharded(FlashNeoX):
)
config = AutoConfig.from_pretrained(
model_id, revision=revision, tp_parallel=True
model_id,
revision=revision,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashGPTNeoXForCausalLM(config)
model = FlashGPTNeoXForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(

View File

@ -39,11 +39,18 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.ClearCacheResponse()
async def Prefill(self, request, context):
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as prefill_prof:
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
)
generations, next_batch = self.model.generate_token(batch)
prefill_prof.export_chrome_trace("prefill.json")
self.cache.set(next_batch)
return generate_pb2.PrefillResponse(
@ -62,12 +69,20 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise ValueError(f"Batch ID {batch_pb.id} not found in cache.")
batches.append(batch)
from torch.profiler import profile, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]
) as decode_prof:
if len(batches) > 1:
batch = self.model.batch_type.concatenate(batches)
else:
batch = batches[0]
generations, next_batch = self.model.generate_token(batch)
decode_prof.export_chrome_trace("decode.json")
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(