183183_TEST_SPLIT_PREDEFINED_COLUMN_NAME = "split"
184184_TEST_SPLIT_TIMESTAMP_COLUMN_NAME = "timestamp"
185185
186+ _FORECASTING_JOB_MODEL_TYPES = [
187+ training_jobs .AutoMLForecastingTrainingJob ,
188+ training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
189+ training_jobs .TemporalFusionTransformerForecastingTrainingJob ,
190+ ]
191+
186192
187193@pytest .fixture
188194def mock_pipeline_service_create ():
@@ -293,13 +299,7 @@ def teardown_method(self):
293299 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
294300 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
295301 @pytest .mark .parametrize ("sync" , [True , False ])
296- @pytest .mark .parametrize (
297- "training_job" ,
298- [
299- training_jobs .AutoMLForecastingTrainingJob ,
300- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
301- ],
302- )
302+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
303303 def test_run_call_pipeline_service_create (
304304 self ,
305305 mock_pipeline_service_create ,
@@ -401,13 +401,7 @@ def test_run_call_pipeline_service_create(
401401 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
402402 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
403403 @pytest .mark .parametrize ("sync" , [True , False ])
404- @pytest .mark .parametrize (
405- "training_job" ,
406- [
407- training_jobs .AutoMLForecastingTrainingJob ,
408- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
409- ],
410- )
404+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
411405 def test_run_call_pipeline_service_create_with_timeout (
412406 self ,
413407 mock_pipeline_service_create ,
@@ -496,13 +490,7 @@ def test_run_call_pipeline_service_create_with_timeout(
496490 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
497491 @pytest .mark .usefixtures ("mock_pipeline_service_get" )
498492 @pytest .mark .parametrize ("sync" , [True , False ])
499- @pytest .mark .parametrize (
500- "training_job" ,
501- [
502- training_jobs .AutoMLForecastingTrainingJob ,
503- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
504- ],
505- )
493+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
506494 def test_run_call_pipeline_if_no_model_display_name_nor_model_labels (
507495 self ,
508496 mock_pipeline_service_create ,
@@ -584,13 +572,7 @@ def test_run_call_pipeline_if_no_model_display_name_nor_model_labels(
584572 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
585573 @pytest .mark .usefixtures ("mock_pipeline_service_get" )
586574 @pytest .mark .parametrize ("sync" , [True , False ])
587- @pytest .mark .parametrize (
588- "training_job" ,
589- [
590- training_jobs .AutoMLForecastingTrainingJob ,
591- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
592- ],
593- )
575+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
594576 def test_run_call_pipeline_if_set_additional_experiments (
595577 self ,
596578 mock_pipeline_service_create ,
@@ -675,13 +657,7 @@ def test_run_call_pipeline_if_set_additional_experiments(
675657 "mock_model_service_get" ,
676658 )
677659 @pytest .mark .parametrize ("sync" , [True , False ])
678- @pytest .mark .parametrize (
679- "training_job" ,
680- [
681- training_jobs .AutoMLForecastingTrainingJob ,
682- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
683- ],
684- )
660+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
685661 def test_run_called_twice_raises (
686662 self ,
687663 mock_dataset_time_series ,
@@ -762,13 +738,7 @@ def test_run_called_twice_raises(
762738 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
763739 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
764740 @pytest .mark .parametrize ("sync" , [True , False ])
765- @pytest .mark .parametrize (
766- "training_job" ,
767- [
768- training_jobs .AutoMLForecastingTrainingJob ,
769- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
770- ],
771- )
741+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
772742 def test_run_raises_if_pipeline_fails (
773743 self ,
774744 mock_pipeline_service_create_and_get_with_fail ,
@@ -823,13 +793,7 @@ def test_run_raises_if_pipeline_fails(
823793 with pytest .raises (RuntimeError ):
824794 job .get_model ()
825795
826- @pytest .mark .parametrize (
827- "training_job" ,
828- [
829- training_jobs .AutoMLForecastingTrainingJob ,
830- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
831- ],
832- )
796+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
833797 def test_raises_before_run_is_called (
834798 self ,
835799 mock_pipeline_service_create ,
@@ -855,13 +819,7 @@ def test_raises_before_run_is_called(
855819 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
856820 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
857821 @pytest .mark .parametrize ("sync" , [True , False ])
858- @pytest .mark .parametrize (
859- "training_job" ,
860- [
861- training_jobs .AutoMLForecastingTrainingJob ,
862- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
863- ],
864- )
822+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
865823 def test_splits_fraction (
866824 self ,
867825 mock_pipeline_service_create ,
@@ -960,13 +918,7 @@ def test_splits_fraction(
960918 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
961919 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
962920 @pytest .mark .parametrize ("sync" , [True , False ])
963- @pytest .mark .parametrize (
964- "training_job" ,
965- [
966- training_jobs .AutoMLForecastingTrainingJob ,
967- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
968- ],
969- )
921+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
970922 def test_splits_timestamp (
971923 self ,
972924 mock_pipeline_service_create ,
@@ -1067,13 +1019,7 @@ def test_splits_timestamp(
10671019 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
10681020 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
10691021 @pytest .mark .parametrize ("sync" , [True , False ])
1070- @pytest .mark .parametrize (
1071- "training_job" ,
1072- [
1073- training_jobs .AutoMLForecastingTrainingJob ,
1074- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1075- ],
1076- )
1022+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
10771023 def test_splits_predefined (
10781024 self ,
10791025 mock_pipeline_service_create ,
@@ -1168,13 +1114,7 @@ def test_splits_predefined(
11681114 @mock .patch .object (training_jobs , "_JOB_WAIT_TIME" , 1 )
11691115 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
11701116 @pytest .mark .parametrize ("sync" , [True , False ])
1171- @pytest .mark .parametrize (
1172- "training_job" ,
1173- [
1174- training_jobs .AutoMLForecastingTrainingJob ,
1175- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1176- ],
1177- )
1117+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
11781118 def test_splits_default (
11791119 self ,
11801120 mock_pipeline_service_create ,
@@ -1264,13 +1204,7 @@ def test_splits_default(
12641204 @mock .patch .object (training_jobs , "_LOG_WAIT_TIME" , 1 )
12651205 @pytest .mark .usefixtures ("mock_pipeline_service_get" )
12661206 @pytest .mark .parametrize ("sync" , [True , False ])
1267- @pytest .mark .parametrize (
1268- "training_job" ,
1269- [
1270- training_jobs .AutoMLForecastingTrainingJob ,
1271- training_jobs .SequenceToSequencePlusForecastingTrainingJob ,
1272- ],
1273- )
1207+ @pytest .mark .parametrize ("training_job" , _FORECASTING_JOB_MODEL_TYPES )
12741208 def test_run_call_pipeline_if_set_additional_experiments_probabilistic_inference (
12751209 self ,
12761210 mock_pipeline_service_create ,
0 commit comments