Run ML inference with multiple differently-trained models

Run in Google Colab View source on GitHub

Running inference with multiple differently-trained models performing the same task is useful in many scenarios, including the following examples:

  • You want to compare the performance of multiple different models.
  • You have models trained on different datasets that you want to use conditionally based on additional metadata.

In Apache Beam, the recommended way to run inference is to use the RunInference transform. By using a KeyedModelHandler, you can efficiently run inference with O(100s) of models without having to manage memory yourself.

This notebook demonstrates how to use a KeyedModelHandler to run inference in an Apache Beam pipeline with multiple different models on a per-key basis. This notebook uses pretrained pipelines from Hugging Face. Before continuing with this notebook, it is recommended that you walk through the Use RunInference in Apache Beam notebook.

Install dependencies

Install both Apache Beam and the dependencies needed by Hugging Face.

!pip install apache_beam[gcp]>=2.51.0 --quiet
!pip install torch --quiet
!pip install transformers --quiet

# To use the newly installed versions, restart the runtime.
exit()
from typing import Dict
from typing import Iterable
from typing import Tuple

from transformers import pipeline

import apache_beam as beam
from apache_beam.ml.inference.base import KeyedModelHandler
from apache_beam.ml.inference.base import KeyModelMapping
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
from apache_beam.ml.inference.base import RunInference

Define the model configurations

A model handler is the Apache Beam method used to define the configuration needed to load and invoke models. Because this example uses two models, we define two model handlers, one for each model. Because both models are incapsulated within Hugging Face pipelines, we use the model handler HuggingFacePipelineModelHandler.

For this example, load the models using Hugging Face, and then run them against an example. The models produce different outputs.

distilbert_mh = HuggingFacePipelineModelHandler('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english")
roberta_mh = HuggingFacePipelineModelHandler('text-classification', model="roberta-large-mnli")

distilbert_pipe = pipeline('text-classification', model="distilbert-base-uncased-finetuned-sst-2-english")
roberta_large_pipe = pipeline(model="roberta-large-mnli")
Downloading (…)lve/main/config.json&colon;   0%|          | 0.00/629 [00&colon;00<?, ?B/s]
Downloading model.safetensors&colon;   0%|          | 0.00/268M [00&colon;00<?, ?B/s]
Downloading (…)okenizer_config.json&colon;   0%|          | 0.00/48.0 [00&colon;00<?, ?B/s]
Downloading (…)solve/main/vocab.txt&colon;   0%|          | 0.00/232k [00&colon;00<?, ?B/s]
Downloading (…)lve/main/config.json&colon;   0%|          | 0.00/688 [00&colon;00<?, ?B/s]
Downloading model.safetensors&colon;   0%|          | 0.00/1.43G [00&colon;00<?, ?B/s]
Downloading (…)olve/main/vocab.json&colon;   0%|          | 0.00/899k [00&colon;00<?, ?B/s]
Downloading (…)olve/main/merges.txt&colon;   0%|          | 0.00/456k [00&colon;00<?, ?B/s]
Downloading (…)/main/tokenizer.json&colon;   0%|          | 0.00/1.36M [00&colon;00<?, ?B/s]
distilbert_pipe("This restaurant is awesome")
[{'label'&colon; 'POSITIVE', 'score'&colon; 0.9998743534088135}]
roberta_large_pipe("This restaurant is awesome")
[{'label'&colon; 'NEUTRAL', 'score'&colon; 0.7313134670257568}]

Define the examples

Define examples to input into the pipeline. The examples include the correct classifications.

examples = [
    ("This restaurant is awesome", "positive"),
    ("This restaurant is bad", "negative"),
    ("I feel fine", "neutral"),
    ("I love chocolate", "positive"),
]

To feed the examples into RunInference, you need distinct keys that can map to the model. In this case, to make it possible to extract the actual sentiment of the example later, define keys in the form <model_name>-<actual_sentiment>.

class FormatExamples(beam.DoFn):
  """
  Map each example to a tuple of ('<model_name>-<actual_sentiment>', 'example').
  Use these keys to map our elements to the correct models.
  """
  def process(self, element: Tuple[str, str]) -> Iterable[Tuple[str, str]]:
    yield (f'distilbert-{element[1]}', element[0])
    yield (f'roberta-{element[1]}', element[0])

Use the formatted keys to define a KeyedModelHandler that maps keys to the ModelHandler used for those keys. The KeyedModelHandler method lets you define an optional max_models_per_worker_hint, which limits the number of models that can be held in a single worker process at one time. If your worker might run out of memory, use this option. For more information about managing memory, see Use a keyed ModelHandler.

per_key_mhs = [
    KeyModelMapping(['distilbert-positive', 'distilbert-neutral', 'distilbert-negative'], distilbert_mh),
    KeyModelMapping(['roberta-positive', 'roberta-neutral', 'roberta-negative'], roberta_mh)
]
mh = KeyedModelHandler(per_key_mhs, max_models_per_worker_hint=2)

Postprocess the results

The RunInference transform returns a tuple that contains the following objects:

  • the original key
  • a PredictionResult object containing the original example and the inference Use those outputs to extract the relevant data. Then, to compare each model's prediction, group this data by the original example.
class ExtractResults(beam.DoFn):
  """
  Extract the relevant data from the PredictionResult object.
  """
  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[Tuple[str, Dict[str, str]]]:
    actual_sentiment = element[0].split('-')[1]
    model = element[0].split('-')[0]
    result = element[1]
    example = result.example
    predicted_sentiment = result.inference[0]['label']

    yield (example, {'model': model, 'actual_sentiment': actual_sentiment, 'predicted_sentiment': predicted_sentiment})

Finally, print the results produced by each model.

class PrintResults(beam.DoFn):
  """
  Print the results produced by each model along with the actual sentiment.
  """
  def process(self, element: Tuple[str, Iterable[Dict[str, str]]]):
    example = element[0]
    actual_sentiment = element[1][0]['actual_sentiment']
    predicted_sentiment_1 = element[1][0]['predicted_sentiment']
    model_1 = element[1][0]['model']
    predicted_sentiment_2 = element[1][1]['predicted_sentiment']
    model_2 = element[1][1]['model']

    if model_1 == 'distilbert':
      distilbert_prediction = predicted_sentiment_1
      roberta_prediction = predicted_sentiment_2
    else:
      roberta_prediction = predicted_sentiment_1
      distilbert_prediction = predicted_sentiment_2

    print(f'Example: {example}\nActual Sentiment: {actual_sentiment}\n'
          f'Distilbert Prediction: {distilbert_prediction}\n'
          f'Roberta Prediction: {roberta_prediction}\n------------')

Run the pipeline

To run a single Apache Beam pipeline, combine the previous steps.

with beam.Pipeline() as beam_pipeline:

  formatted_examples = (
            beam_pipeline
            | "ReadExamples" >> beam.Create(examples)
            | "FormatExamples" >> beam.ParDo(FormatExamples()))
  inferences = (
            formatted_examples
            | "Run Inference" >> RunInference(mh)
            | "ExtractResults" >> beam.ParDo(ExtractResults())
            | "GroupByExample" >> beam.GroupByKey()
  )

  inferences | beam.ParDo(PrintResults())
Example&colon; This restaurant is awesome
Actual Sentiment&colon; positive
Distilbert Prediction&colon; POSITIVE
Roberta Prediction&colon; NEUTRAL
------------
Example&colon; This restaurant is bad
Actual Sentiment&colon; negative
Distilbert Prediction&colon; NEGATIVE
Roberta Prediction&colon; NEUTRAL
------------
Example&colon; I love chocolate
Actual Sentiment&colon; positive
Distilbert Prediction&colon; POSITIVE
Roberta Prediction&colon; NEUTRAL
------------
Example&colon; I feel fine
Actual Sentiment&colon; neutral
Distilbert Prediction&colon; POSITIVE
Roberta Prediction&colon; ENTAILMENT
------------