-
Notifications
You must be signed in to change notification settings - Fork 25.7k
[Inference API] Add Chat Completion to Amazon Bedrock for the Inference API #139411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Inference API] Add Chat Completion to Amazon Bedrock for the Inference API #139411
Conversation
…Completions-support' into Add-Amazon-Bedrock-Unified-Chat-Completions-support # Conflicts: # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/request/completion/ToolAwareUnifiedPublisher.java
…Completions-support' into Add-Amazon-Bedrock-Unified-Chat-Completions-support
|
Hi @jonathan-buttner, I've created a changelog YAML for you. |
| inferenceResultsListener | ||
| ); | ||
| chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener); | ||
| // Chat completions only supports streaming |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is new
| } catch (IOException e) { | ||
| listener.onFailure(new RuntimeException(e)); | ||
| } | ||
| throw new UnsupportedOperationException("Unsupported operation, use streaming execution instead"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is new. This class is only used for chat completion. Chat completion doesn't support non-streaming in general for all providers (not just bedrock) so we don't need this method.
| .modelId(amazonBedrockModel.model()) | ||
| .messages(getConverseMessageList(requestEntity.messages())) | ||
| .additionalModelResponseFieldPaths(requestEntity.additionalModelFields()); | ||
| public Flow.Publisher<StreamingUnifiedChatCompletionResults.Results> executeStreamChatCompletionRequest( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is all new
| package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion; | ||
|
|
||
| import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest; | ||
| import software.amazon.awssdk.core.document.Document; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review the changes in this class carefully.
|
|
||
| package org.elasticsearch.xpack.inference.services.amazonbedrock.request.completion; | ||
|
|
||
| import software.amazon.awssdk.core.document.Document; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please review the changes in this class carefully, it's new content after the other PR.
| /** | ||
| * The task types that the {@link InferenceAction.Request} can accept. | ||
| */ | ||
| private static final EnumSet<TaskType> SUPPORTED_INFERENCE_ACTION_TASK_TYPES = EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.COMPLETION); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is new so we can return an error if we call a chat_completion endpoint without _stream.
| action.execute(inputs, timeout, listener); | ||
| } else { | ||
| if (SUPPORTED_INFERENCE_ACTION_TASK_TYPES.contains(model.getTaskType()) == false) { | ||
| listener.onFailure(createUnsupportedTaskTypeStatusException(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New change is here
| responseString = responseString + " " + useChatCompletionUrlMessage(model); | ||
| } | ||
| listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); | ||
| listener.onFailure(createUnsupportedTaskTypeStatusException(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the same helper
| responseString = responseString + " " + useChatCompletionUrlMessage(model); | ||
| } | ||
| listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); | ||
| listener.onFailure(createUnsupportedTaskTypeStatusException(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the same helper and adding a return because we had a fall through bug here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. Is it possible to add a test that would fail without the return for this (maybe you already have)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added one
| ); | ||
| } | ||
|
|
||
| public static ElasticsearchStatusException createUnsupportedTaskTypeStatusException(Model model, EnumSet<TaskType> supportedTaskTypes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New helper
…elasticsearch into ia-bedrock-completions
| assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); | ||
| } | ||
|
|
||
| public void testExecute_ChatCompletionRequest_NonStreaming_Fails() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're lacking coverage of around returning streamed converse responses and the stream processor logic to parse them. I'll work on those in a follow up PR because we're going to need to refactor how we mock the internal client I think. Unfortunately it's not as straightforward as when we test streaming for other services that don't use an sdk like openai.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I'm fine with this being in a follow-up PR.
|
Pinging @elastic/search-inference-team (Team:Search - Inference) |
| assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); | ||
| } | ||
|
|
||
| public void testExecute_ChatCompletionRequest_NonStreaming_Fails() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I'm fine with this being in a follow-up PR.
DonalEvans
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't able to finish looking at all the test changes today, but I have a few comments/questions for the non-test code.
...earch/xpack/inference/services/amazonbedrock/client/AmazonBedrockChatCompletionExecutor.java
Outdated
Show resolved
Hide resolved
...ticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockStreamingProcessor.java
Outdated
Show resolved
Hide resolved
| var requestTaskSettings = AmazonBedrockCompletionRequestTaskSettings.fromMap(taskSettings); | ||
| var taskSettingsToUse = AmazonBedrockCompletionTaskSettings.of(completionModel.getTaskSettings(), requestTaskSettings); | ||
| return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If requestTaskSettings is equal to taskSettingsToUse, we can return the original model and avoid creating an identical object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AmazonBedrockCompletionRequestTaskSettings and AmazonBedrockCompletionTaskSettings are separate classes. I'll check that the result of AmazonBedrockCompletionTaskSettings.of is the same as the task settings from the model passed in, if so we can return the same completionModel.
...arch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java
Outdated
Show resolved
Hide resolved
...h/xpack/inference/services/amazonbedrock/completion/AmazonBedrockCompletionTaskSettings.java
Outdated
Show resolved
Hide resolved
...ch/xpack/inference/services/amazonbedrock/request/completion/AmazonBedrockConverseUtils.java
Outdated
Show resolved
Hide resolved
| responseString = responseString + " " + useChatCompletionUrlMessage(model); | ||
| } | ||
| listener.onFailure(new ElasticsearchStatusException(responseString, RestStatus.BAD_REQUEST)); | ||
| listener.onFailure(createUnsupportedTaskTypeStatusException(model, SUPPORTED_INFERENCE_ACTION_TASK_TYPES)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice catch. Is it possible to add a test that would fail without the return for this (maybe you already have)?
...java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
Outdated
Show resolved
Hide resolved
...java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
Outdated
Show resolved
Hide resolved
...java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
Outdated
Show resolved
Hide resolved
…elasticsearch into ia-bedrock-completions
This PR implements chat completion for amazon bedrock. It's based on this PR: #133697
Testing
Create the endpoint
Complex request