|
28 | 28 | TransportSecurityMiddleware, |
29 | 29 | TransportSecuritySettings, |
30 | 30 | ) |
31 | | -from mcp.shared.message import ServerMessageMetadata, SessionMessage |
| 31 | +from mcp.shared.message import CloseSSEStreamCallback, ServerMessageMetadata, SessionMessage |
32 | 32 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
33 | 33 | from mcp.types import ( |
34 | 34 | DEFAULT_NEGOTIATED_VERSION, |
@@ -449,9 +449,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re |
449 | 449 | self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover |
450 | 450 | request_stream_reader = self._request_streams[request_id][1] # pragma: no cover |
451 | 451 |
|
| 452 | + # Create close_sse_stream callback (only if event store is configured) |
| 453 | + close_sse_callback = self._create_close_sse_callback(request_id) # pragma: no cover |
| 454 | + |
452 | 455 | if self.is_json_response_enabled: # pragma: no cover |
453 | 456 | # Process the message |
454 | | - metadata = ServerMessageMetadata(request_context=request) |
| 457 | + metadata = ServerMessageMetadata( |
| 458 | + request_context=request, |
| 459 | + close_sse_stream=close_sse_callback, |
| 460 | + ) |
455 | 461 | session_message = SessionMessage(message, metadata=metadata) |
456 | 462 | await writer.send(session_message) |
457 | 463 | try: |
@@ -544,7 +550,10 @@ async def sse_writer(): |
544 | 550 | async with anyio.create_task_group() as tg: |
545 | 551 | tg.start_soon(response, scope, receive, send) |
546 | 552 | # Then send the message to be processed by the server |
547 | | - metadata = ServerMessageMetadata(request_context=request) |
| 553 | + metadata = ServerMessageMetadata( |
| 554 | + request_context=request, |
| 555 | + close_sse_stream=close_sse_callback, |
| 556 | + ) |
548 | 557 | session_message = SessionMessage(message, metadata=metadata) |
549 | 558 | await writer.send(session_message) |
550 | 559 | except Exception: |
@@ -716,6 +725,28 @@ async def terminate(self) -> None: |
716 | 725 | # During cleanup, we catch all exceptions since streams might be in various states |
717 | 726 | logger.debug(f"Error closing streams: {e}") |
718 | 727 |
|
| 728 | + def _create_close_sse_callback(self, request_id: str) -> CloseSSEStreamCallback | None: |
| 729 | + """Create a callback to close the SSE stream for a request. |
| 730 | +
|
| 731 | + Only creates a callback if event store is configured (resumability enabled). |
| 732 | + The callback allows handlers to trigger client reconnection during long-running ops. |
| 733 | +
|
| 734 | + Args: |
| 735 | + request_id: The request ID to create the callback for |
| 736 | +
|
| 737 | + Returns: |
| 738 | + The callback function, or None if event store is not configured |
| 739 | + """ |
| 740 | + if not self._event_store: |
| 741 | + return None |
| 742 | + |
| 743 | + async def close_callback(retry_interval_ms: int | None = None) -> None: |
| 744 | + # Note: retry_interval_ms is accepted for API compatibility but not used |
| 745 | + # The retry interval is set via the priming event |
| 746 | + await self.close_sse_stream(request_id) |
| 747 | + |
| 748 | + return close_callback |
| 749 | + |
719 | 750 | async def close_sse_stream(self, request_id: RequestId) -> None: |
720 | 751 | """Close an SSE stream for a specific request, triggering client reconnection. |
721 | 752 |
|
|
0 commit comments