Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions max/python/max/pipelines/lib/speculative_decoding/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ def update(
self.total_acceptance_lengths += new_metrics.total_acceptance_lengths
self.num_generations += new_metrics.num_generations

def reset(self) -> None:
"""Reset all counters to zero."""
self.bonus_tokens_used = 0
self.draft_tokens_accepted = 0
self.draft_tokens_generated = 0
self.total_acceptance_lengths = 0
self.num_generations = 0

@property
def acceptance_rate(self) -> float:
"""Get the acceptance rate."""
Expand Down
14 changes: 14 additions & 0 deletions max/python/max/serve/scheduler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class BatchMetrics:

draft_tokens_generated: int
draft_tokens_accepted: int
bonus_tokens_used: int

@classmethod
def create(
Expand Down Expand Up @@ -164,13 +165,16 @@ def create(

draft_tokens_generated = 0
draft_tokens_accepted = 0
bonus_tokens_used = 0
if speculative_decoding_metrics is not None:
draft_tokens_generated = (
speculative_decoding_metrics.draft_tokens_generated
)
draft_tokens_accepted = (
speculative_decoding_metrics.draft_tokens_accepted
)
bonus_tokens_used = speculative_decoding_metrics.bonus_tokens_used
speculative_decoding_metrics.reset()

return cls(
batch_type=inputs.batch_type,
Expand Down Expand Up @@ -201,6 +205,7 @@ def create(
disk_blocks_read=disk_blocks_read,
draft_tokens_generated=draft_tokens_generated,
draft_tokens_accepted=draft_tokens_accepted,
bonus_tokens_used=bonus_tokens_used,
)

def pretty_format(self) -> str:
Expand Down Expand Up @@ -267,6 +272,15 @@ def publish_metrics(self) -> None:
METRICS.cache_hits(self.cache_hit_tokens)
METRICS.cache_misses(self.cache_miss_tokens)

if self.draft_tokens_generated > 0:
METRICS.speculative_draft_tokens_accepted(
self.draft_tokens_accepted
)
METRICS.speculative_draft_tokens_generated(
self.draft_tokens_generated
)
METRICS.speculative_bonus_tokens_used(self.bonus_tokens_used)


class SchedulerLogger:
"""Class to periodically log batch-level metrics to console."""
Expand Down
45 changes: 45 additions & 0 deletions max/python/max/serve/telemetry/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,21 @@
unit="ms",
description="Audio output length in milliseconds",
), # type: ignore
"maxserve.speculative.draft_tokens_accepted": _meter.create_counter(
"maxserve.speculative.draft_tokens_accepted",
unit="tokens",
description="Count of accepted draft tokens",
), # type: ignore
"maxserve.speculative.draft_tokens_generated": _meter.create_counter(
"maxserve.speculative.draft_tokens_generated",
unit="tokens",
description="Count of generated draft tokens",
), # type: ignore
"maxserve.speculative.bonus_tokens_used": _meter.create_counter(
"maxserve.speculative.bonus_tokens_used",
unit="tokens",
description="Count of bonus tokens used when all draft tokens accepted",
), # type: ignore
"maxserve.input_tokens_per_request": _meter.create_histogram(
"maxserve.input_tokens_per_request",
unit="tokens",
Expand Down Expand Up @@ -506,6 +521,36 @@ def audio_output_length(self, length_ms: int) -> None:
MetricLevel.DETAILED,
)

def speculative_draft_tokens_accepted(self, value: int) -> None:
self.client.send_measurement(
MaxMeasurement(
"maxserve.speculative.draft_tokens_accepted",
value,
self.extra_attributes,
),
MetricLevel.DETAILED,
)

def speculative_draft_tokens_generated(self, value: int) -> None:
self.client.send_measurement(
MaxMeasurement(
"maxserve.speculative.draft_tokens_generated",
value,
self.extra_attributes,
),
MetricLevel.DETAILED,
)

def speculative_bonus_tokens_used(self, value: int) -> None:
self.client.send_measurement(
MaxMeasurement(
"maxserve.speculative.bonus_tokens_used",
value,
self.extra_attributes,
),
MetricLevel.DETAILED,
)

def input_tokens_per_request(self, value: int) -> None:
self.client.send_measurement(
MaxMeasurement(
Expand Down
99 changes: 98 additions & 1 deletion max/tests/tests/serve/scheduler/test_scheduler_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from max.interfaces import BatchType
import numpy as np
from max.interfaces import (
BatchType,
RequestID,
TextGenerationInputs,
TokenBuffer,
)
from max.pipelines.core import TextContext
from max.pipelines.lib.speculative_decoding.utils import (
SpeculativeDecodingMetrics,
)
from max.serve.scheduler.config import TokenGenerationSchedulerConfig
from max.serve.scheduler.utils import BatchMetrics


Expand Down Expand Up @@ -45,6 +56,7 @@ def test_metric_to_string() -> None:
disk_blocks_read=0,
draft_tokens_generated=0,
draft_tokens_accepted=0,
bonus_tokens_used=0,
)

assert (
Expand All @@ -65,3 +77,88 @@ def test_metric_to_string() -> None:
metrics.pretty_format()
== r"Executed CE batch with 1 reqs | Terminated: 4 reqs, Pending: 5 reqs | Input Tokens: 6/7 toks | Context Tokens: 8/9 toks | Prompt Tput: 12.0 tok/s, Generation Tput: 13.0 tok/s | Batch creation: 10.00s, Execution: 11.00s | Draft Tokens: 5/10 (50.00%) accepted | All Preemptions: 14 reqs"
)


def _make_inputs() -> TextGenerationInputs[TextContext]:
"""Create minimal TextGenerationInputs for BatchMetrics.create()."""
ctx = TextContext(
request_id=RequestID(),
max_length=100,
tokens=TokenBuffer(np.ones(10, dtype=np.int64)),
)
return TextGenerationInputs(batches=[[ctx]], num_steps=1)


_SCHEDULER_CONFIG = TokenGenerationSchedulerConfig(
max_batch_size=4,
max_forward_steps_tg=1,
target_tokens_per_batch_ce=32,
)


def test_create_resets_speculative_metrics_between_batches() -> None:
"""BatchMetrics.create() must reset SpeculativeDecodingMetrics after
reading, so consecutive calls emit per-batch values, not cumulative
totals that would double-count when added to OTEL counters."""
spec_metrics = SpeculativeDecodingMetrics(
bonus_tokens_used=0,
draft_tokens_accepted=0,
draft_tokens_generated=0,
total_acceptance_lengths=0,
num_generations=0,
)

# Batch 1: pipeline updates metrics with 20 generated, 15 accepted.
spec_metrics.update(
SpeculativeDecodingMetrics(
bonus_tokens_used=3,
draft_tokens_accepted=15,
draft_tokens_generated=20,
total_acceptance_lengths=10,
num_generations=2,
)
)

batch1 = BatchMetrics.create(
sch_config=_SCHEDULER_CONFIG,
inputs=_make_inputs(),
kv_cache=None,
batch_creation_time_s=0.01,
batch_execution_time_s=0.05,
num_pending_reqs=0,
num_terminated_reqs=0,
total_preemption_count=0,
speculative_decoding_metrics=spec_metrics,
)

assert batch1.draft_tokens_generated == 20
assert batch1.draft_tokens_accepted == 15
assert batch1.bonus_tokens_used == 3

# Batch 2: pipeline updates metrics with 10 generated, 8 accepted.
spec_metrics.update(
SpeculativeDecodingMetrics(
bonus_tokens_used=1,
draft_tokens_accepted=8,
draft_tokens_generated=10,
total_acceptance_lengths=5,
num_generations=1,
)
)

batch2 = BatchMetrics.create(
sch_config=_SCHEDULER_CONFIG,
inputs=_make_inputs(),
kv_cache=None,
batch_creation_time_s=0.01,
batch_execution_time_s=0.05,
num_pending_reqs=0,
num_terminated_reqs=0,
total_preemption_count=0,
speculative_decoding_metrics=spec_metrics,
)

# Must reflect only batch 2's values, not batch 1 + batch 2.
assert batch2.draft_tokens_generated == 10
assert batch2.draft_tokens_accepted == 8
assert batch2.bonus_tokens_used == 1
Loading