Skip to content

Commit eac5310

Browse files
committed
Fix QR based fitting
Avoid erroring out for low rank design matrices when dropcollinear=false. Avoid unnecessary triangular solves. Avoid indexing in the Q. Avoid slicing R matrix in a way that triggers a minimum norm solution. Remove unnecessary temporaries.
1 parent fce9b70 commit eac5310

File tree

3 files changed

+50
-50
lines changed

3 files changed

+50
-50
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ julia = "1.6"
3636
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
3737
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
3838
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
39+
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
3940
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
4041
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4142
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4243

4344
[targets]
44-
test = ["CategoricalArrays", "CSV", "DataFrames", "RDatasets", "StableRNGs", "Test"]
45+
test = ["CategoricalArrays", "CSV", "DataFrames", "Downloads", "RDatasets", "StableRNGs", "Test"]

src/linpred.jl

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,19 @@ function delbeta! end
6363

6464
function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}) where T<:BlasReal
6565
rnk = rank(p.qr.R)
66-
rnk == length(p.delbeta) || throw(RankDeficientException(rnk))
67-
p.delbeta = p.qr\r
68-
mul!(p.scratchm1, Diagonal(ones(size(r))), p.X)
66+
p.delbeta = p.qr \ r
6967
return p
7068
end
7169

7270
function delbeta!(p::DensePredQR{T,<:QRCompactWY}, r::Vector{T}, wt::Vector{T}) where T<:BlasReal
7371
rnk = rank(p.qr.R)
74-
rnk == length(p.delbeta) || throw(RankDeficientException(rnk))
7572
X = p.X
7673
W = Diagonal(wt)
7774
sqrtW = Diagonal(sqrt.(wt))
7875
mul!(p.scratchm1, sqrtW, X)
79-
mul!(p.delbeta, X'W, r)
80-
qnr = qr(p.scratchm1)
81-
Rinv = inv(qnr.R)
82-
p.delbeta = Rinv * Rinv' * p.delbeta
76+
= sqrtW * r
77+
p.qr = qr!(p.scratchm1)
78+
p.delbeta = p.qr \
8379
return p
8480
end
8581

@@ -88,44 +84,32 @@ function delbeta!(p::DensePredQR{T,<:QRPivoted}, r::Vector{T}) where T<:BlasReal
8884
if rnk == length(p.delbeta)
8985
p.delbeta = p.qr\r
9086
else
91-
R = @view p.qr.R[:, 1:rnk]
92-
Q = @view p.qr.Q[:, 1:size(R, 1)]
87+
R = UpperTriangular(view(parent(p.qr.R), 1:rnk, 1:rnk))
9388
piv = p.qr.p
94-
p.delbeta = zeros(size(p.delbeta))
95-
p.delbeta[1:rnk] = R \ Q'r
89+
fill!(p.delbeta, 0)
90+
p.delbeta[1:rnk] = R \ view(p.qr.Q'r, 1:rnk)
9691
invpermute!(p.delbeta, piv)
9792
end
98-
mul!(p.scratchm1, Diagonal(ones(size(r))), p.X)
9993
return p
10094
end
10195

10296
function delbeta!(p::DensePredQR{T,<:QRPivoted}, r::Vector{T}, wt::Vector{T}) where T<:BlasReal
103-
rnk = rank(p.qr.R)
10497
X = p.X
10598
W = Diagonal(wt)
10699
sqrtW = Diagonal(sqrt.(wt))
107-
delbeta = p.delbeta
108-
scratchm2 = similar(X, T)
109100
mul!(p.scratchm1, sqrtW, X)
110-
mul!(scratchm2, W, X)
111-
mul!(delbeta, transpose(scratchm2), r)
112-
113-
if rnk == length(p.delbeta)
114-
qnr = qr(p.scratchm1)
115-
Rinv = inv(qnr.R)
116-
p.delbeta = Rinv * Rinv' * delbeta
117-
else
118-
qnr = pivoted_qr!(copy(p.scratchm1))
119-
R = @view qnr.R[1:rnk, 1:rnk]
120-
Rinv = inv(R)
121-
piv = qnr.p
122-
permute!(delbeta, piv)
123-
for k=(rnk+1):length(delbeta)
124-
delbeta[k] = -zero(T)
125-
end
126-
p.delbeta[1:rnk] = Rinv * Rinv' * view(delbeta, 1:rnk)
127-
invpermute!(delbeta, piv)
101+
= sqrtW * r
102+
103+
p.qr = pivoted_qr!(copy(p.scratchm1))
104+
rnk = rank(p.qr.R) # FIXME! Don't use svd for this
105+
R = UpperTriangular(view(parent(p.qr.R), 1:rnk, 1:rnk))
106+
permute!(p.delbeta, p.qr.p)
107+
for k = (rnk + 1):length(p.delbeta)
108+
p.delbeta[k] = -zero(T)
128109
end
110+
p.delbeta[1:rnk] = R \ (p.qr.Q'*r̃)[1:rnk]
111+
invpermute!(p.delbeta, p.qr.p)
112+
129113
return p
130114
end
131115

@@ -279,27 +263,25 @@ end
279263
LinearAlgebra.cholesky(p::SparsePredChol{T}) where {T} = copy(p.chol)
280264
LinearAlgebra.cholesky!(p::SparsePredChol{T}) where {T} = p.chol
281265

282-
function invqr(x::DensePredQR{T,<: QRCompactWY}) where T
283-
Q,R = qr(x.scratchm1)
284-
Rinv = inv(R)
266+
function invqr(p::DensePredQR{T,<: QRCompactWY}) where T
267+
Rinv = inv(p.qr.R)
285268
Rinv*Rinv'
286269
end
287270

288-
function invqr(x::DensePredQR{T,<: QRPivoted}) where T
289-
Q,R,pv = pivoted_qr!(copy(x.scratchm1))
290-
rnk = rank(R)
291-
p = length(x.delbeta)
292-
if rnk == p
293-
Rinv = inv(R)
271+
function invqr(p::DensePredQR{T,<: QRPivoted}) where T
272+
rnk = rank(p.qr.R)
273+
k = length(p.delbeta)
274+
if rnk == k
275+
Rinv = inv(p.qr.R)
294276
xinv = Rinv*Rinv'
295-
ipiv = invperm(pv)
277+
ipiv = invperm(p.qr.p)
296278
return xinv[ipiv, ipiv]
297279
else
298-
Rsub = R[1:rnk, 1:rnk]
280+
Rsub = UpperTriangular(view(p.qr.R, 1:rnk, 1:rnk))
299281
RsubInv = inv(Rsub)
300-
xinv = fill(convert(T, NaN), (p,p))
282+
xinv = fill(convert(T, NaN), (k, k))
301283
xinv[1:rnk, 1:rnk] = RsubInv*RsubInv'
302-
ipiv = invperm(pv)
284+
ipiv = invperm(p.qr.p)
303285
return xinv[ipiv, ipiv]
304286
end
305287
end

test/runtests.jl

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ end
177177
@test isa(m2p_dep_pos.pp.chol, CholeskyPivoted)
178178
@test isa(m2p_dep_pos_kw.pp.chol, CholeskyPivoted)
179179
elseif dmethod == :qr
180-
@test_throws RankDeficientException m2 = fit(LinearModel, Xmissingcell, ymissingcell;
181-
method = dmethod, dropcollinear=false)
180+
@test fit(LinearModel, Xmissingcell, ymissingcell;
181+
method = dmethod, dropcollinear=false) isa LinearModel
182182
@test isapprox(coef(m2p), [0.9772643585228962, 11.889730016918342, 3.027347397503282,
183183
3.9661379199401177, 5.079410103608539, 6.194461814118862,
184184
-2.9863884084219015, 7.930328728005132, 8.87999491860477,
@@ -2015,3 +2015,20 @@ end
20152015
# values. It doesn't care about links, offsets, etc. as long as the model matrix,
20162016
# vcov matrix and stderrors are well defined.
20172017
end
2018+
2019+
@testset "NIST - Filip. Issue 558" begin
2020+
fn = Downloads.download("https://www.itl.nist.gov/div898/strd/lls/data/LINKS/DATA/Filip.dat")
2021+
filip_estimates_df = CSV.read(fn, DataFrame; skipto = 31, limit = 11, header = ["parameter", "estimate", "se"], delim = " ", ignorerepeated = true)
2022+
filip_data_df = CSV.read(fn, DataFrame; skipto = 61, header = ["y", "x"], delim = " ", ignorerepeated = true)
2023+
X = [filip_data_df.x[i]^j for i in 1:length(filip_data_df.x), j in 0:10]
2024+
2025+
# No weights
2026+
f1 = lm(X, filip_data_df.y, dropcollinear = false, method = :qr)
2027+
@test coef(f1) filip_estimates_df.estimate rtol = 1e-7
2028+
@test stderror(f1) filip_estimates_df.se rtol = 1e-7
2029+
2030+
# Weights
2031+
f2 = lm(X, filip_data_df.y, dropcollinear = false, method = :qr, wts = ones(length(filip_data_df.y)))
2032+
@test coef(f2) filip_estimates_df.estimate rtol = 1e-7
2033+
@test stderror(f2) filip_estimates_df.se rtol = 1e-7
2034+
end

0 commit comments

Comments
 (0)