Skip to content

Commit

Permalink
Merge pull request RasaHQ#8349 from RasaHQ/bug-fix-7379
Browse files Browse the repository at this point in the history
Count only nlu training examples
  • Loading branch information
ancalita authored Apr 7, 2021
2 parents f598e6c + 4a6e816 commit 30b8991
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 1 deletion.
1 change: 1 addition & 0 deletions changelog/7379.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Change training data validation to only count nlu training examples.
13 changes: 13 additions & 0 deletions data/test_number_nlu_examples/nlu.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
version: "2.0"

nlu:
- intent: greet
examples: |
- Hello!
- Howdy!
- intent: ask_weather
examples: |
- What's the weather like today?
- Does it look sunny outside today?
- Oh, do you mind checking the weather for me please?
11 changes: 11 additions & 0 deletions data/test_number_nlu_examples/rules.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

rules:
- rule: provide_weather
steps:
- intent: ask_weather
- action: utter_weather

- rule: greet
steps:
- intent: greet
- action: utter_greet
8 changes: 8 additions & 0 deletions data/test_number_nlu_examples/stories.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

stories:
- story: simple_story
steps:
- intent: greet
- action: utter_greet
- intent: ask_weather
- action: utter_weather
2 changes: 1 addition & 1 deletion rasa/shared/nlu/training_data/training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def retrieval_intents(self) -> Set[Text]:
@lazy_property
def number_of_examples_per_intent(self) -> Dict[Text, int]:
"""Calculates the number of examples per intent."""
intents = [ex.get(INTENT) for ex in self.training_examples]
intents = [ex.get(INTENT) for ex in self.nlu_examples]
return dict(Counter(intents))

@lazy_property
Expand Down
74 changes: 74 additions & 0 deletions tests/shared/nlu/training_data/test_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,80 @@ def test_train_test_split(filepaths: List[Text]):
)


def test_number_of_examples_per_intent():
message_action = Message(data={"action_name": "utter_greet"})
message_intent = Message(
data={"text": "I would like the newsletter", "intent": "subscribe"}
)
message_non_nlu_intent = Message(data={"intent": "subscribe"})
message_other_intent_one = Message(
data={"text": "What is the weather like today?", "intent": "ask_weather"}
)
message_other_intent_two = Message(
data={"text": "Will it rain today?", "intent": "ask_weather"}
)
message_non_nlu_other_intent_three = Message(data={"intent": "ask_weather"})

training_examples = [
message_action,
message_intent,
message_non_nlu_intent,
message_other_intent_one,
message_other_intent_two,
message_non_nlu_other_intent_three,
]
training_data = TrainingData(training_examples=training_examples)

assert training_data.number_of_examples_per_intent["subscribe"] == 1
assert training_data.number_of_examples_per_intent["ask_weather"] == 2


async def test_number_of_examples_per_intent_with_yaml(tmp_path: Path):
domain_path = tmp_path / "domain.yml"
domain_path.write_text(Domain.empty().as_yaml())

config_path = tmp_path / "config.yml"
config_path.touch()

importer = TrainingDataImporter.load_from_dict(
{},
str(config_path),
str(domain_path),
[
"data/test_number_nlu_examples/nlu.yml",
"data/test_number_nlu_examples/stories.yml",
"data/test_number_nlu_examples/rules.yml",
],
)

training_data = await importer.get_nlu_data()
assert training_data.intents == {"greet", "ask_weather"}
assert training_data.number_of_examples_per_intent["greet"] == 2
assert training_data.number_of_examples_per_intent["ask_weather"] == 3


def test_validate_number_of_examples_per_intent():
message_intent = Message(
data={"text": "I would like the newsletter", "intent": "subscribe"}
)
message_non_nlu_intent = Message(data={"intent": "subscribe"})

training_examples = [
message_intent,
message_non_nlu_intent,
]
training_data = TrainingData(training_examples=training_examples)

with pytest.warns(Warning) as w:
training_data.validate()

assert len(w) == 1
assert (
w[0].message.args[0] == "Intent 'subscribe' has only 1 training examples! "
"Minimum is 2, training may fail."
)


@pytest.mark.parametrize(
"filepaths",
[
Expand Down

0 comments on commit 30b8991

Please sign in to comment.