Skip to content

Commit

Permalink
Merge pull request #23 from huggingface/training
Browse files Browse the repository at this point in the history
Training
  • Loading branch information
thomwolf authored Mar 23, 2018
2 parents 5d9ef6b + 4b858d6 commit 5a2520e
Show file tree
Hide file tree
Showing 93 changed files with 13,894 additions and 246 deletions.
9 changes: 6 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ __pycache__/
/include/
/lib/
/pip-selfcheck.json
neuralcoref/data/*
neuralcoref/train/*
.cache
/runs/*
test_corefs.txt
test_mentions.txt
.cache
/.vscode/*
/.vscode
99 changes: 65 additions & 34 deletions neuralcoref/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from __future__ import unicode_literals
from __future__ import print_function

from pprint import pprint

import sys
import os
import spacy
import numpy as np

from neuralcoref.data import Data, MENTION_TYPE, NO_COREF_LIST
from neuralcoref.compat import unicode_
from neuralcoref.document import Document, MENTION_TYPE, NO_COREF_LIST

PACKAGE_DIRECTORY = os.path.dirname(os.path.abspath(__file__))

Expand All @@ -22,24 +22,28 @@
#######################
###### CLASSES ########

class Model:
class Model(object):
'''
Coreference neural model
'''
def __init__(self, model_path):
weights, biases = [], []
for file in sorted(os.listdir(model_path)):
if file.startswith("single_mention_weights"):
weights.append(np.load(os.path.join(model_path, file)))
w = np.load(os.path.join(model_path, file))
weights.append(w)
if file.startswith("single_mention_bias"):
biases.append(np.load(os.path.join(model_path, file)))
w = np.load(os.path.join(model_path, file))
biases.append(w)
self.single_mention_model = list(zip(weights, biases))
weights, biases = [], []
for file in sorted(os.listdir(model_path)):
if file.startswith("pair_mentions_weights"):
weights.append(np.load(os.path.join(model_path, file)))
w = np.load(os.path.join(model_path, file))
weights.append(w)
if file.startswith("pair_mentions_bias"):
biases.append(np.load(os.path.join(model_path, file)))
w = np.load(os.path.join(model_path, file))
biases.append(w)
self.pair_mentions_model = list(zip(weights, biases))

def _score(self, features, layers):
Expand All @@ -49,8 +53,8 @@ def _score(self, features, layers):
features = np.maximum(features, 0) # ReLU
return np.sum(features)

def get_single_mention_score(self, mention_embedding, anaphoricity_features):
first_layer_input = np.concatenate([mention_embedding,
def get_single_mention_score(self, mention, anaphoricity_features):
first_layer_input = np.concatenate([mention.embedding,
anaphoricity_features], axis=0)[:, np.newaxis]
return self._score(first_layer_input, self.single_mention_model)

Expand All @@ -61,32 +65,31 @@ def get_pair_mentions_score(self, antecedent, mention, pair_features):
return self._score(first_layer_input, self.pair_mentions_model)


class Coref:
class Coref(object):
'''
Main coreference resolution algorithm
'''
def __init__(self, nlp=None, greedyness=0.5, max_dist=50, max_dist_match=500, conll=None, use_no_coref_list=True, debug=False):
def __init__(self, nlp=None, greedyness=0.5, max_dist=50, max_dist_match=500, conll=None,
use_no_coref_list=True, debug=False):
self.greedyness = greedyness
self.max_dist = max_dist
self.max_dist_match = max_dist_match
self.debug = debug

model_path = os.path.join(PACKAGE_DIRECTORY, "weights/conll/" if conll is not None else "weights/")
trained_embed_path = os.path.join(PACKAGE_DIRECTORY, "weights/")
print("Loading neuralcoref model from", model_path)
self.coref_model = Model(model_path)
if nlp is None:
print("Loading spacy model")
try:
spacy.info('en_core_web_sm')
model = 'en_core_web_sm'
except IOError:
print("No spacy 2 model detected, using spacy1 'en' model")
spacy.info('en')
model = 'en'
nlp = spacy.load(model)

model_path = os.path.join(PACKAGE_DIRECTORY, "weights/conll/" if conll is not None else "weights/")
embed_model_path = os.path.join(PACKAGE_DIRECTORY, "weights/")
print("loading model from", model_path)
self.data = Data(nlp, model_path=embed_model_path, conll=conll, use_no_coref_list=use_no_coref_list, consider_speakers=conll)
self.coref_model = Model(model_path)

self.data = Document(nlp, conll=conll, use_no_coref_list=use_no_coref_list, trained_embed_path=trained_embed_path)
self.clusters = {}
self.mention_to_cluster = []
self.mentions_single_scores = {}
Expand Down Expand Up @@ -129,13 +132,22 @@ def _merge_coreference_clusters(self, ant_idx, mention_idx):

del self.clusters[remove_id]

def remove_singletons_clusters(self):
remove_id = []
for key, mentions in self.clusters.items():
if len(mentions) == 1:
remove_id.append(key)
self.mention_to_cluster[key] = None
for rem in remove_id:
del self.clusters[rem]

def display_clusters(self):
'''
Print clusters informations
'''
print(self.clusters)
for key, mentions in self.clusters.items():
print("cluster", key, "(", ", ".join(str(self.data[m]) for m in mentions), ")")
print("cluster", key, "(", ", ".join(unicode_(self.data[m]) for m in mentions), ")")

###################################
####### MAIN COREF FUNCTIONS ######
Expand All @@ -150,11 +162,10 @@ def run_coref_on_mentions(self, mentions):
for mention_idx, ant_list in self.data.get_candidate_pairs(mentions, self.max_dist, self.max_dist_match):
mention = self.data[mention_idx]
feats_, ana_feats = self.data.get_single_mention_features(mention)
anaphoricity_score = self.coref_model.get_single_mention_score(mention.embedding, ana_feats)
self.mentions_single_scores[mention_idx] = anaphoricity_score
single_score = self.coref_model.get_single_mention_score(mention, ana_feats)
self.mentions_single_scores[mention_idx] = single_score
self.mentions_single_features[mention_idx] = {"spansEmbeddings": mention.spans_embeddings_, "wordsEmbeddings": mention.words_embeddings_, "features": feats_}

best_score = anaphoricity_score - 50 * (self.greedyness - 0.5)
best_score = single_score - 50 * (self.greedyness - 0.5)
for ant_idx in ant_list:
antecedent = self.data[ant_idx]
feats_, pwf = self.data.get_pair_mentions_features(antecedent, mention)
Expand All @@ -164,7 +175,6 @@ def run_coref_on_mentions(self, mentions):
"antecedentWordsEmbeddings": antecedent.words_embeddings_,
"mentionSpansEmbeddings": mention.spans_embeddings_,
"mentionWordsEmbeddings": mention.words_embeddings_ }

if score > best_score:
best_score = score
best_ant[mention_idx] = ant_idx
Expand All @@ -173,25 +183,29 @@ def run_coref_on_mentions(self, mentions):
self._merge_coreference_clusters(best_ant[mention_idx], mention_idx)
return (n_ant, best_ant)

def run_coref_on_utterances(self, last_utterances_added=False, follow_chains=True):
def run_coref_on_utterances(self, last_utterances_added=False, follow_chains=True, debug=False):
''' Run the coreference model on some utterances
Arg:
last_utterances_added: run the coreference model over the last utterances added to the data
follow_chains: follow coreference chains over previous utterances
'''
if debug: print("== run_coref_on_utterances == start")
self._prepare_clusters()
if debug: self.display_clusters()
mentions = list(self.data.get_candidate_mentions(last_utterances_added=last_utterances_added))
n_ant, antecedents = self.run_coref_on_mentions(mentions)
mentions = antecedents.values()
if follow_chains and n_ant > 0:
if follow_chains and last_utterances_added and n_ant > 0:
i = 0
while i < MAX_FOLLOW_UP:
i += 1
n_ant, antecedents = self.run_coref_on_mentions(mentions)
mentions = antecedents.values()
if n_ant == 0:
break
if debug: self.display_clusters()
if debug: print("== run_coref_on_utterances == end")

def one_shot_coref(self, utterances, utterances_speakers_id=None, context=None,
context_speakers_id=None, speakers_names=None):
Expand Down Expand Up @@ -236,7 +250,7 @@ def continuous_coref(self, utterances, utterances_speakers_id=None, speakers_nam

def get_utterances(self, last_utterances_added=True):
''' Retrieve the list of parsed uterrances'''
if last_utterances_added:
if last_utterances_added and len(self.data.last_utterances_loaded):
return [self.data.utterances[idx] for idx in self.data.last_utterances_loaded]
else:
return self.data.utterances
Expand Down Expand Up @@ -272,9 +286,10 @@ def get_scores(self):
return {"single_scores": self.mentions_single_scores,
"pair_scores": self.mentions_pairs_scores}

def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
def get_clusters(self, remove_singletons=False, use_no_coref_list=False):
''' Retrieve cleaned clusters'''
clusters = self.clusters
mention_to_cluster = self.mention_to_cluster
remove_id = []
if use_no_coref_list:
for key, mentions in clusters.items():
Expand All @@ -289,7 +304,7 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
for key, mentions in clusters.items():
if self.data.mentions[key].lower_ in NO_COREF_LIST:
remove_id.append(key)
self.mention_to_cluster[key] = None
mention_to_cluster[key] = None
if mentions:
added[mentions[0]] = mentions
for rem in remove_id:
Expand All @@ -301,11 +316,11 @@ def get_clusters(self, remove_singletons=True, use_no_coref_list=True):
for key, mentions in clusters.items():
if len(mentions) == 1:
remove_id.append(key)
self.mention_to_cluster[key] = None
mention_to_cluster[key] = None
for rem in remove_id:
del clusters[rem]

return clusters
return clusters, mention_to_cluster

def get_most_representative(self, last_utterances_added=True, use_no_coref_list=True):
'''
Expand All @@ -314,7 +329,7 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
Return:
Dictionnary of {original_mention: most_representative_resolved_mention, ...}
'''
clusters = self.get_clusters(remove_singletons=True, use_no_coref_list=use_no_coref_list)
clusters, _ = self.get_clusters(remove_singletons=True, use_no_coref_list=use_no_coref_list)
coreferences = {}
for key in self.data.get_candidate_mentions(last_utterances_added=last_utterances_added):
if self.mention_to_cluster[key] is None:
Expand All @@ -333,3 +348,19 @@ def get_most_representative(self, last_utterances_added=True, use_no_coref_list=
representative = mention

return coreferences

if __name__ == '__main__':
coref = Coref(use_no_coref_list=False)
if len(sys.argv) > 1:
sent = sys.argv[1]
coref.one_shot_coref(sent)
else:
coref.one_shot_coref(u"Yes, I noticed that many friends, around me received it. It seems that almost everyone received this SMS.")#u"My sister has a dog. She loves him.")
mentions = coref.get_mentions()
print(mentions)

utterances = coref.get_utterances()
print(utterances)

resolved_utterance_text = coref.get_resolved_utterances()
print(resolved_utterance_text)
2 changes: 2 additions & 0 deletions neuralcoref/bld.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"%PYTHON%" setup.py install --single-version-externally-managed --record=record.txt
if errorlevel 1 exit 1
1 change: 1 addition & 0 deletions neuralcoref/build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
$PYTHON setup.py install --single-version-externally-managed --record=record.txt # Python command to install the script.
4 changes: 4 additions & 0 deletions neuralcoref/checkpoints/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore
32 changes: 32 additions & 0 deletions neuralcoref/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# coding: utf8
"""Py2/3 compatibility"""
import sys

is_python2 = int(sys.version[0]) == 2
is_windows = sys.platform.startswith('win')
is_linux = sys.platform.startswith('linux')
is_osx = sys.platform == 'darwin'

if is_python2:
bytes_ = str
unicode_ = unicode
string_types = (str, unicode)
chr_ = unichr

def unicode_to_bytes(s, encoding='utf8', errors='strict'):
return s.encode(encoding=encoding, errors=errors)

def bytes_to_unicode(b, encoding='utf8', errors='strict'):
return unicode_(b, encoding=encoding, errors=errors)

else:
bytes_ = bytes
unicode_ = str
string_types = (bytes, str)
chr_ = chr

def unicode_to_bytes(s, encoding='utf8', errors='strict'):
return s.encode(encoding=encoding, errors=errors)

def bytes_to_unicode(b, encoding='utf8', errors='strict'):
return b.decode(encoding=encoding, errors=errors)
Loading

0 comments on commit 5a2520e

Please sign in to comment.