Skip to content

Commit 426773f

Browse files
Add close_sse_stream callback to request context for ergonomic stream closing
This change refactors the SSE stream closing API to follow the TypeScript SDK pattern (PR #1166), making it easier for tool handlers to trigger client reconnection during long-running operations. Changes: - Add CloseSSEStreamCallback type alias to message.py - Add close_sse_stream field to ServerMessageMetadata and RequestContext - Create callback closure in StreamableHTTPServerTransport._handle_post_request - Thread callback through lowlevel/server.py to RequestContext - Expose ctx.close_sse_stream() method in FastMCP Context class - Update tests to use new callback API instead of session_manager_ref hack - Fix dead code: remove unused return values from client _handle_sse_response Usage in FastMCP tools: @mcp.tool() async def long_running_tool(ctx: Context) -> str: await ctx.close_sse_stream() # Trigger client reconnection # Continue processing... return "Done"
1 parent 21ce52a commit 426773f

File tree

7 files changed

+71
-22
lines changed

7 files changed

+71
-22
lines changed

src/mcp/client/streamable_http.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,14 +380,8 @@ async def _handle_sse_response(
380380
ctx: RequestContext,
381381
is_initialization: bool = False,
382382
attempt: int = 0,
383-
) -> tuple[bool, str | None]:
384-
"""Handle SSE response from the server with automatic reconnection.
385-
386-
Returns:
387-
Tuple of (has_priming_event, last_event_id) where:
388-
- has_priming_event: True if any event had an ID (priming event received)
389-
- last_event_id: The last event ID received, for resumption
390-
"""
383+
) -> None:
384+
"""Handle SSE response from the server with automatic reconnection."""
391385
has_priming_event = False
392386
last_event_id: str | None = None
393387
is_complete = False
@@ -422,8 +416,6 @@ async def _handle_sse_response(
422416
if not is_complete and has_priming_event and last_event_id: # pragma: no cover
423417
await self._attempt_sse_reconnection(ctx, last_event_id, attempt)
424418

425-
return has_priming_event, last_event_id
426-
427419
async def _attempt_sse_reconnection( # pragma: no cover
428420
self,
429421
ctx: RequestContext,

src/mcp/server/fastmcp/server.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,3 +1303,24 @@ async def warning(self, message: str, **extra: Any) -> None:
13031303
async def error(self, message: str, **extra: Any) -> None:
13041304
"""Send an error log message."""
13051305
await self.log("error", message, **extra)
1306+
1307+
async def close_sse_stream(self, retry_interval_ms: int | None = None) -> None:
1308+
"""Close the SSE stream for this request, triggering client reconnection.
1309+
1310+
Use this to implement polling behavior during long-running operations.
1311+
The client will reconnect after the retry interval and receive any events
1312+
that were stored while disconnected.
1313+
1314+
This is only available when using StreamableHTTP transport with an
1315+
event store configured for resumability. It is a no-op otherwise,
1316+
allowing portable tool code.
1317+
1318+
Args:
1319+
retry_interval_ms: Optional retry interval in milliseconds to suggest
1320+
to the client before closing (currently unused,
1321+
reserved for future use - the retry interval is
1322+
configured at the transport level)
1323+
"""
1324+
callback = self._request_context.close_sse_stream if self._request_context else None
1325+
if callback:
1326+
await callback(retry_interval_ms)

src/mcp/server/lowlevel/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -680,12 +680,14 @@ async def _handle_request(
680680

681681
token = None
682682
try:
683-
# Extract request context from message metadata
683+
# Extract request context and close_sse_stream callback from metadata
684684
request_data = None
685+
close_sse_stream_callback = None
685686
if message.message_metadata is not None and isinstance(
686687
message.message_metadata, ServerMessageMetadata
687688
): # pragma: no cover
688689
request_data = message.message_metadata.request_context
690+
close_sse_stream_callback = message.message_metadata.close_sse_stream
689691

690692
# Set our global state that can be retrieved via
691693
# app.get_request_context()
@@ -696,6 +698,7 @@ async def _handle_request(
696698
session,
697699
lifespan_context,
698700
request=request_data,
701+
close_sse_stream=close_sse_stream_callback,
699702
)
700703
)
701704
response = await handler(req)

src/mcp/server/streamable_http.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
TransportSecurityMiddleware,
2929
TransportSecuritySettings,
3030
)
31-
from mcp.shared.message import ServerMessageMetadata, SessionMessage
31+
from mcp.shared.message import CloseSSEStreamCallback, ServerMessageMetadata, SessionMessage
3232
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
3333
from mcp.types import (
3434
DEFAULT_NEGOTIATED_VERSION,
@@ -449,9 +449,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
449449
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover
450450
request_stream_reader = self._request_streams[request_id][1] # pragma: no cover
451451

452+
# Create close_sse_stream callback (only if event store is configured)
453+
close_sse_callback = self._create_close_sse_callback(request_id) # pragma: no cover
454+
452455
if self.is_json_response_enabled: # pragma: no cover
453456
# Process the message
454-
metadata = ServerMessageMetadata(request_context=request)
457+
metadata = ServerMessageMetadata(
458+
request_context=request,
459+
close_sse_stream=close_sse_callback,
460+
)
455461
session_message = SessionMessage(message, metadata=metadata)
456462
await writer.send(session_message)
457463
try:
@@ -544,7 +550,10 @@ async def sse_writer():
544550
async with anyio.create_task_group() as tg:
545551
tg.start_soon(response, scope, receive, send)
546552
# Then send the message to be processed by the server
547-
metadata = ServerMessageMetadata(request_context=request)
553+
metadata = ServerMessageMetadata(
554+
request_context=request,
555+
close_sse_stream=close_sse_callback,
556+
)
548557
session_message = SessionMessage(message, metadata=metadata)
549558
await writer.send(session_message)
550559
except Exception:
@@ -716,6 +725,28 @@ async def terminate(self) -> None:
716725
# During cleanup, we catch all exceptions since streams might be in various states
717726
logger.debug(f"Error closing streams: {e}")
718727

728+
def _create_close_sse_callback(self, request_id: str) -> CloseSSEStreamCallback | None:
729+
"""Create a callback to close the SSE stream for a request.
730+
731+
Only creates a callback if event store is configured (resumability enabled).
732+
The callback allows handlers to trigger client reconnection during long-running ops.
733+
734+
Args:
735+
request_id: The request ID to create the callback for
736+
737+
Returns:
738+
The callback function, or None if event store is not configured
739+
"""
740+
if not self._event_store:
741+
return None
742+
743+
async def close_callback(retry_interval_ms: int | None = None) -> None:
744+
# Note: retry_interval_ms is accepted for API compatibility but not used
745+
# The retry interval is set via the priming event
746+
await self.close_sse_stream(request_id)
747+
748+
return close_callback
749+
719750
async def close_sse_stream(self, request_id: RequestId) -> None:
720751
"""Close an SSE stream for a specific request, triggering client reconnection.
721752

src/mcp/shared/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from typing_extensions import TypeVar
55

6+
from mcp.shared.message import CloseSSEStreamCallback
67
from mcp.shared.session import BaseSession
78
from mcp.types import RequestId, RequestParams
89

@@ -18,3 +19,4 @@ class RequestContext(Generic[SessionT, LifespanContextT, RequestT]):
1819
session: SessionT
1920
lifespan_context: LifespanContextT
2021
request: RequestT | None = None
22+
close_sse_stream: CloseSSEStreamCallback | None = None

src/mcp/shared/message.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
ResumptionTokenUpdateCallback = Callable[[ResumptionToken], Awaitable[None]]
1616

17+
# Callback type for closing SSE streams (takes optional retry_interval_ms)
18+
CloseSSEStreamCallback = Callable[[int | None], Awaitable[None]]
19+
1720

1821
@dataclass
1922
class ClientMessageMetadata:
@@ -30,6 +33,8 @@ class ServerMessageMetadata:
3033
related_request_id: RequestId | None = None
3134
# Request-specific context (e.g., headers, auth info)
3235
request_context: object | None = None
36+
# Callback to close SSE stream for this request (triggers client reconnection)
37+
close_sse_stream: CloseSSEStreamCallback | None = None
3338

3439

3540
MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None

tests/shared/test_streamable_http.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,9 @@ async def handle_call_tool(name: str, args: dict[str, Any]) -> list[TextContent]
269269
related_request_id=ctx.request_id,
270270
)
271271

272-
# Trigger server-initiated SSE disconnect
273-
if self._session_manager_ref:
274-
session_manager = self._session_manager_ref[0]
275-
request = ctx.request
276-
if isinstance(request, Request):
277-
session_id = request.headers.get("mcp-session-id")
278-
if session_id:
279-
await session_manager.close_sse_stream(session_id, ctx.request_id)
272+
# Trigger server-initiated SSE disconnect using the callback
273+
if ctx.close_sse_stream:
274+
await ctx.close_sse_stream(None)
280275

281276
# Wait a bit for client to reconnect
282277
await anyio.sleep(0.2)

0 commit comments

Comments
 (0)