From 3a79fbc63eb674e4e132a973237a19a7acc10de3 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 6 Dec 2023 16:41:04 +0000 Subject: [PATCH] Updated. --- load_tests/common.js | 21 ++++++++++++------- load_tests/tgi.js | 4 ++-- .../models/flash_causal_lm.py | 10 +++++++++ server/text_generation_server/utils/medusa.py | 2 +- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/load_tests/common.js b/load_tests/common.js index be812e9b..5d71abea 100644 --- a/load_tests/common.js +++ b/load_tests/common.js @@ -7,7 +7,9 @@ const seed = 0; const host = __ENV.HOST || '127.0.0.1:8000'; const timePerToken = new Trend('time_per_token', true); -const throughput = new Counter('tokens_per_s'); +const tokens = new Counter('tokens'); +const new_tokens = new Counter('new_tokens'); +const input_tokens = new Counter('input_tokens'); randomSeed(seed); // const shareGPT = JSON.parse(open("ShareGPT_V3_unfiltered_cleaned_split.json")) @@ -19,7 +21,7 @@ export function get_options(reference_latency_ms){ thresholds: { http_req_failed: ['rate==0'], time_per_token: [{ - threshold: `p(50)<${3 * reference_latency_ms}`, + threshold: `p(50)<${5 * reference_latency_ms}`, abortOnFail: true, delayAbortEval: '10s' }], @@ -28,7 +30,7 @@ export function get_options(reference_latency_ms){ load_test: { executor: 'constant-arrival-rate', duration: '60s', - preAllocatedVUs: 100, + preAllocatedVUs: 10, rate: 10, timeUnit: '1s', }, @@ -48,17 +50,22 @@ export function run(host, generate_payload, max_new_tokens) { return; } + check(res, { 'Post status is 200': (r) => res.status === 200, }); - const n_tokens = max_new_tokens; - const timings = res.timings.duration; + const duration = res.timings.duration; if (res.status === 200) { - const latency_ms_per_token = timings / n_tokens; + const body = res.json(); + const n_tokens = body.details.tokens.length; + const latency_ms_per_token = duration / n_tokens; timePerToken.add(latency_ms_per_token); const latency_in_s = latency_ms_per_token / 1000; const individual_throughput = 1 / latency_in_s; - throughput.add(individual_throughput); + const _input_tokens = body.details.prefill.length; + tokens.add(n_tokens + _input_tokens); + input_tokens.add(_input_tokens); + new_tokens.add(n_tokens); } } diff --git a/load_tests/tgi.js b/load_tests/tgi.js index 93a0e278..1db4ab6f 100644 --- a/load_tests/tgi.js +++ b/load_tests/tgi.js @@ -1,13 +1,13 @@ import { get_options, run } from "./common.js"; -const reference_latency_ms = 30; +const reference_latency_ms = 70; const host = __ENV.HOST || '127.0.0.1:8000'; const max_new_tokens = 50; function generate_payload(gpt){ const input = gpt["conversations"][0]["value"]; - return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "temperature" : 0.5}} + return {"inputs": input, "parameters": {"max_new_tokens": max_new_tokens, "decoder_input_details": true}} } export const options = get_options(reference_latency_ms); diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 63e024ac..260f5e68 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -820,10 +820,20 @@ class FlashCausalLM(Model): else: next_token_logits = out + # import datetime + # from loguru import logger + # start = datetime.datetime.now() next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits ) + # took = datetime.datetime.now() - start + # logger.info(f"Next token chooser {batch.all_input_ids_tensor.shape} took {took}") + # if batch.all_input_ids_tensor.shape[1] < 2000 and took > datetime.timedelta(milliseconds=5): + # next_input_ids, next_token_logprobs, logprobs, accepted_ids, speculative_ids = batch.next_token_chooser( + # batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits, get_speculate(), batch.speculative_ids, speculative_logits, verbose=True + # ) + # import ipdb;ipdb.set_trace() batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( batch.top_n_tokens, batch.top_n_tokens_tensor, logprobs diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py index afa9bfc4..029de122 100644 --- a/server/text_generation_server/utils/medusa.py +++ b/server/text_generation_server/utils/medusa.py @@ -33,7 +33,7 @@ class MedusaModel(torch.nn.Module): def forward(self, x): logits = self.lm_head(x) - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) return logits, speculative_logits