Skip to content

Commit

Permalink
Add support for :0 port
Browse files Browse the repository at this point in the history
  • Loading branch information
System-Glitch committed Oct 17, 2023
1 parent 125bb2f commit dc304ba
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 19 deletions.
2 changes: 1 addition & 1 deletion router.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func (r *Router) StatusHandler(handler StatusHandler, status int, additionalStat
// ServeHTTP dispatches the handler registered in the matched route.
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.URL.Scheme != "" && req.URL.Scheme != "http" {
address := getProxyAddress(r.server.config) + req.URL.Path
address := r.server.getProxyAddress(r.server.config) + req.URL.Path
query := req.URL.Query()
if len(query) != 0 {
address += "?" + query.Encode()
Expand Down
36 changes: 25 additions & 11 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Server struct {
startupHooks []func(*Server)
shutdownHooks []func(*Server)

port int

state uint32 // 0 -> created, 1 -> preparing, 2 -> ready, 3 -> stopped
}

Expand All @@ -72,7 +74,8 @@ func NewWithConfig(cfg *config.Config) (*Server, error) { // TODO with options?
return nil, err
}

host := cfg.GetString("server.host") + ":" + strconv.Itoa(cfg.GetInt("server.port"))
port := cfg.GetInt("server.port")
host := cfg.GetString("server.host") + ":" + strconv.Itoa(port)

server := &Server{
server: &http.Server{
Expand All @@ -87,11 +90,11 @@ func NewWithConfig(cfg *config.Config) (*Server, error) { // TODO with options?
stopChannel: make(chan struct{}, 1),
startupHooks: []func(*Server){},
shutdownHooks: []func(*Server){},
host: host,
baseURL: getAddress(cfg),
proxyBaseURL: getProxyAddress(cfg),
host: cfg.GetString("server.host"),
port: port, // TODO document using 0 as port will auto assign an available port
Logger: slogger,
}
server.refreshURLs()
server.server.ErrorLog = log.New(&errLogWriter{server: server}, "", 0)

if cfg.GetString("database.connection") != "none" {
Expand All @@ -107,9 +110,8 @@ func NewWithConfig(cfg *config.Config) (*Server, error) { // TODO with options?
return server, nil
}

func getAddress(cfg *config.Config) string {
port := cfg.GetInt("server.port")
shouldShowPort := port != 80
func (s *Server) getAddress(cfg *config.Config) string {
shouldShowPort := s.port != 80
host := cfg.GetString("server.domain")
if len(host) == 0 {
host = cfg.GetString("server.host")
Expand All @@ -119,15 +121,15 @@ func getAddress(cfg *config.Config) string {
}

if shouldShowPort {
host += ":" + strconv.Itoa(port)
host += ":" + strconv.Itoa(s.port)
}

return "http://" + host
}

func getProxyAddress(cfg *config.Config) string {
func (s *Server) getProxyAddress(cfg *config.Config) string {
if !cfg.Has("server.proxy.host") {
return getAddress(cfg)
return s.getAddress(cfg)
}

var shouldShowPort bool
Expand All @@ -146,6 +148,11 @@ func getProxyAddress(cfg *config.Config) string {
return proto + "://" + host + cfg.GetString("server.proxy.base")
}

func (s *Server) refreshURLs() {
s.baseURL = s.getAddress(s.config)
s.proxyBaseURL = s.getProxyAddress(s.config)
}

// Service returns the service identified by the given name.
// Panics if no service could be found with the given name.
func (s *Server) Service(name string) Service {
Expand All @@ -172,7 +179,12 @@ func (s *Server) RegisterService(service Service) {

// Host returns the hostname and port the server is running on.
func (s *Server) Host() string {
return s.host
return s.host + ":" + strconv.Itoa(s.port)
}

// Port returns the port the server is running on.
func (s *Server) Port() int {
return s.port
}

// BaseURL returns the base URL of your application.
Expand Down Expand Up @@ -324,6 +336,8 @@ func (s *Server) Start() error {
if err != nil {
return errors.New(err)
}
s.port = ln.Addr().(*net.TCPAddr).Port
s.refreshURLs()
defer func() {
for _, hook := range s.shutdownHooks {
hook(s)
Expand Down
73 changes: 67 additions & 6 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,21 @@ func TestServer(t *testing.T) {
t.Run("0.0.0.0", func(t *testing.T) {
cfg := config.LoadDefault()
cfg.Set("server.host", "0.0.0.0")
assert.Equal(t, "http://127.0.0.1:8080", getAddress(cfg))
cfg.Set("server.port", 8080)
server := &Server{config: cfg, port: 8080}
assert.Equal(t, "http://127.0.0.1:8080", server.getAddress(cfg))
})
t.Run("hide_port", func(t *testing.T) {
cfg := config.LoadDefault()
cfg.Set("server.port", 80)
assert.Equal(t, "http://127.0.0.1", getAddress(cfg))
server := &Server{config: cfg, port: 80}
assert.Equal(t, "http://127.0.0.1", server.getAddress(cfg))
})
t.Run("domain", func(t *testing.T) {
cfg := config.LoadDefault()
cfg.Set("server.domain", "example.org")
assert.Equal(t, "http://example.org:8080", getAddress(cfg))
server := &Server{config: cfg, port: 1234}
assert.Equal(t, "http://example.org:1234", server.getAddress(cfg))
})
})

Expand All @@ -147,7 +151,8 @@ func TestServer(t *testing.T) {
cfg.Set("server.proxy.protocol", "https")
cfg.Set("server.proxy.port", 1234)
cfg.Set("server.proxy.base", "/base")
assert.Equal(t, "https://proxy.example.org:1234/base", getProxyAddress(cfg))
server := &Server{config: cfg, port: 1234}
assert.Equal(t, "https://proxy.example.org:1234/base", server.getProxyAddress(cfg))
})

t.Run("hide_port", func(t *testing.T) {
Expand All @@ -156,14 +161,16 @@ func TestServer(t *testing.T) {
cfg.Set("server.proxy.protocol", "https")
cfg.Set("server.proxy.port", 443)
cfg.Set("server.proxy.base", "/base")
assert.Equal(t, "https://proxy.example.org/base", getProxyAddress(cfg))
server := &Server{config: cfg, port: 443}
assert.Equal(t, "https://proxy.example.org/base", server.getProxyAddress(cfg))

cfg = config.LoadDefault()
cfg.Set("server.proxy.host", "proxy.example.org")
cfg.Set("server.proxy.protocol", "http")
cfg.Set("server.proxy.port", 80)
cfg.Set("server.proxy.base", "/base")
assert.Equal(t, "http://proxy.example.org/base", getProxyAddress(cfg))
server = &Server{config: cfg, port: 80}
assert.Equal(t, "http://proxy.example.org/base", server.getProxyAddress(cfg))
})
})

Expand Down Expand Up @@ -201,6 +208,7 @@ func TestServer(t *testing.T) {
}

assert.Equal(t, "127.0.0.1:8080", server.Host())
assert.Equal(t, 8080, server.Port())
assert.Equal(t, "http://127.0.0.1:8080", server.BaseURL())
assert.Equal(t, "http://127.0.0.1:8080", server.ProxyBaseURL())
assert.False(t, server.IsReady())
Expand Down Expand Up @@ -324,6 +332,59 @@ func TestServer(t *testing.T) {
assert.Equal(t, uint32(3), atomic.LoadUint32(&server.state))
})

t.Run("StartWithAutoPort", func(t *testing.T) {
cfg := config.LoadDefault()
cfg.Set("server.port", 0)
server, err := NewWithConfig(cfg)
if !assert.NoError(t, err) {
return
}

startupHookExecuted := false
wg := sync.WaitGroup{}
wg.Add(2)

server.RegisterStartupHook(func(s *Server) {
// Should be executed when the server is ready
startupHookExecuted = true

assert.True(t, server.IsReady())
assert.NotEqual(t, 0, s.Port())

res, err := http.Get(s.BaseURL())
if !assert.NoError(t, err) {
return
}
respBody, err := io.ReadAll(res.Body)
if !assert.NoError(t, err) {
return
}
_ = res.Body.Close()
assert.Equal(t, []byte("hello world"), respBody)

// Stop the server, goroutine should return
server.Stop()
wg.Done()
})

server.RegisterRoutes(func(s *Server, router *Router) {
router.Get("/", func(r *Response, _ *Request) {
r.String(http.StatusOK, "hello world")
}).Name("base")
})

go func() {
err := server.Start()
assert.Nil(t, err)
wg.Done()
}()

wg.Wait()
assert.True(t, startupHookExecuted)
assert.False(t, server.IsReady())
assert.Equal(t, uint32(3), atomic.LoadUint32(&server.state))
})

t.Run("Start_already_running", func(t *testing.T) {
server, err := NewWithConfig(config.LoadDefault())
if !assert.NoError(t, err) {
Expand Down
2 changes: 1 addition & 1 deletion websocket/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (a *adapter) onError(w http.ResponseWriter, _ *http.Request, status int, re

func (a *adapter) getCheckOriginFunc() func(r *http.Request) bool {
if a.checkOrigin != nil {
return func(r *http.Request) bool {
return func(_ *http.Request) bool {
return a.checkOrigin(a.request)
}
}
Expand Down

0 comments on commit dc304ba

Please sign in to comment.