diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 356ca668..315b6831 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -233,8 +233,9 @@ class FlashCausalLMBatch(Batch): stopping_criterias.append(stopping_criteria) top_n_tokens.append(r.top_n_tokens) - adapter_indices_list.append(torch.full((input_length,), r.adapter_index)) - adapter_set.add(r.adapter_index) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get(r.adapter_id, 0) + adapter_indices_list.append(torch.full((input_length,), adapter_index)) + adapter_set.add(adapter_index) # Paged attention # Remove one as the first token des not have a past @@ -498,7 +499,10 @@ class FlashCausalLMBatch(Batch): top_n_tokens.append(self.top_n_tokens[idx]) - adapter_set.add(self.requests[idx].adapter_index) + adapter_index = tgi_globals.ADAPTER_TO_INDEX.get( + self.requests[idx].adapter_id, 0 + ) + adapter_set.add(adapter_index) remaining_tokens = ( stopping_criteria.max_new_tokens - stopping_criteria.current_tokens diff --git a/server/text_generation_server/models/globals.py b/server/text_generation_server/models/globals.py index 11a9f030..ce57fd5c 100644 --- a/server/text_generation_server/models/globals.py +++ b/server/text_generation_server/models/globals.py @@ -25,3 +25,14 @@ MODEL_ID = None def set_model_id(model_id: str): global MODEL_ID MODEL_ID = model_id + + +# NOTE: eventually we should move this into the router and pass back the +# index in all cases. +global ADAPTER_TO_INDEX +ADAPTER_TO_INDEX = None + + +def set_adapter_to_index(adapter_to_index: dict): + global ADAPTER_TO_INDEX + ADAPTER_TO_INDEX = adapter_to_index diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index a343dc02..36524c82 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -29,7 +29,7 @@ except (ImportError, NotImplementedError): from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor -from text_generation_server.models.globals import set_model_id +from text_generation_server.models.globals import set_model_id, set_adapter_to_index from text_generation_server.utils.adapter import ( AdapterParameters, ) @@ -216,6 +216,7 @@ def serve( trust_remote_code: bool = False, ): unix_socket_template = "unix://{}-{}" + adapter_to_index = {} if sharded: server_urls = [ unix_socket_template.format(uds_path, rank) @@ -251,6 +252,7 @@ def serve( majority_sign_method=0, ) adapter_index = index + adapter_to_index[adapter_id] = adapter_index model.load_adapter( adapter_parameters, None, # adapter_source @@ -263,6 +265,7 @@ def serve( logger.exception("Error when initializing model") raise + set_adapter_to_index(adapter_to_index) server = aio.server( interceptors=[ ExceptionInterceptor(),