Skip to content

Commit

Permalink
Resolve error when steam=True
Browse files Browse the repository at this point in the history
  • Loading branch information
jaypei committed Mar 20, 2023
1 parent 00bd873 commit 6494eae
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 27 deletions.
38 changes: 12 additions & 26 deletions chatgpt_cli/chatapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import datetime
import enum
from typing import List, Dict
import typing as t

import attrs
import openai
Expand Down Expand Up @@ -36,16 +36,6 @@ def system_content_by_act_mode(act_mode: ActMode) -> str:
return "You are an English translator, spelling corrector and improver."
raise ValueError("Invalid act_mode")

# pylint: disable=line-too-long
messages = [
{"role": "system", "content": "You are a helpful, pattern-following assistant that translates corporate jargon into plain English."},
{"role": "system", "name":"example_user", "content": "New synergies will help drive top-line growth."},
{"role": "system", "name": "example_assistant", "content": "Things working well together will increase revenue."},
{"role": "system", "name":"example_user", "content": "Let's circle back when we have more bandwidth to touch base on opportunities for increased leverage."},
{"role": "system", "name": "example_assistant", "content": "Let's talk later when we're less busy about how to do better."},
{"role": "user", "content": "This late pivot means we don't have time to boil the ocean for the client deliverable."},
]


@attrs.define
class ChatMessage:
Expand All @@ -62,19 +52,19 @@ def to_message_json(self) -> dict:

class ChatSession:

histories : List[ChatMessage] = attrs.field(factory=list)
histories : t.List[ChatMessage] = attrs.field(factory=list)

def __init__(self, session_name: str, act_mode: ActMode):
self.session_name : str = session_name
self.act_mode : ActMode = act_mode
self.histories : List[ChatMessage] = []
self.histories : t.List[ChatMessage] = []
self.conversation_count : int = 0
self.histories.append(ChatMessage(
message=system_content_by_act_mode(self.act_mode),
message_type=ChatMessageType.SYSTEM,
))

def generate_query_messages(self) -> list:
def generate_query_messages(self) -> t.List:
query_messages = []
for message in self.histories:
query_messages.append(message.to_message_json())
Expand All @@ -89,7 +79,7 @@ def add_message(self, message: ChatMessage):
class ChatSessionManager:

def __init__(self):
self.sessions : Dict[str, ChatSession] = {}
self.sessions : t.Dict[str, ChatSession] = {}
self.current_session : ChatSession

def get_session(self, session_name: str) -> ChatSession:
Expand Down Expand Up @@ -127,15 +117,8 @@ def _new_chat_completion(self, stream: bool) -> openai.ChatCompletion:
raise CommandError("Rate limit exceeded", 2)

def _single_output(self, console: term.Console, response: openai.ChatCompletion) -> str:
output = []
for chunk in response:
delta_obj = chunk['choices'][0]['delta']
content = delta_obj.get("content")
if content is None:
continue
output.append(content)
message = "".join(output)
console.print(message)
message = response['choices'][0]['message']["content"]
console.print(Markdown(message))
return message

def _live_output(self, console: term.Console, response: openai.ChatCompletion) -> str:
Expand All @@ -156,8 +139,11 @@ def ask(self, question: str, stream: bool, console: term.Console) -> str:
message=question,
message_type=ChatMessageType.USER,
))
response = self._new_chat_completion(stream=stream)
answer = ""
response : openai.ChatCompletion
with term.make_progress_bar(console) as progress:
progress.add_task(":thinking_face: [green]Thinking ...", total=None)
response = self._new_chat_completion(stream=stream)
answer = ""
if not stream:
answer = self._single_output(console, response)
else:
Expand Down
2 changes: 1 addition & 1 deletion chatgpt_cli/cmds/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_question(self):

def ask_openai(self, question):
session_mgr = chatapi.get_session_manager()
session_mgr.ask(question, stream=False, console=term.console)
session_mgr.ask(question, stream=True, console=term.console)
term.console.print("\n")

def run_chat_loop(self):
Expand Down
24 changes: 24 additions & 0 deletions chatgpt_cli/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@


import os
from datetime import timedelta

from rich.console import Console
from rich.text import Text
from rich.progress import Progress, Task, ProgressColumn
from rich.emoji import EMOJI
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
Expand Down Expand Up @@ -42,3 +45,24 @@ def init():

def get_emoji(name):
return EMOJI.get(name, name)


class TimeElapsedColumn(ProgressColumn):
"""Renders time elapsed."""

def render(self, task: "Task") -> Text:
"""Show time elapsed."""
elapsed = task.finished_time if task.finished else task.elapsed
if elapsed is None:
return Text("-:--:--", style="progress.elapsed")
delta = timedelta(seconds=int(elapsed))
return Text(str(delta), style="progress.elapsed")


def make_progress_bar(client_console: Console) -> Progress:
return Progress(
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=client_console,
transient=True,
)
50 changes: 50 additions & 0 deletions tests/test_chatapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import unittest
import json
from datetime import datetime

from chatgpt_cli.chatapi import ChatMessage, ChatMessageType, ChatSession, ActMode
from unittest.mock import Mock


class TestChatMessage(unittest.TestCase):

def test_to_message_json(self):

test_message = ChatMessage(
message="You are a helpful, pattern-following assistant that translates corporate jargon into plain English.",
message_type=ChatMessageType.SYSTEM,
)

expected_output = {
"role": "system",
"content": "You are a helpful, pattern-following assistant that translates corporate jargon into plain English."
}

self.assertEqual(test_message.to_message_json(), expected_output)



class TestChatSession(unittest.TestCase):
def test_generate_query_messages(self):
chat_session = ChatSession("test_session", ActMode.ASSISTANT)
chat_session.add_message(ChatMessage("test message", ChatMessageType.USER))
chat_session.add_message(ChatMessage("system response", ChatMessageType.SYSTEM))
chat_session.add_message(ChatMessage("test message 2", ChatMessageType.USER))

expected_result = [
ChatMessage("You are a friendly and helpful teaching assistant.", ChatMessageType.SYSTEM).to_message_json(),
ChatMessage("test message", ChatMessageType.USER).to_message_json(),
ChatMessage("system response", ChatMessageType.SYSTEM).to_message_json(),
ChatMessage("test message 2", ChatMessageType.USER).to_message_json()
]
self.assertEqual(chat_session.generate_query_messages(), expected_result)

def test_add_message(self):
chat_session = ChatSession("test_session", ActMode.ASSISTANT)
chat_session.add_message(ChatMessage("test message", ChatMessageType.SYSTEM))
chat_session.add_message(ChatMessage("test message 2", ChatMessageType.USER))

self.assertEqual(len(chat_session.histories), 3) # 3 messages including initial system message
self.assertEqual(chat_session.histories[1].message, "test message")
self.assertEqual(chat_session.histories[2].message, "test message 2")
self.assertEqual(chat_session.conversation_count, 1) # one user message added

0 comments on commit 6494eae

Please sign in to comment.