Skip to content

Commit 7cfa6a1

Browse files
alexpantyukhinmigueldeicaza
authored andcommitted
* add list devices. * Adapting ListDevices method for existing method of handling errors.
1 parent 0a3c663 commit 7cfa6a1

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

TensorFlowSharp/Tensorflow.cs

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
using TF_Library = System.IntPtr;
3232
using TF_BufferPtr = System.IntPtr;
3333
using TF_Function = System.IntPtr;
34+
using TF_DeviceList = System.IntPtr;
3435

3536
using size_t = System.UIntPtr;
3637
using System.Numerics;
@@ -2064,6 +2065,27 @@ public TFOutput this [int idx] {
20642065
}
20652066
}
20662067

2068+
public enum DeviceType
2069+
{
2070+
CPU, GPU, TPU
2071+
}
2072+
2073+
public class DeviceAttributes
2074+
{
2075+
public DeviceAttributes (string name, DeviceType deviceType, long memoryLimitBytes)
2076+
{
2077+
Name = name;
2078+
DeviceType = deviceType;
2079+
MemoryLimitBytes = memoryLimitBytes;
2080+
}
2081+
2082+
public string Name { get; private set; }
2083+
2084+
public DeviceType DeviceType { get; private set; }
2085+
2086+
public long MemoryLimitBytes { get; private set; }
2087+
}
2088+
20672089
/// <summary>
20682090
/// Contains options that are used to control how graph importing works.
20692091
/// </summary>
@@ -2321,6 +2343,48 @@ public TFSession (TFStatus status = null) : this (new TFGraph (), status)
23212343
[DllImport (NativeBinding.TensorFlowLibrary)]
23222344
static extern unsafe TF_Session TF_LoadSessionFromSavedModel (TF_SessionOptions session_options, LLBuffer* run_options, string export_dir, string [] tags, int tags_len, TF_Graph graph, LLBuffer* meta_graph_def, TF_Status status);
23232345

2346+
[DllImport (NativeBinding.TensorFlowLibrary)]
2347+
static extern unsafe TF_DeviceList TF_SessionListDevices (TF_Session session, TF_Status status);
2348+
2349+
[DllImport (NativeBinding.TensorFlowLibrary)]
2350+
static extern unsafe int TF_DeviceListCount (TF_DeviceList list);
2351+
2352+
[DllImport (NativeBinding.TensorFlowLibrary)]
2353+
static extern unsafe string TF_DeviceListName (TF_DeviceList list, int index, TF_Status status);
2354+
2355+
[DllImport (NativeBinding.TensorFlowLibrary)]
2356+
static extern unsafe string TF_DeviceListType (TF_DeviceList list, int index, TF_Status status);
2357+
2358+
[DllImport (NativeBinding.TensorFlowLibrary)]
2359+
static extern unsafe long TF_DeviceListMemoryBytes (TF_DeviceList list, int index, TF_Status status);
2360+
2361+
[DllImport (NativeBinding.TensorFlowLibrary)]
2362+
static extern unsafe void TF_DeleteDeviceList (TF_DeviceList list);
2363+
2364+
/// <summary>
2365+
/// Lists available devices in this session.
2366+
/// </summary>
2367+
public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null)
2368+
{
2369+
var cstatus = TFStatus.Setup (status);
2370+
var rawDeviceList = TF_SessionListDevices (this.Handle, cstatus.handle);
2371+
var size = TF_DeviceListCount (rawDeviceList);
2372+
2373+
var list = new List<DeviceAttributes> ();
2374+
for (var i = 0; i < size; i++) {
2375+
var name = TF_DeviceListName (rawDeviceList, i, cstatus.handle);
2376+
var deviceType = (DeviceType) Enum.Parse (typeof(DeviceType), TF_DeviceListType (rawDeviceList, i, cstatus.handle));
2377+
var memory = TF_DeviceListMemoryBytes (rawDeviceList, i, cstatus.handle);
2378+
2379+
list.Add (new DeviceAttributes (name, deviceType, memory));
2380+
}
2381+
2382+
// TODO: Fix deleting.
2383+
// TF_DeleteDeviceList (rawDeviceList);
2384+
2385+
return list;
2386+
}
2387+
23242388
/// <summary>
23252389
/// Creates a session and graph from a saved session model
23262390
/// </summary>

0 commit comments

Comments
 (0)