Last active
December 12, 2022 06:10
-
-
Save xiabingquan/18c008cc8642eeeccd7ba0c750af4d3f to your computer and use it in GitHub Desktop.
Revisions
-
xiabingquan revised this gist
Dec 12, 2022 . 1 changed file with 19 additions and 5 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -60,38 +60,51 @@ def edit_dist_dp(gt, hyp): def analyze(gt, hyp, ops, d, text=None): errs = { "rep": 0, "ins": 0, "del": 0, # "correct": 0 } i, j = len(gt), len(hyp) gt = [''] + gt hyp = [''] + hyp s = [] while i > 0 or j > 0: # exit: i == 0 and j == 0 if ops[i][j] == "keep": s.append((hyp[j], "underline")) i -= 1 j -= 1 # errs["correct"] += 1 elif ops[i][j] == "rep": s.append((f"【{hyp[j]}->{gt[i]}】", '')) i -= 1 j -= 1 errs["rep"] += 1 elif ops[i][j] == "ins": s.append((f"\\{gt[i]}/", '')) i -= 1 errs["ins"] += 1 elif ops[i][j] == "del": s.append((hyp[j], "strike")) j -= 1 errs["del"] += 1 else: raise ValueError assert sum(errs.values()) == d if text is None: text = Text() text.append(f"gt: {' '.join(gt)}\n") text.append(f"hyp: {' '.join(hyp)}\n") text.append("comparison: ") for t in s[::-1]: text.append(t[0], style=t[1]) text.append(' ') text.append('\n') text.append(f"err/num: {d}/{len(gt) - 1} -> ") text.append(', '.join([f"{k}: {v}" for k, v in errs.items()])) return text, errs if __name__ == "__main__": @@ -107,4 +120,5 @@ def analyze(gt, hyp, ops, d, text=None): assert d == editdistance.eval(gt, hyp) with Console() as console: text, err = analyze(gt, hyp, ops, d) console.print(text) -
xiabingquan revised this gist
Dec 12, 2022 . 1 changed file with 5 additions and 4 deletions.There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -59,7 +59,7 @@ def edit_dist_dp(gt, hyp): return dp[m][n], ops def analyze(gt, hyp, ops, d, text=None): i, j = len(gt), len(hyp) gt = [''] + gt hyp = [''] + hyp @@ -70,19 +70,20 @@ def analyze(gt, hyp, ops, d): i -= 1 j -= 1 elif ops[i][j] == "rep": s.append((f"【{hyp[j]}->{gt[i]}】", '')) i -= 1 j -= 1 elif ops[i][j] == "ins": s.append((f"\\{gt[i]}/", '')) i -= 1 elif ops[i][j] == "del": s.append((hyp[j], "strike")) j -= 1 else: raise ValueError if text is None: text = Text() text.append(f"gt: {' '.join(gt)}\n") text.append(f"hyp: {' '.join(hyp)}\n") text.append("analysis: ") -
xiabingquan created this gist
Dec 12, 2022 .There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,109 @@ from argparse import ArgumentParser import editdistance from rich.text import Text from rich.console import Console def edit_dist_dp(gt, hyp): """ A Dynamic Programming based Python program for edit distance problem References: https://www.geeksforgeeks.org/edit-distance-dp-5/ Args: gt: hyp: Returns: """ m, n = len(gt), len(hyp) # Create a table to store results of subproblems dp = [[0 for x in range(n + 1)] for x in range(m + 1)] ops = [['' for x in range(n + 1)] for x in range(m + 1)] # Fill d[][] in bottom up manner for i in range(m + 1): for j in range(n + 1): # If first string is empty, only option is to # insert all characters of second string if i == 0: dp[i][j] = j # Min. operations = j ops[i][j] = "del" # If second string is empty, only option is to # remove all characters of second string elif j == 0: dp[i][j] = i # Min. operations = i ops[i][j] = "ins" # If last characters are same, ignore last char # and recur for remaining string elif gt[i - 1] == hyp[j - 1]: dp[i][j] = dp[i - 1][j - 1] ops[i][j] = "keep" # If last character are different, consider all # possibilities and find minimum else: x = min(dp[i][j - 1], dp[i - 1][j], dp[i - 1][j - 1]) dp[i][j] = 1 + x if dp[i][j - 1] == x: # delete ops[i][j] = "del" if dp[i - 1][j] == x: # insert ops[i][j] = "ins" if dp[i - 1][j - 1] == x: # replace ops[i][j] = "rep" return dp[m][n], ops def analyze(gt, hyp, ops, d): i, j = len(gt), len(hyp) gt = [''] + gt hyp = [''] + hyp s = [] while j >= 0: if ops[i][j] == "keep": s.append((hyp[j], "underline")) i -= 1 j -= 1 elif ops[i][j] == "rep": s.append((f"【{hyp[j]}->{gt[i]}】", "frame")) i -= 1 j -= 1 elif ops[i][j] == "ins": s.append((gt[i], "underline")) i -= 1 elif ops[i][j] == "del": s.append((hyp[j], "strike")) j -= 1 else: raise ValueError text = Text() text.append(f"gt: {' '.join(gt)}\n") text.append(f"hyp: {' '.join(hyp)}\n") text.append("analysis: ") for t in s[::-1]: text.append(t[0], style=t[1]) text.append(' ') text.append(f"\terr/num: {d}/{len(gt) - 1}") return text if __name__ == "__main__": argparser = ArgumentParser() argparser.add_argument("--ground_truth", "--gt", type=str, required=True) argparser.add_argument("--hypothesis", "--hyp", type=str, required=True) argparser.add_argument("--delimeter", "-d", type=str, default=' ') args = argparser.parse_args() gt, hyp = args.ground_truth, args.hypothesis gt, hyp = gt.split(args.delimeter), hyp.split(args.delimeter) d, ops = edit_dist_dp(gt, hyp) assert d == editdistance.eval(gt, hyp) with Console() as console: console.print(analyze(gt, hyp, ops, d))