Skip to content

Commit 02f7d16

Browse files
committed
New bindings in recent tensorflows
1 parent 50e8b23 commit 02f7d16

File tree

3 files changed

+282
-0
lines changed

3 files changed

+282
-0
lines changed

SampleTest/SampleTest.cs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,12 @@ public static void p (string p)
236236
Console.WriteLine (p);
237237
}
238238

239+
#region Samples
240+
//
241+
// Samples to exercise the API usability
242+
//
243+
// From https://github.com/aymericdamien/TensorFlow-Examples
244+
//
239245
void BasicConstantOps ()
240246
{
241247
//
@@ -323,7 +329,53 @@ void BasicMatrix ()
323329

324330
};
325331
}
332+
#if false
333+
void LinearRegression ()
334+
{
335+
Console.WriteLine ("Linear regression");
336+
// Parameters
337+
var learning_rate = 0.01;
338+
var training_epochs = 1000;
339+
var display_step = 50;
340+
341+
// Training data
342+
var train_x = new double [] {
343+
3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167,
344+
7.042, 10.791, 5.313, 7.997, 5.654, 9.27, 3.1
345+
};
346+
var train_y = new double [] {
347+
1.7,2.76,2.09,3.19,1.694,1.573,3.366,2.596,2.53,1.221,
348+
2.827,3.465,1.65,2.904,2.42,2.94,1.3
349+
};
350+
var n_samples = train_x.Length;
351+
using (var g = new TFGraph ()) {
352+
var s = new TFSession (g);
353+
var rng = new Random ();
354+
// tf Graph Input
355+
356+
var X = g.Placeholder (TFDataType.Float);
357+
var Y = g.Placeholder (TFDataType.Float);
358+
var W = g.Variable (new TFShape (rng.Next ()), TFDataType.Float, operName: "weight");
359+
var b = g.Variable (new TFShape (rng.Next ()), TFDataType.Float, operName: "bias");
360+
361+
var pred = g.Add (g.Mul (X, W), b);
362+
363+
// Struggling with the following:
364+
// The call to g.Pow returns a TFOutput, but g.ReduceSum expects a TFTensor
365+
// Python seems to return operation definitions, and somehow those can be p
366+
//passed as tensors:
367+
// tensorflow/python/framework/op_def_library.py
368+
// (apply_op)
369+
//
370+
//https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/linear_regression.py
371+
var cost = g.Div (g.ReduceSum (g.Pow (g.Sub (pred, Y), g.Const (2))), g.Mul (g.Const (2), g.Const (n_samples)));
326372

373+
374+
375+
}
376+
}
377+
#endif
378+
#endregion
327379
public static void Main (string [] args)
328380
{
329381
Console.WriteLine (Environment.CurrentDirectory);

TensorFlowSharp/OperationsExtras.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using System;
2+
using System.Linq;
3+
24
namespace TensorFlow
35
{
46
public partial class TFGraph
@@ -12,5 +14,46 @@ public TFOutput Const (TFTensor value, string operName = null)
1214
{
1315
return Const (value, value.TensorType, operName);
1416
}
17+
18+
// Returns range(0, rank(x)) if reduction_indices is null
19+
TFOutput ReduceDims (TFTensor input, TFOutput? axis = null)
20+
{
21+
if (axis.HasValue)
22+
return axis.Value;
23+
24+
// Fast path: avoid creating Rank and Range ops if ndims is known.
25+
if (input.NumDims >= 0) {
26+
// The python code distinguishes between tensor and sparsetensor
27+
28+
var array = new int [input.NumDims];
29+
for (int i = 0; i < array.Length; i++)
30+
array [i] = i;
31+
32+
return this.Const (array, TFDataType.Int32);
33+
}
34+
return Range (Const (0), Const (input.NumDims), Const (1));
35+
}
36+
37+
/// <summary>
38+
/// Computes the sum of elements across dimensions of a tensor.
39+
/// </summary>
40+
/// <returns>The reduced tensor.</returns>
41+
/// <param name="input">The tensor to reduce. Should have numeric type.</param>
42+
/// <param name="axis">The dimensions to reduce. If not se (the default), reduces all dimensions.</param>
43+
/// <param name="keep_dims">If set to <c>true</c> retains reduced dimensions with length 1.</param>
44+
/// <param name="operName">A name for the operation, optional.</param>
45+
/// <remarks>
46+
/// Reduces input_tensor along the dimensions given in axis.
47+
/// Unless keep_dims is true, the rank of the tensor is reduced by 1 for each
48+
/// entry in axis. If keep_dims is true, the reduced dimensions
49+
/// are retained with length 1.
50+
///
51+
/// If axis has no entries, all dimensions are reduced, and a
52+
/// tensor with a single element is returned.
53+
/// </remarks>
54+
public TFOutput ReduceSum (TFTensor input, TFOutput? axis = null, bool? keep_dims = false, string operName = null)
55+
{
56+
return Sum (Const (input), this.ReduceDims (input, axis), keep_dims, operName);
57+
}
1558
}
1659
}

TensorFlowSharp/Tensorflow.cs

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,116 @@ string MakeUnique (string name)
14981498
return name + val;
14991499
}
15001500

1501+
[DllImport (NativeBinding.TensorFlowLibrary)]
1502+
unsafe extern static void TF_GraphImportGraphDefWithReturnOutputs (
1503+
TF_Graph graph, LLBuffer *graph_def,
1504+
TF_ImportGraphDefOptions options, TFOutput *return_outputs,
1505+
int num_return_outputs, TF_Status status);
1506+
1507+
/// <summary>
1508+
/// Imports a graph serialized into the graph
1509+
/// </summary>
1510+
/// <param name="graphDef">Serialized graph definition (in protocol buffer format).</param>
1511+
/// <param name="options">Import options.</param>
1512+
/// <param name="returnOutputs">Array large enough to contain all the return options.</param>
1513+
/// <param name="status">Status, optional.</param>
1514+
public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, TFOutput [] returnOutputs, TFStatus status = null)
1515+
{
1516+
if (handle == IntPtr.Zero)
1517+
ObjectDisposedException ();
1518+
if (graphDef == null)
1519+
throw new ArgumentNullException (nameof (graphDef));
1520+
if (options == null)
1521+
throw new ArgumentNullException (nameof (options));
1522+
var cstatus = TFStatus.Setup (status);
1523+
1524+
unsafe
1525+
{
1526+
if (returnOutputs == null) {
1527+
TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, null, 0, cstatus.handle);
1528+
} else {
1529+
fixed (TFOutput* first = &returnOutputs [0])
1530+
{
1531+
TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, first, returnOutputs.Length, cstatus.handle);
1532+
}
1533+
}
1534+
}
1535+
}
1536+
1537+
[StructLayout (LayoutKind.Sequential)]
1538+
unsafe struct TFWhileParams
1539+
{
1540+
int ninputs;
1541+
TF_Graph cond_graph;
1542+
TFOutput* cond_inputs;
1543+
TFOutput cond_output;
1544+
TF_Graph* body_graph;
1545+
TFOutput* body_inputs;
1546+
TFOutput* body_outputs;
1547+
IntPtr charPtrName;
1548+
}
1549+
1550+
[DllImport (NativeBinding.TensorFlowLibrary)]
1551+
static extern unsafe TFWhileParams TF_NewWhile (TF_Graph g, TFOutput [] inputs, int ninputs, TF_Status status);
1552+
1553+
[DllImport (NativeBinding.TensorFlowLibrary)]
1554+
static extern void TF_AbortWhile (ref TFWhileParams pars);
1555+
1556+
[DllImport (NativeBinding.TensorFlowLibrary)]
1557+
static extern unsafe void TF_FinishWhile (ref TFWhileParams pars, TF_Status status, TFOutput [] outputs);
1558+
1559+
/// <summary>
1560+
/// Signature of the method that will be invoked by the TFGraph.While method to construct a while loop
1561+
/// </summary>
1562+
/// <remarks>
1563+
/// The method should build up the condition on the conditionGraph and the body of the while
1564+
/// loop in the provided bodyGraph. It should set the condOutput to the value used as the
1565+
/// condition output and the array of values in bodyOutputs to the final outputs as well as the
1566+
/// name to be used, if not set, one will be assigned.
1567+
/// </remarks>
1568+
public delegate void WhileConstructor (TFGraph conditionGraph, TFGraph bodyGraph, TFOutput [] condInputs, TFOutput [] bodyInputs, ref TFOutput condOutput, ref TFOutput [] bodyOutputs, ref string name);
1569+
1570+
/// <summary>
1571+
/// Constructs a while loop with the specified inputs and a callback that composes the while loop
1572+
/// </summary>
1573+
/// <param name="inputs">Inputs.</param>
1574+
/// <param name="constructor">Callback method that fills out the various while loop parameters.</param>
1575+
/// <returns>
1576+
/// true on success, or false if it was not possible to create the while loop.
1577+
/// </returns>
1578+
public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFStatus status = null)
1579+
{
1580+
if (handle == IntPtr.Zero)
1581+
ObjectDisposedException ();
1582+
if (inputs == null)
1583+
throw new ArgumentNullException (nameof (inputs));
1584+
if (constructor == null)
1585+
throw new ArgumentNullException (nameof (constructor));
1586+
var s = TFStatus.Setup (status);
1587+
var result = TF_NewWhile (handle, inputs, inputs.Length, s.handle);
1588+
if (s.Error)
1589+
return null;
1590+
try {
1591+
//
1592+
// Call constructor here
1593+
// Wrap the various TF_graphs (with owns=false)
1594+
// Marshal the condInputs, bodyInputs
1595+
//
1596+
// TODO:
1597+
throw new NotImplementedException ();
1598+
1599+
// On return, copy the condOutput and bodyOututs
1600+
// Set the name
1601+
var ret = new TFOutput [inputs.Length];
1602+
TF_FinishWhile (ref result, s.handle, ret);
1603+
return ret;
1604+
} catch {
1605+
TF_AbortWhile (ref result);
1606+
return null;
1607+
}
1608+
}
1609+
1610+
15011611
}
15021612

15031613
/// <summary>
@@ -2213,6 +2323,81 @@ public void SetPrefix (string prefix)
22132323
TF_ImportGraphDefOptionsSetPrefix (handle, prefix);
22142324
}
22152325

2326+
// extern void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions *opts, const char* src_name, int src_index, TF_Output dst);
2327+
[DllImport (NativeBinding.TensorFlowLibrary)]
2328+
static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions opts, string src_name, int src_index, TFOutput dst);
2329+
2330+
2331+
/// <summary>
2332+
/// Adds an input mapping from a source name and index to a destination output
2333+
/// </summary>
2334+
/// <param name="srcName">Source name.</param>
2335+
/// <param name="srcIndex">Source index (in the source).</param>
2336+
/// <param name="dst">Replacement value for the srcName:srcIndex.</param>
2337+
/// <remarks>
2338+
/// Set any imported nodes with input `src_name:src_index` to have that input
2339+
/// replaced with `dst`. `src_name` refers to a node in the graph to be imported,
2340+
/// `dst` references a node already existing in the graph being imported into.
2341+
/// </remarks>
2342+
public void AddInputMapping (string srcName, int srcIndex, TFOutput dst)
2343+
{
2344+
if (handle == IntPtr.Zero)
2345+
ObjectDisposedException ();
2346+
TF_ImportGraphDefOptionsAddInputMapping (handle, srcName, srcIndex, dst);
2347+
}
2348+
2349+
[DllImport (NativeBinding.TensorFlowLibrary)]
2350+
extern static void TF_ImportGraphDefOptionsAddControlDependency (TF_ImportGraphDefOptions opts, TF_Operation oper);
2351+
2352+
/// <summary>
2353+
/// Cause the imported graph to have a control dependency on the provided operation.
2354+
/// </summary>
2355+
/// <param name="operation">This operation should exist in the graph being imported to.</param>
2356+
public void AddControlDependency (TFOperation operation)
2357+
{
2358+
if (operation == null)
2359+
throw new ArgumentNullException (nameof (operation));
2360+
if (handle == IntPtr.Zero)
2361+
ObjectDisposedException ();
2362+
2363+
TF_ImportGraphDefOptionsAddControlDependency (handle, operation.handle);
2364+
}
2365+
2366+
[DllImport (NativeBinding.TensorFlowLibrary)]
2367+
extern static void TF_ImportGraphDefOptionsAddReturnOutput (TF_ImportGraphDefOptions opts, string oper_name, int index);
2368+
2369+
/// <summary>
2370+
/// Add an output in the graph definition to be returned via the return outputs parameter.
2371+
/// </summary>
2372+
/// <param name="operName">Operation name.</param>
2373+
/// <param name="index">Operation index.</param>
2374+
/// <remarks>
2375+
/// If the output is remapped via an input
2376+
/// mapping, the corresponding existing tensor in graph will be returned.
2377+
/// </remarks>
2378+
public void AddReturnOutput (string operName, int index)
2379+
{
2380+
if (operName == null)
2381+
throw new ArgumentNullException (nameof (operName));
2382+
if (handle == IntPtr.Zero)
2383+
ObjectDisposedException ();
2384+
TF_ImportGraphDefOptionsAddReturnOutput (handle, operName, index);
2385+
}
2386+
2387+
[DllImport (NativeBinding.TensorFlowLibrary)]
2388+
extern static int TF_ImportGraphDefOptionsNumReturnOutputs (TF_ImportGraphDefOptions opts);
2389+
2390+
/// <summary>
2391+
/// Gets the number return outputs added via AddReturnOutput.
2392+
/// </summary>
2393+
/// <value>The number return outputs.</value>
2394+
public int NumReturnOutputs {
2395+
get {
2396+
if (handle == IntPtr.Zero)
2397+
ObjectDisposedException ();
2398+
return TF_ImportGraphDefOptionsNumReturnOutputs (handle);
2399+
}
2400+
}
22162401

22172402
}
22182403

@@ -2854,4 +3039,6 @@ public override string ToString ()
28543039
}
28553040
}
28563041

3042+
3043+
28573044
}

0 commit comments

Comments
 (0)