Skip to content

Commit ff31f2b

Browse files
authored
Merge pull request antoniomika#55 from antoniomika/fix_no_tty
Fixed connection close during a notty session
2 parents 9226761 + 6761b05 commit ff31f2b

File tree

4 files changed

+70
-26
lines changed

4 files changed

+70
-26
lines changed

channels.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ func handleSession(newChannel ssh.NewChannel, sshConn *SSHConnection, state *Sta
4242
data := make([]byte, 4096)
4343
dataRead, err := connection.Read(data)
4444
if err != nil && err == io.EOF {
45+
break
46+
} else if err != nil {
4547
select {
4648
case <-sshConn.Close:
4749
break
@@ -122,8 +124,14 @@ func handleAlias(newChannel ssh.NewChannel, sshConn *SSHConnection, state *State
122124

123125
sshConn.Listeners.Store(conn.RemoteAddr(), nil)
124126

125-
copyBoth(conn, connection, false)
126-
sshConn.CleanUp(state)
127+
copyBoth(conn, connection)
128+
129+
select {
130+
case <-sshConn.Close:
131+
break
132+
default:
133+
sshConn.CleanUp(state)
134+
}
127135
}
128136

129137
func writeToSession(connection ssh.Channel, c string) {

handle.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ func handleRequest(newRequest *ssh.Request, sshConn *SSHConnection, state *State
2222
case "tcpip-forward":
2323
go checkSession(newRequest, sshConn, state)
2424
handleRemoteForward(newRequest, sshConn, state)
25+
26+
err := newRequest.Reply(true, nil)
27+
if err != nil {
28+
log.Println("Error replying to socket request:", err)
29+
}
2530
default:
2631
err := newRequest.Reply(false, nil)
2732
if err != nil {

main.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,20 @@ func main() {
267267

268268
state.SSHConnections.Store(sshConn.RemoteAddr(), holderConn)
269269

270+
go func() {
271+
err := sshConn.Wait()
272+
if err != nil && *debug {
273+
log.Println("Closing SSH connection:", err)
274+
}
275+
276+
select {
277+
case <-holderConn.Close:
278+
break
279+
default:
280+
holderConn.CleanUp(state)
281+
}
282+
}()
283+
270284
go handleRequests(reqs, holderConn, state)
271285
go handleChannels(chans, holderConn, state)
272286

requests.go

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ import (
88
"net"
99
"os"
1010
"strconv"
11-
"sync"
1211
"time"
1312

1413
"github.com/logrusorgru/aurora"
@@ -227,47 +226,65 @@ func handleRemoteForward(newRequest *ssh.Request, sshConn *SSHConnection, state
227226
}
228227
}
229228

230-
go copyBoth(cl, newChan, false)
229+
go copyBoth(cl, newChan)
231230
go ssh.DiscardRequests(newReqs)
232231
}
233232
}
234233

235-
func copyBoth(writer net.Conn, reader ssh.Channel, wait bool) {
234+
// IdleTimeoutConn handles the connection with a context deadline
235+
// code adapted from https://qiita.com/kwi/items/b38d6273624ad3f6ae79
236+
type IdleTimeoutConn struct {
237+
Conn net.Conn
238+
}
239+
240+
// Read is needed to implement the reader part
241+
func (i IdleTimeoutConn) Read(buf []byte) (int, error) {
242+
err := i.Conn.SetDeadline(time.Now().Add(5 * time.Second))
243+
if err != nil {
244+
return 0, err
245+
}
246+
247+
return i.Conn.Read(buf)
248+
}
249+
250+
// Write is needed to implement the writer part
251+
func (i IdleTimeoutConn) Write(buf []byte) (int, error) {
252+
err := i.Conn.SetDeadline(time.Now().Add(5 * time.Second))
253+
if err != nil {
254+
return 0, err
255+
}
256+
257+
return i.Conn.Write(buf)
258+
}
259+
260+
func copyBoth(writer net.Conn, reader ssh.Channel) {
236261
closeBoth := func() {
237-
time.Sleep(100 * time.Millisecond)
238-
writer.Close()
239262
reader.Close()
263+
writer.Close()
240264
}
241265

242-
var wg sync.WaitGroup
243-
244-
go func() {
245-
if wait {
246-
wg.Add(1)
247-
defer wg.Done()
248-
}
266+
tcon := IdleTimeoutConn{
267+
Conn: writer,
268+
}
249269

250-
_, err := io.Copy(reader, writer)
270+
copyToReader := func() {
271+
_, err := io.Copy(reader, tcon)
251272
if err != nil && *debug {
252273
log.Println("Error copying to reader:", err)
253274
}
254-
}()
255275

256-
func() {
257-
if wait {
258-
wg.Add(1)
259-
defer wg.Done()
260-
}
276+
closeBoth()
277+
}
261278

262-
_, err := io.Copy(writer, reader)
279+
copyToWriter := func() {
280+
_, err := io.Copy(tcon, reader)
263281
if err != nil && *debug {
264282
log.Println("Error copying to writer:", err)
265283
}
266-
}()
267284

268-
if wait {
269-
wg.Wait()
285+
closeBoth()
270286
}
271287

272-
closeBoth()
288+
go copyToReader()
289+
copyToWriter()
273290
}

0 commit comments

Comments
 (0)