Skip to content

Commit

Permalink
Formatted code w/black
Browse files Browse the repository at this point in the history
  • Loading branch information
falquaddoomi committed Nov 20, 2024
1 parent 3bafcf6 commit a98f69a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 49 deletions.
20 changes: 11 additions & 9 deletions libs/manubot_ai_editor/prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def __init__(self, config_dir: str | Path, title: str, keywords: str) -> None:
# specify filename-to-prompt mappings; if both are present, we use
# self.config.files, but warn the user that they should only use one
if (
self.prompts_files is not None and
self.config is not None and
self.config.get('files', {}).get('matchings') is not None
self.prompts_files is not None
and self.config is not None
and self.config.get("files", {}).get("matchings") is not None
):
print(
"WARNING: Both 'ai-revision-config.yaml' and 'ai-revision-prompts.yaml' specify filename-to-prompt mappings. "
Expand Down Expand Up @@ -93,7 +93,7 @@ def _load_custom_prompts(self) -> tuple[dict, dict]:
# same as _load_config, if no config folder was specified, we just
if self.config_dir is None:
return (None, None)

prompt_file_path = os.path.join(self.config_dir, "ai-revision-prompts.yaml")

try:
Expand Down Expand Up @@ -150,7 +150,7 @@ def get_prompt_for_filename(
# ai-revision-prompts.yaml specifies prompts_files, then files.matchings
# takes precedence.
# (the user is notified of this in a validation warning in __init__)

# then, consult ai-revision-config.yaml's 'matchings' collection if a
# match is found, use the prompt ai-revision-prompts.yaml
for entry in get_obj_path(self.config, ("files", "matchings"), missing=[]):
Expand All @@ -169,7 +169,10 @@ def get_prompt_for_filename(
if resolved_prompt is not None:
resolved_prompt = resolved_prompt.strip()

return ( resolved_prompt, m, )
return (
resolved_prompt,
m,
)

# since we haven't found a match yet, consult ai-revision-prompts.yaml's
# 'prompts_files' collection
Expand All @@ -185,11 +188,10 @@ def get_prompt_for_filename(
resolved_default_prompt = None
if use_default and self.prompts is not None:
resolved_default_prompt = self.prompts.get(
get_obj_path(self.config, ("files", "default_prompt")),
None
get_obj_path(self.config, ("files", "default_prompt")), None
)

if resolved_default_prompt is not None:
resolved_default_prompt = resolved_default_prompt.strip()

return (resolved_default_prompt, None)
20 changes: 9 additions & 11 deletions tests/test_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def test_revise_methods_with_equation_that_was_alrady_revised(
# GPT3CompletionModel(None, None),
],
)
def test_revise_methods_mutator_epistasis_paper(
tmp_path, model, filename
):
def test_revise_methods_mutator_epistasis_paper(tmp_path, model, filename):
"""
This papers has several test cases:
- it ends with multiple blank lines
Expand All @@ -635,7 +633,7 @@ def test_revise_methods_mutator_epistasis_paper(
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
Briefly, we identified private single-nucleotide mutations in each BXD that were absent from all other BXDs, as well as from the C57BL/6J and DBA/2J parents.
We required each private variant to be meet the following criteria:
Expand All @@ -651,11 +649,11 @@ def test_revise_methods_mutator_epistasis_paper(
* must occur on a parental haplotype that was inherited by at least one other BXD at the same locus; these other BXDs must be homozygous for the reference allele at the variant site
%%% PARAGRAPH END %%%
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
### Extracting mutation signatures
We used SigProfilerExtractor (v.1.1.21) [@PMID:30371878] to extract mutation signatures from the BXD mutation data.
Expand All @@ -678,11 +676,11 @@ def test_revise_methods_mutator_epistasis_paper(
### Comparing mutation spectra between Mouse Genomes Project strains
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)

assert (
r"""
r"""
%%% PARAGRAPH START %%%
We investigated the region implicated by our aggregate mutation spectrum distance approach on chromosome 6 by subsetting the joint-genotyped BXD VCF file (European Nucleotide Archive accession PRJEB45429 [@url:https://www.ebi.ac.uk/ena/browser/view/PRJEB45429]) using `bcftools` [@PMID:33590861].
We defined the candidate interval surrounding the cosine distance peak on chromosome 6 as the 90% bootstrap confidence interval (extending from approximately 95 Mbp to 114 Mbp).
Expand All @@ -693,7 +691,7 @@ def test_revise_methods_mutator_epistasis_paper(
java -Xmx16g -jar /path/to/snpeff/jarfile GRCm38.75 /path/to/bxd/vcf > /path/to/uncompressed/output/vcf
```
""".strip()
in open(tmp_path / filename).read()
in open(tmp_path / filename).read()
)


Expand Down
94 changes: 65 additions & 29 deletions tests/test_prompt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@
from manubot_ai_editor.models import (
GPT3CompletionModel,
RandomManuscriptRevisionModel,
DebuggingManuscriptRevisionModel
DebuggingManuscriptRevisionModel,
)
from manubot_ai_editor.prompt_config import IGNORE_FILE
import pytest

from utils.dir_union import mock_unify_open

MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "content"
MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci"
MANUSCRIPTS_CONFIG_DIR = (
Path(__file__).parent / "manuscripts" / "phenoplier_full" / "ci"
)


# check that this path exists and resolve it
Expand Down Expand Up @@ -42,7 +44,9 @@ def test_create_manuscript_editor():


# check that we can resolve a file to a prompt, and that it's the correct prompt
@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)
)
def test_resolve_prompt():
content_dir = MANUSCRIPTS_DIR.resolve(strict=True)
config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True)
Expand Down Expand Up @@ -100,7 +104,9 @@ def test_resolve_prompt():

# test that we get the default prompt with a None match object for a
# file we don't recognize
@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)
)
def test_resolve_default_prompt_unknown_file():
content_dir = MANUSCRIPTS_DIR.resolve(strict=True)
config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True)
Expand All @@ -114,7 +120,9 @@ def test_resolve_default_prompt_unknown_file():

# check that a file we don't recognize gets match==None and the 'default' prompt
# from the ai-revision-config.yaml file
@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PHENOPLIER_PROMPTS_DIR)
)
def test_unresolved_gets_default_prompt():
content_dir = MANUSCRIPTS_DIR.resolve(strict=True)
config_dir = MANUSCRIPTS_CONFIG_DIR.resolve(strict=True)
Expand Down Expand Up @@ -150,7 +158,9 @@ def test_unresolved_gets_default_prompt():
# - Both ai-revision-config.yaml and ai-revision-prompts.yaml specify filename matchings
# (conflicting_promptsfiles_matchings)
CONFLICTING_PROMPTSFILES_MATCHINGS_DIR = (
Path(__file__).parent / "config_loader_fixtures" / "conflicting_promptsfiles_matchings"
Path(__file__).parent
/ "config_loader_fixtures"
/ "conflicting_promptsfiles_matchings"
)
# ---
# test ManuscriptEditor.prompt_config sub-attributes are set correctly
Expand Down Expand Up @@ -178,7 +188,9 @@ def test_no_config_unloaded():
assert editor.prompt_config.config is None


@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, ONLY_REV_PROMPTS_DIR)
)
def test_only_rev_prompts_loaded():
editor = get_editor()

Expand All @@ -188,7 +200,9 @@ def test_only_rev_prompts_loaded():
assert editor.prompt_config.config is None


@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)
)
def test_both_prompts_loaded():
editor = get_editor()

Expand All @@ -211,7 +225,8 @@ def test_single_generic_loaded():


@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR)
"builtins.open",
mock_unify_open(MANUSCRIPTS_CONFIG_DIR, CONFLICTING_PROMPTSFILES_MATCHINGS_DIR),
)
def test_conflicting_sources_warning(capfd):
"""
Expand All @@ -234,7 +249,7 @@ def test_conflicting_sources_warning(capfd):
# for this test, we define both prompts_files and files.matchings which
# creates a conflict that produces the warning we're looking for
assert editor.prompt_config.prompts_files is not None
assert editor.prompt_config.config['files']['matchings'] is not None
assert editor.prompt_config.config["files"]["matchings"] is not None

expected_warning = (
"WARNING: Both 'ai-revision-config.yaml' and "
Expand Down Expand Up @@ -262,11 +277,13 @@ def test_conflicting_sources_warning(capfd):
RandomManuscriptRevisionModel(),
DebuggingManuscriptRevisionModel(
title="Test title", keywords=["test", "keywords"]
)
),
# GPT3CompletionModel(None, None),
],
)
@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)
)
def test_revise_entire_manuscript(tmp_path, model):
print(f"\n{str(tmp_path)}\n")
me = get_editor()
Expand All @@ -284,7 +301,9 @@ def test_revise_entire_manuscript(tmp_path, model):
assert len(output_md_files) == 9


@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR))
@mock.patch(
"builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, BOTH_PROMPTS_CONFIG_DIR)
)
def test_revise_entire_manuscript_includes_title_keywords(tmp_path):
from os.path import basename

Expand Down Expand Up @@ -317,8 +336,12 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path):

with open(output_md_file, "r") as f:
content = f.read()
assert me.title in content, f"not found in filename: {basename(output_md_file)}"
assert ", ".join(me.keywords) in content, f"not found in filename: {basename(output_md_file)}"
assert (
me.title in content
), f"not found in filename: {basename(output_md_file)}"
assert (
", ".join(me.keywords) in content
), f"not found in filename: {basename(output_md_file)}"


# ==============================================================================
Expand All @@ -329,7 +352,11 @@ def test_revise_entire_manuscript_includes_title_keywords(tmp_path):
Path(__file__).parent / "config_loader_fixtures" / "prompt_propogation"
)

@mock.patch("builtins.open", mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR))

@mock.patch(
"builtins.open",
mock_unify_open(MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR),
)
def test_prompts_in_final_result(tmp_path):
"""
Tests that the prompts are making it into the final resulting .md files.
Expand All @@ -348,9 +375,7 @@ def test_prompts_in_final_result(tmp_path):
"""
me = get_editor()

model = DebuggingManuscriptRevisionModel(
title=me.title, keywords=me.keywords
)
model = DebuggingManuscriptRevisionModel(title=me.title, keywords=me.keywords)

output_folder = tmp_path
assert output_folder.exists()
Expand All @@ -361,7 +386,8 @@ def test_prompts_in_final_result(tmp_path):
files_to_prompts = {
"00.front-matter.md": "This is the front-matter prompt.",
"01.abstract.md": "This is the abstract prompt",
"02.introduction.md": "This is the introduction prompt for the paper titled '%s'." % me.title,
"02.introduction.md": "This is the introduction prompt for the paper titled '%s'."
% me.title,
# "04.00.results.md": "This is the results prompt",
"04.05.00.results_framework.md": "This is the results_framework prompt",
"04.05.01.crispr.md": "This is the crispr prompt",
Expand Down Expand Up @@ -389,15 +415,26 @@ def test_prompts_in_final_result(tmp_path):

# to save on time/cost, we use a version of the phenoplier manuscript that only
# contains the first paragraph of each section
BRIEF_MANUSCRIPTS_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "content"
BRIEF_MANUSCRIPTS_CONFIG_DIR = Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci"
BRIEF_MANUSCRIPTS_DIR = (
Path(__file__).parent
/ "manuscripts"
/ "phenoplier_full_only_first_para"
/ "content"
)
BRIEF_MANUSCRIPTS_CONFIG_DIR = (
Path(__file__).parent / "manuscripts" / "phenoplier_full_only_first_para" / "ci"
)

PROMPT_PROPOGATION_CONFIG_DIR = (
Path(__file__).parent / "config_loader_fixtures" / "prompt_gpt3_e2e"
)


@pytest.mark.cost
@mock.patch("builtins.open", mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR))
@mock.patch(
"builtins.open",
mock_unify_open(BRIEF_MANUSCRIPTS_CONFIG_DIR, PROMPT_PROPOGATION_CONFIG_DIR),
)
def test_prompts_apply_gpt3(tmp_path):
"""
Tests that the custom prompts are applied when actually applying
Expand All @@ -408,16 +445,15 @@ def test_prompts_apply_gpt3(tmp_path):
this test is marked 'cost' and requires the --runcost argument to be run,
e.g. to run just this test: `pytest --runcost -k test_prompts_apply_gpt3`.
As with test_prompts_in_final_result above, files that have no input and
As with test_prompts_in_final_result above, files that have no input and
thus no applied prompt are ignored.
"""
me = get_editor(content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR)

model = GPT3CompletionModel(
title=me.title,
keywords=me.keywords
me = get_editor(
content_dir=BRIEF_MANUSCRIPTS_DIR, config_dir=BRIEF_MANUSCRIPTS_CONFIG_DIR
)

model = GPT3CompletionModel(title=me.title, keywords=me.keywords)

output_folder = tmp_path
assert output_folder.exists()

Expand Down

0 comments on commit a98f69a

Please sign in to comment.