@@ -553,6 +553,63 @@ def test(self, model_w_task, injection_policy, query, inf_kwargs, assert_fn, dty
553
553
assert assert_fn (bs_output , ds_output )
554
554
555
555
556
+ @pytest .mark .seq_inference
557
+ @pytest .mark .parametrize ("model_w_task" , [("meta-llama/Llama-2-7b-hf" , "text-generation" )], ids = ["llama" ])
558
+ @pytest .mark .parametrize ("dtype" , [torch .half ], ids = ["fp16" ])
559
+ class TestLlamaInjection (DistributedTest ):
560
+ world_size = 1
561
+
562
+ def test (self , model_w_task , dtype , query , inf_kwargs , assert_fn ):
563
+ invalid_test_msg = validate_test (model_w_task , dtype , enable_cuda_graph = False , enable_triton = False )
564
+ if invalid_test_msg :
565
+ pytest .skip (invalid_test_msg )
566
+
567
+ if dtype not in get_accelerator ().supported_dtypes ():
568
+ pytest .skip (f"Accelerator { get_accelerator ().device_name ()} does not support { dtype } ." )
569
+
570
+ if not deepspeed .ops .__compatible_ops__ [InferenceBuilder .NAME ]:
571
+ pytest .skip ("This op had not been implemented on this system." , allow_module_level = True )
572
+
573
+ model , task = model_w_task
574
+
575
+ local_rank = int (os .getenv ("LOCAL_RANK" , "0" ))
576
+ device = torch .device (get_accelerator ().device_name (local_rank ))
577
+
578
+ pipe = pipeline (task ,
579
+ model = model ,
580
+ device = torch .device ("cpu" ),
581
+ model_kwargs = {"low_cpu_mem_usage" : True },
582
+ framework = "pt" )
583
+
584
+ if dtype == torch .half :
585
+ pipe .model .half ()
586
+
587
+ pipe .device = device
588
+ pipe .model .to (device )
589
+ bs_output = pipe (query , ** inf_kwargs )
590
+
591
+ try :
592
+ pipe .model = deepspeed .init_inference (
593
+ pipe .model ,
594
+ mp_size = self .world_size ,
595
+ dtype = dtype ,
596
+ replace_with_kernel_inject = True
597
+ )
598
+ check_injection (pipe .model )
599
+ except AttributeError as e :
600
+ if "'LlamaAttention' object has no attribute 'num_heads'" in e :
601
+ pytest .skip ("Skipping due to transformers version compatibility issue with self-attention" )
602
+ raise e
603
+
604
+ ds_output = pipe (query , ** inf_kwargs )
605
+
606
+ print (local_rank , "baseline" , bs_output )
607
+ print (local_rank , "deepspeed" , ds_output )
608
+ # Llama models are not matching baseline exactly
609
+ # We skip the result check for now, since this is irrelevant to this test
610
+ # assert assert_fn(bs_output, ds_output)
611
+
612
+
556
613
@pytest .mark .seq_inference
557
614
@pytest .mark .parametrize ('keep_module_on_host' , [True , False ])
558
615
@pytest .mark .parametrize (
0 commit comments