Skip to content

Commit

Permalink
add commands in chat mode
Browse files Browse the repository at this point in the history
  • Loading branch information
jaypei committed Mar 30, 2023
1 parent 5f81a9d commit 799e678
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 61 deletions.
58 changes: 33 additions & 25 deletions chatgpt_cli/chatapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ class ChatMessageType(enum.Enum):
ASSISTANT = "assistant"


class ActMode(enum.Enum):
ASSISTANT = "Assistant"
TRANSLATOR = "Translator"


def system_content_by_act_mode(act_mode: ActMode) -> str:
if act_mode == ActMode.ASSISTANT:
return "You are a friendly and helpful teaching assistant."
if act_mode == ActMode.TRANSLATOR:
return "You are an English translator, spelling corrector and improver."
raise ValueError("Invalid act_mode")


@attrs.define
class ChatMessage:
message : str
Expand All @@ -54,23 +41,34 @@ class ChatSession:

histories : t.List[ChatMessage] = attrs.field(factory=list)

def __init__(self, session_name: str, act_mode: ActMode):
def __init__(self, session_name: str, prompt: str):
self.session_name : str = session_name
self.act_mode : ActMode = act_mode
self.prompt : str = prompt
self.histories : t.List[ChatMessage] = []
self.conversation_count : int = 0
self.histories.append(ChatMessage(
message=system_content_by_act_mode(self.act_mode),
message_type=ChatMessageType.SYSTEM,
))
self.no_context : bool = not config.get_config().getboolean('CLI', 'default_enable_context')

def __str__(self):
return f"ChatSession({self.session_name}, {self.prompt}, {self.no_context})"

def generate_query_messages(self) -> t.List:
query_messages = []
if self.no_context:
# reversed query for the first user message
for message in reversed(self.histories):
if message.message_type == ChatMessageType.USER:
query_messages.append(message.to_message_json())
break
return query_messages
for message in self.histories:
query_messages.append(message.to_message_json())
return query_messages

def add_message(self, message: ChatMessage):
if self.prompt and message.message_type == ChatMessageType.USER:
prompt_message = config.get_prompt_message(self.prompt)
if prompt_message:
message.message = f"{prompt_message}\n\n{message.message}"
self.histories.append(message)
if message.message_type == ChatMessageType.USER:
self.conversation_count += 1
Expand All @@ -84,7 +82,8 @@ def __init__(self):

def get_session(self, session_name: str) -> ChatSession:
if session_name not in self.sessions:
self.sessions[session_name] = ChatSession(session_name, act_mode=ActMode.ASSISTANT)
self.sessions[session_name] = ChatSession(
session_name, prompt=config.get_config()['CLI']['default_prompt'])
return self.sessions[session_name]

def switch(self, session_name: str) -> ChatSession:
Expand All @@ -96,8 +95,13 @@ def rename(self, old_name: str, new_name: str):
self.sessions[new_name] = self.sessions[old_name]
del self.sessions[old_name]

def create(self, session_name: str, auto_switch: bool=True) -> ChatSession:
new_session = ChatSession(session_name, act_mode=ActMode.ASSISTANT)
def create(
self, session_name: str, auto_switch: bool=True,
prompt: t.Optional[str]=None
) -> ChatSession:
if prompt is None:
prompt = config.get_config()['CLI']['default_prompt']
new_session = ChatSession(session_name, prompt)
self.sessions[session_name] = new_session
if auto_switch:
self.switch(session_name)
Expand All @@ -106,9 +110,13 @@ def create(self, session_name: str, auto_switch: bool=True) -> ChatSession:
def _new_chat_completion(self, stream: bool) -> openai.ChatCompletion:
try:
response = openai.ChatCompletion.create(
model=config.get_config().get('DEFAULT', 'CHATGPT_MODEL'),
model=config.get_config().get('API', 'CHATGPT_MODEL'),
messages=self.current_session.generate_query_messages(),
temperature=0,
temperature=int(config.get_config().get('API', 'TEMPERATURE')),
top_p=1,
n=1,
presence_penalty=0,
frequency_penalty=0,
stream=stream,
)
return response
Expand Down Expand Up @@ -165,7 +173,7 @@ def get_session_manager() -> ChatSessionManager:

def init():
global _session_manager
openai.api_key = config.get_config().get('DEFAULT', 'OPENAI_API_KEY')
openai.api_key = config.get_config().get('API', 'OPENAI_API_KEY')
_session_manager = ChatSessionManager()
_session_manager.create('Chat01', auto_switch=True)

Expand Down
19 changes: 18 additions & 1 deletion chatgpt_cli/cmds/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import importlib
from gettext import gettext as _

from click import Command, Context, Parameter
from click import Command, Group, Context, Parameter

from chatgpt_cli import term

Expand Down Expand Up @@ -38,6 +38,23 @@ def run(self, **kwargs) -> t.Any:
raise NotImplementedError


class BaseMultiCmd(Group):

name: str = "base_multi_cmd"
help: str = "Base multi command."
opts: t.Optional[t.List[Parameter]] = None
subcommands: t.Optional[t.List["BaseCmd"]] = None

def __init__(self, *args, **kwargs):
kwargs.setdefault("name", self.name)
kwargs.setdefault("help", self.help)
super().__init__(*args, **kwargs)
self.params.extend(self.opts or [])
if self.subcommands:
for subcmd in self.subcommands:
self.add_command(subcmd)


def load_cmd(cmd_name: str) -> BaseCmd:
# remove "Command" suffix
if not cmd_name.endswith("Command"):
Expand Down
169 changes: 153 additions & 16 deletions chatgpt_cli/cmds/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
import sys
import typing as t

from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.formatted_text import AnyFormattedText
from prompt_toolkit.completion import Completer, Completion, FuzzyCompleter

from chatgpt_cli import chatapi
from chatgpt_cli import term
from chatgpt_cli import error
from chatgpt_cli import config
from chatgpt_cli.cmds.base import BaseCmd
from chatgpt_cli.term import split_command_line, multiline_input_with_editor


CHAT_BANNER_LOGO = fr"""
Expand All @@ -26,7 +29,8 @@
CHAT_BANNER_INTRO = """\
Welcome to ChatGPT-CLI, the command-line tool for ChatGPT!
Type '/help' to see a list of available commands.
Type '/exit' or <Ctrl-D> to exit the program.
Type '/exit' or [green]Ctrl-D[/green] to exit the program.
Press [green]F10[/green] to enter the EDITOR mode, which can edit multiple lines.
"""


Expand All @@ -53,17 +57,31 @@ def print_chat_banner(self):
term.console.print(CHAT_BANNER_INTRO)

def get_question(self):
current_session = chatapi.get_session_manager().current_session
session_manager = chatapi.get_session_manager()
while True:
current_session = session_manager.current_session
session_name = current_session.session_name
conversation_count = current_session.conversation_count
conversation_count = str(current_session.conversation_count + 1)
if current_session.no_context:
conversation_count = "*"
prompt_message :AnyFormattedText = [
('class:prompt_name', f'({current_session.prompt}) '),
('class:session', f"{session_name} "),
('class:conversation_count', f'{conversation_count} '),
('class:conversation_count', f'{conversation_count} '),
('class:prompt_sep', '> '),
]
question = term.prompt.prompt(prompt_message)
yield question

kb = KeyBindings()

@kb.add("f10")
def _(event):
event.app.current_buffer.text = multiline_input_with_editor(
event.app.current_buffer.text)

question = term.prompt.prompt(
prompt_message, completer=FuzzyCompleter(CommandCompleter()),
key_bindings=kb)
return question

def ask_openai(self, question):
session_mgr = chatapi.get_session_manager()
Expand All @@ -73,18 +91,137 @@ def ask_openai(self, question):
def run_chat_loop(self):
while True:
try:
for question in self.get_question():
if question == "":
continue
if question == "/hist":
current_session = chatapi.get_session_manager().current_session
for message in current_session.histories:
term.console.print(message.to_message_json())
continue
if question in ("/exit", "/quit"):
raise error.CommandExit()
question = self.get_question()
if question == "":
continue
question_pcmd = split_command_line(question)
for cmd in PCOMMANDS:
if question_pcmd[0] == cmd["match"]:
cmd["cls"]().run(question_pcmd[1:]) # type: ignore
break
else:
self.ask_openai(question)
except EOFError:
raise error.CommandExit()
except KeyboardInterrupt:
term.console.print("If you want to exit, please press <Ctrl+D>.")


class PCommandBase:

def complete(self, cmd, args) -> t.Iterator[Completion]:
return iter([])

def run(self, args):
raise NotImplementedError()


class HistoryPCommand(PCommandBase):

def run(self, args):
session_manager = chatapi.get_session_manager()
current_session = session_manager.current_session
for message in current_session.histories:
term.console.print(message.to_message_json())


class ContextPCommand(PCommandBase):

def complete(self, cmd, args) -> t.Iterator[Completion]:
if len(args) == 0:
yield Completion("on", start_position=-len(args))
yield Completion("off", start_position=-len(args))

def run(self, args):
session_manager = chatapi.get_session_manager()
current_session = session_manager.current_session
if len(args) < 1:
return
if args[0] == "on":
current_session.no_context = False
elif args[0] == "off":
current_session.no_context = True
else:
term.console.print("Invalid argument.")


class PromptPCommand(PCommandBase):

def complete(self, cmd, args) -> t.Iterator[Completion]:
for prompt_name in config.prompts():
yield Completion(prompt_name, start_position=-len(args))

def run(self, args):
session_manager = chatapi.get_session_manager()
current_session = session_manager.current_session
if len(args) < 1:
return
prompt_name = args[0]
prompt_txt = config.get_prompt_message(prompt_name)
if not prompt_txt:
term.console.print(f"Prompt '{prompt_name}' is not found.")
return
current_session.prompt = prompt_name

class TitlePCommand(PCommandBase):

def run(self, args):
session_manager = chatapi.get_session_manager()
if len(args) < 1:
term.console.print(
f"Current title is '{session_manager.current_session.session_name}'.")
return
title_name = args[0]
session_manager.current_session.session_name = title_name


class ExitPCommand(PCommandBase):

def run(self, args):
raise error.CommandExit()


class HelpPCommand(PCommandBase):

def run(self, args):
for pcmd in PCOMMANDS:
term.console.print(f"{pcmd['match']:12}{pcmd['desc']}")


# pylint: disable=line-too-long
PCOMMANDS = [
{ "match": "/hist", "desc": "Show history", "cls": HistoryPCommand },
{ "match": "/context", "desc": "Turn on/off context. Ex. /context <on|off>", "cls": ContextPCommand },
{ "match": "/prompt", "desc": "Change prompt. Ex. /prompt <prompt-name>", "cls": PromptPCommand },
{ "match": "/title", "desc": "Change title. Ex. /title <title-name>", "cls": TitlePCommand },
{ "match": "/exit", "desc": "Exit the program", "cls": ExitPCommand },
{ "match": "/quit", "desc": "Exit the program", "cls": ExitPCommand },
{ "match": "/help", "desc": "Show help", "cls": HelpPCommand },
]


class CommandCompleter(Completer):

def get_completions(self, document, complete_event) -> t.Iterator[Completion]:
if not document.text.startswith("/"):
return
pargs = split_command_line(document.text)
if len(pargs) == 0:
return
input_cmd: str = pargs[0]
input_args: t.List[str] = pargs[1:]
matched_cls : t.Optional[t.Type[PCommandBase]] = None
for cmd in PCOMMANDS:
match: str = cmd.get("match") # type: ignore
if input_cmd == match:
matched_cls: t.Type[PCommandBase] = cmd.get("cls") # type: ignore
break
if match.startswith(input_cmd):
yield Completion(
match,
start_position=-len(input_cmd),
)
if matched_cls is None:
return
pcommand = matched_cls()
yield from pcommand.complete(input_cmd, input_args)
Loading

0 comments on commit 799e678

Please sign in to comment.