| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.fixture(scope="module") | 
					
						
							|  |  |  | def t5_sharded_handle(launcher): | 
					
						
							| 
									
										
										
										
											2024-04-12 12:20:31 +00:00
										 |  |  |     with launcher("google/flan-t5-xxl", num_shard=4) as handle: | 
					
						
							| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  |         yield handle | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @pytest.fixture(scope="module") | 
					
						
							|  |  |  | async def t5_sharded(t5_sharded_handle): | 
					
						
							| 
									
										
										
										
											2023-05-31 08:55:59 +00:00
										 |  |  |     await t5_sharded_handle.health(300) | 
					
						
							| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  |     return t5_sharded_handle.client | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-25 14:53:20 +00:00
										 |  |  | @pytest.mark.release | 
					
						
							| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_t5_sharded(t5_sharded, response_snapshot): | 
					
						
							|  |  |  |     response = await t5_sharded.generate( | 
					
						
							|  |  |  |         "Please answer the following question. What is the boiling point of Nitrogen?", | 
					
						
							|  |  |  |         max_new_tokens=10, | 
					
						
							| 
									
										
										
										
											2023-06-02 15:12:30 +00:00
										 |  |  |         decoder_input_details=True, | 
					
						
							| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert response == response_snapshot | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-25 14:53:20 +00:00
										 |  |  | @pytest.mark.release | 
					
						
							| 
									
										
										
										
											2023-05-23 16:16:48 +00:00
										 |  |  | @pytest.mark.asyncio | 
					
						
							|  |  |  | async def test_t5_sharded_load(t5_sharded, generate_load, response_snapshot): | 
					
						
							|  |  |  |     responses = await generate_load( | 
					
						
							|  |  |  |         t5_sharded, | 
					
						
							|  |  |  |         "Please answer the following question. What is the boiling point of Nitrogen?", | 
					
						
							|  |  |  |         max_new_tokens=10, | 
					
						
							|  |  |  |         n=4, | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert len(responses) == 4 | 
					
						
							|  |  |  |     assert all([r.generated_text == responses[0].generated_text for r in responses]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     assert responses == response_snapshot |