Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add remote IP filter to allow a connection from remote kms #692

Merged
merged 5 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/ostracon/commands/show_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func showValidator(cmd *cobra.Command, args []string, config *cfg.Config) error
if err != nil {
return err
}
pv, err = node.CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, chainID, logger)
pv, err = node.CreateAndStartPrivValidatorSocketClient(config, chainID, logger)
if err != nil {
return err
}
Expand Down
2 changes: 2 additions & 0 deletions cmd/ostracon/commands/show_validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"bytes"
"os"
"strings"
"sync"
"testing"

Expand Down Expand Up @@ -79,6 +80,7 @@ func TestShowValidatorWithKMS(t *testing.T) {
}
privval.WithMockKMS(t, dir, chainID, func(addr string, privKey crypto.PrivKey) {
config.PrivValidatorListenAddr = addr
config.PrivValidatorRemoteAddr = addr[:strings.Index(addr, ":")]
require.NoFileExists(t, config.PrivValidatorKeyFile())
output, err := captureStdout(func() {
err := showValidator(ShowValidatorCmd, nil, config)
Expand Down
6 changes: 6 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,14 @@ type BaseConfig struct { //nolint: maligned

// TCP or UNIX socket address for Ostracon to listen on for
// connections from an external PrivValidator process
// example) tcp://0.0.0.0:26659
PrivValidatorListenAddr string `mapstructure:"priv_validator_laddr"`

// Validator's remote address(without port) to allow a connection
// ostracon only allow a connection from this address
// example) 10.0.0.7
PrivValidatorRemoteAddr string `mapstructure:"priv_validator_raddr"`
ulbqb marked this conversation as resolved.
Show resolved Hide resolved

// A JSON file containing the private key to use for p2p authenticated encryption
NodeKey string `mapstructure:"node_key_file"`

Expand Down
6 changes: 6 additions & 0 deletions config/toml.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,14 @@ priv_validator_state_file = "{{ js .BaseConfig.PrivValidatorState }}"

# TCP or UNIX socket address for Ostracon to listen on for
# connections from an external PrivValidator process
# example) tcp://0.0.0.0:26659
priv_validator_laddr = "{{ .BaseConfig.PrivValidatorListenAddr }}"

# Validator's remote address to allow a connection
# ostracon only allow a connection from this address
# example) 127.0.0.1
priv_validator_raddr = "127.0.0.1"

# Path to the JSON file containing the private key to use for node authentication in the p2p protocol
node_key_file = "{{ js .BaseConfig.NodeKey }}"

Expand Down
10 changes: 3 additions & 7 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ func NewNode(config *cfg.Config,
// external signing process.
if config.PrivValidatorListenAddr != "" {
// FIXME: we should start services inside OnStart
privValidator, err = CreateAndStartPrivValidatorSocketClient(config.PrivValidatorListenAddr, genDoc.ChainID, logger)
privValidator, err = CreateAndStartPrivValidatorSocketClient(config, genDoc.ChainID, logger)
if err != nil {
return nil, fmt.Errorf("error with private validator socket client: %w", err)
}
Expand Down Expand Up @@ -1523,12 +1523,8 @@ func saveGenesisDoc(db dbm.DB, genDoc *types.GenesisDoc) error {
return nil
}

func CreateAndStartPrivValidatorSocketClient(
listenAddr,
chainID string,
logger log.Logger,
) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(listenAddr, logger)
func CreateAndStartPrivValidatorSocketClient(config *cfg.Config, chainID string, logger log.Logger) (types.PrivValidator, error) {
pve, err := privval.NewSignerListener(logger, config.PrivValidatorListenAddr, config.PrivValidatorRemoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to start private validator: %w", err)
}
Expand Down
6 changes: 6 additions & 0 deletions node/node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,17 @@ func TestNodeSetAppVersion(t *testing.T) {
}

func TestNodeSetPrivValTCP(t *testing.T) {
address := testFreeAddr(t)
addr := "tcp://" + testFreeAddr(t)

config := cfg.ResetTestRoot("node_priv_val_tcp_test")
defer os.RemoveAll(config.RootDir)
config.BaseConfig.PrivValidatorListenAddr = addr
addrPart, _, err := net.SplitHostPort(address)
if err != nil {
return
}
config.BaseConfig.PrivValidatorRemoteAddr = addrPart

dialer := privval.DialTCPFn(addr, 100*time.Millisecond, ed25519.GenPrivKey())
dialerEndpoint := privval.NewSignerDialerEndpoint(
Expand Down
8 changes: 8 additions & 0 deletions privval/internal/conn_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package internal

import "net"

type ConnectionFilter interface {
Filter(addr net.Addr) net.Addr
String() string
}
44 changes: 44 additions & 0 deletions privval/internal/ip_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package internal

import (
"fmt"
"github.com/Finschia/ostracon/libs/log"
"net"
)

type IpFilter struct {
allowAddr string
log log.Logger
}

func NewIpFilter(addr string, l log.Logger) *IpFilter {
return &IpFilter{
allowAddr: addr,
log: l,
}
}

func (f *IpFilter) Filter(addr net.Addr) net.Addr {
if f.isAllowedAddr(addr) {
return addr
}
return nil
}

func (f *IpFilter) String() string {
return f.allowAddr
}

func (f *IpFilter) isAllowedAddr(addr net.Addr) bool {
if len(f.allowAddr) == 0 {
return false
}
hostAddr, _, err := net.SplitHostPort(addr.String())
if err != nil {
if f.log != nil {
f.log.Error(fmt.Sprintf("IpFilter: can't split host and port from addr.String()=%s", addr.String()))
}
return false

Check warning on line 41 in privval/internal/ip_filter.go

View check run for this annotation

Codecov / codecov/patch

privval/internal/ip_filter.go#L38-L41

Added lines #L38 - L41 were not covered by tests
}
return f.allowAddr == hostAddr
}
91 changes: 91 additions & 0 deletions privval/internal/ip_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package internal

import (
"github.com/stretchr/testify/assert"
"net"
"testing"
)

type addrStub struct {
address string
}

func (a addrStub) Network() string {
return ""
}

func (a addrStub) String() string {
return a.address
}

func TestFilterRemoteConnectionByIP(t *testing.T) {
type fields struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}
tests := []struct {
name string
fields fields
}{
{
"should allow correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"127.0.0.1", addrStub{"127.0.0.1:45678"}, addrStub{"127.0.0.1:45678"}},
},
{
"should not allow different ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"127.0.0.1", addrStub{"10.0.0.2:45678"}, nil},
},
{
"should works for IPv6 with correct ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"2001:db8::1", addrStub{"[2001:db8::1]:80"}, addrStub{"[2001:db8::1]:80"}},
},
{
"should works for IPv6 with incorrect ip",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"2001:db8::2", addrStub{"[2001:db8::1]:80"}, nil},
},
{
"empty allowIP should deny all",
struct {
allowIP string
remoteAddr net.Addr
expected net.Addr
}{"", addrStub{"127.0.0.1:45678"}, nil},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cut := NewIpFilter(tt.fields.allowIP, nil)
assert.Equalf(t, tt.fields.expected, cut.Filter(tt.fields.remoteAddr), tt.name)
})
}
}

func TestIpFilterShouldSetAllowAddress(t *testing.T) {
expected := "192.168.0.1"

cut := NewIpFilter(expected, nil)

assert.Equal(t, expected, cut.allowAddr)
}

func TestIpFilterStringShouldReturnsIP(t *testing.T) {
expected := "127.0.0.1"
assert.Equal(t, expected, NewIpFilter(expected, nil).String())
}
19 changes: 19 additions & 0 deletions privval/internal/null_object_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package internal

import "net"

// NullObject is null object pattern. It does nothing
type NullObject struct {
}

func NewNullObject() *NullObject {
return &NullObject{}
}

func (n NullObject) Filter(addr net.Addr) net.Addr {
return addr
}

func (n NullObject) String() string {
return "NullObject"
}
40 changes: 40 additions & 0 deletions privval/internal/null_object_filter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package internal

import (
"github.com/stretchr/testify/assert"
"net"
"reflect"
"testing"
)

func TestNullObject_filter(t *testing.T) {
stubInput := addrStub{}
tests := []struct {
name string
addr net.Addr
want net.Addr
}{
{
name: "null object does nothing, returns what it receives",
addr: stubInput,
want: stubInput,
},
{
name: "null object does nothing, returns nil it receives nil",
addr: nil,
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNullObject()
if got := n.Filter(tt.addr); !reflect.DeepEqual(got, tt.want) {
t.Errorf("Filter() = %v, want %v", got, tt.want)
}
})
}
}

func TestNullObjectString(t *testing.T) {
assert.Equal(t, "NullObject", NewNullObject().String())
}
28 changes: 28 additions & 0 deletions privval/signer_listener_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import (
"fmt"
"github.com/Finschia/ostracon/privval/internal"
"net"
"time"

Expand All @@ -24,6 +25,19 @@
return func(sl *SignerListenerEndpoint) { sl.signerEndpoint.timeoutReadWrite = timeout }
}

// SignerListenerEndpointAllowAddress sets the address to allow
// connections from the only allowed address
//
func SignerListenerEndpointAllowAddress(protocol string, addr string) SignerListenerEndpointOption {
return func(sl *SignerListenerEndpoint) {
if protocol == "tcp" || len(protocol) == 0 {
sl.connFilter = internal.NewIpFilter(addr, sl.Logger)
return
}
sl.connFilter = internal.NewNullObject()
}
}

// SignerListenerEndpoint listens for an external process to dial in and keeps
// the connection alive by dropping and reconnecting.
//
Expand All @@ -41,6 +55,7 @@
pingInterval time.Duration

instanceMtx tmsync.Mutex // Ensures instance public methods access, i.e. SendRequest
connFilter internal.ConnectionFilter
}

// NewSignerListenerEndpoint returns an instance of SignerListenerEndpoint.
Expand Down Expand Up @@ -186,6 +201,12 @@
{
conn, err := sl.acceptNewConnection()
if err == nil {
remoteAddr := conn.RemoteAddr()
if sl.filter(remoteAddr) == nil {
sl.Logger.Info(fmt.Sprintf("SignerListener: deny a connection request from remote address=%s, expected=%s", remoteAddr, sl.connFilter))
conn.Close()
continue

Check warning on line 208 in privval/signer_listener_endpoint.go

View check run for this annotation

Codecov / codecov/patch

privval/signer_listener_endpoint.go#L206-L208

Added lines #L206 - L208 were not covered by tests
}
sl.Logger.Info("SignerListener: Connected")

// We have a good connection, wait for someone that needs one otherwise cancellation
Expand All @@ -207,6 +228,13 @@
}
}

func (sl *SignerListenerEndpoint) filter(addr net.Addr) net.Addr {
if sl.connFilter == nil {
return addr
}
return sl.connFilter.Filter(addr)

Check warning on line 235 in privval/signer_listener_endpoint.go

View check run for this annotation

Codecov / codecov/patch

privval/signer_listener_endpoint.go#L235

Added line #L235 was not covered by tests
}

func (sl *SignerListenerEndpoint) pingLoop() {
for {
select {
Expand Down
13 changes: 13 additions & 0 deletions privval/signer_listener_endpoint_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package privval

import (
"github.com/Finschia/ostracon/privval/internal"
"net"
"testing"
"time"
Expand Down Expand Up @@ -213,3 +214,15 @@ func getMockEndpoints(

return listenerEndpoint, dialerEndpoint
}

func TestSignerListenerEndpointAllowAddressSetIpFilterForTCP(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("tcp", "127.0.0.1"))
_, ok := cut.connFilter.(*internal.IpFilter)
assert.True(t, ok)
}

func TestSignerListenerEndpointAllowAddressSetNullObjectFilterForUDS(t *testing.T) {
cut := NewSignerListenerEndpoint(nil, nil, SignerListenerEndpointAllowAddress("unix", "/mnt/uds/sock01"))
_, ok := cut.connFilter.(*internal.NullObject)
assert.True(t, ok)
}
Loading
Loading