|
31 | 31 | using TF_Library = System.IntPtr; |
32 | 32 | using TF_BufferPtr = System.IntPtr; |
33 | 33 | using TF_Function = System.IntPtr; |
| 34 | +using TF_DeviceList = System.IntPtr; |
34 | 35 |
|
35 | 36 | using size_t = System.UIntPtr; |
36 | 37 | using System.Numerics; |
@@ -2064,6 +2065,27 @@ public TFOutput this [int idx] { |
2064 | 2065 | } |
2065 | 2066 | } |
2066 | 2067 |
|
| 2068 | + public enum DeviceType |
| 2069 | + { |
| 2070 | + CPU, GPU, TPU |
| 2071 | + } |
| 2072 | + |
| 2073 | + public class DeviceAttributes |
| 2074 | + { |
| 2075 | + public DeviceAttributes (string name, DeviceType deviceType, long memoryLimitBytes) |
| 2076 | + { |
| 2077 | + Name = name; |
| 2078 | + DeviceType = deviceType; |
| 2079 | + MemoryLimitBytes = memoryLimitBytes; |
| 2080 | + } |
| 2081 | + |
| 2082 | + public string Name { get; private set; } |
| 2083 | + |
| 2084 | + public DeviceType DeviceType { get; private set; } |
| 2085 | + |
| 2086 | + public long MemoryLimitBytes { get; private set; } |
| 2087 | + } |
| 2088 | + |
2067 | 2089 | /// <summary> |
2068 | 2090 | /// Contains options that are used to control how graph importing works. |
2069 | 2091 | /// </summary> |
@@ -2321,6 +2343,48 @@ public TFSession (TFStatus status = null) : this (new TFGraph (), status) |
2321 | 2343 | [DllImport (NativeBinding.TensorFlowLibrary)] |
2322 | 2344 | static extern unsafe TF_Session TF_LoadSessionFromSavedModel (TF_SessionOptions session_options, LLBuffer* run_options, string export_dir, string [] tags, int tags_len, TF_Graph graph, LLBuffer* meta_graph_def, TF_Status status); |
2323 | 2345 |
|
| 2346 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2347 | + static extern unsafe TF_DeviceList TF_SessionListDevices (TF_Session session, TF_Status status); |
| 2348 | + |
| 2349 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2350 | + static extern unsafe int TF_DeviceListCount (TF_DeviceList list); |
| 2351 | + |
| 2352 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2353 | + static extern unsafe string TF_DeviceListName (TF_DeviceList list, int index, TF_Status status); |
| 2354 | + |
| 2355 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2356 | + static extern unsafe string TF_DeviceListType (TF_DeviceList list, int index, TF_Status status); |
| 2357 | + |
| 2358 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2359 | + static extern unsafe long TF_DeviceListMemoryBytes (TF_DeviceList list, int index, TF_Status status); |
| 2360 | + |
| 2361 | + [DllImport (NativeBinding.TensorFlowLibrary)] |
| 2362 | + static extern unsafe void TF_DeleteDeviceList (TF_DeviceList list); |
| 2363 | + |
| 2364 | + /// <summary> |
| 2365 | + /// Lists available devices in this session. |
| 2366 | + /// </summary> |
| 2367 | + public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null) |
| 2368 | + { |
| 2369 | + var cstatus = TFStatus.Setup (status); |
| 2370 | + var rawDeviceList = TF_SessionListDevices (this.Handle, cstatus.handle); |
| 2371 | + var size = TF_DeviceListCount (rawDeviceList); |
| 2372 | + |
| 2373 | + var list = new List<DeviceAttributes> (); |
| 2374 | + for (var i = 0; i < size; i++) { |
| 2375 | + var name = TF_DeviceListName (rawDeviceList, i, cstatus.handle); |
| 2376 | + var deviceType = (DeviceType) Enum.Parse (typeof(DeviceType), TF_DeviceListType (rawDeviceList, i, cstatus.handle)); |
| 2377 | + var memory = TF_DeviceListMemoryBytes (rawDeviceList, i, cstatus.handle); |
| 2378 | + |
| 2379 | + list.Add (new DeviceAttributes (name, deviceType, memory)); |
| 2380 | + } |
| 2381 | + |
| 2382 | + // TODO: Fix deleting. |
| 2383 | + // TF_DeleteDeviceList (rawDeviceList); |
| 2384 | + |
| 2385 | + return list; |
| 2386 | + } |
| 2387 | + |
2324 | 2388 | /// <summary> |
2325 | 2389 | /// Creates a session and graph from a saved session model |
2326 | 2390 | /// </summary> |
|
0 commit comments