[PATCH] [wireguard-go] Pool for endpoint objects
Natan Elul
elul.natan at gmail.com
Sun Apr 24 09:20:08 UTC 2022
Use sync.pool for endpoints to avoid memory allocations on each receive.
When an endpoint is returned in bind linux, go allocates memory on the heap
By using sync.pool, the allocations can be reused, and can
dramatically be more efficient.
This patch includes the changes for linux bind, and an optimization
for SetEndpointFor packet, that will use the lock only if needed.
Signed-off-by: Natan Elul <elul.natan at gmail.com>
---
conn/bind_linux.go | 58 +++++++++++++++++++++++++++++++--------
conn/bind_std.go | 15 ++++++++++
conn/bind_windows.go | 14 ++++++++++
conn/bindtest/bindtest.go | 11 ++++++++
conn/conn.go | 5 ++++
device/peer.go | 11 +++++++-
device/pools.go | 1 +
device/receive.go | 1 +
8 files changed, 103 insertions(+), 13 deletions(-)
diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index f11f031..5630481 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -38,6 +38,24 @@ func (endpoint *LinuxSocketEndpoint) Src4()
*ipv4Source { return endpoin
func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 {
return endpoint.dst4() }
func (endpoint *LinuxSocketEndpoint) IsV6() bool {
return endpoint.isV6 }
+func (endpoint *LinuxSocketEndpoint) IsEqual(ep Endpoint) bool {
+ // Protect from mutable sendmsg
+ endpoint.mu.Lock()
+ defer endpoint.mu.Unlock()
+
+ linuxEp := ep.(*LinuxSocketEndpoint)
+ return endpoint.dst == linuxEp.dst && endpoint.src == linuxEp.src
+}
+
+func (endpoint *LinuxSocketEndpoint) Copy() Endpoint {
+ return &LinuxSocketEndpoint{
+ mu: sync.Mutex{},
+ dst: endpoint.dst,
+ src: endpoint.src,
+ isV6: endpoint.isV6,
+ }
+}
+
func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
@@ -58,13 +76,28 @@ func (endpoint *LinuxSocketEndpoint) dst6()
*unix.SockaddrInet6 {
type LinuxSocketBind struct {
// mu guards sock4 and sock6 and the associated fds.
// As long as someone holds mu (read or write), the associated fds are valid.
- mu sync.RWMutex
- sock4 int
- sock6 int
+ mu sync.RWMutex
+ sock4 int
+ sock6 int
+ epElementsPool sync.Pool
+}
+
+func (bind *LinuxSocketBind) GetEndpoint() *LinuxSocketEndpoint {
+ return bind.epElementsPool.Get().(*LinuxSocketEndpoint)
}
-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1,
sock6: -1} }
-func NewDefaultBind() Bind { return NewLinuxSocketBind() }
+func (bind *LinuxSocketBind) PutEndpoint(endpoint Endpoint) {
+ bind.epElementsPool.Put(endpoint)
+}
+
+func NewLinuxSocketBind() Bind {
+ return &LinuxSocketBind{sock4: -1, sock6: -1,
+ epElementsPool: sync.Pool{New: func() interface{} {
+ return new(LinuxSocketEndpoint)
+ }}}
+}
+
+func NewDefaultBind() Bind { return NewLinuxSocketBind() }
var (
_ Endpoint = (*LinuxSocketEndpoint)(nil)
@@ -224,14 +257,14 @@ func (bind *LinuxSocketBind) Close() error {
}
func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+ end := bind.GetEndpoint()
bind.mu.RLock()
defer bind.mu.RUnlock()
if bind.sock4 == -1 {
return 0, nil, net.ErrClosed
}
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
+ n, err := receive4(bind.sock4, buf, end)
+ return n, end, err
}
func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
@@ -448,11 +481,12 @@ func send4(sock int, end *LinuxSocketEndpoint,
buff []byte) error {
// clear src and retry
if err == unix.EINVAL {
- end.ClearSrc()
+ // clear source writing to source ip that can collide with isEqual
read. this is a rare execution code, so we will just
+ // create a copy and use it instead. (avoid write)
+ newEndpoint := end.Copy().(*LinuxSocketEndpoint)
+ newEndpoint.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
+ _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:],
newEndpoint.dst4(), 0)
}
return err
diff --git a/conn/bind_std.go b/conn/bind_std.go
index e0f6cdd..4306f7f 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -27,8 +27,23 @@ type StdNetBind struct {
func NewStdNetBind() Bind { return &StdNetBind{} }
+func (bind *StdNetBind) PutEndpoint(endpoint Endpoint) {
+}
+
type StdNetEndpoint netip.AddrPort
+func (e *StdNetEndpoint) IsEqual(endpoint Endpoint) bool {
+ addrPort := (*netip.AddrPort)(e)
+ addrPortParam := (*netip.AddrPort)(endpoint.(*StdNetEndpoint))
+ return addrPort.Port() == addrPortParam.Port() &&
addrPort.Addr().Compare(addrPortParam.Addr()) == 0
+}
+
+func (e *StdNetEndpoint) Copy() Endpoint {
+ addrPortString := (*netip.AddrPort)(e).String()
+ copyEndpoint, _ := netip.ParseAddrPort(addrPortString)
+ return (*StdNetEndpoint)(©Endpoint)
+}
+
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = (*StdNetEndpoint)(nil)
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 9268bc1..d0a1d66 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -77,6 +77,9 @@ type WinRingBind struct {
isOpen uint32
}
+func (bind *WinRingBind) PutEndpoint(endpoint Endpoint) {
+}
+
func NewDefaultBind() Bind { return NewWinRingBind() }
func NewWinRingBind() Bind {
@@ -131,6 +134,17 @@ func (*WinRingBind) ParseEndpoint(s string)
(Endpoint, error) {
func (*WinRingEndpoint) ClearSrc() {}
+func (e *WinRingEndpoint) IsEqual(endpoint Endpoint) bool {
+ winEndpoint := endpoint.(*WinRingEndpoint)
+ return winEndpoint.family == e.family && winEndpoint.data == e.data
+}
+
+func (e *WinRingEndpoint) Copy() Endpoint {
+ return &WinRingEndpoint{
+ family: e.family,
+ data: e.data,
+ }
+}
func (e *WinRingEndpoint) DstIP() netip.Addr {
switch e.family {
case windows.AF_INET:
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index b38cae6..a04ccc9 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -23,8 +23,19 @@ type ChannelBind struct {
target4, target6 ChannelEndpoint
}
+func (c *ChannelBind) PutEndpoint(endpoint conn.Endpoint) {
+}
+
type ChannelEndpoint uint16
+func (c ChannelEndpoint) IsEqual(endpoint conn.Endpoint) bool {
+ return c == endpoint.(ChannelEndpoint)
+}
+
+func (c ChannelEndpoint) Copy() conn.Endpoint {
+ return c
+}
+
var (
_ conn.Bind = (*ChannelBind)(nil)
_ conn.Endpoint = (*ChannelEndpoint)(nil)
diff --git a/conn/conn.go b/conn/conn.go
index 5a93b2b..c772b25 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -43,6 +43,9 @@ type Bind interface {
// ParseEndpoint creates a new endpoint from a string.
ParseEndpoint(s string) (Endpoint, error)
+
+ // PutEndpoint returns endpoint back to pool
+ PutEndpoint(endpoint Endpoint)
}
// BindSocketToInterface is implemented by Bind objects that support being
@@ -70,6 +73,8 @@ type Endpoint interface {
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() netip.Addr
SrcIP() netip.Addr
+ IsEqual(endpoint Endpoint) bool
+ Copy() Endpoint
}
var (
diff --git a/device/peer.go b/device/peer.go
index 5bd52df..eb1cc41 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -271,7 +271,16 @@ func (peer *Peer) SetEndpointFromPacket(endpoint
conn.Endpoint) {
if peer.disableRoaming {
return
}
+
+ peer.RLock()
+ if peer.endpoint.IsEqual(endpoint) {
+ peer.RUnlock()
+ return
+ }
+
+ peer.RUnlock()
+
peer.Lock()
- peer.endpoint = endpoint
+ peer.endpoint = endpoint.Copy()
peer.Unlock()
}
diff --git a/device/pools.go b/device/pools.go
index f40477b..f861c51 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -70,6 +70,7 @@ func (device *Device) GetInboundElement()
*QueueInboundElement {
}
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ device.net.bind.PutEndpoint(elem.endpoint)
elem.clearPointers()
device.pool.inboundElements.Put(elem)
}
diff --git a/device/receive.go b/device/receive.go
index cc34498..e6cebcd 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -390,6 +390,7 @@ func (device *Device) RoutineHandshake(id int) {
peer.SendKeepalive()
}
skip:
+ device.net.bind.PutEndpoint(elem.endpoint)
device.PutMessageBuffer(elem.buffer)
}
}
--
2.30.1 (Apple Git-130)
More information about the WireGuard
mailing list