Skip to content

Commit 6e710ad

Browse files
authored
java side deprecations/removals (#10185)
typesafe copy implementations
1 parent 693e558 commit 6e710ad

File tree

17 files changed

+90
-405
lines changed

17 files changed

+90
-405
lines changed

libnd4j/include/ops/special_random_ops.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,6 @@ class TruncatedNormalDistribution {
367367
sd::LongType *xShape = shape::shapeOf(xShapeBuffer);
368368
sd::LongType xRank = shape::rank(xShapeBuffer);
369369
sd::LongType *xStride = shape::stride(xShapeBuffer);
370-
const T epsilon = static_cast<T>(1e-5);
371370

372371
auto func = PRAGMA_THREADS_FOR {
373372
for (auto e = start; e < stop; e++) {

libnd4j/include/system/op_boilerplate.h

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,26 @@
7979
#define ELEMENT_THRESHOLD sd::Environment::getInstance().elementwiseThreshold()
8080
#define TAD_THRESHOLD sd::Environment::getInstance().tadThreshold()
8181

82-
#define SHAPELIST(...) new ShapeList({__VA_ARGS__}, block.workspace() != nullptr)
82+
// Helper to pick the correct macro based on the number of arguments.
83+
// This macro works by “peeling” off up to 10 parameters; if only one parameter is passed,
84+
// then _1 is that parameter and NAME ends up as SHAPELIST_1. For two or more parameters,
85+
// NAME ends up as SHAPELIST_N.
86+
#define GET_SHAPELIST_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, NAME, ...) NAME
87+
88+
// This chooser macro always returns SHAPELIST_1 when only one argument is given,
89+
// and returns SHAPELIST_N when more than one argument is passed.
90+
#define SHAPELIST_CHOOSER(...) \
91+
GET_SHAPELIST_MACRO(__VA_ARGS__, SHAPELIST_N, SHAPELIST_N, SHAPELIST_N, SHAPELIST_N, \
92+
SHAPELIST_N, SHAPELIST_N, SHAPELIST_N, SHAPELIST_N, SHAPELIST_N, SHAPELIST_1)
93+
94+
// For one argument, call the constructor directly.
95+
#define SHAPELIST_1(a) new ShapeList(a)
96+
// For two or more arguments, wrap them in braces so that the constructor
97+
// accepting a vector is called.
98+
#define SHAPELIST_N(...) new ShapeList({__VA_ARGS__})
99+
100+
// Finally, define SHAPELIST(...) to select the proper version:
101+
#define SHAPELIST(...) SHAPELIST_CHOOSER(__VA_ARGS__)(__VA_ARGS__)
83102

84103
#ifdef __CUDA_ARCH__
85104
#define PRINT_FIRST(...) \
@@ -2387,7 +2406,7 @@
23872406
auto shapeList = SHAPELIST(); \
23882407
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() \
23892408
: this->getOpDescriptor()->getNumberOfOutputs(); \
2390-
for (int e = 0; e < opLimit; e++) { \
2409+
for (size_t e = 0; e < opLimit; e++) { \
23912410
auto newshape = ConstantShapeHelper::getInstance().createShapeInfo( \
23922411
ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), \
23932412
shape::shapeOf(inputShape->at(e)),shape::extra(inputShape->at(e))); \
@@ -2431,7 +2450,7 @@
24312450
auto shapeList = SHAPELIST(); \
24322451
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() \
24332452
: this->getOpDescriptor()->getNumberOfOutputs(); \
2434-
for (int e = 0; e < opLimit; e++) { \
2453+
for (size_t e = 0; e < opLimit; e++) { \
24352454
sd::LongType* newshape; \
24362455
COPY_SHAPE(inputShape->at(0), newshape); \
24372456
shapeList->push_back(CONSTANT(newshape)); \
@@ -2459,7 +2478,7 @@
24592478
auto shapeList = SHAPELIST(); \
24602479
auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() \
24612480
: this->getOpDescriptor()->getNumberOfOutputs(); \
2462-
for (int e = 0; e < opLimit; e++) { \
2481+
for (size_t e = 0; e < opLimit; e++) { \
24632482
int inputShapeIdx = block.width() < opLimit ? 0 : e; \
24642483
auto shapeInfo = inputShape->at(inputShapeIdx); \
24652484
if(shape::isEmptyConst(shapeInfo)) { \
@@ -2474,7 +2493,7 @@
24742493
} \
24752494
\
24762495
auto dtString = DataTypeUtils::asString(ArrayOptions::dataType(shapeInfo)); \
2477-
printf("CONFIGURABLE_OP_IMPL: Creating empty data type: %s for index %d\n",dtString.c_str(),e);\
2496+
printf("CONFIGURABLE_OP_IMPL: Creating empty data type: %s for index %d\n",dtString.c_str(),static_cast<int>(e));\
24782497
\
24792498
auto newShape = ConstantShapeHelper::getInstance() \
24802499
.emptyShapeInfoWithShape(ArrayOptions::dataType(shapeInfo),shape2); \
@@ -2642,13 +2661,18 @@
26422661

26432662
#endif
26442663

2664+
#include <type_traits>
2665+
#include <cstring>
2666+
#include <algorithm>
2667+
26452668
template <typename TT, typename WW>
26462669
SD_INLINE TT* internal_alloc_host(WW workSpace, sd::LongType len) {
26472670
TT* var;
26482671
if (workSpace == nullptr) {
26492672
#if defined(SD_ALIGNED_ALLOC)
26502673
var = static_cast<TT*>(
2651-
aligned_alloc(SD_DESIRED_ALIGNMENT, (len * sizeof(TT) + SD_DESIRED_ALIGNMENT - 1) & (-SD_DESIRED_ALIGNMENT)));
2674+
aligned_alloc(SD_DESIRED_ALIGNMENT,
2675+
(len * sizeof(TT) + SD_DESIRED_ALIGNMENT - 1) & (-SD_DESIRED_ALIGNMENT)));
26522676
#else
26532677
var = new TT[len];
26542678
#endif
@@ -2658,10 +2682,15 @@ SD_INLINE TT* internal_alloc_host(WW workSpace, sd::LongType len) {
26582682
} else {
26592683
var = reinterpret_cast<TT*>(workSpace->allocateBytes(len * sizeof(TT)));
26602684
}
2661-
memset(var, 0, len * sizeof(TT));
2685+
if constexpr (std::is_trivially_copyable<TT>::value) {
2686+
memset(var, 0, len * sizeof(TT));
2687+
} else {
2688+
std::fill_n(var, len, TT());
2689+
}
26622690
return var;
26632691
}
26642692

2693+
26652694
template <typename TT_PTR, typename WW>
26662695
SD_INLINE void internal_release_host(WW workspace, TT_PTR var) {
26672696
if (workspace == nullptr) {
@@ -2680,7 +2709,6 @@ SD_INLINE void internal_release_host(WW workspace, TT_PTR var) {
26802709
#ifndef __JAVACPP_HACK__
26812710

26822711
#if defined(SD_GCC_FUNCTRACE) && !defined(OP_BOILER_PLATE_THROW_EXCEPTIONS)
2683-
#pragma once
26842712
#define OP_BOILER_PLATE_THROW_EXCEPTIONS
26852713
#include <exceptions/backward.hpp>
26862714
using namespace backward;

libnd4j/include/types/bfloat16.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct isNumericType {
5151
struct bfloat16 {
5252

5353
public:
54+
constexpr bfloat16(const bfloat16&) = default;
55+
5456
int16_t _data;
5557

5658
SD_INLINE SD_HOST_DEVICE bfloat16() { _data = 0; }

libnd4j/include/types/float16.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ struct float16 {
204204
};
205205

206206
public:
207+
constexpr float16(const float16&) = default;
208+
207209
ihalf data;
208210
SD_INLINE SD_HOST_DEVICE float16() { *data.getXP() = 0; }
209211

libnd4j/include/types/float8.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ quarter SD_INLINE SD_HOST_DEVICE cpu_float2quarter_rn(float f);
3737
float SD_INLINE SD_HOST_DEVICE cpu_quarter2float(quarter b);
3838

3939
struct float8 {
40+
constexpr float8(const float8&) = default;
41+
4042
quarter data;
4143

4244
SD_INLINE SD_HOST_DEVICE float8();

libnd4j/include/types/types.h

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,4 +1325,36 @@
13251325
// Callback macro
13261326
#define CALLBACK_INSTANTIATE_NORM(a1, b1, FUNC_NAME, ARGS) \
13271327
INSTANTIATE_NORM(a1, b1, FUNC_NAME, ARGS)
1328-
1328+
#ifndef SD_COPY
1329+
#define SD_COPY
1330+
namespace sd {
1331+
namespace ops {
1332+
template <typename U, typename V>
1333+
static void safe_copy(U* dest, const V* src, size_t count) {
1334+
if constexpr (std::is_same<U, V>::value && std::is_trivially_copyable<U>::value) {
1335+
memcpy(dest, src, count * sizeof(U));
1336+
} else {
1337+
std::copy(src, src + count, dest);
1338+
}
1339+
}
1340+
1341+
#define INSTANTIATE_COPY(a1,b1,FUNC_NAME,ARGS) template void safe_copy<GET_SECOND(a1), GET_SECOND(b1)>( GET_SECOND(a1)* dest, const GET_SECOND(b1) * src, size_t count);
1342+
ITERATE_COMBINATIONS((SD_NUMERIC_TYPES), (SD_NUMERIC_TYPES), INSTANTIATE_COPY,safe_copy,;);
1343+
1344+
template <typename T>
1345+
static void safe_zero(T* dest, size_t count) {
1346+
if constexpr (std::is_trivially_copyable<T>::value) {
1347+
// For trivially copyable types, we can use memset.
1348+
memset(dest, 0, count * sizeof(T));
1349+
} else {
1350+
// Otherwise, default-construct each element.
1351+
std::fill_n(dest, count, T());
1352+
}
1353+
}
1354+
1355+
#define INSTANTIATE_ZERO(a1) template void safe_zero<GET_SECOND(a1)>( GET_SECOND(a1)* dest, size_t count);
1356+
ITERATE_LIST((SD_NUMERIC_TYPES), INSTANTIATE_ZERO)
1357+
}
1358+
1359+
}
1360+
#endif

nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/BlasBufferUtil.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public static long getBlasOffset(INDArray arr) {
4949
* @return the blas stride
5050
*/
5151
public static int getBlasStride(INDArray arr) {
52-
return arr.elementWiseStride();
52+
return arr.stride(-1);
5353
}
5454

5555
/**

nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/params/GemmParams.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ private INDArray copyIfNeccessary(INDArray arr) {
157157
return arr.dup();
158158
else if (arr.ordering() == 'f' && (arr.stride(0) != 1 || arr.stride(1) != arr.size(0)))
159159
return arr.dup();
160-
else if (arr.elementWiseStride() < 0)
160+
else if (arr.isView())
161161
return arr.dup();
162162
return arr;
163163
}

nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/blas/params/GemvParameters.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,19 +44,19 @@ public GemvParameters(INDArray a, INDArray x, INDArray y) {
4444

4545

4646
if (a.ordering() == 'f' && a.isMatrix()) {
47-
this.m = (int) a.rows();
48-
this.n = (int) a.columns();
49-
this.lda = (int) a.rows();
47+
this.m = a.rows();
48+
this.n = a.columns();
49+
this.lda = a.rows();
5050
} else if (a.ordering() == 'c' && a.isMatrix()) {
51-
this.m = (int) a.columns();
52-
this.n = (int) a.rows();
53-
this.lda = (int) a.columns();
51+
this.m = a.columns();
52+
this.n = a.rows();
53+
this.lda = a.columns();
5454
aOrdering = 'T';
5555
}
5656

5757
else {
58-
this.m = (int) a.rows();
59-
this.n = (int) a.columns();
58+
this.m = a.rows();
59+
this.n = a.columns();
6060
this.lda = (int) a.size(0);
6161
}
6262

@@ -69,7 +69,7 @@ public GemvParameters(INDArray a, INDArray x, INDArray y) {
6969
incx = x.stride(1);
7070
}
7171

72-
this.incy = y.elementWiseStride();
72+
this.incy = y.stride(1);
7373

7474
}
7575

@@ -82,7 +82,7 @@ private INDArray copyIfNecessary(INDArray arr) {
8282
return arr.dup();
8383
else if (arr.ordering() == 'f' && (arr.stride(0) != 1 || arr.stride(1) != arr.size(0)))
8484
return arr.dup();
85-
else if (arr.elementWiseStride() < 1)
85+
else if (arr.isView())
8686
return arr.dup();
8787
return arr;
8888
}

nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -384,49 +384,6 @@ public interface NDArrayFactory {
384384
*/
385385
void shuffle(List<INDArray> array, Random rnd, List<long[]> dimensions);
386386

387-
/**
388-
* This method averages input arrays, and returns averaged array
389-
*
390-
* @param arrays
391-
* @return
392-
*/
393-
INDArray average(INDArray target, INDArray[] arrays);
394-
395-
/**
396-
* This method averages input arrays, and returns averaged array
397-
*
398-
* @param arrays
399-
* @return
400-
*/
401-
INDArray average(INDArray[] arrays);
402-
403-
/**
404-
* This method averages input arrays, and returns averaged array
405-
*
406-
* @param arrays
407-
* @return
408-
*/
409-
INDArray average(Collection<INDArray> arrays);
410-
411-
412-
/**
413-
* This method sums given arrays to target
414-
*
415-
* @param target
416-
* @param arrays
417-
* @return
418-
*/
419-
INDArray accumulate(INDArray target, INDArray... arrays);
420-
421-
422-
/**
423-
* This method averages input arrays, and returns averaged array
424-
*
425-
* @param arrays
426-
* @return
427-
*/
428-
INDArray average(INDArray target, Collection<INDArray> arrays);
429-
430387

431388
/**
432389
* Create a random ndarray with the given shape using the given rng

0 commit comments

Comments
 (0)