diff --git a/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json new file mode 100644 index 00000000..1d2bec30 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_fused_kernel_mamba/test_fused_kernel_mamba_load.json @@ -0,0 +1,338 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 5089, + "logprob": null, + "text": "Test" + }, + { + "id": 2748, + "logprob": -9.7265625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -2.4746094, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -1.3857422, + "special": false, + "text": "\n" + }, + { + "id": 510, + "logprob": -2.703125, + "special": false, + "text": "The" + }, + { + "id": 806, + "logprob": -4.1992188, + "special": false, + "text": " first" + }, + { + "id": 2181, + "logprob": -2.703125, + "special": false, + "text": " thing" + }, + { + "id": 309, + "logprob": -1.4160156, + "special": false, + "text": " I" + }, + { + "id": 8344, + "logprob": -1.6171875, + "special": false, + "text": " noticed" + }, + { + "id": 369, + "logprob": -1.0039062, + "special": false, + "text": " was" + }, + { + "id": 326, + "logprob": -0.8823242, + "special": false, + "text": " that" + }, + { + "id": 253, + "logprob": -1.3173828, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nThe first thing I noticed was that the" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 5089, + "logprob": null, + "text": "Test" + }, + { + "id": 2748, + "logprob": -9.7265625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -2.4941406, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -1.3857422, + "special": false, + "text": "\n" + }, + { + "id": 510, + "logprob": -2.703125, + "special": false, + "text": "The" + }, + { + "id": 806, + "logprob": -4.1992188, + "special": false, + "text": " first" + }, + { + "id": 2181, + "logprob": -2.703125, + "special": false, + "text": " thing" + }, + { + "id": 309, + "logprob": -1.4160156, + "special": false, + "text": " I" + }, + { + "id": 8344, + "logprob": -1.6171875, + "special": false, + "text": " noticed" + }, + { + "id": 369, + "logprob": -1.0039062, + "special": false, + "text": " was" + }, + { + "id": 326, + "logprob": -0.8823242, + "special": false, + "text": " that" + }, + { + "id": 253, + "logprob": -1.3173828, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nThe first thing I noticed was that the" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 5089, + "logprob": null, + "text": "Test" + }, + { + "id": 2748, + "logprob": -9.7265625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -2.4941406, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -1.3857422, + "special": false, + "text": "\n" + }, + { + "id": 510, + "logprob": -2.703125, + "special": false, + "text": "The" + }, + { + "id": 806, + "logprob": -4.1992188, + "special": false, + "text": " first" + }, + { + "id": 2181, + "logprob": -2.703125, + "special": false, + "text": " thing" + }, + { + "id": 309, + "logprob": -1.4160156, + "special": false, + "text": " I" + }, + { + "id": 8344, + "logprob": -1.6171875, + "special": false, + "text": " noticed" + }, + { + "id": 369, + "logprob": -1.0039062, + "special": false, + "text": " was" + }, + { + "id": 326, + "logprob": -0.8823242, + "special": false, + "text": " that" + }, + { + "id": 253, + "logprob": -1.3173828, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nThe first thing I noticed was that the" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 5089, + "logprob": null, + "text": "Test" + }, + { + "id": 2748, + "logprob": -9.7265625, + "text": " request" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -2.4941406, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -1.3857422, + "special": false, + "text": "\n" + }, + { + "id": 510, + "logprob": -2.703125, + "special": false, + "text": "The" + }, + { + "id": 806, + "logprob": -4.1992188, + "special": false, + "text": " first" + }, + { + "id": 2181, + "logprob": -2.703125, + "special": false, + "text": " thing" + }, + { + "id": 309, + "logprob": -1.4160156, + "special": false, + "text": " I" + }, + { + "id": 8344, + "logprob": -1.6171875, + "special": false, + "text": " noticed" + }, + { + "id": 369, + "logprob": -1.0039062, + "special": false, + "text": " was" + }, + { + "id": 326, + "logprob": -0.8823242, + "special": false, + "text": " that" + }, + { + "id": 253, + "logprob": -1.3173828, + "special": false, + "text": " the" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nThe first thing I noticed was that the" + } +] diff --git a/integration-tests/models/test_fused_kernel_mamba.py b/integration-tests/models/test_fused_kernel_mamba.py index 98431298..327113b2 100644 --- a/integration-tests/models/test_fused_kernel_mamba.py +++ b/integration-tests/models/test_fused_kernel_mamba.py @@ -47,13 +47,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" assert response == response_snapshot -# TODO: Fix batching -# @pytest.mark.asyncio -# @pytest.mark.private -# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): -# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=2) +@pytest.mark.asyncio +@pytest.mark.private +async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): + responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) -# assert len(responses) == 4 -# assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) -# assert responses == response_snapshot + assert responses == response_snapshot diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index 410a254a..567927f6 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -121,71 +121,36 @@ class MambaBlock(nn.Module): conv_out = causal_conv1d_fn( x=conv_state_new, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation) conv_state = conv_state_new[:, :, 1:] - handle_batched = False - if handle_batched: - bsz, seqlen, dim = hidden_states.shape - # empty output tensor for the loop - output_tensor = torch.zeros( - (bsz, seqlen, dim), - device=hidden_states.device, - dtype=hidden_states.dtype - ) + bsz, seqlen, dim = hidden_states.shape + # empty output tensor for the loop + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) - for i in range(0, bsz): - x = conv_out[:,:,i] - z = _z[:, i, :] - x_db = self.x_proj(x) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj_no_bias(dt) - dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) - dA = torch.exp(dt * self.negA) - dB = dt * B.view(-1, B.size(-1)) - x_shape = (-1, x.size(-1), 1) - # ssm_state = (ssm_state * dA + dB * x.view(x_shape)) - ssm_state[i] = (ssm_state[i] * dA + dB * x)#.view(x_shape)) - c_shape = (C.size(0), C.size(1), -1) - out_mm_shape = (C.size(0), -1) - out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape) - # in-place ops - out.add_((x * self.D).to(out.dtype)) - out.mul_(F.silu(z)) - out = self.out_proj(out) - output_tensor[i] = out + for i in range(0, bsz): + x = conv_out[i:i+1,:,-1] + z = _z[i:i+1, -1, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj_no_bias(dt) + dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) + dA = torch.exp(dt * self.negA) + dB = dt * B.view(-1, B.size(-1)) + x_shape = (-1, x.size(-1), 1) + ssm_state[i] = (ssm_state[i] * dA + dB * x.view(x_shape)) + c_shape = (C.size(0), C.size(1), -1) + out_mm_shape = (C.size(0), -1) + out = torch.matmul(ssm_state[i].to(C.dtype), C.view(c_shape)).view(out_mm_shape) + # in-place ops + out.add_((x * self.D).to(out.dtype)) + out.mul_(F.silu(z)) + out = self.out_proj(out) + output_tensor[i] = out - return output_tensor, conv_state, ssm_state + return output_tensor, conv_state, ssm_state - # TODO: remove this code only left for reference - # # only support decoding with 1 token at a time - # xz = self.in_proj(hidden_states.view((1, -1))) - # x, z = xz.chunk(2, dim=-1) # (B D) - # x = causal_conv1d_update( - # x, - # conv_state, - # self.conv1d.weight.view(self.conv1d.weight.size(0), -1), - # self.conv1d.bias, - # self.activation, - # ) - - # TODO: prefer using batched logic in all cases - # this just pulls the last element of the batch - x = conv_out[:,:,-1] - z = _z[:, -1, :] - x_db = self.x_proj(x) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj_no_bias(dt) - dt = F.softplus(dt + self.dt_proj.bias).view((dt.size(1), -1)) - dA = torch.exp(dt * self.negA) - dB = dt * B.view(-1, B.size(-1)) - x_shape = (-1, x.size(-1), 1) - ssm_state = (ssm_state * dA + dB * x.view(x_shape)) - c_shape = (C.size(0), C.size(1), -1) - out_mm_shape = (C.size(0), -1) - out = torch.matmul(ssm_state.to(C.dtype), C.view(c_shape)).view(out_mm_shape) - # in-place ops - out.add_((x * self.D).to(out.dtype)) - out.mul_(F.silu(z)) - out = self.out_proj(out) - return out.view((1, -1, out.size(-1))), conv_state, ssm_state class ResidualBlock(nn.Module): diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 6c0568c3..ca3e6e92 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -557,7 +557,7 @@ class Mamba(Model): top_tokens = None generation = Generation( - batch.batch_id, + request.id, prefill_tokens, Tokens( [next_token_id_squeezed],