@@ -220,6 +220,9 @@ def __init__(
220220 if self .public_endpoint_domain_name :
221221 self ._public_match_client = self ._instantiate_public_match_client ()
222222
223+ self ._match_grpc_stub_cache = {}
224+ self ._private_service_connect_ip_address = None
225+
223226 @classmethod
224227 def create (
225228 cls ,
@@ -521,40 +524,85 @@ def _instantiate_public_match_client(
521524
522525 def _instantiate_private_match_service_stub (
523526 self ,
524- deployed_index_id : str ,
527+ deployed_index_id : Optional [str ] = None ,
528+ ip_address : Optional [str ] = None ,
525529 ) -> match_service_pb2_grpc .MatchServiceStub :
526530 """Helper method to instantiate private match service stub.
527531 Args:
528532 deployed_index_id (str):
529- Required. The user specified ID of the
530- DeployedIndex.
533+ Optional. Required for private service access endpoint.
534+ The user specified ID of the DeployedIndex.
535+ ip_address (str):
536+ Optional. Required for private service connect. The ip address
537+ the forwarding rule makes use of.
531538 Returns:
532539 stub (match_service_pb2_grpc.MatchServiceStub):
533540 Initialized match service stub.
541+ Raises:
542+ RuntimeError: No deployed index with id deployed_index_id found
543+ ValueError: Should not set ip address for networks other than
544+ private service connect.
534545 """
535- # Find the deployed index by id
536- deployed_indexes = [
537- deployed_index
538- for deployed_index in self .deployed_indexes
539- if deployed_index .id == deployed_index_id
540- ]
546+ if ip_address :
547+ # Should only set for Private Service Connect
548+ if self .public_endpoint_domain_name :
549+ raise ValueError (
550+ "MatchingEngineIndexEndpoint is set to use " ,
551+ "public network. Could not establish connection using "
552+ "provided ip address" ,
553+ )
554+ elif self .private_service_access_network :
555+ raise ValueError (
556+ "MatchingEngineIndexEndpoint is set to use " ,
557+ "private service access network. Could not establish "
558+ "connection using provided ip address" ,
559+ )
560+ else :
561+ # Private Service Access, find server ip for deployed index
562+ deployed_indexes = [
563+ deployed_index
564+ for deployed_index in self .deployed_indexes
565+ if deployed_index .id == deployed_index_id
566+ ]
541567
542- if not deployed_indexes :
543- raise RuntimeError (f"No deployed index with id '{ deployed_index_id } ' found" )
568+ if not deployed_indexes :
569+ raise RuntimeError (
570+ f"No deployed index with id '{ deployed_index_id } ' found"
571+ )
544572
545- # Retrieve server ip from deployed index
546- server_ip = deployed_indexes [0 ].private_endpoints .match_grpc_address
573+ # Retrieve server ip from deployed index
574+ ip_address = deployed_indexes [0 ].private_endpoints .match_grpc_address
547575
548- # Set up channel and stub
549- channel = grpc .insecure_channel ("{}:10000" .format (server_ip ))
550- return match_service_pb2_grpc .MatchServiceStub (channel )
576+ if ip_address not in self ._match_grpc_stub_cache :
577+ # Set up channel and stub
578+ channel = grpc .insecure_channel ("{}:10000" .format (ip_address ))
579+ self ._match_grpc_stub_cache [
580+ ip_address
581+ ] = match_service_pb2_grpc .MatchServiceStub (channel )
582+ return self ._match_grpc_stub_cache [ip_address ]
551583
552584 @property
553585 def public_endpoint_domain_name (self ) -> Optional [str ]:
554586 """Public endpoint DNS name."""
555587 self ._assert_gca_resource_is_available ()
556588 return self ._gca_resource .public_endpoint_domain_name
557589
590+ @property
591+ def private_service_access_network (self ) -> Optional [str ]:
592+ """ "Private service access network."""
593+ self ._assert_gca_resource_is_available ()
594+ return self ._gca_resource .network
595+
596+ @property
597+ def private_service_connect_ip_address (self ) -> Optional [str ]:
598+ """ "Private service connect ip address."""
599+ return self ._private_service_connect_ip_address
600+
601+ @private_service_connect_ip_address .setter
602+ def private_service_connect_ip_address (self , ip_address : str ) -> Optional [str ]:
603+ """ "Setter for private service connect ip address."""
604+ self ._private_service_connect_ip_address = ip_address
605+
558606 def update (
559607 self ,
560608 display_name : str ,
@@ -1300,7 +1348,8 @@ def read_index_datapoints(
13001348 if not self ._public_match_client :
13011349 # Call private match service stub with BatchGetEmbeddings request
13021350 embeddings = self ._batch_get_embeddings (
1303- deployed_index_id = deployed_index_id , ids = ids
1351+ deployed_index_id = deployed_index_id ,
1352+ ids = ids ,
13041353 )
13051354
13061355 response = []
@@ -1362,7 +1411,8 @@ def _batch_get_embeddings(
13621411 List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
13631412 """
13641413 stub = self ._instantiate_private_match_service_stub (
1365- deployed_index_id = deployed_index_id
1414+ deployed_index_id = deployed_index_id ,
1415+ ip_address = self ._private_service_connect_ip_address ,
13661416 )
13671417
13681418 # Create the batch get embeddings request
@@ -1420,7 +1470,8 @@ def match(
14201470 List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
14211471 """
14221472 stub = self ._instantiate_private_match_service_stub (
1423- deployed_index_id = deployed_index_id
1473+ deployed_index_id = deployed_index_id ,
1474+ ip_address = self ._private_service_connect_ip_address ,
14241475 )
14251476
14261477 # Create the batch match request
0 commit comments