-
Notifications
You must be signed in to change notification settings - Fork 404
/
Copy pathazure.lua
54 lines (45 loc) · 1.54 KB
/
azure.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
---@class AvanteAzureProvider: AvanteDefaultBaseProvider
---@field deployment string
---@field api_version string
---@field temperature number
---@field max_tokens number
local Utils = require("avante.utils")
local P = require("avante.providers")
local O = require("avante.providers").openai
---@class AvanteProviderFunctor
local M = {}
M.api_key_name = "AZURE_OPENAI_API_KEY"
M.parse_messages = O.parse_messages
M.parse_response = O.parse_response
M.parse_response_without_stream = O.parse_response_without_stream
function M.parse_curl_args(provider, prompt_opts)
local provider_conf, request_body = P.parse_config(provider)
local headers = {
["Content-Type"] = "application/json",
}
if P.env.require_api_key(provider_conf) then headers["api-key"] = provider.parse_api_key() end
-- NOTE: When using "o" series set the supported parameters only
if O.is_o_series_model(provider_conf.model) then
request_body.max_tokens = nil
request_body.temperature = 1
end
return {
url = Utils.url_join(
provider_conf.endpoint,
"/openai/deployments/"
---@diagnostic disable-next-line: undefined-field
.. provider_conf.deployment
.. "/chat/completions?api-version="
---@diagnostic disable-next-line: undefined-field
.. provider_conf.api_version
),
proxy = provider_conf.proxy,
insecure = provider_conf.allow_insecure,
headers = headers,
body = vim.tbl_deep_extend("force", {
messages = M.parse_messages(prompt_opts),
stream = true,
}, request_body),
}
end
return M