Skip to content

Commit

Permalink
fix(sb_ai): make sure the shared session is initialized before runnin…
Browse files Browse the repository at this point in the history
…g inferences (#439)
  • Loading branch information
nyannyacha authored Nov 8, 2024
1 parent b10fb3d commit 8095acf
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
9 changes: 7 additions & 2 deletions crates/sb_ai/js/ai.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ const parseJSONOverEventStream = async function* (itr, signal) {

class Session {
model;
init;
is_ext_inference_api;
inferenceAPIHost;

Expand All @@ -94,7 +95,7 @@ class Session {
this.is_ext_inference_api = false;

if (model === 'gte-small') {
core.ops.op_sb_ai_init_model(model);
this.init = core.ops.op_sb_ai_init_model(model);
} else {
this.inferenceAPIHost = core.ops.op_get_env('AI_INFERENCE_API_HOST');
this.is_ext_inference_api = !!this.inferenceAPIHost; // only enable external inference API if env variable is set
Expand Down Expand Up @@ -183,7 +184,7 @@ class Session {
case 'openaicompatible': {
const finishReason = message.choices[0].finish_reason;

if (!!finishReason) {
if (finishReason) {
if (finishReason !== 'stop') {
throw new Error('Expected a completed response.');
}
Expand Down Expand Up @@ -226,6 +227,10 @@ class Session {
}
}

if (this.init) {
await this.init;
}

const mean_pool = opts.mean_pool ?? true;
const normalize = opts.normalize ?? true;
const result = await core.ops.op_sb_ai_run_model(
Expand Down
85 changes: 53 additions & 32 deletions crates/sb_ai/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod onnx;
mod session;

use anyhow::{bail, Error};
use anyhow::{anyhow, bail, Error};
use base_rt::BlockingScopeCPUUsageMetricExt;
use deno_core::error::AnyError;
use deno_core::OpState;
Expand All @@ -15,7 +15,7 @@ use std::path::Path;
use std::rc::Rc;
use std::sync::Arc;
use tokenizers::Tokenizer;
use tokio::sync::mpsc;
use tokio::sync::{mpsc, oneshot};
use tokio::task;

use tracing::{error, trace_span};
Expand Down Expand Up @@ -50,12 +50,18 @@ fn mean_pool(last_hidden_states: ArrayView3<f32>, attention_mask: ArrayView3<i64
sum_hidden_states / sum_attention_mask
}

fn init_gte(state: &mut OpState) -> Result<(), Error> {
let spawner = state.borrow::<V8TaskSpawner>().clone();
let cross_thread_spawner = state.borrow::<V8CrossThreadTaskSpawner>().clone();

async fn init_gte(state: Rc<RefCell<OpState>>) -> Result<(), Error> {
let models_dir = std::env::var("SB_AI_MODELS_DIR").unwrap_or("/etc/sb_ai/models".to_string());

let (tx, rx) = oneshot::channel::<Option<Error>>();
let (spawner, cross_thread_spawner) = {
let state = state.borrow_mut();
let spawner = state.borrow::<V8TaskSpawner>().clone();
let cross_thread_spawner = state.borrow::<V8CrossThreadTaskSpawner>().clone();

(spawner, cross_thread_spawner)
};

spawner.spawn(move |scope| {
let state = JsRuntime::op_state_from(scope);
let mut state = state.borrow_mut();
Expand All @@ -65,37 +71,42 @@ fn init_gte(state: &mut OpState) -> Result<(), Error> {
let _ = state.try_take::<mpsc::UnboundedSender<GteModelRequest>>();

state.put::<mpsc::UnboundedSender<GteModelRequest>>(req_tx);

req_rx
};

let session =
load_session_from_file(Path::new(&models_dir).join("gte-small").join("model.onnx"));

if session.is_err() {
let err = session.as_ref().unwrap_err();
error!(reason = ?err, "failed to create session");
return;
}
let (_, session) = match load_session_from_file(
Path::new(&models_dir).join("gte-small").join("model.onnx"),
) {
Ok(session) => session,
Err(err) => {
error!(reason = ?err, "failed to create session");
let _ = tx.send(Some(err));
return;
}
};

let (_, session) = session.unwrap();
let tokenizer = Tokenizer::from_file(
let mut tokenizer = match Tokenizer::from_file(
Path::new(&models_dir)
.join("gte-small")
.join("tokenizer.json"),
)
.map_err(anyhow::Error::msg);

if tokenizer.is_err() {
let err = tokenizer.as_ref().unwrap_err();
error!(reason = ?err, "failed to create tokenizer");
return;
}

let mut tokenizer = tokenizer.unwrap();
.map_err(anyhow::Error::msg)
{
Ok(tokenizer) => tokenizer,
Err(err) => {
error!(reason = ?err, "failed to create tokenizer");
let _ = tx.send(Some(err));
return;
}
};

// model's default max length is 128. Increase it to 512.
let truncation = tokenizer.get_truncation_mut().unwrap();
let Some(truncation) = tokenizer.get_truncation_mut() else {
let err = anyhow!("failed to get mutable truncation parameter");
error!(reason = ?err);
let _ = tx.send(Some(err));
return;
};

truncation.max_length = 512;

Expand Down Expand Up @@ -155,6 +166,7 @@ fn init_gte(state: &mut OpState) -> Result<(), Error> {
} else {
result
};

Ok(result.view().to_slice().unwrap().to_vec())
},
);
Expand Down Expand Up @@ -182,10 +194,16 @@ fn init_gte(state: &mut OpState) -> Result<(), Error> {
});
});
}
}))
}));

let _ = tx.send(None);
});

Ok(())
let Some(err) = rx.await.map_err(AnyError::from)? else {
return Ok(());
};

Err(err)
}

async fn run_gte(
Expand Down Expand Up @@ -217,11 +235,14 @@ async fn run_gte(
result.unwrap()
}

#[op2]
#[op2(async)]
#[serde]
pub fn op_sb_ai_init_model(state: &mut OpState, #[string] name: String) -> Result<(), AnyError> {
pub async fn op_sb_ai_init_model(
state: Rc<RefCell<OpState>>,
#[string] name: String,
) -> Result<(), AnyError> {
if name == "gte-small" {
init_gte(state)
init_gte(state).await
} else {
bail!("model not supported")
}
Expand Down

0 comments on commit 8095acf

Please sign in to comment.