diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index bf193b4f..ee726faf 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -287,7 +287,7 @@ class CausalLMBatch(Batch): max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens if kv_tuple: - self.past_key_values = [tuple(layer) for layer in self.past_key_values] + self.past_key_values = tuple([tuple(layer) for layer in self.past_key_values]) self.requests = requests self.requests_idx_mapping = requests_idx_mapping @@ -374,7 +374,7 @@ class CausalLMBatch(Batch): # input_ids is always of shape [batch_size, 1] # We do not need to pad it if input_ids is None: - input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + input_ids = batch.input_ids.new_empty((total_batch_size, max_total_tokens)) # Copy to correct indices input_ids[start_index:end_index] = batch.input_ids @@ -522,6 +522,9 @@ class CausalLMBatch(Batch): else: past_key_values.append([padded_past_keys, padded_past_values]) + if kv_tuple: + past_key_values = tuple(past_key_values) + return cls( batch_id=batches[0].batch_id, requests=requests, @@ -893,11 +896,10 @@ class CausalLM(Model): batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) - next_tokens = torch.tensor(next_token_ids, dtype=torch.int64).to(self.device) if token_idx is None: - batch.input_ids[:, 0] = next_tokens[:, 0] + batch.input_ids[:, 0] = next_token_ids[:, 0] else: - batch.input_ids[:, token_idx] = next_tokens + batch.input_ids[:, token_idx] = next_token_ids # We finished all generations in the batch; there is no next batch if stopped: if self.hb_profer_started == True: diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index b7ab751b..61cde5be 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -153,6 +153,8 @@ def serve( data_type = torch.bfloat16 else: data_type = torch.float + if revision == "None": + revision = None try: model = get_model(model_id, revision=revision, dtype=data_type) except Exception: