Skip to content

Commit 61cff4b

Browse files
lingyinwcopybara-github
authored andcommitted
feat: Support private service connect for MatchingEngineIndexEndpoint match() and read_index_datapoints().
PiperOrigin-RevId: 596852286
1 parent 776d0da commit 61cff4b

File tree

2 files changed

+146
-20
lines changed

2 files changed

+146
-20
lines changed

google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,9 @@ def __init__(
220220
if self.public_endpoint_domain_name:
221221
self._public_match_client = self._instantiate_public_match_client()
222222

223+
self._match_grpc_stub_cache = {}
224+
self._private_service_connect_ip_address = None
225+
223226
@classmethod
224227
def create(
225228
cls,
@@ -521,40 +524,85 @@ def _instantiate_public_match_client(
521524

522525
def _instantiate_private_match_service_stub(
523526
self,
524-
deployed_index_id: str,
527+
deployed_index_id: Optional[str] = None,
528+
ip_address: Optional[str] = None,
525529
) -> match_service_pb2_grpc.MatchServiceStub:
526530
"""Helper method to instantiate private match service stub.
527531
Args:
528532
deployed_index_id (str):
529-
Required. The user specified ID of the
530-
DeployedIndex.
533+
Optional. Required for private service access endpoint.
534+
The user specified ID of the DeployedIndex.
535+
ip_address (str):
536+
Optional. Required for private service connect. The ip address
537+
the forwarding rule makes use of.
531538
Returns:
532539
stub (match_service_pb2_grpc.MatchServiceStub):
533540
Initialized match service stub.
541+
Raises:
542+
RuntimeError: No deployed index with id deployed_index_id found
543+
ValueError: Should not set ip address for networks other than
544+
private service connect.
534545
"""
535-
# Find the deployed index by id
536-
deployed_indexes = [
537-
deployed_index
538-
for deployed_index in self.deployed_indexes
539-
if deployed_index.id == deployed_index_id
540-
]
546+
if ip_address:
547+
# Should only set for Private Service Connect
548+
if self.public_endpoint_domain_name:
549+
raise ValueError(
550+
"MatchingEngineIndexEndpoint is set to use ",
551+
"public network. Could not establish connection using "
552+
"provided ip address",
553+
)
554+
elif self.private_service_access_network:
555+
raise ValueError(
556+
"MatchingEngineIndexEndpoint is set to use ",
557+
"private service access network. Could not establish "
558+
"connection using provided ip address",
559+
)
560+
else:
561+
# Private Service Access, find server ip for deployed index
562+
deployed_indexes = [
563+
deployed_index
564+
for deployed_index in self.deployed_indexes
565+
if deployed_index.id == deployed_index_id
566+
]
541567

542-
if not deployed_indexes:
543-
raise RuntimeError(f"No deployed index with id '{deployed_index_id}' found")
568+
if not deployed_indexes:
569+
raise RuntimeError(
570+
f"No deployed index with id '{deployed_index_id}' found"
571+
)
544572

545-
# Retrieve server ip from deployed index
546-
server_ip = deployed_indexes[0].private_endpoints.match_grpc_address
573+
# Retrieve server ip from deployed index
574+
ip_address = deployed_indexes[0].private_endpoints.match_grpc_address
547575

548-
# Set up channel and stub
549-
channel = grpc.insecure_channel("{}:10000".format(server_ip))
550-
return match_service_pb2_grpc.MatchServiceStub(channel)
576+
if ip_address not in self._match_grpc_stub_cache:
577+
# Set up channel and stub
578+
channel = grpc.insecure_channel("{}:10000".format(ip_address))
579+
self._match_grpc_stub_cache[
580+
ip_address
581+
] = match_service_pb2_grpc.MatchServiceStub(channel)
582+
return self._match_grpc_stub_cache[ip_address]
551583

552584
@property
553585
def public_endpoint_domain_name(self) -> Optional[str]:
554586
"""Public endpoint DNS name."""
555587
self._assert_gca_resource_is_available()
556588
return self._gca_resource.public_endpoint_domain_name
557589

590+
@property
591+
def private_service_access_network(self) -> Optional[str]:
592+
""" "Private service access network."""
593+
self._assert_gca_resource_is_available()
594+
return self._gca_resource.network
595+
596+
@property
597+
def private_service_connect_ip_address(self) -> Optional[str]:
598+
""" "Private service connect ip address."""
599+
return self._private_service_connect_ip_address
600+
601+
@private_service_connect_ip_address.setter
602+
def private_service_connect_ip_address(self, ip_address: str) -> Optional[str]:
603+
""" "Setter for private service connect ip address."""
604+
self._private_service_connect_ip_address = ip_address
605+
558606
def update(
559607
self,
560608
display_name: str,
@@ -1300,7 +1348,8 @@ def read_index_datapoints(
13001348
if not self._public_match_client:
13011349
# Call private match service stub with BatchGetEmbeddings request
13021350
embeddings = self._batch_get_embeddings(
1303-
deployed_index_id=deployed_index_id, ids=ids
1351+
deployed_index_id=deployed_index_id,
1352+
ids=ids,
13041353
)
13051354

13061355
response = []
@@ -1362,7 +1411,8 @@ def _batch_get_embeddings(
13621411
List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs.
13631412
"""
13641413
stub = self._instantiate_private_match_service_stub(
1365-
deployed_index_id=deployed_index_id
1414+
deployed_index_id=deployed_index_id,
1415+
ip_address=self._private_service_connect_ip_address,
13661416
)
13671417

13681418
# Create the batch get embeddings request
@@ -1420,7 +1470,8 @@ def match(
14201470
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
14211471
"""
14221472
stub = self._instantiate_private_match_service_stub(
1423-
deployed_index_id=deployed_index_id
1473+
deployed_index_id=deployed_index_id,
1474+
ip_address=self._private_service_connect_ip_address,
14241475
)
14251476

14261477
# Create the batch match request

tests/unit/aiplatform/test_matching_engine_index_endpoint.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@
246246
_TEST_RETURN_FULL_DATAPOINT = True
247247
_TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name"
248248
_TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"]
249+
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5"
249250
_TEST_READ_INDEX_DATAPOINTS_RESPONSE = [
250251
gca_index_v1beta1.IndexDatapoint(
251252
datapoint_id="1",
@@ -1137,6 +1138,54 @@ def test_private_index_endpoint_find_neighbor_queries(
11371138
)
11381139
index_endpoint_match_queries_mock.assert_called_with(batch_match_request)
11391140

1141+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1142+
def test_index_private_service_connect_endpoint_match_queries(
1143+
self, index_endpoint_match_queries_mock
1144+
):
1145+
aiplatform.init(project=_TEST_PROJECT)
1146+
1147+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1148+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1149+
)
1150+
1151+
my_index_endpoint.private_service_connect_ip_address = (
1152+
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
1153+
)
1154+
my_index_endpoint.match(
1155+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1156+
queries=_TEST_QUERIES,
1157+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1158+
filter=_TEST_FILTER,
1159+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1160+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1161+
)
1162+
1163+
batch_request = match_service_pb2.BatchMatchRequest(
1164+
requests=[
1165+
match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex(
1166+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1167+
requests=[
1168+
match_service_pb2.MatchRequest(
1169+
num_neighbors=_TEST_NUM_NEIGHBOURS,
1170+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1171+
float_val=_TEST_QUERIES[0],
1172+
restricts=[
1173+
match_service_pb2.Namespace(
1174+
name="class",
1175+
allow_tokens=["token_1"],
1176+
deny_tokens=["token_2"],
1177+
)
1178+
],
1179+
per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
1180+
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
1181+
)
1182+
],
1183+
)
1184+
]
1185+
)
1186+
1187+
index_endpoint_match_queries_mock.assert_called_with(batch_request)
1188+
11401189
@pytest.mark.usefixtures("get_index_public_endpoint_mock")
11411190
def test_index_public_endpoint_match_queries(
11421191
self, index_public_endpoint_match_queries_mock
@@ -1330,7 +1379,7 @@ def test_index_endpoint_batch_get_embeddings(
13301379
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)
13311380

13321381
@pytest.mark.usefixtures("get_index_endpoint_mock")
1333-
def test_index_private_endpoint_read_index_datapoints(
1382+
def test_index_endpoint_find_neighbors_for_private_service_access(
13341383
self, index_endpoint_batch_get_embeddings_mock
13351384
):
13361385
aiplatform.init(project=_TEST_PROJECT)
@@ -1350,3 +1399,29 @@ def test_index_private_endpoint_read_index_datapoints(
13501399
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)
13511400

13521401
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE
1402+
1403+
@pytest.mark.usefixtures("get_index_endpoint_mock")
1404+
def test_index_endpoint_find_neighbors_for_private_service_connect(
1405+
self, index_endpoint_batch_get_embeddings_mock
1406+
):
1407+
aiplatform.init(project=_TEST_PROJECT)
1408+
1409+
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
1410+
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
1411+
)
1412+
1413+
my_index_endpoint.private_service_connect_ip = (
1414+
_TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS
1415+
)
1416+
response = my_index_endpoint.read_index_datapoints(
1417+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
1418+
ids=["1", "2"],
1419+
)
1420+
1421+
batch_request = match_service_pb2.BatchGetEmbeddingsRequest(
1422+
deployed_index_id=_TEST_DEPLOYED_INDEX_ID, id=["1", "2"]
1423+
)
1424+
1425+
index_endpoint_batch_get_embeddings_mock.assert_called_with(batch_request)
1426+
1427+
assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE

0 commit comments

Comments
 (0)