5555
5656from vertexai .preview import language_models
5757from google .cloud .aiplatform_v1 import Execution as GapicExecution
58+ from google .cloud .aiplatform .compat .types import (
59+ encryption_spec as gca_encryption_spec ,
60+ )
5861
5962_TEST_PROJECT = "test-project"
6063_TEST_LOCATION = "us-central1"
6164
65+ # CMEK encryption
66+ _TEST_ENCRYPTION_KEY_NAME = "key_1234"
67+ _TEST_ENCRYPTION_SPEC = gca_encryption_spec .EncryptionSpec (
68+ kms_key_name = _TEST_ENCRYPTION_KEY_NAME
69+ )
70+
6271_TEXT_BISON_PUBLISHER_MODEL_DICT = {
6372 "name" : "publishers/google/models/text-bison" ,
6473 "version_id" : "001" ,
166175 "dag" : {"tasks" : {}},
167176 "inputDefinitions" : {
168177 "parameters" : {
169- "project" : {"parameterType" : "STRING" },
170- "location" : {
178+ "api_endpoint" : {
179+ "defaultValue" : "aiplatform.googleapis.com/ui" ,
180+ "isOptional" : True ,
171181 "parameterType" : "STRING" ,
172182 },
173- "large_model_reference" : {
183+ "dataset_name" : {
184+ "defaultValue" : "" ,
185+ "isOptional" : True ,
186+ "parameterType" : "STRING" ,
187+ },
188+ "dataset_uri" : {
189+ "defaultValue" : "" ,
190+ "isOptional" : True ,
174191 "parameterType" : "STRING" ,
175192 },
176- "model_display_name" : {
193+ "encryption_spec_key_name" : {
194+ "defaultValue" : "" ,
195+ "isOptional" : True ,
177196 "parameterType" : "STRING" ,
178197 },
198+ "large_model_reference" : {
199+ "defaultValue" : "text-bison-001" ,
200+ "isOptional" : True ,
201+ "parameterType" : "STRING" ,
202+ },
203+ "learning_rate" : {
204+ "defaultValue" : 3 ,
205+ "isOptional" : True ,
206+ "parameterType" : "NUMBER_DOUBLE" ,
207+ },
208+ "location" : {"parameterType" : "STRING" },
209+ "model_display_name" : {"parameterType" : "STRING" },
210+ "project" : {"parameterType" : "STRING" },
179211 "train_steps" : {
212+ "defaultValue" : 1000 ,
213+ "isOptional" : True ,
180214 "parameterType" : "NUMBER_INTEGER" ,
181215 },
182- "dataset_uri" : {"parameterType" : "STRING" },
183- "dataset_name" : {"parameterType" : "STRING" },
184216 }
185217 },
186218 },
@@ -480,6 +512,7 @@ def test_tune_model(
480512 aiplatform .init (
481513 project = _TEST_PROJECT ,
482514 location = _TEST_LOCATION ,
515+ encryption_spec_key_name = _TEST_ENCRYPTION_KEY_NAME ,
483516 )
484517 with mock .patch .object (
485518 target = model_garden_service_client_v1beta1 .ModelGardenServiceClient ,
@@ -497,6 +530,11 @@ def test_tune_model(
497530 tuning_job_location = "europe-west4" ,
498531 tuned_model_location = "us-central1" ,
499532 )
533+ call_kwargs = mock_pipeline_service_create .call_args [1 ]
534+ assert (
535+ call_kwargs ["pipeline_job" ].encryption_spec .kms_key_name
536+ == _TEST_ENCRYPTION_KEY_NAME
537+ )
500538
501539 @pytest .mark .usefixtures (
502540 "get_model_with_tuned_version_label_mock" ,
@@ -518,7 +556,6 @@ def test_get_tuned_model(
518556 _TEXT_BISON_PUBLISHER_MODEL_DICT
519557 ),
520558 ):
521-
522559 tuned_model = language_models .TextGenerationModel .get_tuned_model (
523560 test_constants .ModelConstants ._TEST_MODEL_RESOURCE_NAME
524561 )
0 commit comments