1414# See the License for the specific language governing permissions and
1515# limitations under the License.
1616
17- from typing import Dict , List , Tuple , Union
17+ import logging
18+ import os
19+
20+ from typing import Dict , List , Optional , Tuple , Union
1821
1922try :
2023 from lit_nlp .api import dataset as lit_dataset
24+ from lit_nlp .api import dtypes as lit_dtypes
2125 from lit_nlp .api import model as lit_model
2226 from lit_nlp .api import types as lit_types
2327 from lit_nlp import notebook
@@ -82,6 +86,7 @@ def __init__(
8286 model : str ,
8387 input_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
8488 output_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
89+ attribution_method : str = "sampled_shapley" ,
8590 ):
8691 """Construct a VertexLitModel.
8792 Args:
@@ -94,39 +99,33 @@ def __init__(
9499 output_types:
95100 Required. An OrderedDict of string names matching the labels of the model
96101 as the key, and the associated LitType of the label.
102+ attribution_method:
103+ Optional. A string to choose what attribution configuration to
104+ set up the explainer with. Valid options are 'sampled_shapley'
105+ or 'integrated_gradients'.
97106 """
98- self ._loaded_model = tf .saved_model .load (model )
99- serving_default = self ._loaded_model .signatures [
100- tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
101- ]
102- _ , self ._kwargs_signature = serving_default .structured_input_signature
103- self ._output_signature = serving_default .structured_outputs
104-
105- if len (self ._kwargs_signature ) != 1 :
106- raise ValueError ("Please use a model with only one input tensor." )
107-
108- if len (self ._output_signature ) != 1 :
109- raise ValueError ("Please use a model with only one output tensor." )
110-
107+ self ._load_model (model )
111108 self ._input_types = input_types
112109 self ._output_types = output_types
110+ self ._input_tensor_name = next (iter (self ._kwargs_signature ))
111+ self ._attribution_explainer = None
112+ if os .environ .get ("LIT_PROXY_URL" ):
113+ self ._set_up_attribution_explainer (model , attribution_method )
114+
115+ @property
116+ def attribution_explainer (self ,) -> Optional ["AttributionExplainer" ]: # noqa: F821
117+ """Gets the attribution explainer property if set."""
118+ return self ._attribution_explainer
113119
114120 def predict_minibatch (
115121 self , inputs : List [lit_types .JsonDict ]
116122 ) -> List [lit_types .JsonDict ]:
117- """Returns predictions for a single batch of examples.
118- Args:
119- inputs:
120- sequence of inputs, following model.input_spec()
121- Returns:
122- list of outputs, following model.output_spec()
123- """
124123 instances = []
125124 for input in inputs :
126125 instance = [input [feature ] for feature in self ._input_types ]
127126 instances .append (instance )
128127 prediction_input_dict = {
129- next ( iter ( self ._kwargs_signature )) : tf .convert_to_tensor (instances )
128+ self ._input_tensor_name : tf .convert_to_tensor (instances )
130129 }
131130 prediction_dict = self ._loaded_model .signatures [
132131 tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
@@ -140,6 +139,15 @@ def predict_minibatch(
140139 for label , value in zip (self ._output_types .keys (), prediction )
141140 }
142141 )
142+ # Get feature attributions
143+ if self .attribution_explainer :
144+ attributions = self .attribution_explainer .explain (
145+ [{self ._input_tensor_name : i } for i in instances ]
146+ )
147+ for i , attribution in enumerate (attributions ):
148+ outputs [i ]["feature_attribution" ] = lit_dtypes .FeatureSalience (
149+ attribution .feature_importance ()
150+ )
143151 return outputs
144152
145153 def input_spec (self ) -> lit_types .Spec :
@@ -148,7 +156,70 @@ def input_spec(self) -> lit_types.Spec:
148156
149157 def output_spec (self ) -> lit_types .Spec :
150158 """Return a spec describing model outputs."""
151- return self ._output_types
159+ output_spec_dict = dict (self ._output_types )
160+ if self .attribution_explainer :
161+ output_spec_dict ["feature_attribution" ] = lit_types .FeatureSalience (
162+ signed = True
163+ )
164+ return output_spec_dict
165+
166+ def _load_model (self , model : str ):
167+ """Loads a TensorFlow saved model and populates the input and output signature attributes of the class.
168+ Args:
169+ model: Required. A string reference to a TensorFlow saved model directory.
170+ Raises:
171+ ValueError if the model has more than one input tensor or more than one output tensor.
172+ """
173+ self ._loaded_model = tf .saved_model .load (model )
174+ serving_default = self ._loaded_model .signatures [
175+ tf .saved_model .DEFAULT_SERVING_SIGNATURE_DEF_KEY
176+ ]
177+ _ , self ._kwargs_signature = serving_default .structured_input_signature
178+ self ._output_signature = serving_default .structured_outputs
179+
180+ if len (self ._kwargs_signature ) != 1 :
181+ raise ValueError ("Please use a model with only one input tensor." )
182+
183+ if len (self ._output_signature ) != 1 :
184+ raise ValueError ("Please use a model with only one output tensor." )
185+
186+ def _set_up_attribution_explainer (
187+ self , model : str , attribution_method : str = "integrated_gradients"
188+ ):
189+ """Populates the attribution explainer attribute of the class.
190+ Args:
191+ model: Required. A string reference to a TensorFlow saved model directory.
192+ attribution_method:
193+ Optional. A string to choose what attribution configuration to
194+ set up the explainer with. Valid options are 'sampled_shapley'
195+ or 'integrated_gradients'.
196+ """
197+ try :
198+ import explainable_ai_sdk
199+ from explainable_ai_sdk .metadata .tf .v2 import SavedModelMetadataBuilder
200+ except ImportError :
201+ logging .info (
202+ "Skipping explanations because the Explainable AI SDK is not installed."
203+ 'Please install the SDK using "pip install explainable-ai-sdk"'
204+ )
205+ return
206+
207+ builder = SavedModelMetadataBuilder (model )
208+ builder .get_metadata ()
209+ builder .set_numeric_metadata (
210+ self ._input_tensor_name ,
211+ index_feature_mapping = list (self ._input_types .keys ()),
212+ )
213+ builder .save_metadata (model )
214+ if attribution_method == "integrated_gradients" :
215+ explainer_config = explainable_ai_sdk .IntegratedGradientsConfig ()
216+ else :
217+ explainer_config = explainable_ai_sdk .SampledShapleyConfig ()
218+
219+ self ._attribution_explainer = explainable_ai_sdk .load_model_from_local_path (
220+ model , explainer_config
221+ )
222+ self ._load_model (model )
152223
153224
154225def create_lit_dataset (
@@ -172,22 +243,27 @@ def create_lit_model(
172243 model : str ,
173244 input_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
174245 output_types : "OrderedDict[str, lit_types.LitType]" , # noqa: F821
246+ attribution_method : str = "sampled_shapley" ,
175247) -> lit_model .Model :
176248 """Creates a LIT Model object.
177249 Args:
178250 model:
179- Required. A string reference to a local TensorFlow saved model directory.
180- The model must have at most one input and one output tensor.
251+ Required. A string reference to a local TensorFlow saved model directory.
252+ The model must have at most one input and one output tensor.
181253 input_types:
182- Required. An OrderedDict of string names matching the features of the model
183- as the key, and the associated LitType of the feature.
254+ Required. An OrderedDict of string names matching the features of the model
255+ as the key, and the associated LitType of the feature.
184256 output_types:
185- Required. An OrderedDict of string names matching the labels of the model
186- as the key, and the associated LitType of the label.
257+ Required. An OrderedDict of string names matching the labels of the model
258+ as the key, and the associated LitType of the label.
259+ attribution_method:
260+ Optional. A string to choose what attribution configuration to
261+ set up the explainer with. Valid options are 'sampled_shapley'
262+ or 'integrated_gradients'.
187263 Returns:
188264 A LIT Model object that has the same functionality as the model provided.
189265 """
190- return _VertexLitModel (model , input_types , output_types )
266+ return _VertexLitModel (model , input_types , output_types , attribution_method )
191267
192268
193269def open_lit (
@@ -198,11 +274,11 @@ def open_lit(
198274 """Open LIT from the provided models and datasets.
199275 Args:
200276 models:
201- Required. A list of LIT models to open LIT with.
277+ Required. A list of LIT models to open LIT with.
202278 input_types:
203- Required. A lit of LIT datasets to open LIT with.
279+ Required. A lit of LIT datasets to open LIT with.
204280 open_in_new_tab:
205- Optional. A boolean to choose if LIT open in a new tab or not.
281+ Optional. A boolean to choose if LIT open in a new tab or not.
206282 Raises:
207283 ImportError if LIT is not installed.
208284 """
@@ -216,24 +292,31 @@ def set_up_and_open_lit(
216292 model : Union [str , lit_model .Model ],
217293 input_types : Union [List [str ], Dict [str , lit_types .LitType ]],
218294 output_types : Union [str , List [str ], Dict [str , lit_types .LitType ]],
295+ attribution_method : str = "sampled_shapley" ,
219296 open_in_new_tab : bool = True ,
220297) -> Tuple [lit_dataset .Dataset , lit_model .Model ]:
221298 """Creates a LIT dataset and model and opens LIT.
222299 Args:
223- dataset:
300+ dataset:
224301 Required. A Pandas DataFrame that includes feature column names and data.
225- column_types:
302+ column_types:
226303 Required. An OrderedDict of string names matching the columns of the dataset
227304 as the key, and the associated LitType of the column.
228- model:
305+ model:
229306 Required. A string reference to a TensorFlow saved model directory.
230307 The model must have at most one input and one output tensor.
231- input_types:
308+ input_types:
232309 Required. An OrderedDict of string names matching the features of the model
233310 as the key, and the associated LitType of the feature.
234- output_types:
311+ output_types:
235312 Required. An OrderedDict of string names matching the labels of the model
236313 as the key, and the associated LitType of the label.
314+ attribution_method:
315+ Optional. A string to choose what attribution configuration to
316+ set up the explainer with. Valid options are 'sampled_shapley'
317+ or 'integrated_gradients'.
318+ open_in_new_tab:
319+ Optional. A boolean to choose if LIT open in a new tab or not.
237320 Returns:
238321 A Tuple of the LIT dataset and model created.
239322 Raises:
@@ -244,8 +327,12 @@ def set_up_and_open_lit(
244327 dataset = create_lit_dataset (dataset , column_types )
245328
246329 if not isinstance (model , lit_model .Model ):
247- model = create_lit_model (model , input_types , output_types )
330+ model = create_lit_model (
331+ model , input_types , output_types , attribution_method = attribution_method
332+ )
248333
249- open_lit ({"model" : model }, {"dataset" : dataset }, open_in_new_tab = open_in_new_tab )
334+ open_lit (
335+ {"model" : model }, {"dataset" : dataset }, open_in_new_tab = open_in_new_tab ,
336+ )
250337
251338 return dataset , model
0 commit comments