forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathnode_builder_test.cc
More file actions
59 lines (51 loc) · 2.04 KB
/
node_builder_test.cc
File metadata and controls
59 lines (51 loc) · 2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include <gtest/gtest.h>
namespace tensorflow {
namespace {
REGISTER_OP("Source").Output("o: out_types").Attr("out_types: list(type)");
REGISTER_OP("Sink").Input("i: T").Attr("T: type");
TEST(NodeBuilderTest, Simple) {
RequireDefaultOps();
Graph graph(OpRegistry::Global());
Node* source_node;
EXPECT_OK(NodeBuilder("source_op", "Source")
.Attr("out_types", {DT_INT32, DT_STRING})
.Finalize(&graph, &source_node));
ASSERT_TRUE(source_node != nullptr);
// Try connecting to each of source_node's outputs.
EXPECT_OK(NodeBuilder("sink1", "Sink")
.Input(source_node)
.Finalize(&graph, nullptr));
EXPECT_OK(NodeBuilder("sink2", "Sink")
.Input(source_node, 1)
.Finalize(&graph, nullptr));
// Generate an error if the index is out of range.
EXPECT_FALSE(NodeBuilder("sink3", "Sink")
.Input(source_node, 2)
.Finalize(&graph, nullptr)
.ok());
EXPECT_FALSE(NodeBuilder("sink4", "Sink")
.Input(source_node, -1)
.Finalize(&graph, nullptr)
.ok());
EXPECT_FALSE(NodeBuilder("sink5", "Sink")
.Input({source_node, -1})
.Finalize(&graph, nullptr)
.ok());
// Generate an error if the node is nullptr. This can happen when using
// GraphDefBuilder if there was an error creating the input node.
EXPECT_FALSE(NodeBuilder("sink6", "Sink")
.Input(nullptr)
.Finalize(&graph, nullptr)
.ok());
EXPECT_FALSE(NodeBuilder("sink7", "Sink")
.Input(NodeBuilder::NodeOut(nullptr, 0))
.Finalize(&graph, nullptr)
.ok());
}
} // namespace
} // namespace tensorflow