// Copyright (c) 2025, Apple Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-3-clause license that can be
// found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
#import "CoreMLPythonArray.h"
@implementation PybindCompatibleArray
+ (MLMultiArrayDataType)dataTypeOf:(py::array)array {
const auto& dt = array.dtype();
char kind = dt.kind();
size_t itemsize = dt.itemsize();
if(kind == 'i' && itemsize == 4) {
return MLMultiArrayDataTypeInt32;
}
#if BUILT_WITH_MACOS26_SDK
else if (kind == 'i' && itemsize == 1) {
return MLMultiArrayDataTypeInt8;
}
#endif
else if(kind == 'f' && itemsize == 4) {
return MLMultiArrayDataTypeFloat32;
} else if( (kind == 'f' || kind == 'd') && itemsize == 8) {
return MLMultiArrayDataTypeDouble;
}
throw std::runtime_error("Unsupported array type: " + std::to_string(kind) + " with itemsize = " + std::to_string(itemsize));
}
+ (NSArray