WireGuardNT: Tunnels cannot be "nested"

David Lönnhager dv.lnh.d at gmail.com
Thu Sep 16 09:32:34 UTC 2021


The patch below does what I want, though I suspect it has problems. Effectively,
it aims to bind the endpoint socket implicitly instead of performing a
route lookup
manually. I expected that not setting IP_PKTINFO after some route change could
cause it to not rebind the socket correctly, but it doesn't seem to
have that problem.
Feedback would be appreciated.

Expanding on my older comment slightly, what we want is for an
endpoint to connect
inside the tunnel *if no other route can be used*. With wireguard-go and other
implementations, we are able to do this by dropping traffic going
outside the tunnel
using WFP (or nftables), and adding a route for the tunnel interface
to that endpoint.
This way we can create "multihop" tunnels. wireguard-nt simply ignores
this route.

Code for reproducing the issue can be provided if it would be helpful.

David

---
driver/peer.h | 1 -
driver/socket.c | 172 ++++++++++++------------------------------------
2 files changed, 41 insertions(+), 132 deletions(-)

diff --git a/driver/peer.h b/driver/peer.h
index d5d14d7..27a81e9 100644
--- a/driver/peer.h
+++ b/driver/peer.h
@@ -33,7 +33,6 @@ typedef struct _ENDPOINT
};
};
UINT32 RoutingGeneration;
- UINT32 UpdateGeneration;
} ENDPOINT;
typedef enum _HANDSHAKE_TX_ACTION
diff --git a/driver/socket.c b/driver/socket.c
index 11e402b..854ec65 100644
--- a/driver/socket.c
+++ b/driver/socket.c
@@ -173,114 +173,6 @@ CidrMaskMatchV6(_In_ CONST IN6_ADDR *Addr, _In_
CONST IP_ADDRESS_PREFIX *Prefix)
((UINT32 *)&Prefix->Prefix.Ipv6.sin6_addr)[WholeParts];
}
-_IRQL_requires_max_(PASSIVE_LEVEL)
-_IRQL_raises_(DISPATCH_LEVEL)
-_Acquires_shared_lock_(Peer->EndpointLock)
-_Requires_lock_not_held_(Peer->EndpointLock)
-static NTSTATUS
-SocketResolvePeerEndpoint(_Inout_ WG_PEER *Peer, _Out_ _At_(*Irql,
_IRQL_saves_) KIRQL *Irql)
-{
- *Irql = ExAcquireSpinLockShared(&Peer->EndpointLock);
-retryWhileHoldingSharedLock:
- if ((Peer->Endpoint.Addr.si_family == AF_INET &&
- Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV4) &&
- Peer->Endpoint.Src4.ipi_ifindex && Peer->Endpoint.Src4.ipi_ifindex
!= Peer->Device->InterfaceIndex) ||
- (Peer->Endpoint.Addr.si_family == AF_INET6 &&
- Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV6) &&
- Peer->Endpoint.Src6.ipi6_ifindex && Peer->Endpoint.Src6.ipi6_ifindex
!= Peer->Device->InterfaceIndex))
- return STATUS_SUCCESS;
-
- SOCKADDR_INET Addr;
- UINT32 UpdateGeneration = Peer->Endpoint.UpdateGeneration;
- RtlCopyMemory(&Addr, &Peer->Endpoint.Addr, sizeof(Addr));
- ExReleaseSpinLockShared(&Peer->EndpointLock, *Irql);
- SOCKADDR_INET SrcAddr = { 0 };
- ULONG BestIndex = 0, BestCidr = 0, BestMetric = ~0UL;
- NET_LUID BestLuid = { 0 };
- MIB_IPFORWARD_TABLE2 *Table;
- NTSTATUS Status = GetIpForwardTable2(Addr.si_family, &Table);
- if (!NT_SUCCESS(Status))
- return Status;
- union
- {
- MIB_IF_ROW2 Interface;
- MIB_IPINTERFACE_ROW IpInterface;
- } *If = MemAllocate(sizeof(*If));
- if (!If)
- return STATUS_INSUFFICIENT_RESOURCES;
- for (ULONG i = 0; i < Table->NumEntries; ++i)
- {
- if (Table->Table[i].InterfaceLuid.Value == Peer->Device->InterfaceLuid.Value)
- continue;
- if (Table->Table[i].DestinationPrefix.PrefixLength < BestCidr)
- continue;
- if (Addr.si_family == AF_INET &&
!CidrMaskMatchV4(&Addr.Ipv4.sin_addr,
&Table->Table[i].DestinationPrefix))
- continue;
- if (Addr.si_family == AF_INET6 &&
!CidrMaskMatchV6(&Addr.Ipv6.sin6_addr,
&Table->Table[i].DestinationPrefix))
- continue;
- If->Interface = (MIB_IF_ROW2){ .InterfaceLuid =
Table->Table[i].InterfaceLuid };
- if (!NT_SUCCESS(GetIfEntry2(&If->Interface)) ||
If->Interface.OperStatus != IfOperStatusUp)
- continue;
- If->IpInterface =
- (MIB_IPINTERFACE_ROW){ .Family = Addr.si_family, .InterfaceLuid =
Table->Table[i].InterfaceLuid };
- if (!NT_SUCCESS(GetIpInterfaceEntry(&If->IpInterface)))
- continue;
- ULONG Metric = Table->Table[i].Metric + If->IpInterface.Metric;
- if (Table->Table[i].DestinationPrefix.PrefixLength == BestCidr &&
Metric > BestMetric)
- continue;
- BestCidr = Table->Table[i].DestinationPrefix.PrefixLength;
- BestMetric = Metric;
- BestIndex = Table->Table[i].InterfaceIndex;
- BestLuid = Table->Table[i].InterfaceLuid;
- }
- MemFree(If);
- if (Table->NumEntries && BestIndex)
- Status = GetBestRoute2(&BestLuid, 0, NULL, &Addr, 0,
&Table->Table[0], &SrcAddr);
- FreeMibTable(Table);
- if (!BestIndex)
- return STATUS_BAD_NETWORK_PATH;
- if (!NT_SUCCESS(Status))
- return Status;
-
- *Irql = ExAcquireSpinLockExclusive(&Peer->EndpointLock);
- if (UpdateGeneration != Peer->Endpoint.UpdateGeneration)
- {
- ExReleaseSpinLockExclusiveFromDpcLevel(&Peer->EndpointLock);
- ExAcquireSpinLockSharedAtDpcLevel(&Peer->EndpointLock);
- goto retryWhileHoldingSharedLock;
- }
- if (Peer->Endpoint.Addr.si_family == AF_INET)
- {
- Peer->Endpoint.Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src4));
- Peer->Endpoint.Cmsg.cmsg_level = IPPROTO_IP;
- Peer->Endpoint.Cmsg.cmsg_type = IP_PKTINFO;
- Peer->Endpoint.Src4.ipi_addr = SrcAddr.Ipv4.sin_addr;
- Peer->Endpoint.Src4.ipi_ifindex = BestIndex;
- Peer->Endpoint.CmsgHack4.cmsg_len = WSA_CMSG_LEN(0);
- Peer->Endpoint.CmsgHack4.cmsg_level = IPPROTO_IP;
- Peer->Endpoint.CmsgHack4.cmsg_type = IP_OPTIONS;
- Peer->Endpoint.RoutingGeneration = ReadNoFence(&RoutingGenerationV4);
- }
- else if (Peer->Endpoint.Addr.si_family == AF_INET6)
- {
- Peer->Endpoint.Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src6));
- Peer->Endpoint.Cmsg.cmsg_level = IPPROTO_IPV6;
- Peer->Endpoint.Cmsg.cmsg_type = IPV6_PKTINFO;
- Peer->Endpoint.Src6.ipi6_addr = SrcAddr.Ipv6.sin6_addr;
- Peer->Endpoint.Src6.ipi6_ifindex = BestIndex;
- Peer->Endpoint.CmsgHack6.cmsg_len = WSA_CMSG_LEN(0);
- Peer->Endpoint.CmsgHack6.cmsg_level = IPPROTO_IPV6;
- Peer->Endpoint.CmsgHack6.cmsg_type = IPV6_RTHDR;
- Peer->Endpoint.RoutingGeneration = ReadNoFence(&RoutingGenerationV6);
- }
- ++Peer->Endpoint.UpdateGeneration, ++UpdateGeneration;
- ExReleaseSpinLockExclusiveFromDpcLevel(&Peer->EndpointLock);
- ExAcquireSpinLockSharedAtDpcLevel(&Peer->EndpointLock);
- if (Peer->Endpoint.UpdateGeneration != UpdateGeneration)
- goto retryWhileHoldingSharedLock;
- return STATUS_SUCCESS;
-}
-
#pragma warning(suppress : 28194) /* `Nbl` is aliased in Ctx->Nbl or
freed on failure. */
#pragma warning(suppress : 28167) /* IRQL is either not raised on
SocketResolvePeerEndpoint failure, or \
restored by ExReleaseSpinLockShared */
@@ -320,10 +212,7 @@ SocketSendNblsToPeer(WG_PEER *Peer,
NET_BUFFER_LIST *First, BOOLEAN *AllKeepaliv
Ctx->Wg = Peer->Device;
IoInitializeIrp(&Ctx->Irp, sizeof(Ctx->IrpBuffer), 1);
IoSetCompletionRoutine(&Ctx->Irp, NblSendComplete, Ctx, TRUE, TRUE, TRUE);
- KIRQL Irql;
- Status = SocketResolvePeerEndpoint(Peer, &Irql);
- if (!NT_SUCCESS(Status))
- goto cleanupCtx;
+ KIRQL Irql = ExAcquireSpinLockShared(&Peer->EndpointLock);
SOCKET *Socket = NULL;
RcuReadLockAtDpcLevel();
if (Peer->Endpoint.Addr.si_family == AF_INET)
@@ -340,13 +229,24 @@ SocketSendNblsToPeer(WG_PEER *Peer,
NET_BUFFER_LIST *First, BOOLEAN *AllKeepaliv
if (NoWskSendMessages)
WskSendMessages = PolyfilledWskSendMessages;
#endif
+ ULONG CmsgLen = 0;
+ WSACMSGHDR *Cmsg = NULL;
+ if ((Peer->Endpoint.Addr.si_family == AF_INET &&
+ Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV4) &&
+ Peer->Endpoint.Src4.ipi_ifindex) ||
+ (Peer->Endpoint.Addr.si_family == AF_INET6 &&
+ Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV6) &&
+ Peer->Endpoint.Src6.ipi6_ifindex)) {
+ CmsgLen = (ULONG)WSA_CMSGDATA_ALIGN(Peer->Endpoint.Cmsg.cmsg_len) +
WSA_CMSG_SPACE(0);
+ Cmsg = &Peer->Endpoint.Cmsg;
+ }
Status = WskSendMessages(
Socket->Sock,
FirstWskBuf,
0,
(PSOCKADDR)&Peer->Endpoint.Addr,
- (ULONG)WSA_CMSGDATA_ALIGN(Peer->Endpoint.Cmsg.cmsg_len) + WSA_CMSG_SPACE(0),
- &Peer->Endpoint.Cmsg,
+ CmsgLen,
+ Cmsg,
&Ctx->Irp);
RcuReadUnlockFromDpcLevel();
ExReleaseSpinLockShared(&Peer->EndpointLock, Irql);
@@ -364,7 +264,6 @@ SocketSendNblsToPeer(WG_PEER *Peer,
NET_BUFFER_LIST *First, BOOLEAN *AllKeepaliv
cleanupRcuLock:
RcuReadUnlockFromDpcLevel();
ExReleaseSpinLockShared(&Peer->EndpointLock, Irql);
-cleanupCtx:
ExFreeToLookasideListEx(&SocketSendCtxCache, Ctx);
cleanupNbls:
FreeSendNetBufferList(Peer->Device, First, 0);
@@ -390,10 +289,7 @@ SocketSendBufferToPeer(WG_PEER *Peer, CONST VOID
*Buffer, ULONG Len)
Ctx->Wg = Peer->Device;
IoInitializeIrp(&Ctx->Irp, sizeof(Ctx->IrpBuffer), 1);
IoSetCompletionRoutine(&Ctx->Irp, BufferSendComplete, Ctx, TRUE, TRUE, TRUE);
- KIRQL Irql;
- Status = SocketResolvePeerEndpoint(Peer, &Irql);
- if (!NT_SUCCESS(Status))
- goto cleanupMdl;
+ KIRQL Irql = ExAcquireSpinLockShared(&Peer->EndpointLock);
SOCKET *Socket = NULL;
RcuReadLockAtDpcLevel();
if (Peer->Endpoint.Addr.si_family == AF_INET)
@@ -405,14 +301,25 @@ SocketSendBufferToPeer(WG_PEER *Peer, CONST VOID
*Buffer, ULONG Len)
Status = STATUS_NETWORK_UNREACHABLE;
goto cleanupRcuLock;
}
+ ULONG CmsgLen = 0;
+ WSACMSGHDR *Cmsg = NULL;
+ if ((Peer->Endpoint.Addr.si_family == AF_INET &&
+ Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV4) &&
+ Peer->Endpoint.Src4.ipi_ifindex) ||
+ (Peer->Endpoint.Addr.si_family == AF_INET6 &&
+ Peer->Endpoint.RoutingGeneration ==
(UINT32)ReadNoFence(&RoutingGenerationV6) &&
+ Peer->Endpoint.Src6.ipi6_ifindex)) {
+ CmsgLen = (ULONG)WSA_CMSGDATA_ALIGN(Peer->Endpoint.Cmsg.cmsg_len) +
WSA_CMSG_SPACE(0);
+ Cmsg = &Peer->Endpoint.Cmsg;
+ }
Status = ((WSK_PROVIDER_DATAGRAM_DISPATCH *)Socket->Sock->Dispatch)
->WskSendTo(
Socket->Sock,
&Ctx->Buffer,
0,
(PSOCKADDR)&Peer->Endpoint.Addr,
- (ULONG)WSA_CMSGDATA_ALIGN(Peer->Endpoint.Cmsg.cmsg_len) + WSA_CMSG_SPACE(0),
- &Peer->Endpoint.Cmsg,
+ CmsgLen,
+ Cmsg,
&Ctx->Irp);
RcuReadUnlockFromDpcLevel();
ExReleaseSpinLockShared(&Peer->EndpointLock, Irql);
@@ -423,7 +330,6 @@ SocketSendBufferToPeer(WG_PEER *Peer, CONST VOID
*Buffer, ULONG Len)
cleanupRcuLock:
RcuReadUnlockFromDpcLevel();
ExReleaseSpinLockShared(&Peer->EndpointLock, Irql);
-cleanupMdl:
MemFreeDataAndMdlChain(Ctx->Buffer.Mdl);
cleanupCtx:
ExFreeToLookasideListEx(&SocketSendCtxCache, Ctx);
@@ -452,9 +358,6 @@ SocketSendBufferAsReplyToNbl(WG_DEVICE *Wg, CONST
NET_BUFFER_LIST *InNbl, CONST
if (!NT_SUCCESS(Status))
goto cleanupMdl;
Status = STATUS_BAD_NETWORK_PATH;
- if ((Endpoint.Addr.si_family == AF_INET && Endpoint.Src4.ipi_ifindex
== Wg->InterfaceIndex) ||
- (Endpoint.Addr.si_family == AF_INET6 && Endpoint.Src6.ipi6_ifindex
== Wg->InterfaceIndex))
- goto cleanupMdl;
KIRQL Irql = RcuReadLock();
SOCKET *Socket = NULL;
if (Endpoint.Addr.si_family == AF_INET)
@@ -466,14 +369,25 @@ SocketSendBufferAsReplyToNbl(WG_DEVICE *Wg,
CONST NET_BUFFER_LIST *InNbl, CONST
Status = STATUS_NETWORK_UNREACHABLE;
goto cleanupRcuLock;
}
+ ULONG CmsgLen = 0;
+ WSACMSGHDR *Cmsg = NULL;
+ if ((Endpoint.Addr.si_family == AF_INET &&
+ Endpoint.RoutingGeneration == (UINT32)ReadNoFence(&RoutingGenerationV4) &&
+ Endpoint.Src4.ipi_ifindex) ||
+ (Endpoint.Addr.si_family == AF_INET6 &&
+ Endpoint.RoutingGeneration == (UINT32)ReadNoFence(&RoutingGenerationV6) &&
+ Endpoint.Src6.ipi6_ifindex)) {
+ CmsgLen = (ULONG)WSA_CMSGDATA_ALIGN(Endpoint.Cmsg.cmsg_len) +
WSA_CMSG_SPACE(0);
+ Cmsg = &Endpoint.Cmsg;
+ }
Status = ((WSK_PROVIDER_DATAGRAM_DISPATCH *)Socket->Sock->Dispatch)
->WskSendTo(
Socket->Sock,
&Ctx->Buffer,
0,
(PSOCKADDR)&Endpoint.Addr,
- (ULONG)WSA_CMSGDATA_ALIGN(Endpoint.Cmsg.cmsg_len) + WSA_CMSG_SPACE(0),
- &Endpoint.Cmsg,
+ CmsgLen,
+ Cmsg,
&Ctx->Irp);
RcuReadUnlock(Irql);
return Status;
@@ -600,7 +514,6 @@ SocketSetPeerEndpoint(WG_PEER *Peer, CONST
ENDPOINT *Endpoint)
if (Endpoint->Addr.si_family == AF_INET)
{
Peer->Endpoint.Addr.Ipv4 = Endpoint->Addr.Ipv4;
- if (Endpoint->Src4.ipi_ifindex != Peer->Device->InterfaceIndex)
{
Peer->Endpoint.Cmsg = Endpoint->Cmsg;
Peer->Endpoint.Src4 = Endpoint->Src4;
@@ -610,7 +523,6 @@ SocketSetPeerEndpoint(WG_PEER *Peer, CONST
ENDPOINT *Endpoint)
else if (Endpoint->Addr.si_family == AF_INET6)
{
Peer->Endpoint.Addr.Ipv6 = Endpoint->Addr.Ipv6;
- if (Endpoint->Src6.ipi6_ifindex != Peer->Device->InterfaceIndex)
{
Peer->Endpoint.Cmsg = Endpoint->Cmsg;
Peer->Endpoint.Src6 = Endpoint->Src6;
@@ -620,7 +532,6 @@ SocketSetPeerEndpoint(WG_PEER *Peer, CONST
ENDPOINT *Endpoint)
else
goto out;
Peer->Endpoint.RoutingGeneration = Endpoint->RoutingGeneration;
- ++Peer->Endpoint.UpdateGeneration;
out:
ExReleaseSpinLockExclusive(&Peer->EndpointLock, Irql);
}
@@ -643,7 +554,6 @@ SocketClearPeerEndpointSrc(WG_PEER *Peer)
Irql = ExAcquireSpinLockExclusive(&Peer->EndpointLock);
Peer->Endpoint.RoutingGeneration = 0;
- ++Peer->Endpoint.UpdateGeneration;
RtlZeroMemory(&Peer->Endpoint.Src6, sizeof(Peer->Endpoint.Src6));
ExReleaseSpinLockExclusive(&Peer->EndpointLock, Irql);
}
-- 
2.31.1


More information about the WireGuard mailing list