Skip to content

Commit 7e96b4a

Browse files
bring back MatrixBase, simplify lower_tri and upper_tri
1 parent 1103573 commit 7e96b4a

7 files changed

Lines changed: 173 additions & 114 deletions

File tree

inst/include/Rcpp/Vector.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include <Rcpp/vector/swap.h>
3939
#include <Rcpp/vector/Demangler.h>
4040

41+
#include <Rcpp/vector/MatrixBase.h>
4142
#include <Rcpp/vector/Matrix.h>
4243

4344
#endif

inst/include/Rcpp/sugar/matrix/lower_tri.h

Lines changed: 26 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,32 @@
22
#define Rcpp__sugar__lower_tri_h
33

44
namespace Rcpp{
5-
namespace sugar{
6-
7-
template <int RTYPE, bool LHS_NA, typename LHS_T>
8-
class LowerTri : public VectorBase<
9-
LGLSXP ,
10-
false ,
11-
LowerTri<RTYPE,LHS_NA,LHS_T>
12-
> {
13-
public:
14-
typedef Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T> LHS_TYPE ;
15-
16-
LowerTri( const LHS_TYPE& lhs, bool diag) :
17-
nr( lhs.nrow() ), nc( lhs.ncol() ),
18-
getter( diag ? (&LowerTri::get_diag_true) : (&LowerTri::get_diag_false) ){}
19-
20-
// inline int operator[]( int index ) const {
21-
// int i = Rcpp::internal::get_line( index, nr ) ;
22-
// int j = Rcpp::internal::get_column( index, nr, i ) ;
23-
// return get(i,j) ;
24-
// }
25-
inline int operator()( int i, int j ) const {
26-
return get(i,j) ;
27-
}
28-
29-
inline int size() const { return nr * nc ; }
30-
inline int nrow() const { return nr; }
31-
inline int ncol() const { return nc; }
32-
33-
private:
34-
int nr, nc ;
35-
typedef bool (LowerTri::*Method)(int,int) ;
36-
37-
Method getter ;
38-
inline bool get_diag_true( int i, int j ){
39-
return i <= j ;
40-
}
41-
inline bool get_diag_false( int i, int j ){
42-
return i < j ;
43-
}
44-
inline bool get( int i, int j){
45-
return (this->*getter)(i, j ) ;
46-
}
47-
48-
} ;
49-
50-
} // sugar
51-
52-
template <int RTYPE, bool LHS_NA, typename LHS_T>
53-
inline sugar::LowerTri<RTYPE,LHS_NA,LHS_T>
54-
lower_tri( const Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T>& lhs, bool diag = false){
55-
return sugar::LowerTri<RTYPE,LHS_NA,LHS_T>( lhs, diag ) ;
56-
}
5+
namespace sugar{
6+
7+
class LowerTri : public MatrixBase<LGLSXP,false,LowerTri> {
8+
public:
9+
LowerTri( int nr_, int nc_, bool diag) : nr(nr_), nc(nc_), keep_diag(diag){}
10+
11+
inline int operator()( int i, int j ) const {
12+
return keep_diag ? (i<=j) : (i<j) ;
13+
}
14+
15+
inline int size() const { return nr * nc ; }
16+
inline int nrow() const { return nr; }
17+
inline int ncol() const { return nc; }
18+
19+
private:
20+
int nr, nc ;
21+
bool keep_diag ;
22+
23+
} ;
24+
25+
} // sugar
26+
27+
template <int RTYPE, bool LHS_NA, typename LHS_T>
28+
inline sugar::LowerTri lower_tri( const Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T>& lhs, bool diag = false){
29+
return sugar::LowerTri( lhs.nrow(), lhs.ncol(), diag ) ;
30+
}
5731

5832
} // Rcpp
5933

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
#ifndef RCPP_SUGAR_MATRIX_FUNCTIONS_H
22
#define RCPP_SUGAR_MATRIX_FUNCTIONS_H
33

4-
#include <Rcpp/sugar/matrix/outer.h>
5-
#include <Rcpp/sugar/matrix/row.h>
6-
#include <Rcpp/sugar/matrix/col.h>
7-
#include <Rcpp/sugar/matrix/lower_tri.h>
84
#include <Rcpp/sugar/matrix/upper_tri.h>
9-
#include <Rcpp/sugar/matrix/diag.h>
10-
#include <Rcpp/sugar/matrix/as_vector.h>
5+
#include <Rcpp/sugar/matrix/lower_tri.h>
6+
7+
// #include <Rcpp/sugar/matrix/outer.h>
8+
// #include <Rcpp/sugar/matrix/row.h>
9+
// #include <Rcpp/sugar/matrix/col.h>
10+
// #include <Rcpp/sugar/matrix/diag.h>
11+
// #include <Rcpp/sugar/matrix/as_vector.h>
1112

1213
#endif

inst/include/Rcpp/sugar/matrix/upper_tri.h

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,34 @@
22
#define Rcpp__sugar__upper_tri_h
33

44
namespace Rcpp{
5-
namespace sugar{
6-
7-
template <int RTYPE, bool LHS_NA, typename LHS_T>
8-
class UpperTri : public VectorBase<
9-
LGLSXP ,
10-
false ,
11-
UpperTri<RTYPE,LHS_NA,LHS_T>
12-
> {
13-
public:
14-
typedef Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T> LHS_TYPE ;
15-
16-
UpperTri( const LHS_TYPE& lhs, bool diag) :
17-
nr( lhs.nrow() ), nc( lhs.ncol() ),
18-
getter( diag ? (&UpperTri::get_diag_true) : (&UpperTri::get_diag_false) ){}
19-
20-
inline int operator()( int i, int j ) const {
21-
return get(i,j) ;
22-
}
23-
24-
inline int size() const { return nr * nc ; }
25-
inline int nrow() const { return nr; }
26-
inline int ncol() const { return nc; }
27-
28-
private:
29-
int nr, nc ;
30-
typedef bool (UpperTri::*Method)(int,int) ;
31-
32-
Method getter ;
33-
inline bool get_diag_true( int i, int j ){
34-
return i >= j ;
35-
}
36-
inline bool get_diag_false( int i, int j ){
37-
return i > j ;
38-
}
39-
inline bool get( int i, int j){
40-
return (this->*getter)(i, j ) ;
41-
}
42-
43-
} ;
44-
45-
} // sugar
46-
47-
template <int RTYPE, bool LHS_NA, typename LHS_T>
48-
inline sugar::UpperTri<RTYPE,LHS_NA,LHS_T>
49-
upper_tri( const Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T>& lhs, bool diag = false){
50-
return sugar::UpperTri<RTYPE,LHS_NA,LHS_T>( lhs, diag ) ;
51-
}
5+
namespace sugar{
6+
7+
class UpperTri : public MatrixBase<LGLSXP,false,UpperTri> {
8+
public:
9+
10+
UpperTri( int nr_, int nc_, bool diag) : nr(nr_), nc(nc_), keep_diag(diag){}
11+
12+
inline int operator()( int i, int j ) const {
13+
return keep_diag ? (i>=j) : (i>j) ;
14+
}
15+
16+
inline int size() const { return nr * nc ; }
17+
inline int nrow() const { return nr; }
18+
inline int ncol() const { return nc; }
19+
20+
private:
21+
int nr, nc ;
22+
bool keep_diag ;
23+
24+
} ;
25+
26+
} // sugar
27+
28+
template <int RTYPE, bool LHS_NA, typename LHS_T>
29+
inline sugar::UpperTri
30+
upper_tri( const Rcpp::MatrixBase<RTYPE,LHS_NA,LHS_T>& lhs, bool diag = false){
31+
return sugar::UpperTri( lhs.nrow(), lhs.ncol(), diag ) ;
32+
}
5233

5334
} // Rcpp
5435

inst/include/Rcpp/sugar/sugar.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
#include <Rcpp/sugar/operators/operators.h>
88
#include <Rcpp/sugar/functions/functions.h>
99

10-
// #include <Rcpp/sugar/matrix/matrix_functions.h>
10+
#include <Rcpp/sugar/matrix/matrix_functions.h>
1111

1212
#endif

inst/include/Rcpp/vector/Matrix.h

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,77 @@
33

44
namespace Rcpp{
55

6+
template <int RTYPE, typename Mat>
7+
class MatrixColumn : public VectorBase<RTYPE,true,MatrixColumn<RTYPE,Mat>> {
8+
public:
9+
using iterator = typename Mat::iterator ;
10+
using Proxy = typename Mat::Proxy ;
11+
12+
MatrixColumn( Mat& mat, int i) : start(mat.begin() + i * mat.nrow()), n(mat.nrow()){}
13+
14+
inline int size() const { return n ;}
15+
inline Proxy operator[]( int i){ return *(start+i) ; }
16+
17+
private:
18+
typename Mat::iterator start ;
19+
int n ;
20+
};
21+
22+
template <int RTYPE, typename Mat>
23+
class const_MatrixColumn : public VectorBase<RTYPE,true,MatrixColumn<RTYPE,Mat>> {
24+
public:
25+
using const_iterator = typename Mat::const_iterator ;
26+
using const_Proxy = typename Mat::const_Proxy ;
27+
28+
const_MatrixColumn( const Mat& mat, int i) : start(mat.begin() + i * mat.nrow()), n(mat.nrow()){}
29+
30+
inline int size() const { return n ;}
31+
inline const_Proxy operator[]( int i){ return *(start+i) ; }
32+
33+
private:
34+
typename Mat::const_iterator start ;
35+
int n ;
36+
};
37+
38+
39+
template <int RTYPE, typename Mat>
40+
class MatrixRow : public VectorBase<RTYPE,true,MatrixRow<RTYPE,Mat>> {
41+
public:
42+
using iterator = typename Mat::iterator ;
43+
using Proxy = typename Mat::Proxy ;
44+
45+
MatrixRow( Mat& mat, int i) : start(mat.begin() + i), n(mat.ncol()), nr(mat.nrow()){}
46+
47+
inline int size() const { return n ;}
48+
inline Proxy operator[]( int i){ return *(start+i*nr) ; }
49+
50+
private:
51+
typename Mat::iterator start ;
52+
int n ;
53+
int nr ;
54+
};
55+
56+
template <int RTYPE, typename Mat>
57+
class const_MatrixRow : public VectorBase<RTYPE,true,MatrixRow<RTYPE,Mat>> {
58+
public:
59+
using const_iterator = typename Mat::const_iterator ;
60+
using const_Proxy = typename Mat::const_Proxy ;
61+
62+
const_MatrixRow( const Mat& mat, int i) : start(mat.begin() + i * mat.nrow()), n(mat.ncol()), nr(mat.nrow()){}
63+
64+
inline int size() const { return n ;}
65+
inline const_Proxy operator[]( int i){ return *(start+i*nr) ; }
66+
67+
private:
68+
typename Mat::const_iterator start ;
69+
int n ;
70+
int nr ;
71+
};
72+
73+
74+
675
template <int RTYPE, template <class> class StoragePolicy = PreserveStorage>
7-
class Matrix {
76+
class Matrix : public MatrixBase<RTYPE, true, Matrix<RTYPE,StoragePolicy> >{
877
private:
978
Vector<RTYPE,StoragePolicy> vec ;
1079
int* dims ;
@@ -14,7 +83,12 @@ namespace Rcpp{
1483
using const_Proxy = typename Vector<RTYPE,StoragePolicy>::const_Proxy ;
1584
using iterator = typename Vector<RTYPE,StoragePolicy>::iterator ;
1685
using const_iterator = typename Vector<RTYPE,StoragePolicy>::const_iterator ;
17-
86+
87+
using Column = MatrixColumn<RTYPE, Matrix> ;
88+
using const_Column = const_MatrixColumn<RTYPE, Matrix> ;
89+
using Row = MatrixRow<RTYPE, Matrix> ;
90+
using const_Row = const_MatrixRow<RTYPE, Matrix> ;
91+
1892
Matrix(int nr, int nc) : vec(nr, nc){
1993
set_dims(nr, nc) ;
2094
}
@@ -29,12 +103,9 @@ namespace Rcpp{
29103
dims = INTEGER(d) ;
30104
}
31105

32-
inline int nrow() const {
33-
return dims[0] ;
34-
}
35-
inline int ncol() const {
36-
return dims[1] ;
37-
}
106+
inline int nrow() const { return dims[0] ; }
107+
inline int ncol() const { return dims[1] ; }
108+
inline int size() const { return vec.size() ; }
38109

39110
inline iterator begin(){ return vec.begin() ; }
40111
inline iterator end(){ return vec.end(); }
@@ -45,6 +116,16 @@ namespace Rcpp{
45116
inline Proxy operator()(int i, int j) { return vec[offset(i,j)] ; }
46117
inline const_Proxy operator()(int i, int j) const { return vec[offset(i,j)] ; }
47118

119+
inline Column col(int i){ return Column(*this, i) ; }
120+
inline const_Column col(int i) const { return const_Column(*this, i) ; }
121+
inline Column operator()(internal::NamedPlaceHolder, int i){ return col(i); }
122+
inline const_Column operator()(internal::NamedPlaceHolder, int i) const { return col(i); }
123+
124+
inline Row row(int i){ return Row(*this, i) ; }
125+
inline const_Row row(int i) const { return const_Row(*this, i) ; }
126+
inline Column operator()(int i, internal::NamedPlaceHolder){ return row(i); }
127+
inline const_Column operator()(int i, internal::NamedPlaceHolder) const { return row(i); }
128+
48129
private:
49130

50131
inline void set_dims(int nr, int nc){
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef Rcpp__vector__MatrixBase_h
2+
#define Rcpp__vector__MatrixBase_h
3+
4+
namespace Rcpp{
5+
6+
template <int RTYPE, bool NA, typename Matrix>
7+
class MatrixBase {
8+
public:
9+
using value_type = typename traits::storage_type<RTYPE>::type ;
10+
11+
Matrix& get_ref() { return static_cast<Matrix&>(*this) ; }
12+
13+
inline value_type operator()(int i, int j){ return get_ref()(i,j) ; }
14+
inline int nrow() const { return get_ref().nrow() ; }
15+
inline int ncol() const { return get_ref().ncol() ; }
16+
inline int size() const { return get_ref().size() ; }
17+
} ;
18+
19+
}
20+
21+
#endif

0 commit comments

Comments
 (0)