Skip to content

Commit

Permalink
Merge pull request #60 from manubot/langchain-integration
Browse files Browse the repository at this point in the history
LangChain Integration
  • Loading branch information
falquaddoomi authored Dec 3, 2024
2 parents c74f9c7 + 4c66a58 commit d7cf0c1
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 142 deletions.
2 changes: 1 addition & 1 deletion libs/manubot_ai_editor/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
OPENAI_API_KEY = "OPENAI_API_KEY"

# Language model to use. For example, "text-davinci-003", "gpt-3.5-turbo", "gpt-3.5-turbo-0301", etc
# The tool currently supports the "chat/completions", "completions", and "edits" endpoints, and you can check
# The tool currently supports the "chat/completions" and "completions" endpoints, and you can check
# compatible models here: https://platform.openai.com/docs/models/model-endpoint-compatibility
LANGUAGE_MODEL = "AI_EDITOR_LANGUAGE_MODEL"

Expand Down
128 changes: 79 additions & 49 deletions libs/manubot_ai_editor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import time
import json

import openai
from langchain_openai import OpenAI, ChatOpenAI
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage

from manubot_ai_editor import env_vars

Expand Down Expand Up @@ -141,12 +142,13 @@ def __init__(
super().__init__()

# make sure the OpenAI API key is set
openai.api_key = openai_api_key
if openai_api_key is None:
# attempt to get the OpenAI API key from the environment, since one
# wasn't specified as an argument
openai_api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None:
openai.api_key = os.environ.get(env_vars.OPENAI_API_KEY, None)

if openai.api_key is None or openai.api_key.strip() == "":
# if it's *still* not set, bail
if openai_api_key is None or openai_api_key.strip() == "":
raise ValueError(
f"OpenAI API key not found. Please provide it as parameter "
f"or set it as an the environment variable "
Expand Down Expand Up @@ -221,17 +223,14 @@ def __init__(
self.title = title
self.keywords = keywords if keywords is not None else []

# adjust options if edits or chat endpoint was selected
# adjust options if chat endpoint was selected
self.endpoint = "chat"

if model_engine.startswith(
("text-davinci-", "text-curie-", "text-babbage-", "text-ada-")
):
self.endpoint = "completions"

if "-edit-" in model_engine:
self.endpoint = "edits"

print(f"Language model: {model_engine}")
print(f"Model endpoint used: {self.endpoint}")

Expand All @@ -253,6 +252,18 @@ def __init__(

self.several_spaces_pattern = re.compile(r"\s+")

if self.endpoint == "chat":
client_cls = ChatOpenAI
else:
client_cls = OpenAI

# construct the OpenAI client after all the rest of
# the settings above have been processed
self.client = client_cls(
api_key=openai_api_key,
**self.model_parameters,
)

def get_prompt(
self, paragraph_text: str, section_name: str = None, resolved_prompt: str = None
) -> str | tuple[str, str]:
Expand All @@ -268,13 +279,9 @@ def get_prompt(
resolved_prompt: prompt resolved via ai-revision config, if available
Returns:
If self.endpoint != "edits", then returns a string with the prompt to be used by the model for the revision of the paragraph.
A string with the prompt to be used by the model for the revision of the paragraph.
It contains two paragraphs of text: the command for the model
("Revise...") and the paragraph to revise.
If self.endpoint == "edits", then returns a tuple with two strings:
1) the instructions to be used by the model for the revision of the paragraph,
2) the paragraph to revise.
"""

# prompts are resolved in the following order, with the first satisfied
Expand Down Expand Up @@ -310,8 +317,6 @@ def get_prompt(
f"Using custom prompt from environment variable '{env_vars.CUSTOM_PROMPT}'"
)

# FIXME: if {paragraph_text} is in the prompt, this won't work for the edits endpoint
# a simple workaround is to remove {paragraph_text} from the prompt
prompt = custom_prompt.format(**placeholders)
elif resolved_prompt:
# use the resolved prompt from the ai-revision config files, if available
Expand Down Expand Up @@ -384,14 +389,10 @@ def get_prompt(
if custom_prompt is None:
prompt = self.several_spaces_pattern.sub(" ", prompt).strip()

if self.endpoint != "edits":
if custom_prompt is not None and "{paragraph_text}" in custom_prompt:
return prompt
if custom_prompt is not None and "{paragraph_text}" in custom_prompt:
return prompt

return f"{prompt}.\n\n{paragraph_text.strip()}"
else:
prompt = prompt.replace("the following paragraph", "this paragraph")
return f"{prompt}.", paragraph_text.strip()
return f"{prompt}.\n\n{paragraph_text.strip()}"

def get_max_tokens(self, paragraph_text: str, fraction: float = 2.0) -> int:
"""
Expand Down Expand Up @@ -465,21 +466,30 @@ def get_max_tokens_from_error_message(error_message: str) -> dict[str, int] | No
}

def get_params(self, paragraph_text, section_name, resolved_prompt=None):
"""
Given the paragraph text and section name, produces parameters that are
used when invoking an LLM via an API.
The specific parameters vary depending on the endpoint being used, which
is determined by the model that was chosen when GPT3CompletionModel was
instantiated.
Args:
paragraph_text: The text of the paragraph to be revised.
section_name: The name of the section the paragraph belongs to.
resolved_prompt: The prompt resolved via ai-revision config files, if available.
Returns:
A dictionary of parameters to be used when invoking an LLM API.
"""
max_tokens = self.get_max_tokens(paragraph_text)
prompt = self.get_prompt(paragraph_text, section_name, resolved_prompt)

params = {
"n": 1,
}

if self.endpoint == "edits":
params.update(
{
"instruction": prompt[0],
"input": prompt[1],
}
)
elif self.endpoint == "chat":
if self.endpoint == "chat":
params.update(
{
"messages": [
Expand All @@ -502,19 +512,23 @@ def get_params(self, paragraph_text, section_name, resolved_prompt=None):

return params

def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolved_prompt=None):
def revise_paragraph(
self, paragraph_text: str, section_name: str = None, resolved_prompt=None
):
"""
It revises a paragraph using GPT-3 completion model.
Arguments:
paragraph_text (str): Paragraph text to revise.
section_name (str): Section name of the paragraph.
throw_error (bool): If True, it throws an error if the API call fails.
If False, it returns the original paragraph text.
section_name (str): Section name of the paragrap
resolved_prompt (str): Prompt resolved via ai-revision config files, if available.
Returns:
Revised paragraph text.
"""

# based on the paragraph text to revise and the section to which it
# belongs, constructs parameters that we'll use to query the LLM's API
params = self.get_params(paragraph_text, section_name, resolved_prompt)

retry_count = 0
Expand All @@ -526,17 +540,33 @@ def revise_paragraph(self, paragraph_text: str, section_name: str = None, resolv
flush=True,
)

if self.endpoint == "edits":
completions = openai.Edit.create(**params)
elif self.endpoint == "chat":
completions = openai.ChatCompletion.create(**params)
else:
completions = openai.Completion.create(**params)
# map the prompt to langchain's prompt types, based on what
# kind of endpoint we're using
if "messages" in params:
# map the messages to langchain's message types
# based on the 'role' field
prompt = [
(
HumanMessage(content=msg["content"])
if msg["role"] == "user"
else SystemMessage(content=msg["content"])
)
for msg in params["messages"]
]
elif "prompt" in params:
prompt = [HumanMessage(content=params["prompt"])]

response = self.client.invoke(
input=prompt,
max_tokens=params.get("max_tokens"),
stop=params.get("stop"),
)

if self.endpoint == "chat":
message = completions.choices[0].message.content.strip()
if isinstance(response, BaseMessage):
message = response.content.strip()
else:
message = completions.choices[0].text.strip()
message = response.strip()

except Exception as e:
error_message = str(e)
print(f"Error: {error_message}")
Expand Down Expand Up @@ -583,10 +613,10 @@ class DebuggingManuscriptRevisionModel(GPT3CompletionModel):
"""

def __init__(self, *args, **kwargs):
if 'title' not in kwargs or kwargs['title'] is None:
kwargs['title'] = "Debugging Title"
if 'keywords' not in kwargs or kwargs['keywords'] is None:
kwargs['keywords'] = ["debugging", "keywords"]
if "title" not in kwargs or kwargs["title"] is None:
kwargs["title"] = "Debugging Title"
if "keywords" not in kwargs or kwargs["keywords"] is None:
kwargs["keywords"] = ["debugging", "keywords"]

super().__init__(*args, **kwargs)

Expand Down
20 changes: 11 additions & 9 deletions libs/manubot_ai_editor/prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__(self, config_dir: str | Path, title: str, keywords: str) -> None:
# specify filename-to-prompt mappings; if both are present, we use
# self.config.files, but warn the user that they should only use one
if (
self.prompts_files is not None and
self.config is not None and
self.config.get('files', {}).get('matchings') is not None
self.prompts_files is not None
and self.config is not None
and self.config.get("files", {}).get("matchings") is not None
):
print(
"WARNING: Both 'ai-revision-config.yaml' and 'ai-revision-prompts.yaml' specify filename-to-prompt mappings. "
Expand Down Expand Up @@ -93,7 +93,7 @@ def _load_custom_prompts(self) -> tuple[dict, dict]:
# same as _load_config, if no config folder was specified, we just
if self.config_dir is None:
return (None, None)

prompt_file_path = os.path.join(self.config_dir, "ai-revision-prompts.yaml")

try:
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_prompt_for_filename(
# ai-revision-prompts.yaml specifies prompts_files, then files.matchings
# takes precedence.
# (the user is notified of this in a validation warning in __init__)

# then, consult ai-revision-config.yaml's 'matchings' collection if a
# match is found, use the prompt ai-revision-prompts.yaml
for entry in get_obj_path(self.config, ("files", "matchings"), missing=[]):
Expand All @@ -169,7 +169,10 @@ def get_prompt_for_filename(
if resolved_prompt is not None:
resolved_prompt = resolved_prompt.strip()

return ( resolved_prompt, m, )
return (
resolved_prompt,
m,
)

# since we haven't found a match yet, consult ai-revision-prompts.yaml's
# 'prompts_files' collection
Expand All @@ -185,11 +188,10 @@ def get_prompt_for_filename(
resolved_default_prompt = None
if use_default and self.prompts is not None:
resolved_default_prompt = self.prompts.get(
get_obj_path(self.config, ("files", "default_prompt")),
None
get_obj_path(self.config, ("files", "default_prompt")), None
)

if resolved_default_prompt is not None:
resolved_default_prompt = resolved_default_prompt.strip()

return (resolved_default_prompt, None)
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setuptools.setup(
name="manubot-ai-editor",
version="0.5.2",
version="0.5.3",
author="Milton Pividori",
author_email="[email protected]",
description="A Manubot plugin to revise a manuscript using GPT-3",
Expand All @@ -25,7 +25,8 @@
],
python_requires=">=3.10",
install_requires=[
"openai==0.28",
"langchain-core~=0.3.6",
"langchain-openai~=0.2.0",
"pyyaml",
],
classifiers=[
Expand Down
20 changes: 9 additions & 11 deletions tests/test_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def test_revise_methods_with_equation_that_was_alrady_revised(
# GPT3CompletionModel(None, None),
],
)
def test_revise_methods_mutator_epistasis_paper(
tmp_path, model, filename
):
def test_revise_methods_mutator_epistasis_paper(tmp_path, model, filename):
"""
This papers has several test cases:
- it ends with multiple blank lines
Expand All @@ -635,7 +633,7 @@ def test_revise_methods_mutator_epistasis_paper(
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
Briefly, we identified private single-nucleotide mutations in each BXD that were absent from all other BXDs, as well as from the C57BL/6J and DBA/2J parents.
We required each private variant to be meet the following criteria:
Expand All @@ -651,11 +649,11 @@ def test_revise_methods_mutator_epistasis_paper(
* must occur on a parental haplotype that was inherited by at least one other BXD at the same locus; these other BXDs must be homozygous for the reference allele at the variant site
%%% PARAGRAPH END %%%
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
### Extracting mutation signatures
We used SigProfilerExtractor (v.1.1.21) [@PMID:30371878] to extract mutation signatures from the BXD mutation data.
Expand All @@ -678,11 +676,11 @@ def test_revise_methods_mutator_epistasis_paper(
### Comparing mutation spectra between Mouse Genomes Project strains
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
We investigated the region implicated by our aggregate mutation spectrum distance approach on chromosome 6 by subsetting the joint-genotyped BXD VCF file (European Nucleotide Archive accession PRJEB45429 [@url:https://www.ebi.ac.uk/ena/browser/view/PRJEB45429]) using `bcftools` [@PMID:33590861].
We defined the candidate interval surrounding the cosine distance peak on chromosome 6 as the 90% bootstrap confidence interval (extending from approximately 95 Mbp to 114 Mbp).
Expand All @@ -693,7 +691,7 @@ def test_revise_methods_mutator_epistasis_paper(
java -Xmx16g -jar /path/to/snpeff/jarfile GRCm38.75 /path/to/bxd/vcf > /path/to/uncompressed/output/vcf
```
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)


Expand Down
Loading

0 comments on commit d7cf0c1

Please sign in to comment.