Skip to content

Commit deaacff

Browse files
authored
fix: early oidc refresh with fake idp tests (#22712) (cherry 2.31) (#22716)
Confirmed manually using this branch with 5min tokens (always refreshed) and 15min tokens (refreshed after 5min elapsed)
1 parent 2828d28 commit deaacff

File tree

3 files changed

+238
-19
lines changed

3 files changed

+238
-19
lines changed

coderd/provisionerdserver/provisionerdserver.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
564564
// The check `s.OIDCConfig != nil` is not as strict, since it can be an interface
565565
// pointing to a typed nil.
566566
if !reflect.ValueOf(s.OIDCConfig).IsNil() {
567-
workspaceOwnerOIDCAccessToken, err = obtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
567+
workspaceOwnerOIDCAccessToken, err = ObtainOIDCAccessToken(ctx, s.Logger, s.Database, s.OIDCConfig, owner.ID)
568568
if err != nil {
569569
return nil, failJob(fmt.Sprintf("obtain OIDC access token: %s", err))
570570
}
@@ -3075,15 +3075,15 @@ func deleteSessionTokenForUserAndWorkspace(ctx context.Context, db database.Stor
30753075
return nil
30763076
}
30773077

3078-
func shouldRefreshOIDCToken(link database.UserLink) bool {
3078+
func shouldRefreshOIDCToken(link database.UserLink) (bool, time.Time) {
30793079
if link.OAuthRefreshToken == "" {
30803080
// We cannot refresh even if we wanted to
3081-
return false
3081+
return false, link.OAuthExpiry
30823082
}
30833083

30843084
if link.OAuthExpiry.IsZero() {
30853085
// 0 expire means the token never expires, so we shouldn't refresh
3086-
return false
3086+
return false, link.OAuthExpiry
30873087
}
30883088

30893089
// This handles an edge case where the token is about to expire. A workspace
@@ -3093,15 +3093,19 @@ func shouldRefreshOIDCToken(link database.UserLink) bool {
30933093
//
30943094
// If an OIDC provider issues short-lived tokens less than our defined period,
30953095
// the token will always be refreshed on every workspace build.
3096-
assumeExpiredAt := dbtime.Now().Add(-1 * time.Minute * 10)
3096+
//
3097+
// By setting the expiration backwards, we are effectively shortening the
3098+
// time a token can be alive for by 10 minutes.
3099+
// Note: This is how it is done in the oauth2 package's own token refreshing logic.
3100+
expiresAt := link.OAuthExpiry.Add(-time.Minute * 10)
30973101

30983102
// Return if the token is assumed to be expired.
3099-
return link.OAuthExpiry.Before(assumeExpiredAt)
3103+
return expiresAt.Before(dbtime.Now()), expiresAt
31003104
}
31013105

3102-
// obtainOIDCAccessToken returns a valid OpenID Connect access token
3106+
// ObtainOIDCAccessToken returns a valid OpenID Connect access token
31033107
// for the user if it's able to obtain one, otherwise it returns an empty string.
3104-
func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
3108+
func ObtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.Store, oidcConfig promoauth.OAuth2Config, userID uuid.UUID) (string, error) {
31053109
link, err := db.GetUserLinkByUserIDLoginType(ctx, database.GetUserLinkByUserIDLoginTypeParams{
31063110
UserID: userID,
31073111
LoginType: database.LoginTypeOIDC,
@@ -3113,11 +3117,13 @@ func obtainOIDCAccessToken(ctx context.Context, logger slog.Logger, db database.
31133117
return "", xerrors.Errorf("get owner oidc link: %w", err)
31143118
}
31153119

3116-
if shouldRefreshOIDCToken(link) {
3120+
if shouldRefresh, expiresAt := shouldRefreshOIDCToken(link); shouldRefresh {
31173121
token, err := oidcConfig.TokenSource(ctx, &oauth2.Token{
31183122
AccessToken: link.OAuthAccessToken,
31193123
RefreshToken: link.OAuthRefreshToken,
3120-
Expiry: link.OAuthExpiry,
3124+
// Use the expiresAt returned by shouldRefreshOIDCToken.
3125+
// It will force a refresh with an expired time.
3126+
Expiry: expiresAt,
31213127
}).Token()
31223128
if err != nil {
31233129
// If OIDC fails to refresh, we return an empty string and don't fail.

coderd/provisionerdserver/provisionerdserver_internal_test.go

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,51 @@ func TestShouldRefreshOIDCToken(t *testing.T) {
3636
want: false,
3737
},
3838
{
39-
name: "ExpiredBeyondAssumedWindow",
39+
name: "LongExpired",
4040
link: database.UserLink{
4141
OAuthRefreshToken: "refresh",
42-
OAuthExpiry: now.Add(-20 * time.Minute),
42+
OAuthExpiry: now.Add(-1 * time.Hour),
4343
},
4444
want: true,
4545
},
4646
{
47-
name: "ExpiredWithinAssumedWindow",
47+
// Edge being "+/- 10 minutes"
48+
name: "EdgeExpired",
4849
link: database.UserLink{
4950
OAuthRefreshToken: "refresh",
50-
OAuthExpiry: now.Add(-5 * time.Minute),
51+
OAuthExpiry: now.Add(-1 * time.Minute * 10),
52+
},
53+
want: true,
54+
},
55+
{
56+
name: "Expired",
57+
link: database.UserLink{
58+
OAuthRefreshToken: "refresh",
59+
OAuthExpiry: now.Add(-1 * time.Minute),
60+
},
61+
want: true,
62+
},
63+
{
64+
name: "SoonToBeExpired",
65+
link: database.UserLink{
66+
OAuthRefreshToken: "refresh",
67+
OAuthExpiry: now.Add(5 * time.Minute),
68+
},
69+
want: true,
70+
},
71+
{
72+
name: "SoonToBeExpiredEdge",
73+
link: database.UserLink{
74+
OAuthRefreshToken: "refresh",
75+
OAuthExpiry: now.Add(9 * time.Minute),
76+
},
77+
want: true,
78+
},
79+
{
80+
name: "AfterEdge",
81+
link: database.UserLink{
82+
OAuthRefreshToken: "refresh",
83+
OAuthExpiry: now.Add(11 * time.Minute),
5184
},
5285
want: false,
5386
},
@@ -59,13 +92,22 @@ func TestShouldRefreshOIDCToken(t *testing.T) {
5992
},
6093
want: false,
6194
},
95+
{
96+
name: "NotEvenCloseExpired",
97+
link: database.UserLink{
98+
OAuthRefreshToken: "refresh",
99+
OAuthExpiry: now.Add(time.Hour * 24),
100+
},
101+
want: false,
102+
},
62103
}
63104

64105
for _, tc := range testCases {
65106
tc := tc
66107
t.Run(tc.name, func(t *testing.T) {
67108
t.Parallel()
68-
require.Equal(t, tc.want, shouldRefreshOIDCToken(tc.link))
109+
shouldRefresh, _ := shouldRefreshOIDCToken(tc.link)
110+
require.Equal(t, tc.want, shouldRefresh)
69111
})
70112
}
71113
}
@@ -76,7 +118,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
76118
t.Run("NoToken", func(t *testing.T) {
77119
t.Parallel()
78120
db, _ := dbtestutil.NewDB(t)
79-
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
121+
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, nil, uuid.Nil)
80122
require.NoError(t, err)
81123
})
82124
t.Run("InvalidConfig", func(t *testing.T) {
@@ -89,7 +131,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
89131
LoginType: database.LoginTypeOIDC,
90132
OAuthExpiry: dbtime.Now().Add(-time.Hour),
91133
})
92-
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
134+
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
93135
require.NoError(t, err)
94136
})
95137
t.Run("MissingLink", func(t *testing.T) {
@@ -98,7 +140,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
98140
user := dbgen.User(t, db, database.User{
99141
LoginType: database.LoginTypeOIDC,
100142
})
101-
tok, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
143+
tok, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &oauth2.Config{}, user.ID)
102144
require.Empty(t, tok)
103145
require.NoError(t, err)
104146
})
@@ -111,7 +153,7 @@ func TestObtainOIDCAccessToken(t *testing.T) {
111153
LoginType: database.LoginTypeOIDC,
112154
OAuthExpiry: dbtime.Now().Add(-time.Hour),
113155
})
114-
_, err := obtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
156+
_, err := ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, &testutil.OAuth2Config{
115157
Token: &oauth2.Token{
116158
AccessToken: "token",
117159
},

coderd/provisionerdserver/provisionerdserver_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"testing"
1616
"time"
1717

18+
"github.com/golang-jwt/jwt/v4"
1819
"github.com/google/uuid"
1920
"github.com/prometheus/client_golang/prometheus"
2021
"github.com/stretchr/testify/assert"
@@ -30,6 +31,7 @@ import (
3031
"github.com/coder/coder/v2/coderd"
3132
"github.com/coder/coder/v2/coderd/audit"
3233
"github.com/coder/coder/v2/coderd/coderdtest"
34+
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
3335
"github.com/coder/coder/v2/coderd/database"
3436
"github.com/coder/coder/v2/coderd/database/dbauthz"
3537
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -58,6 +60,175 @@ import (
5860
"github.com/coder/serpent"
5961
)
6062

63+
// TestTokenIsRefreshedEarly creates a fake OIDC IDP that sets expiration times
64+
// of the token to values that are "near expiration". Expiration being 10minutes
65+
// earlier than it needs to be. The `ObtainOIDCAccessToken` should refresh these
66+
// tokens early.
67+
func TestTokenIsRefreshedEarly(t *testing.T) {
68+
t.Parallel()
69+
70+
t.Run("WithCoderd", func(t *testing.T) {
71+
t.Parallel()
72+
tokenRefreshCount := 0
73+
fake := oidctest.NewFakeIDP(t,
74+
oidctest.WithServing(),
75+
oidctest.WithDefaultExpire(time.Minute*8),
76+
oidctest.WithRefresh(func(email string) error {
77+
tokenRefreshCount++
78+
return nil
79+
}),
80+
)
81+
cfg := fake.OIDCConfig(t, nil, func(cfg *coderd.OIDCConfig) {
82+
cfg.AllowSignups = true
83+
})
84+
db, ps := dbtestutil.NewDB(t)
85+
owner := coderdtest.New(t, &coderdtest.Options{
86+
OIDCConfig: cfg,
87+
IncludeProvisionerDaemon: true,
88+
Database: db,
89+
Pubsub: ps,
90+
})
91+
first := coderdtest.CreateFirstUser(t, owner)
92+
version := coderdtest.CreateTemplateVersion(t, owner, first.OrganizationID, nil)
93+
coderdtest.AwaitTemplateVersionJobCompleted(t, owner, version.ID)
94+
template := coderdtest.CreateTemplate(t, owner, first.OrganizationID, version.ID)
95+
96+
// Setup an OIDC user.
97+
client, _ := fake.Login(t, owner, jwt.MapClaims{
98+
"email": "[email protected]",
99+
"email_verified": true,
100+
"sub": uuid.NewString(),
101+
})
102+
103+
// Creating a workspace should refresh the oidc early.
104+
tokenRefreshCount = 0
105+
wrk := coderdtest.CreateWorkspace(t, client, template.ID)
106+
coderdtest.AwaitWorkspaceBuildJobCompleted(t, client, wrk.LatestBuild.ID)
107+
require.Equal(t, 1, tokenRefreshCount)
108+
})
109+
}
110+
111+
//nolint:tparallel,paralleltest // Sub tests need to run sequentially.
112+
func TestTokenIsRefreshedEarlyWithoutCoderd(t *testing.T) {
113+
t.Parallel()
114+
tokenRefreshCount := 0
115+
fake := oidctest.NewFakeIDP(t,
116+
oidctest.WithServing(),
117+
oidctest.WithDefaultExpire(time.Minute*8),
118+
oidctest.WithRefresh(func(email string) error {
119+
tokenRefreshCount++
120+
return nil
121+
}),
122+
)
123+
cfg := fake.OIDCConfig(t, nil)
124+
125+
// Fetch a valid token from the fake OIDC provider
126+
token, err := fake.GenerateAuthenticatedToken(jwt.MapClaims{
127+
"email": "[email protected]",
128+
"email_verified": true,
129+
"sub": uuid.NewString(),
130+
})
131+
require.NoError(t, err)
132+
133+
db, _ := dbtestutil.NewDB(t)
134+
user := dbgen.User(t, db, database.User{})
135+
dbgen.UserLink(t, db, database.UserLink{
136+
UserID: user.ID,
137+
LoginType: database.LoginTypeOIDC,
138+
LinkedID: "foo",
139+
OAuthAccessToken: token.AccessToken,
140+
OAuthRefreshToken: token.RefreshToken,
141+
// The oauth expiry does not really matter, since each test will manually control
142+
// this value.
143+
OAuthExpiry: dbtime.Now().Add(time.Hour),
144+
})
145+
146+
setLinkExpiration := func(t *testing.T, exp time.Time) database.UserLink {
147+
ctx := testutil.Context(t, testutil.WaitShort)
148+
links, err := db.GetUserLinksByUserID(ctx, user.ID)
149+
require.NoError(t, err)
150+
require.Len(t, links, 1)
151+
link := links[0]
152+
153+
newLink, err := db.UpdateUserLink(ctx, database.UpdateUserLinkParams{
154+
OAuthAccessToken: link.OAuthAccessToken,
155+
OAuthAccessTokenKeyID: link.OAuthAccessTokenKeyID,
156+
OAuthRefreshToken: link.OAuthRefreshToken,
157+
OAuthRefreshTokenKeyID: link.OAuthRefreshTokenKeyID,
158+
OAuthExpiry: exp,
159+
Claims: link.Claims,
160+
UserID: link.UserID,
161+
LoginType: link.LoginType,
162+
})
163+
require.NoError(t, err)
164+
return newLink
165+
}
166+
167+
for _, c := range []struct {
168+
name string
169+
// expires is a function to return a more up to date "now".
170+
// Because the oauth library is calling `time.Now()`, we cannot use
171+
// mocked clocks.
172+
expires func() time.Time
173+
refreshExpected bool
174+
}{
175+
{
176+
name: "ZeroExpiry",
177+
expires: func() time.Time { return time.Time{} },
178+
refreshExpected: false,
179+
},
180+
{
181+
name: "LongExpired",
182+
expires: func() time.Time { return dbtime.Now().Add(-time.Hour) },
183+
refreshExpected: true,
184+
},
185+
{
186+
name: "EdgeExpired",
187+
expires: func() time.Time { return dbtime.Now().Add(-time.Minute * 10) },
188+
refreshExpected: true,
189+
},
190+
{
191+
name: "RecentExpired",
192+
expires: func() time.Time { return dbtime.Now().Add(-time.Second * -1) },
193+
refreshExpected: true,
194+
},
195+
196+
{
197+
name: "Future",
198+
expires: func() time.Time { return dbtime.Now().Add(time.Hour) },
199+
refreshExpected: false,
200+
},
201+
{
202+
name: "FutureWithinRefreshWindow",
203+
expires: func() time.Time { return dbtime.Now().Add(time.Minute * 8) },
204+
refreshExpected: true,
205+
},
206+
} {
207+
t.Run(c.name, func(t *testing.T) {
208+
ctx := testutil.Context(t, testutil.WaitShort)
209+
oldLink := setLinkExpiration(t, c.expires())
210+
tokenRefreshCount = 0
211+
_, err := provisionerdserver.ObtainOIDCAccessToken(ctx, testutil.Logger(t), db, cfg, user.ID)
212+
require.NoError(t, err)
213+
links, err := db.GetUserLinksByUserID(ctx, user.ID)
214+
require.NoError(t, err)
215+
require.Len(t, links, 1)
216+
newLink := links[0]
217+
218+
if c.refreshExpected {
219+
require.Equal(t, 1, tokenRefreshCount)
220+
221+
require.NotEqual(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
222+
require.NotEqual(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
223+
} else {
224+
require.Equal(t, 0, tokenRefreshCount)
225+
require.Equal(t, oldLink.OAuthAccessToken, newLink.OAuthAccessToken)
226+
require.Equal(t, oldLink.OAuthRefreshToken, newLink.OAuthRefreshToken)
227+
}
228+
})
229+
}
230+
}
231+
61232
func testTemplateScheduleStore() *atomic.Pointer[schedule.TemplateScheduleStore] {
62233
poitr := &atomic.Pointer[schedule.TemplateScheduleStore]{}
63234
store := schedule.NewAGPLTemplateScheduleStore()

0 commit comments

Comments
 (0)