Skip to content

Commit

Permalink
fix: add custom time format option
Browse files Browse the repository at this point in the history
  • Loading branch information
stebenz committed Mar 6, 2023
1 parent 59373a3 commit 71ad690
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pkg/provider/attribute_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (p *IdentityProvider) attributeQueryHandleFunc(w http.ResponseWriter, r *ht
queriedAttrs = append(queriedAttrs, queriedAttr)
}
}
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs)
response = makeAttributeQueryResponse(attrQuery.Id, p.GetEntityID(r.Context()), sp.GetEntityID(), attrs, queriedAttrs, p.timeFormat)
return nil
},
func() {
Expand Down
7 changes: 5 additions & 2 deletions pkg/provider/identityprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import (
"crypto/rsa"
"encoding/pem"
"fmt"
"html/template"
"io"
"net/http"
"reflect"
"html/template"
"time"

"github.com/zitadel/saml/pkg/provider/serviceprovider"
Expand Down Expand Up @@ -67,6 +67,8 @@ type IdentityProvider struct {

metadataEndpoint *Endpoint
endpoints *Endpoints

timeFormat string
}

type Endpoints struct {
Expand All @@ -77,7 +79,7 @@ type Endpoints struct {
attributeEndpoint Endpoint
}

func NewIdentityProvider(ctx context.Context, metadata Endpoint, conf *IdentityProviderConfig, storage IDPStorage) (*IdentityProvider, error) {
func NewIdentityProvider(metadata Endpoint, conf *IdentityProviderConfig, storage IDPStorage) (*IdentityProvider, error) {
postTemplate, err := template.New("post").Parse(postTemplate)
if err != nil {
return nil, err
Expand All @@ -95,6 +97,7 @@ func NewIdentityProvider(ctx context.Context, metadata Endpoint, conf *IdentityP
postTemplate: postTemplate,
logoutTemplate: logoutTemplate,
endpoints: endpointConfigToEndpoints(conf.Endpoints),
timeFormat: DefaultTimeFormat,
}

if conf.MetadataIDPConfig == nil {
Expand Down
3 changes: 1 addition & 2 deletions pkg/provider/identityprovider_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package provider

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
Expand Down Expand Up @@ -171,7 +170,7 @@ func TestIDP_certificateHandleFunc(t *testing.T) {
return
}

idp, err := NewIdentityProvider(context.Background(), endpoint, tt.args.config, mockStorage)
idp, err := NewIdentityProvider(endpoint, tt.args.config, mockStorage)
if (err != nil) != tt.res.err {
t.Errorf("NewIdentityProvider() error = %v", err.Error())
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/provider/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
authRequest, err := p.storage.AuthRequestByID(r.Context(), requestID)
if err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error()))
response.sendBackResponse(r, w, response.makeDeniedResponse(fmt.Errorf("failed to get request: %w", err).Error(), p.timeFormat))
return
}
response.RequestID = authRequest.GetAuthRequestID()
Expand Down Expand Up @@ -63,19 +63,19 @@ func (p *IdentityProvider) callbackHandleFunc(w http.ResponseWriter, r *http.Req
return
}

samlResponse := response.makeSuccessfulResponse(attrs)
samlResponse := response.makeSuccessfulResponse(attrs, p.timeFormat)

switch response.ProtocolBinding {
case PostBinding:
if err := createPostSignature(r.Context(), samlResponse, p); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error()))
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
case RedirectBinding:
if err := createRedirectSignature(r.Context(), samlResponse, p, response); err != nil {
logging.Error(err)
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error()))
response.sendBackResponse(r, w, response.makeResponderFailResponse(fmt.Errorf("failed to sign response: %w", err).Error(), p.timeFormat))
return
}
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/provider/login_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package provider

import (
"context"
"io/ioutil"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -245,7 +244,7 @@ func TestSSO_loginHandleFunc(t *testing.T) {
return
}

idp, err := NewIdentityProvider(context.Background(), endpoint, tt.args.config, mockStorage)
idp, err := NewIdentityProvider(endpoint, tt.args.config, mockStorage)
if (err != nil) != tt.res.err {
t.Errorf("NewIdentityProvider() error = %v", err.Error())
return
Expand Down
8 changes: 4 additions & 4 deletions pkg/provider/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return nil
},
func() {
response.sendBackLogoutResponse(w, response.makeUnsupportedlLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error()))
response.sendBackLogoutResponse(w, response.makeUnsupportedlLogoutResponse(fmt.Errorf("failed to decode request: %w", err).Error(), p.timeFormat))
},
)

Expand All @@ -71,7 +71,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
func() string { return logoutRequest.NotOnOrAfter },
),
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error()))
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to validate request: %w", err).Error(), p.timeFormat))
},
)

Expand All @@ -82,7 +82,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque
return err
},
func() {
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error()))
response.sendBackLogoutResponse(w, response.makeDeniedLogoutResponse(fmt.Errorf("failed to find registered serviceprovider: %w", err).Error(), p.timeFormat))
},
)

Expand All @@ -105,7 +105,7 @@ func (p *IdentityProvider) logoutHandleFunc(w http.ResponseWriter, r *http.Reque

response.sendBackLogoutResponse(
w,
response.makeSuccessfulLogoutResponse(),
response.makeSuccessfulLogoutResponse(p.timeFormat),
)
logging.Info(fmt.Sprintf("logout request for user %s", logoutRequest.NameID.Text))
}
Expand Down
15 changes: 9 additions & 6 deletions pkg/provider/logout_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"bytes"
"encoding/base64"
"encoding/xml"
"net/http"
"html/template"
"net/http"
"time"

"github.com/zitadel/saml/pkg/provider/xml/saml"
Expand Down Expand Up @@ -67,11 +67,11 @@ func (r *LogoutResponse) sendBackLogoutResponse(w http.ResponseWriter, resp *sam
}
}

func (r *LogoutResponse) makeSuccessfulLogoutResponse() *samlp.LogoutResponseType {
func (r *LogoutResponse) makeSuccessfulLogoutResponse(timeFormat string) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(DefaultTimeFormat),
time.Now().UTC().Format(timeFormat),
StatusCodeSuccess,
"",
getIssuer(r.Issuer),
Expand All @@ -80,11 +80,12 @@ func (r *LogoutResponse) makeSuccessfulLogoutResponse() *samlp.LogoutResponseTyp

func (r *LogoutResponse) makeUnsupportedlLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(DefaultTimeFormat),
time.Now().UTC().Format(timeFormat),
StatusCodeRequestUnsupported,
message,
getIssuer(r.Issuer),
Expand All @@ -93,11 +94,12 @@ func (r *LogoutResponse) makeUnsupportedlLogoutResponse(

func (r *LogoutResponse) makePartialLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(DefaultTimeFormat),
time.Now().UTC().Format(timeFormat),
StatusCodePartialLogout,
message,
getIssuer(r.Issuer),
Expand All @@ -106,11 +108,12 @@ func (r *LogoutResponse) makePartialLogoutResponse(

func (r *LogoutResponse) makeDeniedLogoutResponse(
message string,
timeFormat string,
) *samlp.LogoutResponseType {
return makeLogoutResponse(
r.RequestID,
r.LogoutURL,
time.Now().UTC().Format(DefaultTimeFormat),
time.Now().UTC().Format(timeFormat),
StatusCodeRequestDenied,
message,
getIssuer(r.Issuer),
Expand Down
12 changes: 10 additions & 2 deletions pkg/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
)

const (
DefaultTimeFormat = "2006-01-02T15:04:05.999999Z"
PostBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
RedirectBinding = "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
SOAPBinding = "urn:oasis:names:tc:SAML:2.0:bindings:SOAP"
Expand Down Expand Up @@ -79,10 +80,10 @@ type Provider struct {
conf *Config
issuerFromRequest IssuerFromRequest
identityProvider *IdentityProvider
timeFormat string
}

func NewProvider(
ctx context.Context,
storage Storage,
path string,
conf *Config,
Expand All @@ -94,7 +95,6 @@ func NewProvider(
}

idp, err := NewIdentityProvider(
ctx,
metadataEndpoint,
conf.IDPConfig,
storage,
Expand Down Expand Up @@ -246,3 +246,11 @@ func WithAllowInsecure() Option {
return nil
}
}

// WithCustomTimeFormat allows the use of a custom timeformat instead of the default
func WithCustomTimeFormat(timeFormat string) Option {
return func(p *Provider) error {
p.identityProvider.timeFormat = timeFormat
return nil
}
}
23 changes: 14 additions & 9 deletions pkg/provider/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
)

const (
DefaultTimeFormat = "2006-01-02T15:04:05.999999Z"
StatusCodeSuccess = "urn:oasis:names:tc:SAML:2.0:status:Success"
StatusCodeVersionMissmatch = "urn:oasis:names:tc:SAML:2.0:status:VersionMismatch"
StatusCodeAuthNFailed = "urn:oasis:names:tc:SAML:2.0:status:AuthnFailed"
Expand Down Expand Up @@ -99,9 +98,10 @@ func (r *Response) sendBackResponse(

func (r *Response) makeUnsupportedBindingResponse(
message string,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
return makeResponse(
NewID(),
r.RequestID,
Expand All @@ -115,9 +115,10 @@ func (r *Response) makeUnsupportedBindingResponse(

func (r *Response) makeResponderFailResponse(
message string,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
return makeResponse(
NewID(),
r.RequestID,
Expand All @@ -131,9 +132,10 @@ func (r *Response) makeResponderFailResponse(

func (r *Response) makeDeniedResponse(
message string,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
return makeResponse(
NewID(),
r.RequestID,
Expand All @@ -147,9 +149,10 @@ func (r *Response) makeDeniedResponse(

func (r *Response) makeFailedResponse(
message string,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
return makeResponse(
NewID(),
r.RequestID,
Expand All @@ -163,10 +166,11 @@ func (r *Response) makeFailedResponse(

func (r *Response) makeSuccessfulResponse(
attributes *Attributes,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
fiveFromNowStr := now.Add(5 * time.Minute).Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
fiveFromNowStr := now.Add(5 * time.Minute).Format(timeFormat)

return r.makeAssertionResponse(
nowStr,
Expand Down Expand Up @@ -200,12 +204,13 @@ func makeAttributeQueryResponse(
entityID string,
attributes *Attributes,
queriedAttrs []saml.AttributeType,
timeFormat string,
) *samlp.ResponseType {
now := time.Now().UTC()
nowStr := now.Format(DefaultTimeFormat)
nowStr := now.Format(timeFormat)
fiveMinutes, _ := time.ParseDuration("5m")
fiveFromNow := now.Add(fiveMinutes)
fiveFromNowStr := fiveFromNow.Format(DefaultTimeFormat)
fiveFromNowStr := fiveFromNow.Format(timeFormat)

providedAttrs := []*saml.AttributeType{}
attrsSaml := attributes.GetSAML()
Expand Down
Loading

0 comments on commit 71ad690

Please sign in to comment.