[PATCH] wireguard: convert index_hashtable and pubkey_hashtable into rhashtables

Hamza Mahfooz someguy at effective-light.com
Fri Aug 6 04:43:14 UTC 2021


It is made mention of in commit e7096c131e516 ("net: WireGuard secure
network tunnel"), that it is desirable to move away from the statically
sized hash-table implementation.

Signed-off-by: Hamza Mahfooz <someguy at effective-light.com>
---
 drivers/net/wireguard/device.c     |   4 +
 drivers/net/wireguard/device.h     |   2 +-
 drivers/net/wireguard/noise.c      |   1 +
 drivers/net/wireguard/noise.h      |   1 +
 drivers/net/wireguard/peer.h       |   2 +-
 drivers/net/wireguard/peerlookup.c | 190 ++++++++++++++---------------
 drivers/net/wireguard/peerlookup.h |  27 ++--
 7 files changed, 112 insertions(+), 115 deletions(-)

diff --git a/drivers/net/wireguard/device.c b/drivers/net/wireguard/device.c
index 551ddaaaf540..3bd43c9481ef 100644
--- a/drivers/net/wireguard/device.c
+++ b/drivers/net/wireguard/device.c
@@ -243,7 +243,9 @@ static void wg_destruct(struct net_device *dev)
 	skb_queue_purge(&wg->incoming_handshakes);
 	free_percpu(dev->tstats);
 	free_percpu(wg->incoming_handshakes_worker);
+	wg_index_hashtable_destroy(wg->index_hashtable);
 	kvfree(wg->index_hashtable);
+	wg_pubkey_hashtable_destroy(wg->peer_hashtable);
 	kvfree(wg->peer_hashtable);
 	mutex_unlock(&wg->device_update_lock);
 
@@ -382,8 +384,10 @@ static int wg_newlink(struct net *src_net, struct net_device *dev,
 err_free_tstats:
 	free_percpu(dev->tstats);
 err_free_index_hashtable:
+	wg_index_hashtable_destroy(wg->index_hashtable);
 	kvfree(wg->index_hashtable);
 err_free_peer_hashtable:
+	wg_pubkey_hashtable_destroy(wg->peer_hashtable);
 	kvfree(wg->peer_hashtable);
 	return ret;
 }
diff --git a/drivers/net/wireguard/device.h b/drivers/net/wireguard/device.h
index 854bc3d97150..24980eb766af 100644
--- a/drivers/net/wireguard/device.h
+++ b/drivers/net/wireguard/device.h
@@ -50,7 +50,7 @@ struct wg_device {
 	struct multicore_worker __percpu *incoming_handshakes_worker;
 	struct cookie_checker cookie_checker;
 	struct pubkey_hashtable *peer_hashtable;
-	struct index_hashtable *index_hashtable;
+	struct rhashtable *index_hashtable;
 	struct allowedips peer_allowedips;
 	struct mutex device_update_lock, socket_update_lock;
 	struct list_head device_list, peer_list;
diff --git a/drivers/net/wireguard/noise.c b/drivers/net/wireguard/noise.c
index c0cfd9b36c0b..d42a0ff2be5d 100644
--- a/drivers/net/wireguard/noise.c
+++ b/drivers/net/wireguard/noise.c
@@ -797,6 +797,7 @@ bool wg_noise_handshake_begin_session(struct noise_handshake *handshake,
 	new_keypair->i_am_the_initiator = handshake->state ==
 					  HANDSHAKE_CONSUMED_RESPONSE;
 	new_keypair->remote_index = handshake->remote_index;
+	new_keypair->entry.index = handshake->entry.index;
 
 	if (new_keypair->i_am_the_initiator)
 		derive_keys(&new_keypair->sending, &new_keypair->receiving,
diff --git a/drivers/net/wireguard/noise.h b/drivers/net/wireguard/noise.h
index c527253dba80..ea705747e4e4 100644
--- a/drivers/net/wireguard/noise.h
+++ b/drivers/net/wireguard/noise.h
@@ -72,6 +72,7 @@ struct noise_handshake {
 
 	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
 	u8 remote_static[NOISE_PUBLIC_KEY_LEN];
+	siphash_key_t skey;
 	u8 remote_ephemeral[NOISE_PUBLIC_KEY_LEN];
 	u8 precomputed_static_static[NOISE_PUBLIC_KEY_LEN];
 
diff --git a/drivers/net/wireguard/peer.h b/drivers/net/wireguard/peer.h
index 76e4d3128ad4..d5403fb7a6a0 100644
--- a/drivers/net/wireguard/peer.h
+++ b/drivers/net/wireguard/peer.h
@@ -48,7 +48,7 @@ struct wg_peer {
 	atomic64_t last_sent_handshake;
 	struct work_struct transmit_handshake_work, clear_peer_work, transmit_packet_work;
 	struct cookie latest_cookie;
-	struct hlist_node pubkey_hash;
+	struct rhash_head pubkey_hash;
 	u64 rx_bytes, tx_bytes;
 	struct timer_list timer_retransmit_handshake, timer_send_keepalive;
 	struct timer_list timer_new_handshake, timer_zero_key_material;
diff --git a/drivers/net/wireguard/peerlookup.c b/drivers/net/wireguard/peerlookup.c
index f2783aa7a88f..2ea2ba85a33d 100644
--- a/drivers/net/wireguard/peerlookup.c
+++ b/drivers/net/wireguard/peerlookup.c
@@ -7,18 +7,29 @@
 #include "peer.h"
 #include "noise.h"
 
-static struct hlist_head *pubkey_bucket(struct pubkey_hashtable *table,
-					const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
+struct pubkey_pair {
+	u8 key[NOISE_PUBLIC_KEY_LEN];
+	siphash_key_t skey;
+};
+
+static u32 pubkey_hash(const void *data, u32 len, u32 seed)
 {
+	const struct pubkey_pair *pair = data;
+
 	/* siphash gives us a secure 64bit number based on a random key. Since
-	 * the bits are uniformly distributed, we can then mask off to get the
-	 * bits we need.
+	 * the bits are uniformly distributed.
 	 */
-	const u64 hash = siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key);
 
-	return &table->hashtable[hash & (HASH_SIZE(table->hashtable) - 1)];
+	return (u32)siphash(pair->key, len, &pair->skey);
 }
 
+static const struct rhashtable_params wg_peer_params = {
+	.key_len = NOISE_PUBLIC_KEY_LEN,
+	.key_offset = offsetof(struct wg_peer, handshake.remote_static),
+	.head_offset = offsetof(struct wg_peer, pubkey_hash),
+	.hashfn = pubkey_hash
+};
+
 struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
 {
 	struct pubkey_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
@@ -27,26 +38,25 @@ struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
 		return NULL;
 
 	get_random_bytes(&table->key, sizeof(table->key));
-	hash_init(table->hashtable);
-	mutex_init(&table->lock);
+	rhashtable_init(&table->hashtable, &wg_peer_params);
+
 	return table;
 }
 
 void wg_pubkey_hashtable_add(struct pubkey_hashtable *table,
 			     struct wg_peer *peer)
 {
-	mutex_lock(&table->lock);
-	hlist_add_head_rcu(&peer->pubkey_hash,
-			   pubkey_bucket(table, peer->handshake.remote_static));
-	mutex_unlock(&table->lock);
+	memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
+	WARN_ON(rhashtable_insert_fast(&table->hashtable, &peer->pubkey_hash,
+				       wg_peer_params));
 }
 
 void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
 				struct wg_peer *peer)
 {
-	mutex_lock(&table->lock);
-	hlist_del_init_rcu(&peer->pubkey_hash);
-	mutex_unlock(&table->lock);
+	memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
+	rhashtable_remove_fast(&table->hashtable, &peer->pubkey_hash,
+			       wg_peer_params);
 }
 
 /* Returns a strong reference to a peer */
@@ -54,41 +64,54 @@ struct wg_peer *
 wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
 			   const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
 {
-	struct wg_peer *iter_peer, *peer = NULL;
+	struct wg_peer *peer = NULL;
+	struct pubkey_pair pair;
+
+	memcpy(pair.key, pubkey, NOISE_PUBLIC_KEY_LEN);
+	memcpy(&pair.skey, &table->key, sizeof(pair.skey));
 
 	rcu_read_lock_bh();
-	hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey),
-				    pubkey_hash) {
-		if (!memcmp(pubkey, iter_peer->handshake.remote_static,
-			    NOISE_PUBLIC_KEY_LEN)) {
-			peer = iter_peer;
-			break;
-		}
-	}
-	peer = wg_peer_get_maybe_zero(peer);
+	peer = wg_peer_get_maybe_zero(rhashtable_lookup_fast(&table->hashtable,
+							     &pair,
+							     wg_peer_params));
 	rcu_read_unlock_bh();
+
 	return peer;
 }
 
-static struct hlist_head *index_bucket(struct index_hashtable *table,
-				       const __le32 index)
+void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table)
+{
+	WARN_ON(atomic_read(&table->hashtable.nelems));
+	rhashtable_destroy(&table->hashtable);
+}
+
+static u32 index_hash(const void *data, u32 len, u32 seed)
 {
+	const __le32 *index = data;
+
 	/* Since the indices are random and thus all bits are uniformly
-	 * distributed, we can find its bucket simply by masking.
+	 * distributed, we can use them as the hash value.
 	 */
-	return &table->hashtable[(__force u32)index &
-				 (HASH_SIZE(table->hashtable) - 1)];
+
+	return (__force u32)*index;
 }
 
-struct index_hashtable *wg_index_hashtable_alloc(void)
+static const struct rhashtable_params index_entry_params = {
+	.key_len = sizeof(__le32),
+	.key_offset = offsetof(struct index_hashtable_entry, index),
+	.head_offset = offsetof(struct index_hashtable_entry, index_hash),
+	.hashfn = index_hash
+};
+
+struct rhashtable *wg_index_hashtable_alloc(void)
 {
-	struct index_hashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
+	struct rhashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
 
 	if (!table)
 		return NULL;
 
-	hash_init(table->hashtable);
-	spin_lock_init(&table->lock);
+	rhashtable_init(table, &index_entry_params);
+
 	return table;
 }
 
@@ -116,111 +139,86 @@ struct index_hashtable *wg_index_hashtable_alloc(void)
  * is another thing to consider moving forward.
  */
 
-__le32 wg_index_hashtable_insert(struct index_hashtable *table,
+__le32 wg_index_hashtable_insert(struct rhashtable *table,
 				 struct index_hashtable_entry *entry)
 {
 	struct index_hashtable_entry *existing_entry;
 
-	spin_lock_bh(&table->lock);
-	hlist_del_init_rcu(&entry->index_hash);
-	spin_unlock_bh(&table->lock);
+	wg_index_hashtable_remove(table, entry);
 
 	rcu_read_lock_bh();
 
 search_unused_slot:
 	/* First we try to find an unused slot, randomly, while unlocked. */
 	entry->index = (__force __le32)get_random_u32();
-	hlist_for_each_entry_rcu_bh(existing_entry,
-				    index_bucket(table, entry->index),
-				    index_hash) {
-		if (existing_entry->index == entry->index)
-			/* If it's already in use, we continue searching. */
-			goto search_unused_slot;
-	}
 
-	/* Once we've found an unused slot, we lock it, and then double-check
-	 * that nobody else stole it from us.
-	 */
-	spin_lock_bh(&table->lock);
-	hlist_for_each_entry_rcu_bh(existing_entry,
-				    index_bucket(table, entry->index),
-				    index_hash) {
-		if (existing_entry->index == entry->index) {
-			spin_unlock_bh(&table->lock);
-			/* If it was stolen, we start over. */
-			goto search_unused_slot;
-		}
+	existing_entry = rhashtable_lookup_get_insert_fast(table,
+							   &entry->index_hash,
+							   index_entry_params);
+
+	if (existing_entry) {
+		WARN_ON(IS_ERR(existing_entry));
+
+		/* If it's already in use, we continue searching. */
+		goto search_unused_slot;
 	}
-	/* Otherwise, we know we have it exclusively (since we're locked),
-	 * so we insert.
-	 */
-	hlist_add_head_rcu(&entry->index_hash,
-			   index_bucket(table, entry->index));
-	spin_unlock_bh(&table->lock);
 
 	rcu_read_unlock_bh();
 
 	return entry->index;
 }
 
-bool wg_index_hashtable_replace(struct index_hashtable *table,
+bool wg_index_hashtable_replace(struct rhashtable *table,
 				struct index_hashtable_entry *old,
 				struct index_hashtable_entry *new)
 {
-	bool ret;
+	int ret = rhashtable_replace_fast(table, &old->index_hash,
+					  &new->index_hash,
+					  index_entry_params);
 
-	spin_lock_bh(&table->lock);
-	ret = !hlist_unhashed(&old->index_hash);
-	if (unlikely(!ret))
-		goto out;
+	WARN_ON(ret == -EINVAL);
 
-	new->index = old->index;
-	hlist_replace_rcu(&old->index_hash, &new->index_hash);
-
-	/* Calling init here NULLs out index_hash, and in fact after this
-	 * function returns, it's theoretically possible for this to get
-	 * reinserted elsewhere. That means the RCU lookup below might either
-	 * terminate early or jump between buckets, in which case the packet
-	 * simply gets dropped, which isn't terrible.
-	 */
-	INIT_HLIST_NODE(&old->index_hash);
-out:
-	spin_unlock_bh(&table->lock);
-	return ret;
+	return ret != -ENOENT;
 }
 
-void wg_index_hashtable_remove(struct index_hashtable *table,
+void wg_index_hashtable_remove(struct rhashtable *table,
 			       struct index_hashtable_entry *entry)
 {
-	spin_lock_bh(&table->lock);
-	hlist_del_init_rcu(&entry->index_hash);
-	spin_unlock_bh(&table->lock);
+	rhashtable_remove_fast(table, &entry->index_hash, index_entry_params);
 }
 
 /* Returns a strong reference to a entry->peer */
 struct index_hashtable_entry *
-wg_index_hashtable_lookup(struct index_hashtable *table,
+wg_index_hashtable_lookup(struct rhashtable *table,
 			  const enum index_hashtable_type type_mask,
 			  const __le32 index, struct wg_peer **peer)
 {
-	struct index_hashtable_entry *iter_entry, *entry = NULL;
+	struct index_hashtable_entry *entry = NULL;
 
 	rcu_read_lock_bh();
-	hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index),
-				    index_hash) {
-		if (iter_entry->index == index) {
-			if (likely(iter_entry->type & type_mask))
-				entry = iter_entry;
-			break;
-		}
-	}
+	entry = rhashtable_lookup_fast(table, &index, index_entry_params);
+
 	if (likely(entry)) {
+		if (unlikely(!(entry->type & type_mask))) {
+			entry = NULL;
+			goto out;
+		}
+
 		entry->peer = wg_peer_get_maybe_zero(entry->peer);
 		if (likely(entry->peer))
 			*peer = entry->peer;
 		else
 			entry = NULL;
 	}
+
+out:
 	rcu_read_unlock_bh();
+
 	return entry;
 }
+
+void wg_index_hashtable_destroy(struct rhashtable *table)
+{
+	WARN_ON(atomic_read(&table->nelems));
+	rhashtable_destroy(table);
+}
diff --git a/drivers/net/wireguard/peerlookup.h b/drivers/net/wireguard/peerlookup.h
index ced811797680..a3cef26cb733 100644
--- a/drivers/net/wireguard/peerlookup.h
+++ b/drivers/net/wireguard/peerlookup.h
@@ -8,17 +8,14 @@
 
 #include "messages.h"
 
-#include <linux/hashtable.h>
-#include <linux/mutex.h>
+#include <linux/rhashtable.h>
 #include <linux/siphash.h>
 
 struct wg_peer;
 
 struct pubkey_hashtable {
-	/* TODO: move to rhashtable */
-	DECLARE_HASHTABLE(hashtable, 11);
+	struct rhashtable hashtable;
 	siphash_key_t key;
-	struct mutex lock;
 };
 
 struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void);
@@ -29,12 +26,7 @@ void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
 struct wg_peer *
 wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
 			   const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
-
-struct index_hashtable {
-	/* TODO: move to rhashtable */
-	DECLARE_HASHTABLE(hashtable, 13);
-	spinlock_t lock;
-};
+void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table);
 
 enum index_hashtable_type {
 	INDEX_HASHTABLE_HANDSHAKE = 1U << 0,
@@ -43,22 +35,23 @@ enum index_hashtable_type {
 
 struct index_hashtable_entry {
 	struct wg_peer *peer;
-	struct hlist_node index_hash;
+	struct rhash_head index_hash;
 	enum index_hashtable_type type;
 	__le32 index;
 };
 
-struct index_hashtable *wg_index_hashtable_alloc(void);
-__le32 wg_index_hashtable_insert(struct index_hashtable *table,
+struct rhashtable *wg_index_hashtable_alloc(void);
+__le32 wg_index_hashtable_insert(struct rhashtable *table,
 				 struct index_hashtable_entry *entry);
-bool wg_index_hashtable_replace(struct index_hashtable *table,
+bool wg_index_hashtable_replace(struct rhashtable *table,
 				struct index_hashtable_entry *old,
 				struct index_hashtable_entry *new);
-void wg_index_hashtable_remove(struct index_hashtable *table,
+void wg_index_hashtable_remove(struct rhashtable *table,
 			       struct index_hashtable_entry *entry);
 struct index_hashtable_entry *
-wg_index_hashtable_lookup(struct index_hashtable *table,
+wg_index_hashtable_lookup(struct rhashtable *table,
 			  const enum index_hashtable_type type_mask,
 			  const __le32 index, struct wg_peer **peer);
+void wg_index_hashtable_destroy(struct rhashtable *table);
 
 #endif /* _WG_PEERLOOKUP_H */
-- 
2.32.0



More information about the WireGuard mailing list