forked from migueldeicaza/TensorFlowSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTensorflow.cs
More file actions
2330 lines (1996 loc) · 84.4 KB
/
Tensorflow.cs
File metadata and controls
2330 lines (1996 loc) · 84.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//
// TensorFlow.cs; Bindings to the TensorFlow C API for .NET
//
// Authors:
// Miguel de Icaza ([email protected])
//
// Strongly typed API
// The API generally takes a TF_Status that defaults to null, if the value is null, on error, this raises an exception, otherwise, the error is returned on the TF_Status.
// You can use TFStatus.Default for a value to use when you do not want to create the value yourself and are ok reusing the value.
//
// Guidaance on doing language bindings for Tensorflow:
// https://www.tensorflow.org/versions/r0.11/how_tos/language_bindings/
//
//
using System;
using System.Runtime.InteropServices;
using System.Text;
using System.Globalization;
using System.Linq;
// We use this TF_Xxx as the native "TF_Xxx *" as those are opaque
using TF_Status = System.IntPtr;
using TF_SessionOptions = System.IntPtr;
using TF_Graph = System.IntPtr;
using TF_OperationDescription = System.IntPtr;
using TF_Operation = System.IntPtr;
using TF_Session = System.IntPtr;
using TF_DeprecatedSession = System.IntPtr;
using TF_Tensor = System.IntPtr;
using TF_ImportGraphDefOptions = System.IntPtr;
using TF_Library = System.IntPtr;
using TF_BufferPtr = System.IntPtr;
using size_t = System.UIntPtr;
using System.Numerics;
using System.Collections.Generic;
using System.Linq.Expressions;
namespace TensorFlow
{
static partial class NativeBinding
{
public const string TensorFlowLibrary = "libtensorflow";
internal static string GetStr (this IntPtr x) => Marshal.PtrToStringAnsi (x);
}
public static class TFCore {
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe IntPtr TF_Version ();
public static string Version => TF_Version ().GetStr ();
// extern size_t TF_DataTypeSize (TF_DataType dt);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern IntPtr TF_DataTypeSize (TFDataType dt);
public static long GetDataTypeSize (TFDataType dt) => (long)TF_DataTypeSize (dt);
// extern TF_Buffer * TF_GetAllOpList ();
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe IntPtr TF_GetAllOpList ();
public static TFBuffer GetAllOpList ()
{
return new TFBuffer (TF_GetAllOpList ());
}
}
public abstract class TFDisposable : IDisposable
{
internal IntPtr handle;
public IntPtr Handle => handle;
public TFDisposable ()
{ }
public TFDisposable (IntPtr handle)
{
this.handle = handle;
}
public void Dispose ()
{
Dispose (true);
GC.SuppressFinalize (this);
}
~TFDisposable ()
{
Dispose (false);
}
// Must be implemented in subclasses to dispose the unmanaged object, it does
// not need to take care of zeroing out the handle, that is done by the Dispose
// method inherited from TFDisposable
internal abstract void NativeDispose (IntPtr handle);
public virtual void Dispose (bool disposing)
{
if (disposing) {
if (handle != IntPtr.Zero)
NativeDispose (handle);
handle = IntPtr.Zero;
}
}
internal static void ObjectDisposedException ()
{
throw new ObjectDisposedException ("The object was disposed");
}
}
public class TFException : Exception {
public TFException (string message) : base (message) { }
}
/// <summary>
/// Used to track the result of TensorFlow operations.
/// </summary>
/// <remarks>
/// TFStatus is used to track the status of a call to some TensorFlow
/// operations. Instances of this object are passed to various
/// TensorFlow operations and you can use the <see cref="P:TensorFlow.TFStatus.Ok"/>
/// to quickly check if the operation succeeded, or get more detail from the
/// <see cref="P:TensorFlow.TFStatus.StatusCode"/> and a human-readable text
/// using the <see cref="P:TensorFlow.TFStatus.StatusMessage"/> property.
///
/// The convenience <see cref="M:TensorFlow.TFStatus.Raise"/> can be used
/// to raise a <see cref="P:TensorFlow.TFException"/> if the status of the
/// operation did not succeed.
/// </remarks>
public class TFStatus : TFDisposable
{
// extern TF_Status * TF_NewStatus ();
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe TF_Status TF_NewStatus ();
[ThreadStatic] public static TFStatus Default = new TFStatus ();
/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFStatus"/> class.
/// </summary>
public TFStatus () : base (TF_NewStatus ())
{
}
// extern void TF_DeleteStatus (TF_Status *);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe void TF_DeleteStatus (TF_Status status);
internal override void NativeDispose (IntPtr handle)
{
TF_DeleteStatus (handle);
}
// extern void TF_SetStatus (TF_Status *s, TF_Code code, const char *msg);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetStatus (TF_Status s, TFCode code, string msg);
/// <summary>
/// Sets the status code on this TFStatus.
/// </summary>
/// <param name="code">Code.</param>
/// <param name="msg">Message.</param>
public void SetStatusCode (TFCode code, string msg)
{
TF_SetStatus (handle, code, msg);
}
// extern TF_Code TF_GetCode (const TF_Status *s);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe TFCode TF_GetCode (TF_Status s);
/// <summary>
/// Gets the status code for the status code.
/// </summary>
/// <value>The status code as an enumeration.</value>
public TFCode StatusCode {
get {
return TF_GetCode (handle);
}
}
// extern const char * TF_Message (const TF_Status *s);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe IntPtr TF_Message (TF_Status s);
/// <summary>
/// Gets a human-readable status message.
/// </summary>
/// <value>The status message.</value>
public string StatusMessage => TF_Message (handle).GetStr ();
/// <summary>
/// Returns a <see cref="T:System.String"/> that represents the current <see cref="T:TensorFlow.TFStatus"/>.
/// </summary>
/// <returns>A <see cref="T:System.String"/> that represents the current <see cref="T:TensorFlow.TFStatus"/>.</returns>
public override string ToString ()
{
return string.Format ("[TFStatus: StatusCode={0}, StatusMessage={1}]", StatusCode, StatusMessage);
}
/// <summary>
/// Gets a value indicating whether this <see cref="T:TensorFlow.TFStatus"/> state has been set to ok.
/// </summary>
/// <value><c>true</c> if ok; otherwise, <c>false</c>.</value>
public bool Ok => StatusCode == TFCode.Ok;
/// <summary>
/// Gets a value indicating whether this <see cref="T:TensorFlow.TFStatus"/> state has been set to an error.
/// </summary>
/// <value><c>true</c> if error; otherwise, <c>false</c>.</value>
public bool Error => StatusCode != TFCode.Ok;
/// <summary>
/// Convenience method that raises an exception if the current status is an error.
/// </summary>
/// <remarks>
/// You can use this method as a convenience to raise an exception after you
/// invoke an operation if the operation did not succeed.
/// </remarks>
public void Raise ()
{
if (TF_GetCode (handle) != TFCode.Ok)
throw new TFException (StatusMessage);
}
//
// Utility function used to simplify implementing the idiom
// where the user optionally provides a TFStatus, if it is provided,
// the error is returned there; If it is not provided, then an
// exception is raised.
//
internal bool CheckMaybeRaise (TFStatus incomingStatus, bool last = true)
{
if (incomingStatus == null) {
if (handle == IntPtr.Zero)
Console.WriteLine ("oops");
if (StatusCode != TFCode.Ok) {
var e = new TFException (StatusMessage);
Dispose ();
throw e;
}
if (last)
Dispose ();
return true;
}
return StatusCode == TFCode.Ok;
}
internal static TFStatus Setup (TFStatus incoming)
{
return incoming == null ? new TFStatus () : incoming;
}
}
internal class TFString
{
// extern size_t TF_StringEncode (const char *src, size_t src_len, char *dst, size_t dst_len, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe size_t TF_StringEncode (byte* src, size_t src_len, sbyte* dst, size_t dst_len, TF_Status status);
// extern size_t TF_StringDecode (const char *src, size_t src_len, const char **dst, size_t *dst_len, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe size_t TF_StringDecode (sbyte* src, size_t src_len, sbyte** dst, size_t* dst_len, TF_Status status);
// extern size_t TF_StringEncodedSize (size_t len);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern size_t TF_StringEncodedSize (size_t len);
}
public class TFSessionOptions : TFDisposable
{
// extern TF_SessionOptions * TF_NewSessionOptions ();
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe TF_SessionOptions TF_NewSessionOptions ();
public TFSessionOptions () : base (TF_NewSessionOptions ()) { }
// extern void TF_DeleteSessionOptions (TF_SessionOptions *);
[DllImport (NativeBinding.TensorFlowLibrary)]
internal static extern unsafe void TF_DeleteSessionOptions (TF_SessionOptions options);
internal override void NativeDispose (IntPtr handle)
{
TF_DeleteSessionOptions (handle);
}
// extern void TF_SetTarget (TF_SessionOptions *options, const char *target);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetTarget (TF_SessionOptions options, string target);
public void SetTarget (string target)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
TF_SetTarget (handle, target);
}
// extern void TF_SetConfig (TF_SessionOptions *options, const void *proto, size_t proto_len, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetConfig (TF_SessionOptions options, IntPtr proto, size_t proto_len, TF_Status status);
public void SetConfig (IntPtr protoData, int length, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var cstatus = TFStatus.Setup (status);
TF_SetConfig (handle, protoData, (UIntPtr)length, cstatus.handle);
cstatus.CheckMaybeRaise (status);
}
}
/// <summary>
/// Represents a computation graph. Graphs may be shared between sessions and are thread safe.
/// </summary>
/// <remarks>
/// Graphs consist of operations (represented by TFOperation objects), these can be named, or
/// the runtime will automatically assign a name.
///
/// For debugging purposes, you might want to group operations together, for this, call the
/// WithScope method with your new scope, which will create a new namespace for your object names.
///
/// For example, if you call WithScope ("demo"), and add an operation named "add" inside the
/// scope, the full name of the operation will be "demo/add", if you create a new scope inside, say
/// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
/// </remarks>
public partial class TFGraph : TFDisposable
{
// extern TF_Graph * TF_NewGraph ();
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe TF_Graph TF_NewGraph ();
/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFGraph"/> class.
/// </summary>
public TFGraph () : base (TF_NewGraph ())
{
}
internal TFGraph (IntPtr handle) : base (handle)
{
}
// extern void TF_DeleteGraph (TF_Graph *);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_DeleteGraph (TF_Graph graph);
internal override void NativeDispose (IntPtr handle)
{
TF_DeleteGraph (handle);
}
// extern void TF_GraphSetTensorShape (TF_Graph *graph, TF_Output output, const int64_t *dims, const int num_dims, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TFOutput output, ref long [] dims, int num_dims, TF_Status status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_GraphSetTensorShape (TF_Graph graph, TFOutput output, IntPtr dims, int num_dims, TF_Status status);
public void SetTensorShape (TFOutput output, long [] dims, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var cstatus = TFStatus.Setup (status);
if (dims == null)
TF_GraphSetTensorShape (handle, output, IntPtr.Zero, 0, cstatus.handle);
else
TF_GraphSetTensorShape (handle, output, ref dims, dims.Length, cstatus.handle);
cstatus.CheckMaybeRaise (status);
}
// extern int TF_GraphGetTensorNumDims (TF_Graph *graph, TF_Output output, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe int TF_GraphGetTensorNumDims (TF_Graph graph, TFOutput output, TF_Status status);
public int GetTensorNumDims (TFOutput output, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var cstatus = TFStatus.Setup (status);
var code = TF_GraphGetTensorNumDims (handle, output, cstatus.handle);
cstatus.CheckMaybeRaise (status);
return code;
}
// extern void TF_GraphGetTensorShape (TF_Graph *graph, TF_Output output, int64_t *dims, int num_dims, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_GraphGetTensorShape (TF_Graph graph, TFOutput output, ref long [] dims, int num_dims, TF_Status status);
public long [] GetTensorShape (TFOutput output, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var cstatus = TFStatus.Setup (status);
var n = TF_GraphGetTensorNumDims (handle, output, cstatus.handle);
if (!cstatus.CheckMaybeRaise (status, last: false))
return null;
var dims = new long [n];
TF_GraphGetTensorShape (handle, output, ref dims, dims.Length, cstatus.handle);
cstatus.CheckMaybeRaise (status);
return dims;
}
// extern void TF_GraphToGraphDef (TF_Graph *graph, TF_Buffer *output_graph_def, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_GraphToGraphDef (TF_Graph graph, LLBuffer* output_graph_def, TF_Status status);
public void ToGraphDef (TFBuffer outputGraphDef, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (outputGraphDef == null)
throw new ArgumentNullException (nameof (outputGraphDef));
var cstatus = TFStatus.Setup (status);
unsafe
{
TF_GraphToGraphDef (handle, outputGraphDef.LLBuffer, cstatus.handle);
}
cstatus.CheckMaybeRaise (status);
}
// extern void TF_GraphImportGraphDef (TF_Graph *graph, const TF_Buffer *graph_def, const TF_ImportGraphDefOptions *options, TF_Status *status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_GraphImportGraphDef (TF_Graph graph, LLBuffer* graph_def, TF_ImportGraphDefOptions options, TF_Status status);
public void Import (TFBuffer graphDef, string prefix = "", TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (graphDef == null)
throw new ArgumentNullException (nameof (graphDef));
if (prefix == null)
throw new ArgumentNullException (nameof (prefix));
using (var options = new TFImportGraphDefOptions ()) {
options.SetPrefix (prefix);
Import (graphDef, options, status);
}
}
public void Import (TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (graphDef == null)
throw new ArgumentNullException (nameof (graphDef));
if (options == null)
throw new ArgumentNullException (nameof (options));
var cstatus = TFStatus.Setup (status);
unsafe
{
TF_GraphImportGraphDef (handle, graphDef.LLBuffer, options.handle, cstatus.handle);
}
cstatus.CheckMaybeRaise (status);
}
public void Import (byte [] buffer, string prefix = "", TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (buffer == null)
throw new ArgumentNullException (nameof (buffer));
if (prefix == null)
throw new ArgumentNullException (nameof (prefix));
using (var options = new TFImportGraphDefOptions ()) {
options.SetPrefix (prefix);
Import (buffer, options, status);
}
}
public void Import (byte [] buffer, TFImportGraphDefOptions options, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (buffer == null)
throw new ArgumentNullException (nameof (buffer));
if (options == null)
throw new ArgumentNullException (nameof (options));
var cstatus = TFStatus.Setup (status);
using (var tb = new TFBuffer (buffer, 0, buffer.Length))
Import (tb, options, status);
cstatus.CheckMaybeRaise (cstatus);
}
// extern TF_Operation * TF_GraphOperationByName (TF_Graph *graph, const char *oper_name);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe TF_Operation TF_GraphOperationByName (TF_Graph graph, string oper_name);
public TFOperation this [string name] {
get {
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var h = TF_GraphOperationByName (handle, name);
if (h == IntPtr.Zero)
return null;
return new TFOperation (this, h);
}
}
// extern TF_Operation * TF_GraphNextOperation (TF_Graph *graph, size_t *pos);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe TF_Operation TF_GraphNextOperation (TF_Graph graph, ref IntPtr token);
public IEnumerable<TFOperation> GetEnumerator ()
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
IntPtr token = IntPtr.Zero;
IntPtr operll;
while ((operll = TF_GraphNextOperation (handle, ref token)) != IntPtr.Zero)
yield return new TFOperation (this, operll);
}
/// <summary>
/// Returns the tensor shape for the specific output pparameters as an array of longs.
/// </summary>
/// <returns>null for single dimension, .</returns>
/// <param name="output">The output operation to probe.</param>
/// <param name="status">Status.</param>
public long [] GetShape (TFOutput output, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
var cstatus = TFStatus.Setup (status);
var ndims = TF_GraphGetTensorNumDims (handle, output, cstatus.handle);
if (!cstatus.CheckMaybeRaise (status, last: false))
return null;
if (ndims == 0)
return null;
var ret = new long [ndims];
TF_GraphGetTensorShape (handle, output, ref ret, ndims, cstatus.handle);
cstatus.CheckMaybeRaise (status);
return ret;
}
/// <summary>
/// Returns the current name scope in use, to change this, use the WithScope method.
/// </summary>
/// <value>The current name scope.</value>
public string CurrentNameScope { get; internal set; } = "";
/// <summary>
/// Creates a new namescope by setting the scope to the description provided.
/// </summary>
/// <returns>A new scope that will remain in use until the return TFScope is disposed.</returns>
/// <param name="nameScopeDesc">The namescope description, if the value is null, this
/// will reset the toplevel namescope to be the empty value. </param>
/// <remarks>
/// To more easily name your operations and group then, you can use the
/// WithScope method to set a current name scope that alter the complete name
/// of an operation added to the graph.
///
/// The graph starts with a scope set to the empty string, you can introduce new
/// scopes by calling WithScope, and can be conveniently used with the C# using
/// statement, like this:
///
/// <code>
/// Assert (graph.CurrentNamescope, "");
/// using (var nested = graph.WithScope ("nested")){
/// Assert (graph.CurrentNameScope, "nested");
/// using (var inner = graph.WithScope ("inner")){
/// Assert (graph.CurrentNameScope, "nested/inner");
/// }
/// }
/// </code>
/// </remarks>
public TFScope WithScope (string nameScopeDesc)
{
var scope = new TFScope (this);
if (scope == null)
CurrentNameScope = "";
else if (CurrentNameScope.Length == 0)
CurrentNameScope = nameScopeDesc;
else
CurrentNameScope = CurrentNameScope + "/" + nameScopeDesc;
return scope;
}
Dictionary<string, int> values = new Dictionary<string, int> ();
string MakeName (string operName, string userName)
{
if (userName == null) {
var k = CurrentNameScope == "" ? operName : CurrentNameScope + "/" + operName;
return MakeUnique (k);
}
if (CurrentNameScope == "")
return userName;
return CurrentNameScope + "/" + userName;
}
string MakeUnique (string name)
{
int val = 0;
if (!values.TryGetValue (name, out val))
val = 0;
else
val++;
values [name] = val;
return name + val;
}
internal int LastId;
internal int GetNextId ()
{
return LastId++;
}
[DllImport (NativeBinding.TensorFlowLibrary)]
unsafe extern static void TF_GraphImportGraphDefWithReturnOutputs (
TF_Graph graph, LLBuffer *graph_def,
TF_ImportGraphDefOptions options, TFOutput *return_outputs,
int num_return_outputs, TF_Status status);
/// <summary>
/// Imports a graph serialized into the graph
/// </summary>
/// <param name="graphDef">Serialized graph definition (in protocol buffer format).</param>
/// <param name="options">Import options.</param>
/// <param name="returnOutputs">Array large enough to contain all the return options.</param>
/// <param name="status">Status, optional.</param>
public void ImportGraphDef (TFBuffer graphDef, TFImportGraphDefOptions options, TFOutput [] returnOutputs, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (graphDef == null)
throw new ArgumentNullException (nameof (graphDef));
if (options == null)
throw new ArgumentNullException (nameof (options));
var cstatus = TFStatus.Setup (status);
unsafe
{
if (returnOutputs == null) {
TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, null, 0, cstatus.handle);
} else {
fixed (TFOutput* first = &returnOutputs [0])
{
TF_GraphImportGraphDefWithReturnOutputs (handle, graphDef.LLBuffer, options.handle, first, returnOutputs.Length, cstatus.handle);
}
}
}
}
[StructLayout (LayoutKind.Sequential)]
unsafe struct TFWhileParams
{
public int ninputs;
public TF_Graph cond_graph;
public TFOutput* cond_inputs;
public TFOutput cond_output;
public TF_Graph body_graph;
public TFOutput* body_inputs;
public TFOutput* body_outputs;
public IntPtr charPtrName;
}
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe TFWhileParams TF_NewWhile (TF_Graph g, TFOutput [] inputs, int ninputs, TF_Status status);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern void TF_AbortWhile (ref TFWhileParams pars);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_FinishWhile (ref TFWhileParams pars, TF_Status status, TFOutput *outputs);
static unsafe TFOutput [] CopyFrom (TFOutput* ptr, int n)
{
var r = new TFOutput [n];
for (int i = 0; i < n; i++)
r [i] = ptr [i];
return r;
}
/// <summary>
/// Signature of the method that will be invoked by the TFGraph.While method to construct a while loop
/// </summary>
/// <remarks>
/// The method should build up the condition on the conditionGraph and the body of the while
/// loop in the provided bodyGraph. It should set the condOutput to the value used as the
/// condition output and the array of values in bodyOutputs to the final outputs as well as the
/// name to be used, if not set, one will be assigned.
///
/// The conditionGraph represents the while condition and the inputs are the current values of the
/// input variables (condInputs). The output should be a scalar boolean.
///
/// The loop body graph is in bodyGraph, The inputs are the current values of the loop
/// variables. The outputs are the updated values of the loop variables.
///
/// You can use the passed status record problems with it.
/// </remarks>
public delegate void WhileConstructor (TFGraph conditionGraph, TFOutput [] condInputs, out TFOutput condOutput, TFGraph bodyGraph, TFOutput [] bodyInputs, TFOutput [] bodyOutputs, out string name);
/// <summary>
/// Constructs a while loop with the specified inputs and a callback that composes the while loop
/// </summary>
/// <param name="inputs">Inputs.</param>
/// <param name="constructor">Callback method that fills out the various while loop parameters.</param>
/// <returns>
/// An array of TFOutputs from creating the While loop, or null if there is an error creating the
/// while loop, or if the constructor raised an exception when it was invoked.
/// </returns>
public TFOutput [] While (TFOutput [] inputs, WhileConstructor constructor, TFStatus status = null)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (inputs == null)
throw new ArgumentNullException (nameof (inputs));
if (constructor == null)
throw new ArgumentNullException (nameof (constructor));
var cstatus = TFStatus.Setup (status);
TFWhileParams result = TF_NewWhile (handle, inputs, inputs.Length, cstatus.handle);
if (cstatus.Error)
return null;
try {
//
// Call constructor here
// Wrap the various TF_graphs (with owns=false)
// Marshal the condInputs, bodyInputs
//
TFOutput condOutput;
string name;
int n = result.ninputs;
TFOutput [] bodyOutputs = new TFOutput [n];
unsafe
{
var condGraph = new TFGraphUnowned (result.cond_graph);
var bodyGraph = new TFGraphUnowned (result.body_graph);
constructor (condGraph, CopyFrom (result.cond_inputs, n), out result.cond_output, bodyGraph, CopyFrom (result.body_inputs, n), bodyOutputs, out name);
}
if (name == null || name == "")
name = MakeUnique ("while");
// On return, copy the condOutput and bodyOututs
var text = Encoding.UTF8.GetBytes (name);
result.charPtrName = Marshal.AllocHGlobal (text.Length + 1);
Marshal.Copy (text, 0, result.charPtrName, text.Length);
Marshal.WriteByte (result.charPtrName, text.Length, 0);
unsafe
{
for (int i = 0; i < n; i++)
result.body_outputs [i] = bodyOutputs [i];
var ret = new TFOutput [inputs.Length];
fixed (TFOutput* first = &ret [0])
TF_FinishWhile (ref result, cstatus.handle, first);
if (cstatus.CheckMaybeRaise (status))
return ret;
}
return null;
} catch {
TF_AbortWhile (ref result);
return null;
}
}
}
//
// A TFGraph that will not release the undelying handle, this is used
// when we want to surface a TFGraph that we do not own, so we do not
// want to delete the handle when this object is collected
//
internal class TFGraphUnowned : TFGraph
{
internal TFGraphUnowned (IntPtr handle) : base (handle)
{
}
internal override void NativeDispose (TF_Status handle)
{
// nothing, we do not own the handle
}
}
/// <summary>
/// TFGraph name scope handle
/// </summary>
/// <remarks>
/// Instances of this class when disposed restore the CurrentNameScope to the
/// value they had when the TFGraph.WithScope method was called.
/// </remarks>
public class TFScope : IDisposable
{
TFGraph container;
string name;
internal TFScope (TFGraph container)
{
this.container = container;
name = container.CurrentNameScope;
}
public void Dispose ()
{
container.CurrentNameScope = name;
}
}
/// <summary>
/// Low-level TensorFlow operation builder
/// </summary>
/// <remarks>
/// This is the low-level API that is used to create operations by manually specificying all
/// the parameters of an operation (inputs, outputs, attribute descriptions) that can then
/// be attached into a graph.
///
/// Generally, you will instead be using the methods surfaced in <see cref="T:TensorFlow.TFGraph"/>
/// that surfaces a C# high-level API that has already been bound to the built-in TensorFlow
/// nodes.
/// </remarks>
public class TFOperationDesc : TFDisposable
{
string opType, operName;
TFGraph graph;
// extern TF_OperationDescription * TF_NewOperation (TF_Graph *graph, const char *op_type, const char *oper_name);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe TF_OperationDescription TF_NewOperation (TF_Graph graph, string opType, string oper_name);
public TFOperationDesc (TFGraph graph, string opType, string operName) : base (IntPtr.Zero)
{
if (graph == null)
throw new ArgumentNullException ("graph");
handle = TF_NewOperation (graph.handle, opType, operName);
this.graph = graph;
this.opType = opType;
this.operName = operName;
}
internal override void NativeDispose (IntPtr handle)
{
// If you reach this, you never called FinishOperation
Console.WriteLine ($"TFOperationDescription({opType},{operName} was never turned into an TFOperation");
}
// extern void TF_SetDevice (TF_OperationDescription *desc, const char *device);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetDevice (TF_OperationDescription desc, string device);
public void SetDevice (string device)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (device == null)
throw new ArgumentNullException ("device");
TF_SetDevice (handle, device);
}
// extern void TF_AddInput (TF_OperationDescription *desc, TF_Output input);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_AddInput (TF_OperationDescription desc, TFOutput input);
public void AddInput (TFOutput input)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
TF_AddInput (handle, input);
}
// extern void TF_AddInputList (TF_OperationDescription *desc, const TF_Output *inputs, int num_inputs);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_AddInputList (TF_OperationDescription desc, TFOutput [] inputs, int num_inputs);
public void AddInputs (params TFOutput [] inputs)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (inputs == null || inputs.Length == 0)
return;
TF_AddInputList (handle, inputs, inputs.Length);
}
// extern void TF_AddControlInput (TF_OperationDescription *desc, TF_Operation *input);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_AddControlInput (TF_OperationDescription desc, TF_Operation input);
public void AddControlInput (TFOperation input)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (input == null)
throw new ArgumentNullException ("input");
TF_AddControlInput (handle, input.handle);
}
// extern void TF_ColocateWith (TF_OperationDescription *desc, TF_Operation *op);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_ColocateWith (TF_OperationDescription desc, TF_Operation op);
public void ColocateWith (TFOperation op)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (op == null)
throw new ArgumentNullException ("op");
TF_ColocateWith (handle, op.handle);
}
// extern void TF_SetAttrString (TF_OperationDescription *desc, const char *attr_name, const void *value, size_t length);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetAttrString (TF_OperationDescription desc, string attr_name, IntPtr value, size_t length);
public void SetAttr (string attrName, string value)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (attrName == null)
throw new ArgumentNullException (nameof (attrName));
var bytes = Encoding.UTF8.GetBytes (value);
var buf = Marshal.AllocHGlobal (bytes.Length + 1);
Marshal.Copy (bytes, 0, buf, bytes.Length);
TF_SetAttrString (handle, attrName, buf, (UIntPtr)bytes.Length);
}
// extern void TF_SetAttrStringList (TF_OperationDescription *desc, const char *attr_name, const void *const *values, const size_t *lengths, int num_values);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetAttrStringList (TF_OperationDescription desc, string attr_name, IntPtr [] values, UIntPtr [] lengths, int num_values);
public void SetAttr (string attrName, string [] values)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (attrName == null)
throw new ArgumentNullException (nameof (attrName));
if (values == null)
throw new ArgumentNullException (nameof (values));
int n = values.Length;
var unmanaged = new IntPtr [n];
var lenghts = new UIntPtr [n];
for (int i = 0; i < n; i++) {
var bytes = Encoding.UTF8.GetBytes (values [i]);
var buf = Marshal.AllocHGlobal (bytes.Length + 1);
var bc = bytes.Length;
Marshal.Copy (bytes, 0, buf, bc);
unmanaged [i] = buf;
lenghts [i] = (size_t)bc;
}
TF_SetAttrStringList (handle, attrName, unmanaged, lenghts, n);
}
// extern void TF_SetAttrInt (TF_OperationDescription *desc, const char *attr_name, int64_t value);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetAttrInt (TF_OperationDescription desc, string attr_name, long value);
public void SetAttr (string attrName, long value)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (attrName == null)
throw new ArgumentNullException (nameof (attrName));
TF_SetAttrInt (handle, attrName, value);
}
// extern void TF_SetAttrIntList (TF_OperationDescription *desc, const char *attr_name, const int64_t *values, int num_values);
[DllImport (NativeBinding.TensorFlowLibrary)]
static extern unsafe void TF_SetAttrIntList (TF_OperationDescription desc, string attr_name, long [] values, int num_values);
public void SetAttr (string attrName, long [] values)
{
if (handle == IntPtr.Zero)
ObjectDisposedException ();
if (attrName == null)
throw new ArgumentNullException (nameof (attrName));
if (values == null)
throw new ArgumentNullException (nameof (values));
TF_SetAttrIntList (handle, attrName, values, values.Length);
}