forked from juju/juju
-
Notifications
You must be signed in to change notification settings - Fork 0
/
open.go
270 lines (237 loc) · 7.78 KB
/
open.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
// Copyright 2014 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.
package mongo
import (
"crypto/tls"
"crypto/x509"
stderrors "errors"
"fmt"
"net"
"strings"
"time"
"github.com/juju/errors"
"github.com/juju/http/v2"
"github.com/juju/mgo/v2"
"github.com/juju/names/v4"
"github.com/juju/utils/v2/cert"
)
// SocketTimeout should be long enough that even a slow mongo server
// will respond in that length of time, and must also be long enough
// to allow for completion of heavyweight queries.
//
// Note: 1 minute is mgo's default socket timeout value.
//
// Also note: We have observed mongodb occasionally getting "stuck"
// for over 30s in the field.
const SocketTimeout = time.Minute
// defaultDialTimeout should be representative of the upper bound of
// time taken to dial a mongo server from within the same
// cloud/private network.
const defaultDialTimeout = 30 * time.Second
// DialOpts holds configuration parameters that control the
// Dialing behavior when connecting to a controller.
type DialOpts struct {
// Timeout is the amount of time to wait contacting
// a controller.
Timeout time.Duration
// SocketTimeout is the amount of time to wait for a
// non-responding socket to the database before it is forcefully
// closed. If this is zero, the value of the SocketTimeout const
// will be used.
SocketTimeout time.Duration
// Direct informs whether to establish connections only with the
// specified seed servers, or to obtain information for the whole
// cluster and establish connections with further servers too.
Direct bool
// PostDial, if non-nil, is called by DialWithInfo with the
// mgo.Session after a successful dial but before DialWithInfo
// returns to its caller.
PostDial func(*mgo.Session) error
// PostDialServer, if non-nil, is called by DialWithInfo after
// dialing a MongoDB server connection, successfully or not.
// The address dialed and amount of time taken are included,
// as well as the error if any.
PostDialServer func(addr string, _ time.Duration, _ error)
// PoolLimit defines the per-server socket pool limit
PoolLimit int
}
// DefaultDialOpts returns a DialOpts representing the default
// parameters for contacting a controller.
//
// NOTE(axw) these options are inappropriate for tests in CI,
// as CI tends to run on machines with slow I/O (or thrashed
// I/O with limited IOPs). For tests, use mongotest.DialOpts().
func DefaultDialOpts() DialOpts {
return DialOpts{
Timeout: defaultDialTimeout,
SocketTimeout: SocketTimeout,
}
}
// Info encapsulates information about cluster of
// mongo servers and can be used to make a
// connection to that cluster.
type Info struct {
// Addrs gives the addresses of the MongoDB servers for the state.
// Each address should be in the form address:port.
Addrs []string
// CACert holds the CA certificate that will be used
// to validate the controller's certificate, in PEM format.
CACert string
// DisableTLS controls whether the connection to MongoDB servers
// is made using TLS (the default), or not.
DisableTLS bool
}
// MongoInfo encapsulates information about cluster of
// servers holding juju state and can be used to make a
// connection to that cluster.
type MongoInfo struct {
// mongo.Info contains the addresses and cert of the mongo cluster.
Info
// Tag holds the name of the entity that is connecting.
// It should be nil when connecting as an administrator.
Tag names.Tag
// Password holds the password for the connecting entity.
Password string
}
// DialInfo returns information on how to dial
// the state's mongo server with the given info
// and dial options.
func DialInfo(info Info, opts DialOpts) (*mgo.DialInfo, error) {
if len(info.Addrs) == 0 {
return nil, stderrors.New("no mongo addresses")
}
var tlsConfig *tls.Config
if !info.DisableTLS {
if len(info.CACert) == 0 {
return nil, stderrors.New("missing CA certificate")
}
xcert, err := cert.ParseCert(info.CACert)
if err != nil {
return nil, fmt.Errorf("cannot parse CA certificate: %v", err)
}
pool := x509.NewCertPool()
pool.AddCert(xcert)
tlsConfig = http.SecureTLSConfig()
tlsConfig.RootCAs = pool
tlsConfig.ServerName = "juju-mongodb"
// TODO(natefinch): revisit this when are full-time on mongo 3.
// We have to add non-ECDHE suites because mongo doesn't support ECDHE.
moreSuites := []uint16{
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
}
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, moreSuites...)
}
dial := func(server *mgo.ServerAddr) (_ net.Conn, err error) {
if opts.PostDialServer != nil {
before := time.Now()
defer func() {
taken := time.Now().Sub(before)
opts.PostDialServer(server.String(), taken, err)
}()
}
addr := server.TCPAddr().String()
c, err := net.DialTimeout("tcp", addr, opts.Timeout)
if err != nil {
logger.Debugf("mongodb connection failed, will retry: %v", err)
return nil, err
}
if tlsConfig != nil {
cc := tls.Client(c, tlsConfig)
if err := cc.Handshake(); err != nil {
logger.Warningf("TLS handshake failed: %v", err)
if err := c.Close(); err != nil {
logger.Warningf("failed to close connection: %v", err)
}
return nil, err
}
c = cc
}
logger.Debugf("dialed mongodb server at %q", addr)
return c, nil
}
return &mgo.DialInfo{
Addrs: info.Addrs,
Timeout: opts.Timeout,
DialServer: dial,
Direct: opts.Direct,
PoolLimit: opts.PoolLimit,
}, nil
}
// DialWithInfo establishes a new session to the cluster identified by info,
// with the specified options. If either Tag or Password are specified, then
// a Login call on the admin database will be made.
func DialWithInfo(info MongoInfo, opts DialOpts) (*mgo.Session, error) {
if opts.Timeout == 0 {
return nil, errors.New("a non-zero Timeout must be specified")
}
dialInfo, err := DialInfo(info.Info, opts)
if err != nil {
return nil, err
}
session, err := mgo.DialWithInfo(dialInfo)
if err != nil {
return nil, err
}
if opts.SocketTimeout == 0 {
opts.SocketTimeout = SocketTimeout
}
session.SetSocketTimeout(opts.SocketTimeout)
if opts.PostDial != nil {
if err := opts.PostDial(session); err != nil {
session.Close()
return nil, errors.Annotate(err, "PostDial failed")
}
}
if info.Tag != nil || info.Password != "" {
user := AdminUser
if info.Tag != nil {
user = info.Tag.String()
}
if err := Login(session, user, info.Password); err != nil {
session.Close()
return nil, errors.Trace(err)
}
}
return session, nil
}
// Login logs in to the mongodb admin database.
func Login(session *mgo.Session, user, password string) error {
admin := session.DB("admin")
if err := admin.Login(user, password); err != nil {
return MaybeUnauthorizedf(err, "cannot log in to admin database as %q", user)
}
return nil
}
// MaybeUnauthorizedf checks if the cause of the given error is a Mongo
// authorization error, and if so, wraps the error with errors.Unauthorizedf.
func MaybeUnauthorizedf(err error, message string, args ...interface{}) error {
if isUnauthorized(errors.Cause(err)) {
err = errors.Unauthorizedf("unauthorized mongo access: %s", err)
}
return errors.Annotatef(err, message, args...)
}
func isUnauthorized(err error) bool {
if err == nil {
return false
}
// Some unauthorized access errors have no error code,
// just a simple error string; and some do have error codes
// but are not of consistent types (LastError/QueryError).
for _, prefix := range []string{
"auth fail",
"not authorized",
"server returned error on SASL authentication step: Authentication failed.",
} {
if strings.HasPrefix(err.Error(), prefix) {
return true
}
}
if err, ok := err.(*mgo.QueryError); ok {
return err.Code == 10057 ||
err.Code == 13 ||
err.Message == "need to login" ||
err.Message == "unauthorized"
}
return false
}