forked from apple/coremltools
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCoreMLPython.h
More file actions
64 lines (54 loc) · 2.21 KB
/
CoreMLPython.h
File metadata and controls
64 lines (54 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wexit-time-destructors"
#pragma clang diagnostic ignored "-Wdocumentation"
#pragma clang diagnostic ignored "-Wdeprecated-declarations"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#pragma clang diagnostic pop
#import <CoreML/CoreML.h>
#import "NeuralNetworkBuffer.hpp"
#import "Validation/NeuralNetwork/NeuralNetworkShapes.hpp"
namespace py = pybind11;
namespace CoreML {
namespace Python {
class Model {
private:
MLModel *m_model = nil;
NSURL *compiledUrl = nil;
public:
Model(const Model&) = delete;
Model& operator=(const Model&) = delete;
~Model();
explicit Model(const std::string& urlStr, bool useCPUOnly);
py::dict predict(const py::dict& input, bool useCPUOnly);
static py::bytes autoSetSpecificationVersion(const py::bytes& modelBytes);
static int32_t maximumSupportedSpecificationVersion();
std::string toString() const;
};
class NeuralNetworkShapeInformation {
private:
std::unique_ptr<NeuralNetworkShaper> shaper;
public:
NeuralNetworkShapeInformation(const std::string& filename);
NeuralNetworkShapeInformation(const std::string& filename, bool useInputAndOutputConstraints);
void init(const std::string& filename);
py::dict shape(const std::string& name);
std::string toString() const;
void print() const;
};
// TODO:
// Create template class and create instance with respect
// to datatypes
class NeuralNetworkBufferInformation {
private:
std::unique_ptr<NNBuffer::NeuralNetworkBuffer> nnBuffer;
public:
NeuralNetworkBufferInformation(const std::string& bufferFilePath, NNBuffer::BufferMode mode);
~NeuralNetworkBufferInformation();
template <typename T>
u_int64_t addBuffer(const std::vector<T>& buffer);
template <typename T>
std::vector<T> getBuffer(const u_int64_t offset);
};
}
}