Skip to content

Commit

Permalink
Add filter func
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 24, 2024
1 parent 97e789f commit e985d19
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 6 additions & 2 deletions aisploit/classifiers/presidio/presidio_analyser.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from dataclasses import dataclass, field
from typing import List
from typing import Callable, List, Optional

from presidio_analyzer import AnalyzerEngine, EntityRecognizer, RecognizerResult

from ...core import BaseTextClassifier, Score


@dataclass
@dataclass(kw_only=True)
class PresidioAnalyserClassifier(BaseTextClassifier[List[RecognizerResult]]):
"""A text classifier using the Presidio Analyzer for detecting Personally Identifiable Information (PII)."""

language: str = "en"
entities: List[str] | None = None
threshold: float = 0.7
additional_recognizers: List[EntityRecognizer] = field(default_factory=list)
filter_func: Optional[Callable[[str, RecognizerResult], bool]] = None
tags: List[str] = field(default_factory=lambda: ["leakage"], init=False)

def __post_init__(self) -> None:
Expand All @@ -36,6 +37,9 @@ def score(self, input: str, _: List[str] | None = None) -> Score[List[Recognizer
"""
results = self._analyzer.analyze(text=input, entities=self.entities, language=self.language)

if self.filter_func:
results = [result for result in results if self.filter_func(input, result)]

return Score[List[RecognizerResult]](
flagged=len(results) > 0,
value=results,
Expand Down
4 changes: 2 additions & 2 deletions examples/classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -94,7 +94,7 @@
"Score(flagged=True, value=[type: PERSON, start: 11, end: 19, score: 0.85, type: PHONE_NUMBER, start: 43, end: 55, score: 0.75], description='Returns True if entities are found in the input', explanation='Found 2 entities in input')"
]
},
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
Expand Down

0 comments on commit e985d19

Please sign in to comment.