forked from migueldeicaza/TensorFlowSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathArrayOps.cs
More file actions
72 lines (65 loc) · 2.47 KB
/
ArrayOps.cs
File metadata and controls
72 lines (65 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
//
// ArrayOps: support for manipulating tensors
//
// Authors:
// Stephanus van Staden
//
// This is a port of the Python code in tensorflow
//
//
using System;
namespace TensorFlow
{
public partial class TFGraph
{
/// <summary>
/// Outputs Zero values based on shape of tensor
/// </summary>
/// <param name="shape">Shape of the output tensor</param>
/// <param name="dtype">Optional Type of the Zero value. Default: Double</param>
/// <param name="operName">Operation name, optional.</param>
/// <returns></returns>
public TFOutput Zeros (TFShape shape, TFDataType dtype = TFDataType.Double, string operName = null)
{
return Constant (0, shape, dtype, operName);
}
/// <summary>
/// Outputs One values based on shape of tensor
/// </summary>
/// <param name="shape">Shape of the output tensor</param>
/// <param name="dtype">Optional Type of the Zero value. Default: Double</param>
/// <param name="operName">Operation name, optional.</param>
/// <returns></returns>
public TFOutput Ones (TFShape shape, TFDataType dtype = TFDataType.Double, string operName = null)
{
return Constant (1, shape, dtype, operName);
}
/// <summary>
/// Create a constant tensor based on a shape
/// Used by Zeros and Ones
/// </summary>
/// <param name="value">Value for tensor</param>
/// <param name="tfshape">Shape of the tensor</param>
/// <param name="dtype">Optional Type of the Zero value. Default: Double</param>
/// <param name="operName">Operation name, optional.</param>
/// <returns></returns>
/// see https://github.com/tensorflow/tensorflow/blob/r1.1/tensorflow/python/framework/constant_op.py
public TFOutput Constant (object value, TFShape tfshape, TFDataType dtype = TFDataType.Double, string operName = null)
{
//convert the .net type to relevant tensorflow type
object dtvalue = TFTensor.FetchSimple (dtype, value);
var shape = tfshape.ToArray ();
var idx = new int [shape.Length];
for (int i = 0; i < shape.Length; i++) {
if (shape [i] > Int32.MaxValue)
throw new ArgumentOutOfRangeException ("Shape can not be longer than 32 bits");
}
Array data = null;
if (tfshape.IsLongArray) data = Array.CreateInstance (dtvalue.GetType (), tfshape.ToArray ());
else data = Array.CreateInstance (dtvalue.GetType (), tfshape.ToIntArray ());
TFTensor.Set (data, dtype, shape, idx, 0, value);
TFTensor tensor_value = new TFTensor (data);
return Const (tensor_value, tensor_value.TensorType, operName);
}
}
}