diff --git a/context.go b/context.go index beababe..f0751d4 100644 --- a/context.go +++ b/context.go @@ -120,6 +120,10 @@ func newContext(srv *Server) (*sshContext, context.CancelFunc) { return ctx, cancel } +func resetPermissions(ctx Context) { + ctx.Permissions().Permissions = &gossh.Permissions{} +} + // this is separate from newContext because we will get ConnMetadata // at different points so it needs to be applied separately func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { diff --git a/server.go b/server.go index 396998f..2423af5 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package ssh import ( "context" + "encoding/base64" "errors" "fmt" "net" @@ -29,6 +30,16 @@ var DefaultChannelHandlers = map[string]ChannelHandler{ "session": DefaultSessionHandler, } +var permissionsPublicKeyExt = "gliderlabs/ssh.PublicKey" + +func ensureNoPKInPermissions(ctx Context) error { + if _, ok := ctx.Permissions().Permissions.Extensions[permissionsPublicKeyExt]; ok { + return errors.New("misconfigured server: public key incorrectly set") + } + + return nil +} + // Server defines parameters for running an SSH server. The zero value for // Server is a valid configuration. When both PasswordHandler and // PublicKeyHandler are nil, no client authentication is performed. @@ -151,8 +162,14 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { } if srv.PasswordHandler != nil { config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + resetPermissions(ctx) applyConnMetadata(ctx, conn) - if ok := srv.PasswordHandler(ctx, string(password)); !ok { + err := ensureNoPKInPermissions(ctx) + if err != nil { + return ctx.Permissions().Permissions, err + } + ok := srv.PasswordHandler(ctx, string(password)) + if !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } return ctx.Permissions().Permissions, nil @@ -160,18 +177,33 @@ func (srv *Server) config(ctx Context) *gossh.ServerConfig { } if srv.PublicKeyHandler != nil { config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + resetPermissions(ctx) applyConnMetadata(ctx, conn) - if ok := srv.PublicKeyHandler(ctx, key); !ok { + err := ensureNoPKInPermissions(ctx) + if err != nil { + return ctx.Permissions().Permissions, err + } + ok := srv.PublicKeyHandler(ctx, key) + if !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } - ctx.SetValue(ContextKeyPublicKey, key) + + pkStr := base64.StdEncoding.EncodeToString(key.Marshal()) + ctx.Permissions().Permissions.Extensions[permissionsPublicKeyExt] = pkStr + return ctx.Permissions().Permissions, nil } } if srv.KeyboardInteractiveHandler != nil { config.KeyboardInteractiveCallback = func(conn gossh.ConnMetadata, challenger gossh.KeyboardInteractiveChallenge) (*gossh.Permissions, error) { + resetPermissions(ctx) applyConnMetadata(ctx, conn) - if ok := srv.KeyboardInteractiveHandler(ctx, challenger); !ok { + ok := srv.KeyboardInteractiveHandler(ctx, challenger) + err := ensureNoPKInPermissions(ctx) + if err != nil { + return ctx.Permissions().Permissions, err + } + if !ok { return ctx.Permissions().Permissions, fmt.Errorf("permission denied") } return ctx.Permissions().Permissions, nil @@ -303,6 +335,35 @@ func (srv *Server) HandleConn(newConn net.Conn) { return } + if sshConn.Permissions != nil { + // Now that the connection was authed, if the permissionsPublicKeyExt was + // attached, we need to re-parse it as a public key. + if keyData, ok := sshConn.Permissions.Extensions[permissionsPublicKeyExt]; ok { + decodedData, err := base64.StdEncoding.DecodeString(keyData) + if err != nil { + if srv.ConnectionFailedCallback != nil { + srv.ConnectionFailedCallback(conn, err) + } + return + } + + key, err := gossh.ParsePublicKey(decodedData) + if err != nil { + if srv.ConnectionFailedCallback != nil { + srv.ConnectionFailedCallback(conn, err) + } + return + } + + ctx.SetValue(ContextKeyPublicKey, key) + } + } + + // Additionally, now that the connection was authed, we can take the + // permissions off of the gossh.Conn and re-attach them to the Permissions + // object stored in the Context. + ctx.Permissions().Permissions = sshConn.Permissions + srv.trackConn(sshConn, true) defer srv.trackConn(sshConn, false)