@@ -409,6 +409,73 @@ def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method
409409 metadata = _TEST_REQUEST_METADATA ,
410410 )
411411
412+ @pytest .mark .usefixtures ("get_index_mock" )
413+ @pytest .mark .parametrize ("sync" , [True , False ])
414+ @pytest .mark .parametrize (
415+ "index_update_method" ,
416+ [
417+ _TEST_INDEX_STREAM_UPDATE_METHOD ,
418+ _TEST_INDEX_BATCH_UPDATE_METHOD ,
419+ _TEST_INDEX_EMPTY_UPDATE_METHOD ,
420+ _TEST_INDEX_INVALID_UPDATE_METHOD ,
421+ ],
422+ )
423+ def test_create_tree_ah_index_with_empty_index (
424+ self , create_index_mock , sync , index_update_method
425+ ):
426+ aiplatform .init (project = _TEST_PROJECT )
427+
428+ my_index = aiplatform .MatchingEngineIndex .create_tree_ah_index (
429+ display_name = _TEST_INDEX_DISPLAY_NAME ,
430+ contents_delta_uri = None ,
431+ dimensions = _TEST_INDEX_CONFIG_DIMENSIONS ,
432+ approximate_neighbors_count = _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT ,
433+ distance_measure_type = _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
434+ leaf_node_embedding_count = _TEST_LEAF_NODE_EMBEDDING_COUNT ,
435+ leaf_nodes_to_search_percent = _TEST_LEAF_NODES_TO_SEARCH_PERCENT ,
436+ description = _TEST_INDEX_DESCRIPTION ,
437+ labels = _TEST_LABELS ,
438+ sync = sync ,
439+ index_update_method = index_update_method ,
440+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
441+ )
442+
443+ if not sync :
444+ my_index .wait ()
445+
446+ config = {
447+ "treeAhConfig" : {
448+ "leafNodeEmbeddingCount" : _TEST_LEAF_NODE_EMBEDDING_COUNT ,
449+ "leafNodesToSearchPercent" : _TEST_LEAF_NODES_TO_SEARCH_PERCENT ,
450+ }
451+ }
452+
453+ expected = gca_index .Index (
454+ display_name = _TEST_INDEX_DISPLAY_NAME ,
455+ metadata = {
456+ "config" : {
457+ "algorithmConfig" : config ,
458+ "dimensions" : _TEST_INDEX_CONFIG_DIMENSIONS ,
459+ "approximateNeighborsCount" : _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT ,
460+ "distanceMeasureType" : _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
461+ },
462+ },
463+ description = _TEST_INDEX_DESCRIPTION ,
464+ labels = _TEST_LABELS ,
465+ index_update_method = _TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP [
466+ index_update_method
467+ ],
468+ encryption_spec = gca_encryption_spec .EncryptionSpec (
469+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME
470+ ),
471+ )
472+
473+ create_index_mock .assert_called_once_with (
474+ parent = _TEST_PARENT ,
475+ index = expected ,
476+ metadata = _TEST_REQUEST_METADATA ,
477+ )
478+
412479 @pytest .mark .usefixtures ("get_index_mock" )
413480 def test_create_tree_ah_index_backward_compatibility (self , create_index_mock ):
414481 aiplatform .init (project = _TEST_PROJECT )
@@ -513,6 +580,64 @@ def test_create_brute_force_index(
513580 metadata = _TEST_REQUEST_METADATA ,
514581 )
515582
583+ @pytest .mark .usefixtures ("get_index_mock" )
584+ @pytest .mark .parametrize ("sync" , [True , False ])
585+ @pytest .mark .parametrize (
586+ "index_update_method" ,
587+ [
588+ _TEST_INDEX_STREAM_UPDATE_METHOD ,
589+ _TEST_INDEX_BATCH_UPDATE_METHOD ,
590+ _TEST_INDEX_EMPTY_UPDATE_METHOD ,
591+ _TEST_INDEX_INVALID_UPDATE_METHOD ,
592+ ],
593+ )
594+ def test_create_brute_force_index_with_empty_index (
595+ self , create_index_mock , sync , index_update_method
596+ ):
597+ aiplatform .init (project = _TEST_PROJECT )
598+
599+ my_index = aiplatform .MatchingEngineIndex .create_brute_force_index (
600+ display_name = _TEST_INDEX_DISPLAY_NAME ,
601+ dimensions = _TEST_INDEX_CONFIG_DIMENSIONS ,
602+ distance_measure_type = _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
603+ description = _TEST_INDEX_DESCRIPTION ,
604+ labels = _TEST_LABELS ,
605+ sync = sync ,
606+ index_update_method = index_update_method ,
607+ encryption_spec_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
608+ )
609+
610+ if not sync :
611+ my_index .wait ()
612+
613+ config = {"bruteForceConfig" : {}}
614+
615+ expected = gca_index .Index (
616+ display_name = _TEST_INDEX_DISPLAY_NAME ,
617+ metadata = {
618+ "config" : {
619+ "algorithmConfig" : config ,
620+ "dimensions" : _TEST_INDEX_CONFIG_DIMENSIONS ,
621+ "approximateNeighborsCount" : None ,
622+ "distanceMeasureType" : _TEST_INDEX_DISTANCE_MEASURE_TYPE ,
623+ },
624+ },
625+ description = _TEST_INDEX_DESCRIPTION ,
626+ labels = _TEST_LABELS ,
627+ index_update_method = _TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP [
628+ index_update_method
629+ ],
630+ encryption_spec = gca_encryption_spec .EncryptionSpec (
631+ kms_key_name = _TEST_ENCRYPTION_SPEC_KEY_NAME ,
632+ ),
633+ )
634+
635+ create_index_mock .assert_called_once_with (
636+ parent = _TEST_PARENT ,
637+ index = expected ,
638+ metadata = _TEST_REQUEST_METADATA ,
639+ )
640+
516641 @pytest .mark .usefixtures ("get_index_mock" )
517642 def test_create_brute_force_index_backward_compatibility (self , create_index_mock ):
518643 aiplatform .init (project = _TEST_PROJECT )
0 commit comments