package ssh import ( "context" "errors" "fmt" "net" "sync" "time" gossh "golang.org/x/crypto/ssh" ) // ErrServerClosed is returned by the Server's Serve, ListenAndServe, // and ListenAndServeTLS methods after a call to Shutdown or Close. var ErrServerClosed = errors.New("ssh: Server closed") type RequestHandler func(ctx Context, srv *Server, req *gossh.Request) (ok bool, payload []byte) var DefaultRequestHandlers = map[string]RequestHandler{} type ChannelHandler func(srv *Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx Context) var DefaultChannelHandlers = map[string]ChannelHandler{ "session": DefaultSessionHandler, } // Server defines parameters for running an SSH server. The zero value for // Server is a valid configuration. When both PasswordHandler and // PublicKeyHandler are nil, no client authentication is performed. type Server struct { Addr string // TCP address to listen on, ":22" if empty Handler Handler // handler to invoke, ssh.DefaultHandler if nil HostSigners []Signer // private keys for the host key, must have at least one Version string // server version to be sent before the initial handshake KeyboardInteractiveHandler KeyboardInteractiveHandler // keyboard-interactive authentication handler PasswordHandler PasswordHandler // password authentication handler PublicKeyHandler PublicKeyHandler // public key authentication handler PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil ReversePortForwardingCallback ReversePortForwardingCallback // callback for allowing reverse port forwarding, denies all if nil ServerConfigCallback ServerConfigCallback // callback for configuring detailed SSH options SessionRequestCallback SessionRequestCallback // callback for allowing or denying SSH sessions IdleTimeout time.Duration // connection timeout when no activity, none if empty MaxTimeout time.Duration // absolute connection timeout, none if empty // ChannelHandlers allow overriding the built-in session handlers or provide // extensions to the protocol, such as tcpip forwarding. By default only the // "session" handler is enabled. ChannelHandlers map[string]ChannelHandler // RequestHandlers allow overriding the server-level request handlers or // provide extensions to the protocol, such as tcpip forwarding. By default // no handlers are enabled. RequestHandlers map[string]RequestHandler listenerWg sync.WaitGroup mu sync.Mutex listeners map[net.Listener]struct{} conns map[*gossh.ServerConn]struct{} connWg sync.WaitGroup doneChan chan struct{} } func (srv *Server) ensureHostSigner() error { if len(srv.HostSigners) == 0 { signer, err := generateSigner() if err != nil { return err } srv.HostSigners = append(srv.HostSigners, signer) } return nil } func (srv *Server) ensureHandlers() { srv.mu.Lock() defer srv.mu.Unlock() if srv.RequestHandlers == nil { srv.RequestHandlers = map[string]RequestHandler{} for k, v := range DefaultRequestHandlers { srv.RequestHandlers[k] = v } } if srv.ChannelHandlers == nil { srv.ChannelHandlers = map[string]ChannelHandler{} for k, v := range DefaultChannelHandlers { srv.ChannelHandlers[k] = v } } } func (srv *Server) config(ctx Context) *gossh.ServerConfig { var config *gossh.ServerConfig if srv.ServerConfigCallback == nil { config = &gossh.ServerConfig{} } else { config = srv.ServerConfigCallback(ctx) } for _, signer := range srv.HostSigners { config.AddHostKey(signer) } if srv.PasswordHandler == nil && srv.PublicKeyHandler == nil { config.NoClientAuth = true } if srv.Version != "" { config.ServerVersion = "SSH-2.0-" + srv.Version } if srv.PasswordHandler != nil { config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { applyConnMetadata(ctx, conn) if ok := srv.PasswordHandler(ctx, string(password)); !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } return ctx.Permissions().Permissions, nil } } if srv.PublicKeyHandler != nil { config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { applyConnMetadata(ctx, conn) if ok := srv.PublicKeyHandler(ctx, key); !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } ctx.SetValue(ContextKeyPublicKey, key) return ctx.Permissions().Permissions, nil } } if srv.KeyboardInteractiveHandler != nil { config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } return ctx.Permissions().Permissions, nil } } return config } // Handle sets the Handler for the server. func (srv *Server) Handle(fn Handler) { srv.Handler = fn } // Close immediately closes all active listeners and all active // connections. // // Close returns any error returned from closing the Server's // underlying Listener(s). func (srv *Server) Close() error { srv.mu.Lock() defer srv.mu.Unlock() srv.closeDoneChanLocked() err := srv.closeListenersLocked() for c := range srv.conns { c.Close() delete(srv.conns, c) } return err } // Shutdown gracefully shuts down the server without interrupting any // active connections. Shutdown works by first closing all open // listeners, and then waiting indefinitely for connections to close. // If the provided context expires before the shutdown is complete, // then the context's error is returned. func (srv *Server) Shutdown(ctx context.Context) error { srv.mu.Lock() lnerr := srv.closeListenersLocked() srv.closeDoneChanLocked() srv.mu.Unlock() finished := make(chan struct{}, 1) go func() { srv.listenerWg.Wait() srv.connWg.Wait() finished <- struct{}{} }() select { case <-ctx.Done(): return ctx.Err() case <-finished: return lnerr } } // Serve accepts incoming connections on the Listener l, creating a new // connection goroutine for each. The connection goroutines read requests and then // calls srv.Handler to handle sessions. // // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { srv.ensureHandlers() defer l.Close() if err := srv.ensureHostSigner(); err != nil { return err } if srv.Handler == nil { srv.Handler = DefaultHandler } var tempDelay time.Duration srv.trackListener(l, true) defer srv.trackListener(l, false) for { conn, e := l.Accept() if e != nil { select { case <-srv.getDoneChan(): return ErrServerClosed default: } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay > max { tempDelay = max } time.Sleep(tempDelay) continue } return e } go srv.handleConn(conn) } } func (srv *Server) handleConn(newConn net.Conn) { if srv.ConnCallback != nil { cbConn := srv.ConnCallback(newConn) if cbConn == nil { newConn.Close() return } newConn = cbConn } ctx, cancel := newContext(srv) conn := &serverConn{ Conn: newConn, idleTimeout: srv.IdleTimeout, closeCanceler: cancel, } if srv.MaxTimeout > 0 { conn.maxDeadline = time.Now().Add(srv.MaxTimeout) } defer conn.Close() sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx)) if err != nil { // TODO: trigger event callback return } srv.trackConn(sshConn, true) defer srv.trackConn(sshConn, false) ctx.SetValue(ContextKeyConn, sshConn) applyConnMetadata(ctx, sshConn) //go gossh.DiscardRequests(reqs) go srv.handleRequests(ctx, reqs) for ch := range chans { handler := srv.ChannelHandlers[ch.ChannelType()] if handler == nil { handler = srv.ChannelHandlers["default"] } if handler == nil { ch.Reject(gossh.UnknownChannelType, "unsupported channel type") continue } go handler(srv, sshConn, ch, ctx) } } func (srv *Server) handleRequests(ctx Context, in <-chan *gossh.Request) { for req := range in { handler := srv.RequestHandlers[req.Type] if handler == nil { handler = srv.RequestHandlers["default"] } if handler == nil { req.Reply(false, nil) continue } /*reqCtx, cancel := context.WithCancel(ctx) defer cancel() */ ret, payload := handler(ctx, srv, req) req.Reply(ret, payload) } } // ListenAndServe listens on the TCP network address srv.Addr and then calls // Serve to handle incoming connections. If srv.Addr is blank, ":22" is used. // ListenAndServe always returns a non-nil error. func (srv *Server) ListenAndServe() error { addr := srv.Addr if addr == "" { addr = ":22" } ln, err := net.Listen("tcp", addr) if err != nil { return err } return srv.Serve(ln) } // AddHostKey adds a private key as a host key. If an existing host key exists // with the same algorithm, it is overwritten. Each server config must have at // least one host key. func (srv *Server) AddHostKey(key Signer) { // these are later added via AddHostKey on ServerConfig, which performs the // check for one of every algorithm. srv.HostSigners = append(srv.HostSigners, key) } // SetOption runs a functional option against the server. func (srv *Server) SetOption(option Option) error { return option(srv) } func (srv *Server) getDoneChan() <-chan struct{} { srv.mu.Lock() defer srv.mu.Unlock() return srv.getDoneChanLocked() } func (srv *Server) getDoneChanLocked() chan struct{} { if srv.doneChan == nil { srv.doneChan = make(chan struct{}) } return srv.doneChan } func (srv *Server) closeDoneChanLocked() { ch := srv.getDoneChanLocked() select { case <-ch: // Already closed. Don't close again. default: // Safe to close here. We're the only closer, guarded // by srv.mu. close(ch) } } func (srv *Server) closeListenersLocked() error { var err error for ln := range srv.listeners { if cerr := ln.Close(); cerr != nil && err == nil { err = cerr } delete(srv.listeners, ln) } return err } func (srv *Server) trackListener(ln net.Listener, add bool) { srv.mu.Lock() defer srv.mu.Unlock() if srv.listeners == nil { srv.listeners = make(map[net.Listener]struct{}) } if add { // If the *Server is being reused after a previous // Close or Shutdown, reset its doneChan: if len(srv.listeners) == 0 && len(srv.conns) == 0 { srv.doneChan = nil } srv.listeners[ln] = struct{}{} srv.listenerWg.Add(1) } else { delete(srv.listeners, ln) srv.listenerWg.Done() } } func (srv *Server) trackConn(c *gossh.ServerConn, add bool) { srv.mu.Lock() defer srv.mu.Unlock() if srv.conns == nil { srv.conns = make(map[*gossh.ServerConn]struct{}) } if add { srv.conns[c] = struct{}{} srv.connWg.Add(1) } else { delete(srv.conns, c) srv.connWg.Done() } }