Skip to content

Commit cc521eb

Browse files
Place all the nodes created by the trivial_test_graph_input_yielder
PiperOrigin-RevId: 171045878
1 parent 9b93012 commit cc521eb

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
3939

4040
// x is from the feed.
4141
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
42-
Output x =
43-
RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
42+
Output x = RandomNormal(s.WithOpName("x").WithDevice("/CPU:0"),
43+
{batch_size, 1}, DataType::DT_FLOAT);
4444

4545
// Create stages.
4646
std::vector<Output> last_stage;
@@ -64,16 +64,19 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
6464
}
6565

6666
if (insert_queue) {
67-
FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
68-
QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
69-
QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
70-
QueueClose cancel(s.WithOpName("cancel"), queue,
67+
FIFOQueue queue(s.WithOpName("queue").WithDevice("/CPU:0"),
68+
{DataType::DT_FLOAT});
69+
QueueEnqueue enqueue(s.WithOpName("enqueue").WithDevice("/CPU:0"), queue,
70+
last_stage);
71+
QueueDequeue dequeue(s.WithOpName("dequeue").WithDevice("/CPU:0"), queue,
72+
{DataType::DT_FLOAT});
73+
QueueClose cancel(s.WithOpName("cancel").WithDevice("/CPU:0"), queue,
7174
QueueClose::CancelPendingEnqueues(true));
7275
last_stage = {dequeue[0]};
7376
}
7477

7578
// Create output.
76-
AddN output(s.WithOpName("y"), last_stage);
79+
AddN output(s.WithOpName("y").WithDevice("/CPU:0"), last_stage);
7780

7881
GraphDef def;
7982
TF_CHECK_OK(s.ToGraphDef(&def));

0 commit comments

Comments
 (0)