@@ -1007,6 +1007,65 @@ def test_run_call_pipeline_service_create_with_tabular_dataset(
10071007
10081008 assert job ._has_logged_custom_job
10091009
1010+ def test_custom_training_tabular_done (
1011+ self ,
1012+ mock_pipeline_service_create ,
1013+ mock_pipeline_service_get ,
1014+ mock_python_package_to_gcs ,
1015+ mock_tabular_dataset ,
1016+ mock_model_service_get ,
1017+ ):
1018+ aiplatform .init (
1019+ project = _TEST_PROJECT ,
1020+ staging_bucket = _TEST_BUCKET_NAME ,
1021+ credentials = _TEST_CREDENTIALS ,
1022+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
1023+ )
1024+
1025+ job = training_jobs .CustomTrainingJob (
1026+ display_name = _TEST_DISPLAY_NAME ,
1027+ labels = _TEST_LABELS ,
1028+ script_path = _TEST_LOCAL_SCRIPT_FILE_NAME ,
1029+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
1030+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
1031+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
1032+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
1033+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
1034+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
1035+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
1036+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
1037+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
1038+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
1039+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
1040+ model_description = _TEST_MODEL_DESCRIPTION ,
1041+ )
1042+
1043+ job .run (
1044+ dataset = mock_tabular_dataset ,
1045+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
1046+ service_account = _TEST_SERVICE_ACCOUNT ,
1047+ network = _TEST_NETWORK ,
1048+ args = _TEST_RUN_ARGS ,
1049+ environment_variables = _TEST_ENVIRONMENT_VARIABLES ,
1050+ machine_type = _TEST_MACHINE_TYPE ,
1051+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
1052+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
1053+ model_display_name = _TEST_MODEL_DISPLAY_NAME ,
1054+ model_labels = _TEST_MODEL_LABELS ,
1055+ training_fraction_split = _TEST_TRAINING_FRACTION_SPLIT ,
1056+ validation_fraction_split = _TEST_VALIDATION_FRACTION_SPLIT ,
1057+ test_fraction_split = _TEST_TEST_FRACTION_SPLIT ,
1058+ timestamp_split_column_name = _TEST_TIMESTAMP_SPLIT_COLUMN_NAME ,
1059+ tensorboard = _TEST_TENSORBOARD_RESOURCE_NAME ,
1060+ sync = False ,
1061+ )
1062+
1063+ assert job .done () is False
1064+
1065+ job .wait ()
1066+
1067+ assert job .done () is True
1068+
10101069 @pytest .mark .parametrize ("sync" , [True , False ])
10111070 def test_run_call_pipeline_service_create_with_bigquery_destination (
10121071 self ,
@@ -2323,6 +2382,59 @@ def setup_method(self):
23232382 def teardown_method (self ):
23242383 initializer .global_pool .shutdown (wait = True )
23252384
2385+ def test_custom_container_training_tabular_done (
2386+ self ,
2387+ mock_pipeline_service_create ,
2388+ mock_pipeline_service_get ,
2389+ mock_tabular_dataset ,
2390+ mock_model_service_get ,
2391+ ):
2392+ aiplatform .init (
2393+ project = _TEST_PROJECT ,
2394+ staging_bucket = _TEST_BUCKET_NAME ,
2395+ encryption_spec_key_name = _TEST_DEFAULT_ENCRYPTION_KEY_NAME ,
2396+ )
2397+
2398+ job = training_jobs .CustomContainerTrainingJob (
2399+ display_name = _TEST_DISPLAY_NAME ,
2400+ labels = _TEST_LABELS ,
2401+ container_uri = _TEST_TRAINING_CONTAINER_IMAGE ,
2402+ command = _TEST_TRAINING_CONTAINER_CMD ,
2403+ model_serving_container_image_uri = _TEST_SERVING_CONTAINER_IMAGE ,
2404+ model_serving_container_predict_route = _TEST_SERVING_CONTAINER_PREDICTION_ROUTE ,
2405+ model_serving_container_health_route = _TEST_SERVING_CONTAINER_HEALTH_ROUTE ,
2406+ model_instance_schema_uri = _TEST_MODEL_INSTANCE_SCHEMA_URI ,
2407+ model_parameters_schema_uri = _TEST_MODEL_PARAMETERS_SCHEMA_URI ,
2408+ model_prediction_schema_uri = _TEST_MODEL_PREDICTION_SCHEMA_URI ,
2409+ model_serving_container_command = _TEST_MODEL_SERVING_CONTAINER_COMMAND ,
2410+ model_serving_container_args = _TEST_MODEL_SERVING_CONTAINER_ARGS ,
2411+ model_serving_container_environment_variables = _TEST_MODEL_SERVING_CONTAINER_ENVIRONMENT_VARIABLES ,
2412+ model_serving_container_ports = _TEST_MODEL_SERVING_CONTAINER_PORTS ,
2413+ model_description = _TEST_MODEL_DESCRIPTION ,
2414+ )
2415+
2416+ job .run (
2417+ dataset = mock_tabular_dataset ,
2418+ base_output_dir = _TEST_BASE_OUTPUT_DIR ,
2419+ args = _TEST_RUN_ARGS ,
2420+ environment_variables = _TEST_ENVIRONMENT_VARIABLES ,
2421+ machine_type = _TEST_MACHINE_TYPE ,
2422+ accelerator_type = _TEST_ACCELERATOR_TYPE ,
2423+ accelerator_count = _TEST_ACCELERATOR_COUNT ,
2424+ model_display_name = _TEST_MODEL_DISPLAY_NAME ,
2425+ model_labels = _TEST_MODEL_LABELS ,
2426+ predefined_split_column_name = _TEST_PREDEFINED_SPLIT_COLUMN_NAME ,
2427+ service_account = _TEST_SERVICE_ACCOUNT ,
2428+ tensorboard = _TEST_TENSORBOARD_RESOURCE_NAME ,
2429+ sync = False ,
2430+ )
2431+
2432+ assert job .done () is False
2433+
2434+ job .wait ()
2435+
2436+ assert job .done () is True
2437+
23262438 @pytest .mark .parametrize ("sync" , [True , False ])
23272439 def test_run_call_pipeline_service_create_with_tabular_dataset (
23282440 self ,
0 commit comments