4545from vertexai .language_models import (
4646 _language_models as tunable_models ,
4747)
48+ import warnings
4849
4950try :
5051 from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
@@ -606,18 +607,28 @@ def start_chat(
606607 self ,
607608 * ,
608609 history : Optional [List ["Content" ]] = None ,
610+ response_validation : bool = True ,
609611 ) -> "ChatSession" :
610612 """Creates a stateful chat session.
611613
612614 Args:
613615 history: Previous history to initialize the chat session.
616+ response_validation: Whether to validate responses before adding
617+ them to chat history. By default, `send_message` will raise
618+ error if the request or response is blocked or if the response
619+ is incomplete due to going over the max token limit.
620+ If set to `False`, the chat session history will always
621+ accumulate the request and response messages even if the
622+ reponse if blocked or incomplete. This can result in an unusable
623+ chat session state.
614624
615625 Returns:
616626 A ChatSession object.
617627 """
618628 return ChatSession (
619629 model = self ,
620630 history = history ,
631+ response_validation = response_validation ,
621632 )
622633
623634
@@ -628,6 +639,29 @@ def start_chat(
628639]
629640
630641
642+ def _validate_response (
643+ response : "GenerationResponse" ,
644+ request_contents : Optional [List ["Content" ]] = None ,
645+ response_chunks : Optional [List ["GenerationResponse" ]] = None ,
646+ ) -> None :
647+ candidate = response .candidates [0 ]
648+ if candidate .finish_reason not in _SUCCESSFUL_FINISH_REASONS :
649+ message = (
650+ "The model response did not completed successfully.\n "
651+ f"Finish reason: { candidate .finish_reason } .\n "
652+ f"Finish message: { candidate .finish_message } .\n "
653+ f"Safety ratings: { candidate .safety_ratings } .\n "
654+ "To protect the integrity of the chat session, the request and response were not added to chat history.\n "
655+ "To skip the response validation, specify `model.start_chat(response_validation=False)`.\n "
656+ "Note that letting blocked or otherwise incomplete responses into chat history might lead to future interactions being blocked by the service."
657+ )
658+ raise ResponseValidationError (
659+ message = message ,
660+ request_contents = request_contents ,
661+ responses = response_chunks ,
662+ )
663+
664+
631665class ChatSession :
632666 """Chat session holds the chat history."""
633667
@@ -639,15 +673,15 @@ def __init__(
639673 model : _GenerativeModel ,
640674 * ,
641675 history : Optional [List ["Content" ]] = None ,
642- raise_on_blocked : bool = True ,
676+ response_validation : bool = True ,
643677 ):
644678 if history :
645679 if not all (isinstance (item , Content ) for item in history ):
646680 raise ValueError ("history must be a list of Content objects." )
647681
648682 self ._model = model
649683 self ._history = history or []
650- self ._raise_on_blocked = raise_on_blocked
684+ self ._response_validator = _validate_response if response_validation else None
651685
652686 @property
653687 def history (self ) -> List ["Content" ]:
@@ -784,13 +818,12 @@ def _send_message(
784818 tools = tools ,
785819 )
786820 # By default we're not adding incomplete interactions to history.
787- if self ._raise_on_blocked :
788- if response .candidates [0 ].finish_reason not in _SUCCESSFUL_FINISH_REASONS :
789- raise ResponseBlockedError (
790- message = "The response was blocked." ,
791- request_contents = request_history ,
792- responses = [response ],
793- )
821+ if self ._response_validator is not None :
822+ self ._response_validator (
823+ response = response ,
824+ request_contents = request_history ,
825+ response_chunks = [response ],
826+ )
794827
795828 # Adding the request and the first response candidate to history
796829 response_message = response .candidates [0 ].content
@@ -841,13 +874,13 @@ async def _send_message_async(
841874 tools = tools ,
842875 )
843876 # By default we're not adding incomplete interactions to history.
844- if self ._raise_on_blocked :
845- if response . candidates [ 0 ]. finish_reason not in _SUCCESSFUL_FINISH_REASONS :
846- raise ResponseBlockedError (
847- message = "The response was blocked." ,
848- request_contents = request_history ,
849- responses = [ response ],
850- )
877+ if self ._response_validator is not None :
878+ self . _response_validator (
879+ response = response ,
880+ request_contents = request_history ,
881+ response_chunks = [ response ] ,
882+ )
883+
851884 # Adding the request and the first response candidate to history
852885 response_message = response .candidates [0 ].content
853886 # Response role is NOT set by the model.
@@ -905,13 +938,12 @@ def _send_message_streaming(
905938 else :
906939 full_response = chunk
907940 # By default we're not adding incomplete interactions to history.
908- if self ._raise_on_blocked :
909- if chunk .candidates [0 ].finish_reason not in _SUCCESSFUL_FINISH_REASONS :
910- raise ResponseBlockedError (
911- message = "The response was blocked." ,
912- request_contents = request_history ,
913- responses = chunks ,
914- )
941+ if self ._response_validator is not None :
942+ self ._response_validator (
943+ response = chunk ,
944+ request_contents = request_history ,
945+ response_chunks = chunks ,
946+ )
915947 yield chunk
916948 if not full_response :
917949 return
@@ -973,16 +1005,13 @@ async def async_generator():
9731005 else :
9741006 full_response = chunk
9751007 # By default we're not adding incomplete interactions to history.
976- if self ._raise_on_blocked :
977- if (
978- chunk .candidates [0 ].finish_reason
979- not in _SUCCESSFUL_FINISH_REASONS
980- ):
981- raise ResponseBlockedError (
982- message = "The response was blocked." ,
983- request_contents = request_history ,
984- responses = chunks ,
985- )
1008+ if self ._response_validator is not None :
1009+ self ._response_validator (
1010+ response = chunk ,
1011+ request_contents = request_history ,
1012+ response_chunks = chunks ,
1013+ )
1014+
9861015 yield chunk
9871016 if not full_response :
9881017 return
@@ -996,6 +1025,36 @@ async def async_generator():
9961025 return async_generator ()
9971026
9981027
1028+ class _PreviewChatSession (ChatSession ):
1029+ __doc__ = ChatSession .__doc__
1030+
1031+ # This class preserves backwards compatibility with the `raise_on_blocked` parameter.
1032+
1033+ def __init__ (
1034+ self ,
1035+ model : _GenerativeModel ,
1036+ * ,
1037+ history : Optional [List ["Content" ]] = None ,
1038+ response_validation : bool = True ,
1039+ # Deprecated
1040+ raise_on_blocked : Optional [bool ] = None ,
1041+ ):
1042+ if raise_on_blocked is not None :
1043+ warnings .warn (
1044+ message = "Use `response_validation` instead of `raise_on_blocked`."
1045+ )
1046+ if response_validation is not None :
1047+ raise ValueError (
1048+ "Cannot use `response_validation` when `raise_on_blocked` is set."
1049+ )
1050+ response_validation = raise_on_blocked
1051+ super ().__init__ (
1052+ model = model ,
1053+ history = history ,
1054+ response_validation = response_validation ,
1055+ )
1056+
1057+
9991058class ResponseBlockedError (Exception ):
10001059 def __init__ (
10011060 self ,
@@ -1008,6 +1067,10 @@ def __init__(
10081067 self .responses = responses
10091068
10101069
1070+ class ResponseValidationError (ResponseBlockedError ):
1071+ pass
1072+
1073+
10111074### Structures
10121075
10131076
0 commit comments