Skip to content

Latest commit

 

History

History
120 lines (104 loc) · 3.57 KB

File metadata and controls

120 lines (104 loc) · 3.57 KB

description: Finds the maximum directed spanning tree of a digraph.

text.max_spanning_tree

View source

Finds the maximum directed spanning tree of a digraph.

text.max_spanning_tree(
    num_nodes, scores, forest=False, name=None
)

Given a batch of directed graphs with scored arcs and root selections, solves for the maximum spanning tree of each digraph, where the score of a tree is defined as the sum of the scores of the arcs and roots making up the tree.

Returns the score of the maximum spanning tree of each digraph, as well as the arcs and roots in that tree. Each digraph in a batch may contain a different number of nodes, so the sizes of the digraphs must be provided as an input.

Note that this operation is only differentiable w.r.t. its |scores| input and its |max_scores| output.

The code here is intended for NLP applications, but attempts to remain agnostic to particular NLP tasks (such as dependency parsing).

Args

`num_nodes` A `Tensor` of type `int32`. [B] vector where entry b is number of nodes in the b'th digraph.
`scores` A `Tensor`. Must be one of the following types: `int32`, `float32`, `float64`. [B,M,M] tensor where entry b,t,s is the score of the arc from node s to node t in the b'th directed graph if s!=t, or the score of selecting node t as a root in the b'th digraph if s==t. This uniform tenosor requires that M is >= num_nodes[b] for all b (ie. all graphs in the batch), and ignores entries b,s,t where s or t is >= num_nodes[b]. Arcs or root selections with non-finite score are treated as nonexistent.
`forest` An optional `bool`. Defaults to `False`. If true, solves for a maximum spanning forest instead of a maximum spanning tree, where a spanning forest is a set of disjoint trees that span the nodes of the digraph.
`name` A name for the operation (optional).

Returns

A tuple of `Tensor` objects (max_scores, argmax_sources).
`max_scores` A `Tensor`. Has the same type as `scores`. [B] vector where entry b is the score of the maximum spanning tree of the b'th digraph.
`argmax_sources` A `Tensor` of type `int32`. [B,M] matrix where entry b,t is the source of the arc inbound to t in the maximum spanning tree of the b'th digraph, or t if t is a root. Entries b,t where t is >= num_nodes[b] are set to -1. Quickly finding the roots can be done as: tf.equal(tf.map_fn(lambda x: tf.range(tf.size(x)), argmax_sources), argmax_sources)