Skip to content

Commit

Permalink
feat: package vllm runtime into image (kaito-project#655)
Browse files Browse the repository at this point in the history
**Reason for Change**:
<!-- What does this PR improve or fix in Kaito? Why is it needed? -->

- pack vllm runtime and chat_template files
- align the requirements of the two runtimes
- support to load chat_template for hf runtime

Signed-off-by: jerryzhuang <[email protected]>
  • Loading branch information
zhuangqh authored Oct 31, 2024
1 parent f64b35a commit 1709ba0
Show file tree
Hide file tree
Showing 20 changed files with 245 additions and 123 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/unit-tests-ragengine.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: unit-tests-ragengine

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

on:
push:
branches: [main]
paths-ignore: ["docs/**", "**.md", "**.mdx", "**.png", "**.jpg"]
pull_request:
branches: [main, release-**]
paths-ignore: ["docs/**", "**.md", "**.mdx", "**.png", "**.jpg"]

permissions:
contents: read
packages: write

env:
GO_VERSION: "1.22"

jobs:
unit-tests:
runs-on: ubuntu-latest
environment: unit-tests
steps:
- name: Harden Runner
uses: step-security/harden-runner@5c7944e73c4c2a096b17a9cb74d65b6c2bbafbde # v2.9.1
with:
egress-policy: audit

- name: Check out the code
uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
with:
submodules: true
fetch-depth: 0

- name: Run unit tests
run: |
make rag-service-test
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
- name: Run unit tests & Generate coverage
run: |
make unit-test
make rag-service-test
make tuning-metrics-server-test
- name: Run inference api e2e tests
Expand Down
20 changes: 9 additions & 11 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -97,25 +97,23 @@ unit-test: ## Run unit tests.
-race -coverprofile=coverage.txt -covermode=atomic
go tool cover -func=coverage.txt

.PHONY: virtualenv
virtualenv:
pip install virtualenv

.PHONY: rag-service-test
rag-service-test: virtualenv
./hack/run-pytest-in-venv.sh ragengine/tests ragengine/requirements.txt
rag-service-test:
pip install -r ragengine/requirements.txt
pytest -o log_cli=true -o log_cli_level=INFO ragengine/tests

.PHONY: tuning-metrics-server-test
tuning-metrics-server-test: virtualenv
./hack/run-pytest-in-venv.sh presets/tuning/text-generation/metrics presets/tuning/text-generation/requirements.txt
tuning-metrics-server-test:
pip install -r ./presets/dependencies/requirements-test.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/tuning/text-generation/metrics

## --------------------------------------
## E2E tests
## --------------------------------------

inference-api-e2e: virtualenv
./hack/run-pytest-in-venv.sh presets/inference/vllm presets/inference/vllm/requirements.txt
./hack/run-pytest-in-venv.sh presets/inference/text-generation presets/inference/text-generation/requirements.txt
inference-api-e2e:
pip install -r ./presets/dependencies/requirements-test.txt
pytest -o log_cli=true -o log_cli_level=INFO presets/inference

# Ginkgo configurations
GINKGO_FOCUS ?=
Expand Down
42 changes: 20 additions & 22 deletions docker/presets/models/tfs/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,29 @@ ARG MODEL_TYPE
ARG VERSION

# Set the working directory
WORKDIR /workspace/tfs
WORKDIR /workspace

# Write the version to a file
RUN echo $VERSION > /workspace/tfs/version.txt
COPY kaito/presets/dependencies/requirements.txt /workspace/requirements.txt

# First, copy just the preset files and install dependencies
# This is done before copying the code to utilize Docker's layer caching and
# avoid reinstalling dependencies unless the requirements file changes.
# Inference
COPY kaito/presets/inference/${MODEL_TYPE}/requirements.txt /workspace/tfs/inference-requirements.txt
RUN pip install --no-cache-dir -r inference-requirements.txt
RUN pip install --no-cache-dir -r /workspace/requirements.txt

COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py /workspace/tfs/inference_api.py
# 1. Huggingface transformers
COPY kaito/presets/inference/${MODEL_TYPE}/inference_api.py \
kaito/presets/tuning/${MODEL_TYPE}/cli.py \
kaito/presets/tuning/${MODEL_TYPE}/fine_tuning.py \
kaito/presets/tuning/${MODEL_TYPE}/parser.py \
kaito/presets/tuning/${MODEL_TYPE}/dataset.py \
kaito/presets/tuning/${MODEL_TYPE}/metrics/metrics_server.py \
/workspace/tfs/

# Fine Tuning
COPY kaito/presets/tuning/${MODEL_TYPE}/requirements.txt /workspace/tfs/tuning-requirements.txt
RUN pip install --no-cache-dir -r tuning-requirements.txt
# 2. vLLM
COPY kaito/presets/inference/vllm/inference_api.py /workspace/vllm/inference_api.py

COPY kaito/presets/tuning/${MODEL_TYPE}/cli.py /workspace/tfs/cli.py
COPY kaito/presets/tuning/${MODEL_TYPE}/fine_tuning.py /workspace/tfs/fine_tuning.py
COPY kaito/presets/tuning/${MODEL_TYPE}/parser.py /workspace/tfs/parser.py
COPY kaito/presets/tuning/${MODEL_TYPE}/dataset.py /workspace/tfs/dataset.py
# Chat template
ADD kaito/presets/inference/chat_templates /workspace/chat_templates

# Copy the metrics server
COPY kaito/presets/tuning/${MODEL_TYPE}/metrics/metrics_server.py /workspace/tfs/metrics_server.py

# Copy the entire model weights to the weights directory
COPY ${WEIGHTS_PATH} /workspace/tfs/weights
# Model weights
COPY ${WEIGHTS_PATH} /workspace/weights
RUN echo $VERSION > /workspace/version.txt && \
ln -s /workspace/weights /workspace/tfs/weights && \
ln -s /workspace/weights /workspace/vllm/weights
38 changes: 0 additions & 38 deletions hack/run-pytest-in-venv.sh

This file was deleted.

3 changes: 3 additions & 0 deletions presets/dependencies/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Requirement

Pip dependencies for both Kaito inference and tuning.
8 changes: 8 additions & 0 deletions presets/dependencies/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Common dependencies
-r requirements.txt

# Test dependencies
pytest
httpx
peft
requests
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
# Dependencies for TFS

# Core Dependencies
transformers==4.41.2
torch==2.2.0
vllm==0.6.3
transformers >= 4.45.0
torch==2.4.0
accelerate==0.30.1
fastapi>=0.111.0,<0.112.0 # Allow patch updates
pydantic==2.7.4
pydantic>=2.9
uvicorn[standard]>=0.29.0,<0.30.0 # Allow patch updates
uvloop
peft==0.11.1
numpy==1.22.4
numpy<3.0,>=1.25.0
sentencepiece==0.2.0
jinja2>=3.1.0

# Utility libraries
datasets==2.19.1
peft==0.11.1
bitsandbytes==0.42.0
sentencepiece==0.2.0

# Less critical, can be latest
gputil
psutil
# For UTs
pytest
httpx
peft
trl
20 changes: 20 additions & 0 deletions presets/inference/chat_templates/falcon-instruct.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{% if messages[0]['role'] == 'system' %}
{% set system_message = messages[0]['content'] %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}

{{ system_message | trim }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{% set content = message['content'].replace('\r\n', '\n').replace('\n\n', '\n') %}
{{ '\n\n' + message['role'] | capitalize + ': ' + content | trim }}
{% endfor %}

{% if add_generation_prompt %}
{{ '\n\nAssistant:' }}
{% endif %}
24 changes: 24 additions & 0 deletions presets/inference/chat_templates/llama-2-chat.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{% if messages[0]['role'] == 'system' %}
{% set system_message = '<<SYS>>\n' + messages[0]['content'] | trim + '\n<</SYS>>\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}

{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{% if loop.index0 == 0 %}
{% set content = system_message + message['content'] %}
{% else %}
{% set content = message['content'] %}
{% endif %}

{% if message['role'] == 'user' %}
{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + content | trim + ' ' + eos_token }}
{% endif %}
{% endfor %}
18 changes: 18 additions & 0 deletions presets/inference/chat_templates/llama-3-instruct.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% if messages[0]['role'] == 'system' %}
{% set offset = 1 %}
{% else %}
{% set offset = 0 %}
{% endif %}

{{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
{% endfor %}

{% if add_generation_prompt %}
{{ '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
{% endif %}
19 changes: 19 additions & 0 deletions presets/inference/chat_templates/mistral-instruct.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{% if messages[0]['role'] == 'system' %}
{% set system_message = messages[0]['content'] | trim + '\n\n' %}
{% set messages = messages[1:] %}
{% else %}
{% set system_message = '' %}
{% endif %}

{{ bos_token + system_message}}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{% if message['role'] == 'user' %}
{{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
{% elif message['role'] == 'assistant' %}
{{ ' ' + message['content'] | trim + eos_token }}
{% endif %}
{% endfor %}
18 changes: 18 additions & 0 deletions presets/inference/chat_templates/phi-3-small.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% if messages[0]['role'] == 'system' %}
{% set offset = 1 %}
{% else %}
{% set offset = 0 %}
{% endif %}

{{ bos_token }}
{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{{ '<|' + message['role'] + '|>\n' + message['content'] | trim + '<|end|>' + '\n' }}
{% endfor %}

{% if add_generation_prompt %}
{{ '<|assistant|>\n' }}
{% endif %}
17 changes: 17 additions & 0 deletions presets/inference/chat_templates/phi-3.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{% if messages[0]['role'] == 'system' %}
{% set offset = 1 %}
{% else %}
{% set offset = 0 %}
{% endif %}

{% for message in messages %}
{% if (message['role'] == 'user') != (loop.index0 % 2 == offset) %}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif %}

{{ '<|' + message['role'] + '|>\n' + message['content'] | trim + '<|end|>' + '\n' }}
{% endfor %}

{% if add_generation_prompt %}
{{ '<|assistant|>\n' }}
{% endif %}
Loading

0 comments on commit 1709ba0

Please sign in to comment.