Skip to content

Commit 57d6ccd

Browse files
committed
feat: add GPU support
Signed-off-by: kallebysantos <[email protected]>
1 parent 5ed28a2 commit 57d6ccd

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

Dockerfile

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,53 @@ RUN objcopy --strip-debug \
2626
--add-gnu-debuglink=/root/edge-runtime.debug \
2727
/root/edge-runtime
2828

29-
RUN ./scripts/install_onnx.sh $ONNXRUNTIME_VERSION $TARGETPLATFORM /root/onnxruntime
30-
RUN ./scripts/download_models.sh
3129

32-
FROM debian:bookworm-slim
30+
# Application runtime without ONNX
31+
FROM debian:bookworm-slim as edge-runtime-base
3332

3433
RUN apt-get update && apt-get install -y libssl-dev && rm -rf /var/lib/apt/lists/*
3534
RUN apt-get remove -y perl && apt-get autoremove -y
3635

3736
COPY --from=builder /root/edge-runtime /usr/local/bin/edge-runtime
3837
COPY --from=builder /root/edge-runtime.debug /usr/local/bin/edge-runtime.debug
39-
COPY --from=builder /root/onnxruntime /usr/local/bin/onnxruntime
40-
COPY --from=builder /usr/src/edge-runtime/models /etc/sb_ai/models
4138

4239
ENV ORT_DYLIB_PATH=/usr/local/bin/onnxruntime/lib/libonnxruntime.so
43-
ENV SB_AI_MODELS_DIR=/etc/sb_ai/models
40+
41+
42+
# ONNX Runtime provider
43+
# Application runtime with ONNX
44+
FROM builder as ort
45+
RUN ./scripts/install_onnx.sh $ONNXRUNTIME_VERSION $TARGETPLATFORM /root/onnxruntime
46+
47+
48+
# ONNX Runtime CUDA provider
49+
# Application runtime with ONNX CUDA
50+
FROM builder as ort-cuda
51+
RUN ./scripts/install_onnx.sh $ONNXRUNTIME_VERSION $TARGETPLATFORM /root/onnxruntime --gpu
52+
53+
54+
FROM builder as preload-models
55+
RUN ./scripts/download_models.sh
56+
57+
58+
# With CUDA
59+
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 as edge-runtime-cuda
60+
61+
COPY --from=edge-runtime-base /usr/local/bin/edge-runtime /usr/local/bin/edge-runtime
62+
COPY --from=builder /root/edge-runtime.debug /usr/local/bin/edge-runtime.debug
63+
COPY --from=ort-cuda /root/onnxruntime /usr/local/bin/onnxruntime
64+
COPY --from=preload-models /usr/src/edge-runtime/models /etc/sb_ai/models
65+
66+
ENV ORT_DYLIB_PATH=/usr/local/bin/onnxruntime/lib/libonnxruntime.so
67+
ENV NVIDIA_VISIBLE_DEVICES=all
68+
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
69+
70+
ENTRYPOINT ["edge-runtime"]
71+
72+
73+
# Base
74+
FROM edge-runtime-base as edge-runtime
75+
COPY --from=ort /root/onnxruntime /usr/local/bin/onnxruntime
76+
COPY --from=preload-models /usr/src/edge-runtime/models /etc/sb_ai/models
4477

4578
ENTRYPOINT ["edge-runtime"]

crates/sb_ai/session.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ use once_cell::sync::Lazy;
33
use std::collections::HashMap;
44
use std::sync::Mutex;
55
use std::{path::PathBuf, sync::Arc};
6+
use tracing::debug;
67

78
use anyhow::{anyhow, Error};
89
use ort::{
9-
CPUExecutionProvider, ExecutionProviderDispatch, GraphOptimizationLevel, Session,
10-
SessionBuilder,
10+
CPUExecutionProvider, CUDAExecutionProvider, ExecutionProvider, ExecutionProviderDispatch,
11+
GraphOptimizationLevel, Session, SessionBuilder,
1112
};
1213

1314
use crate::onnx::ensure_onnx_env_init;
@@ -49,14 +50,32 @@ fn cpu_execution_provider() -> Box<dyn Iterator<Item = ExecutionProviderDispatch
4950
)
5051
}
5152

53+
fn cuda_execution_provider() -> Box<dyn Iterator<Item = ExecutionProviderDispatch>> {
54+
let cuda = CUDAExecutionProvider::default();
55+
let providers = match cuda.is_available() {
56+
Ok(is_cuda_available) => {
57+
debug!(cuda_support = is_cuda_available);
58+
if is_cuda_available {
59+
vec![cuda.build()]
60+
} else {
61+
vec![]
62+
}
63+
}
64+
65+
_ => vec![],
66+
};
67+
68+
Box::new(providers.into_iter().chain(cpu_execution_provider()))
69+
}
70+
5271
fn create_session(model_bytes: &[u8]) -> Result<Arc<Session>, Error> {
5372
let session = {
5473
if let Some(err) = ensure_onnx_env_init() {
5574
return Err(anyhow!("failed to create onnx environment: {err}"));
5675
}
5776

5877
get_session_builder()?
59-
.with_execution_providers(cpu_execution_provider())?
78+
.with_execution_providers(cuda_execution_provider())?
6079
.commit_from_memory(model_bytes)?
6180
};
6281

0 commit comments

Comments
 (0)