Skip to content

Commit f060536

Browse files
authored
fix horrifying predict output transpose issue (#144)
1 parent e6de5af commit f060536

File tree

6 files changed

+46
-3
lines changed

6 files changed

+46
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "XGBoost"
22
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
3-
version = "2.1.0"
3+
version = "2.1.1"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

assets/data/blobs.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

src/booster.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ function deserialize(::Type{Booster}, buf::AbstractVector{UInt8}, data=DMatrix[]
255255
deserialize!(b, buf)
256256
end
257257

258+
# sadly this is type unstable because we might return a transpose
258259
"""
259260
predict(b::Booster, data; margin=false, training=false, ntree_limit=0)
260261
@@ -287,8 +288,9 @@ function predict(b::Booster, Xy::DMatrix;
287288
odim = Ref{Lib.bst_ulong}()
288289
o = Ref{Ptr{Cfloat}}()
289290
xgbcall(XGBoosterPredictFromDMatrix, b.handle, Xy.handle, opts, oshape, odim, o)
290-
dims = unsafe_wrap(Array, oshape[], odim[])
291-
unsafe_wrap(Array, o[], tuple(dims...))
291+
dims = reverse(unsafe_wrap(Array, oshape[], odim[]))
292+
o = unsafe_wrap(Array, o[], tuple(dims...))
293+
length(dims) > 1 ? transpose(o) : o
292294
end
293295
predict(b::Booster, Xy; kw...) = predict(b, DMatrix(Xy); kw...)
294296

test/gendata.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# this file is for generating test data that's stored in the repository
2+
# it uses MLJ which is not a dependency
3+
4+
# JSON is used as the serialization format only because it is already a dependency
5+
6+
using MLJ
7+
using JSON3
8+
9+
10+
# see utils.jl for loading of the output
11+
function gen_classification()
12+
(X, y) = make_blobs(1000, 3; centers=3, rng=999)
13+
y = Int.(int.(y)) .- 1
14+
X = MLJ.matrix(X)
15+
fname = joinpath(@__DIR__,"..","assets","data","blobs.json")
16+
dict = Dict("X1"=>X[:,1], "X2"=>X[:,2], "X3"=>X[:,3], "y"=>y)
17+
JSON3.write(open(fname, write=true), dict)
18+
end
19+

test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,17 @@ end
126126
end
127127
end
128128

129+
@testset "Blobs training" begin
130+
(X, y) = load_classification()
131+
132+
bst = xgboost((X, y), num_round=10, objective="multi:softprob", num_class=3, watchlist=Dict())
133+
134+
= map-> argmax(ζ) - 1, eachrow(predict(bst, X)))
135+
136+
# this is a pretty low bar that xgboost should always pass
137+
@test sum(ŷ .== y)/length(y) > 0.9
138+
end
139+
129140
@testset "Feature importance" begin
130141
dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"))
131142
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"))

test/utils.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using JSON3
2+
13

24
function testfilepath(name::AbstractString)
35
dir = joinpath(dirname(pathof(XGBoost)), "..")
@@ -22,3 +24,11 @@ function readlibsvm(fname::AbstractString, shape)
2224
end
2325
(dmx, label)
2426
end
27+
28+
function load_classification()
29+
fname = joinpath(@__DIR__,"..","assets","data","blobs.json")
30+
o = JSON3.read(String(open(read, fname)))
31+
X = Matrix{Float32}([o[:X1] o[:X2] o[:X3]])
32+
y = o[:y]
33+
(X, y)
34+
end

0 commit comments

Comments
 (0)