Skip to content

Commit

Permalink
models: add ability to run the model agian if the model returns an em…
Browse files Browse the repository at this point in the history
…pty paragraph
  • Loading branch information
miltondp committed Jan 1, 2023
1 parent c00bae3 commit 7e90233
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
5 changes: 5 additions & 0 deletions libs/manubot/ai_editor/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@
# AI model when revising a paragraph. For example, for the introduction, prompts
# contain sentences to preserve most of the citations to other papers.
SECTIONS_MAPPING = "AI_EDITOR_FILENAME_SECTION_MAPPING"

# Sometimes the AI model returns an empty paragraph. Usually, this is resolved
# by running again the model. The AI Editor will try three (3) times in these
# cases. This variable allows to specify the number of retries.
RETRY_COUNT = "AI_EDITOR_RETRY_COUNT"
24 changes: 20 additions & 4 deletions libs/manubot/ai_editor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
frequency_penalty: float = None,
best_of: int = None,
top_p: float = None,
retry_count: int = 3,
):
super().__init__()

Expand Down Expand Up @@ -174,7 +175,18 @@ def __init__(
best_of = int(os.environ[env_vars.BEST_OF])
print(f"Using best_of from environment variable '{env_vars.BEST_OF}'")
except ValueError:
# if it is not a float, we ignore it
# if it is not an int, we ignore it
pass

self.retry_count = retry_count
if env_vars.RETRY_COUNT in os.environ:
try:
self.retry_count = int(os.environ[env_vars.RETRY_COUNT])
print(
f"Using retry_count from environment variable '{env_vars.RETRY_COUNT}'"
)
except ValueError:
# if it is not an int, we ignore it
pass

self.title = title
Expand Down Expand Up @@ -294,7 +306,11 @@ def revise_paragraph(self, paragraph_text, section_name):

params.update(self.model_parameters)

completions = openai.Completion.create(**params)
retry_count = 0
message = ""
while message == "" and retry_count < self.retry_count:
completions = openai.Completion.create(**params)
message = completions.choices[0].text.strip()
retry_count += 1

message = completions.choices[0].text
return message.strip()
return message

0 comments on commit 7e90233

Please sign in to comment.