Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 5 additions & 57 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
"os/user"
"runtime"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -502,18 +500,18 @@ func Serve(opts *ServeConfig) {
}
select {
case <-ctx.Done():
// Cancellation. We can stop the server by closing the listener.
// This isn't graceful at all but this is currently only used by
// tests and its our only way to stop.
_ = listener.Close()

// If this is a grpc server, then we also ask the server itself to
// end which will kill all connections. There isn't an easy way to do
// this for net/rpc currently but net/rpc is more and more unused.
if s, ok := server.(*GRPCServer); ok {
s.Stop()
}

// Cancellation. We can stop the server by closing the listener.
// This isn't graceful at all but this is currently only used by
// tests and its our only way to stop.
_ = listener.Close()

// Wait for the server itself to shut down
<-doneCh

Expand All @@ -525,56 +523,6 @@ func Serve(opts *ServeConfig) {
}
}

func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
if runtime.GOOS == "windows" {
return serverListener_tcp()
}

return serverListener_unix(unixSocketCfg)
}

func serverListener_tcp() (net.Listener, error) {
envMinPort := os.Getenv("PLUGIN_MIN_PORT")
envMaxPort := os.Getenv("PLUGIN_MAX_PORT")

var minPort, maxPort int64
var err error

switch {
case len(envMinPort) == 0:
minPort = 0
default:
minPort, err = strconv.ParseInt(envMinPort, 10, 32)
if err != nil {
return nil, fmt.Errorf("couldn't get value from PLUGIN_MIN_PORT: %v", err)
}
}

switch {
case len(envMaxPort) == 0:
maxPort = 0
default:
maxPort, err = strconv.ParseInt(envMaxPort, 10, 32)
if err != nil {
return nil, fmt.Errorf("couldn't get value from PLUGIN_MAX_PORT: %v", err)
}
}

if minPort > maxPort {
return nil, fmt.Errorf("PLUGIN_MIN_PORT value of %d is greater than PLUGIN_MAX_PORT value of %d", minPort, maxPort)
}

for port := minPort; port <= maxPort; port++ {
address := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", address)
if err == nil {
return listener, nil
}
}

return nil, errors.New("couldn't bind plugin TCP listener")
}

func serverListener_unix(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
tf, err := os.CreateTemp(unixSocketCfg.socketDir, "plugin")
if err != nil {
Expand Down
26 changes: 21 additions & 5 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"log"
"net"
"os"
"path"
"path/filepath"
"runtime"
"strings"
"testing"
Expand Down Expand Up @@ -68,6 +68,21 @@ func TestServer_testMode(t *testing.T) {
if err := client.Ping(); err != nil {
t.Fatalf("should not err: %s", err)
}
// Grab the impl
raw, err := client.Dispense("test")
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

tester, ok := raw.(testInterface)
if !ok {
t.Fatalf("bad: %#v", raw)
}

n := tester.Double(3)
if n != 6 {
t.Fatal("invalid response", n)
}

// Kill which should do nothing
c.Kill()
Expand Down Expand Up @@ -309,9 +324,10 @@ func TestServer_testStdLogger(t *testing.T) {

func TestUnixSocketDir(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("go-plugin doesn't support unix sockets on Windows")
if !isSupportUnix() {
t.Skip("go-plugin doesn't support unix sockets on Windows")
}
}

tmpDir := t.TempDir()
t.Setenv(EnvUnixSocketDir, tmpDir)

Expand Down Expand Up @@ -344,8 +360,8 @@ func TestUnixSocketDir(t *testing.T) {
t.Fatal("should've received reattach")
}

actualDir := path.Clean(path.Dir(cfg.Addr.String()))
expectedDir := path.Clean(tmpDir)
actualDir := filepath.Clean(filepath.Dir(cfg.Addr.String()))
expectedDir := filepath.Clean(tmpDir)
if actualDir != expectedDir {
t.Fatalf("Expected socket in dir: %s, but was in %s", expectedDir, actualDir)
}
Expand Down
13 changes: 13 additions & 0 deletions server_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build !windows
// +build !windows

package plugin

var serverListener = serverListener_unix

func isSupportUnix() bool {
return true
}
72 changes: 72 additions & 0 deletions server_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

//go:build windows
// +build windows

package plugin

import (
"errors"
"fmt"
"net"
"os"
"strconv"

"golang.org/x/sys/windows"
)

func serverListener(unixSocketCfg UnixSocketConfig) (net.Listener, error) {
if isSupportUnix() {
unixSocketCfg.Group = ""
return serverListener_unix(unixSocketCfg)
}
return serverListener_tcp()
}

func serverListener_tcp() (net.Listener, error) {
envMinPort := os.Getenv("PLUGIN_MIN_PORT")
envMaxPort := os.Getenv("PLUGIN_MAX_PORT")

var minPort, maxPort int64
var err error

switch {
case len(envMinPort) == 0:
minPort = 0
default:
minPort, err = strconv.ParseInt(envMinPort, 10, 32)
if err != nil {
return nil, fmt.Errorf("couldn't get value from PLUGIN_MIN_PORT: %v", err)
}
}

switch {
case len(envMaxPort) == 0:
maxPort = 0
default:
maxPort, err = strconv.ParseInt(envMaxPort, 10, 32)
if err != nil {
return nil, fmt.Errorf("couldn't get value from PLUGIN_MAX_PORT: %v", err)
}
}

if minPort > maxPort {
return nil, fmt.Errorf("PLUGIN_MIN_PORT value of %d is greater than PLUGIN_MAX_PORT value of %d", minPort, maxPort)
}

for port := minPort; port <= maxPort; port++ {
address := fmt.Sprintf("127.0.0.1:%d", port)
listener, err := net.Listen("tcp", address)
if err == nil {
return listener, nil
}
}

return nil, errors.New("couldn't bind plugin TCP listener")
}

func isSupportUnix() bool {
major, _, build := windows.RtlGetNtVersionNumbers()
return major >= 10 && build >= 17063
}