[PATCH] wireguard: convert index_hashtable and pubkey_hashtable into rhashtables

Hamza Mahfooz someguy at effective-light.com
Wed Sep 8 10:27:24 UTC 2021


ping

On Fri, Aug 6 2021 at 12:43:14 AM -0400, Hamza Mahfooz 
<someguy at effective-light.com> wrote:
> It is made mention of in commit e7096c131e516 ("net: WireGuard secure
> network tunnel"), that it is desirable to move away from the 
> statically
> sized hash-table implementation.
> 
> Signed-off-by: Hamza Mahfooz <someguy at effective-light.com>
> ---
>  drivers/net/wireguard/device.c     |   4 +
>  drivers/net/wireguard/device.h     |   2 +-
>  drivers/net/wireguard/noise.c      |   1 +
>  drivers/net/wireguard/noise.h      |   1 +
>  drivers/net/wireguard/peer.h       |   2 +-
>  drivers/net/wireguard/peerlookup.c | 190 
> ++++++++++++++---------------
>  drivers/net/wireguard/peerlookup.h |  27 ++--
>  7 files changed, 112 insertions(+), 115 deletions(-)
> 
> diff --git a/drivers/net/wireguard/device.c 
> b/drivers/net/wireguard/device.c
> index 551ddaaaf540..3bd43c9481ef 100644
> --- a/drivers/net/wireguard/device.c
> +++ b/drivers/net/wireguard/device.c
> @@ -243,7 +243,9 @@ static void wg_destruct(struct net_device *dev)
>  	skb_queue_purge(&wg->incoming_handshakes);
>  	free_percpu(dev->tstats);
>  	free_percpu(wg->incoming_handshakes_worker);
> +	wg_index_hashtable_destroy(wg->index_hashtable);
>  	kvfree(wg->index_hashtable);
> +	wg_pubkey_hashtable_destroy(wg->peer_hashtable);
>  	kvfree(wg->peer_hashtable);
>  	mutex_unlock(&wg->device_update_lock);
> 
> @@ -382,8 +384,10 @@ static int wg_newlink(struct net *src_net, 
> struct net_device *dev,
>  err_free_tstats:
>  	free_percpu(dev->tstats);
>  err_free_index_hashtable:
> +	wg_index_hashtable_destroy(wg->index_hashtable);
>  	kvfree(wg->index_hashtable);
>  err_free_peer_hashtable:
> +	wg_pubkey_hashtable_destroy(wg->peer_hashtable);
>  	kvfree(wg->peer_hashtable);
>  	return ret;
>  }
> diff --git a/drivers/net/wireguard/device.h 
> b/drivers/net/wireguard/device.h
> index 854bc3d97150..24980eb766af 100644
> --- a/drivers/net/wireguard/device.h
> +++ b/drivers/net/wireguard/device.h
> @@ -50,7 +50,7 @@ struct wg_device {
>  	struct multicore_worker __percpu *incoming_handshakes_worker;
>  	struct cookie_checker cookie_checker;
>  	struct pubkey_hashtable *peer_hashtable;
> -	struct index_hashtable *index_hashtable;
> +	struct rhashtable *index_hashtable;
>  	struct allowedips peer_allowedips;
>  	struct mutex device_update_lock, socket_update_lock;
>  	struct list_head device_list, peer_list;
> diff --git a/drivers/net/wireguard/noise.c 
> b/drivers/net/wireguard/noise.c
> index c0cfd9b36c0b..d42a0ff2be5d 100644
> --- a/drivers/net/wireguard/noise.c
> +++ b/drivers/net/wireguard/noise.c
> @@ -797,6 +797,7 @@ bool wg_noise_handshake_begin_session(struct 
> noise_handshake *handshake,
>  	new_keypair->i_am_the_initiator = handshake->state ==
>  					  HANDSHAKE_CONSUMED_RESPONSE;
>  	new_keypair->remote_index = handshake->remote_index;
> +	new_keypair->entry.index = handshake->entry.index;
> 
>  	if (new_keypair->i_am_the_initiator)
>  		derive_keys(&new_keypair->sending, &new_keypair->receiving,
> diff --git a/drivers/net/wireguard/noise.h 
> b/drivers/net/wireguard/noise.h
> index c527253dba80..ea705747e4e4 100644
> --- a/drivers/net/wireguard/noise.h
> +++ b/drivers/net/wireguard/noise.h
> @@ -72,6 +72,7 @@ struct noise_handshake {
> 
>  	u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
>  	u8 remote_static[NOISE_PUBLIC_KEY_LEN];
> +	siphash_key_t skey;
>  	u8 remote_ephemeral[NOISE_PUBLIC_KEY_LEN];
>  	u8 precomputed_static_static[NOISE_PUBLIC_KEY_LEN];
> 
> diff --git a/drivers/net/wireguard/peer.h 
> b/drivers/net/wireguard/peer.h
> index 76e4d3128ad4..d5403fb7a6a0 100644
> --- a/drivers/net/wireguard/peer.h
> +++ b/drivers/net/wireguard/peer.h
> @@ -48,7 +48,7 @@ struct wg_peer {
>  	atomic64_t last_sent_handshake;
>  	struct work_struct transmit_handshake_work, clear_peer_work, 
> transmit_packet_work;
>  	struct cookie latest_cookie;
> -	struct hlist_node pubkey_hash;
> +	struct rhash_head pubkey_hash;
>  	u64 rx_bytes, tx_bytes;
>  	struct timer_list timer_retransmit_handshake, timer_send_keepalive;
>  	struct timer_list timer_new_handshake, timer_zero_key_material;
> diff --git a/drivers/net/wireguard/peerlookup.c 
> b/drivers/net/wireguard/peerlookup.c
> index f2783aa7a88f..2ea2ba85a33d 100644
> --- a/drivers/net/wireguard/peerlookup.c
> +++ b/drivers/net/wireguard/peerlookup.c
> @@ -7,18 +7,29 @@
>  #include "peer.h"
>  #include "noise.h"
> 
> -static struct hlist_head *pubkey_bucket(struct pubkey_hashtable 
> *table,
> -					const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
> +struct pubkey_pair {
> +	u8 key[NOISE_PUBLIC_KEY_LEN];
> +	siphash_key_t skey;
> +};
> +
> +static u32 pubkey_hash(const void *data, u32 len, u32 seed)
>  {
> +	const struct pubkey_pair *pair = data;
> +
>  	/* siphash gives us a secure 64bit number based on a random key. 
> Since
> -	 * the bits are uniformly distributed, we can then mask off to get 
> the
> -	 * bits we need.
> +	 * the bits are uniformly distributed.
>  	 */
> -	const u64 hash = siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key);
> 
> -	return &table->hashtable[hash & (HASH_SIZE(table->hashtable) - 1)];
> +	return (u32)siphash(pair->key, len, &pair->skey);
>  }
> 
> +static const struct rhashtable_params wg_peer_params = {
> +	.key_len = NOISE_PUBLIC_KEY_LEN,
> +	.key_offset = offsetof(struct wg_peer, handshake.remote_static),
> +	.head_offset = offsetof(struct wg_peer, pubkey_hash),
> +	.hashfn = pubkey_hash
> +};
> +
>  struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
>  {
>  	struct pubkey_hashtable *table = kvmalloc(sizeof(*table), 
> GFP_KERNEL);
> @@ -27,26 +38,25 @@ struct pubkey_hashtable 
> *wg_pubkey_hashtable_alloc(void)
>  		return NULL;
> 
>  	get_random_bytes(&table->key, sizeof(table->key));
> -	hash_init(table->hashtable);
> -	mutex_init(&table->lock);
> +	rhashtable_init(&table->hashtable, &wg_peer_params);
> +
>  	return table;
>  }
> 
>  void wg_pubkey_hashtable_add(struct pubkey_hashtable *table,
>  			     struct wg_peer *peer)
>  {
> -	mutex_lock(&table->lock);
> -	hlist_add_head_rcu(&peer->pubkey_hash,
> -			   pubkey_bucket(table, peer->handshake.remote_static));
> -	mutex_unlock(&table->lock);
> +	memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
> +	WARN_ON(rhashtable_insert_fast(&table->hashtable, 
> &peer->pubkey_hash,
> +				       wg_peer_params));
>  }
> 
>  void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
>  				struct wg_peer *peer)
>  {
> -	mutex_lock(&table->lock);
> -	hlist_del_init_rcu(&peer->pubkey_hash);
> -	mutex_unlock(&table->lock);
> +	memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
> +	rhashtable_remove_fast(&table->hashtable, &peer->pubkey_hash,
> +			       wg_peer_params);
>  }
> 
>  /* Returns a strong reference to a peer */
> @@ -54,41 +64,54 @@ struct wg_peer *
>  wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
>  			   const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
>  {
> -	struct wg_peer *iter_peer, *peer = NULL;
> +	struct wg_peer *peer = NULL;
> +	struct pubkey_pair pair;
> +
> +	memcpy(pair.key, pubkey, NOISE_PUBLIC_KEY_LEN);
> +	memcpy(&pair.skey, &table->key, sizeof(pair.skey));
> 
>  	rcu_read_lock_bh();
> -	hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey),
> -				    pubkey_hash) {
> -		if (!memcmp(pubkey, iter_peer->handshake.remote_static,
> -			    NOISE_PUBLIC_KEY_LEN)) {
> -			peer = iter_peer;
> -			break;
> -		}
> -	}
> -	peer = wg_peer_get_maybe_zero(peer);
> +	peer = 
> wg_peer_get_maybe_zero(rhashtable_lookup_fast(&table->hashtable,
> +							     &pair,
> +							     wg_peer_params));
>  	rcu_read_unlock_bh();
> +
>  	return peer;
>  }
> 
> -static struct hlist_head *index_bucket(struct index_hashtable *table,
> -				       const __le32 index)
> +void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table)
> +{
> +	WARN_ON(atomic_read(&table->hashtable.nelems));
> +	rhashtable_destroy(&table->hashtable);
> +}
> +
> +static u32 index_hash(const void *data, u32 len, u32 seed)
>  {
> +	const __le32 *index = data;
> +
>  	/* Since the indices are random and thus all bits are uniformly
> -	 * distributed, we can find its bucket simply by masking.
> +	 * distributed, we can use them as the hash value.
>  	 */
> -	return &table->hashtable[(__force u32)index &
> -				 (HASH_SIZE(table->hashtable) - 1)];
> +
> +	return (__force u32)*index;
>  }
> 
> -struct index_hashtable *wg_index_hashtable_alloc(void)
> +static const struct rhashtable_params index_entry_params = {
> +	.key_len = sizeof(__le32),
> +	.key_offset = offsetof(struct index_hashtable_entry, index),
> +	.head_offset = offsetof(struct index_hashtable_entry, index_hash),
> +	.hashfn = index_hash
> +};
> +
> +struct rhashtable *wg_index_hashtable_alloc(void)
>  {
> -	struct index_hashtable *table = kvmalloc(sizeof(*table), 
> GFP_KERNEL);
> +	struct rhashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
> 
>  	if (!table)
>  		return NULL;
> 
> -	hash_init(table->hashtable);
> -	spin_lock_init(&table->lock);
> +	rhashtable_init(table, &index_entry_params);
> +
>  	return table;
>  }
> 
> @@ -116,111 +139,86 @@ struct index_hashtable 
> *wg_index_hashtable_alloc(void)
>   * is another thing to consider moving forward.
>   */
> 
> -__le32 wg_index_hashtable_insert(struct index_hashtable *table,
> +__le32 wg_index_hashtable_insert(struct rhashtable *table,
>  				 struct index_hashtable_entry *entry)
>  {
>  	struct index_hashtable_entry *existing_entry;
> 
> -	spin_lock_bh(&table->lock);
> -	hlist_del_init_rcu(&entry->index_hash);
> -	spin_unlock_bh(&table->lock);
> +	wg_index_hashtable_remove(table, entry);
> 
>  	rcu_read_lock_bh();
> 
>  search_unused_slot:
>  	/* First we try to find an unused slot, randomly, while unlocked. */
>  	entry->index = (__force __le32)get_random_u32();
> -	hlist_for_each_entry_rcu_bh(existing_entry,
> -				    index_bucket(table, entry->index),
> -				    index_hash) {
> -		if (existing_entry->index == entry->index)
> -			/* If it's already in use, we continue searching. */
> -			goto search_unused_slot;
> -	}
> 
> -	/* Once we've found an unused slot, we lock it, and then 
> double-check
> -	 * that nobody else stole it from us.
> -	 */
> -	spin_lock_bh(&table->lock);
> -	hlist_for_each_entry_rcu_bh(existing_entry,
> -				    index_bucket(table, entry->index),
> -				    index_hash) {
> -		if (existing_entry->index == entry->index) {
> -			spin_unlock_bh(&table->lock);
> -			/* If it was stolen, we start over. */
> -			goto search_unused_slot;
> -		}
> +	existing_entry = rhashtable_lookup_get_insert_fast(table,
> +							   &entry->index_hash,
> +							   index_entry_params);
> +
> +	if (existing_entry) {
> +		WARN_ON(IS_ERR(existing_entry));
> +
> +		/* If it's already in use, we continue searching. */
> +		goto search_unused_slot;
>  	}
> -	/* Otherwise, we know we have it exclusively (since we're locked),
> -	 * so we insert.
> -	 */
> -	hlist_add_head_rcu(&entry->index_hash,
> -			   index_bucket(table, entry->index));
> -	spin_unlock_bh(&table->lock);
> 
>  	rcu_read_unlock_bh();
> 
>  	return entry->index;
>  }
> 
> -bool wg_index_hashtable_replace(struct index_hashtable *table,
> +bool wg_index_hashtable_replace(struct rhashtable *table,
>  				struct index_hashtable_entry *old,
>  				struct index_hashtable_entry *new)
>  {
> -	bool ret;
> +	int ret = rhashtable_replace_fast(table, &old->index_hash,
> +					  &new->index_hash,
> +					  index_entry_params);
> 
> -	spin_lock_bh(&table->lock);
> -	ret = !hlist_unhashed(&old->index_hash);
> -	if (unlikely(!ret))
> -		goto out;
> +	WARN_ON(ret == -EINVAL);
> 
> -	new->index = old->index;
> -	hlist_replace_rcu(&old->index_hash, &new->index_hash);
> -
> -	/* Calling init here NULLs out index_hash, and in fact after this
> -	 * function returns, it's theoretically possible for this to get
> -	 * reinserted elsewhere. That means the RCU lookup below might 
> either
> -	 * terminate early or jump between buckets, in which case the packet
> -	 * simply gets dropped, which isn't terrible.
> -	 */
> -	INIT_HLIST_NODE(&old->index_hash);
> -out:
> -	spin_unlock_bh(&table->lock);
> -	return ret;
> +	return ret != -ENOENT;
>  }
> 
> -void wg_index_hashtable_remove(struct index_hashtable *table,
> +void wg_index_hashtable_remove(struct rhashtable *table,
>  			       struct index_hashtable_entry *entry)
>  {
> -	spin_lock_bh(&table->lock);
> -	hlist_del_init_rcu(&entry->index_hash);
> -	spin_unlock_bh(&table->lock);
> +	rhashtable_remove_fast(table, &entry->index_hash, 
> index_entry_params);
>  }
> 
>  /* Returns a strong reference to a entry->peer */
>  struct index_hashtable_entry *
> -wg_index_hashtable_lookup(struct index_hashtable *table,
> +wg_index_hashtable_lookup(struct rhashtable *table,
>  			  const enum index_hashtable_type type_mask,
>  			  const __le32 index, struct wg_peer **peer)
>  {
> -	struct index_hashtable_entry *iter_entry, *entry = NULL;
> +	struct index_hashtable_entry *entry = NULL;
> 
>  	rcu_read_lock_bh();
> -	hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index),
> -				    index_hash) {
> -		if (iter_entry->index == index) {
> -			if (likely(iter_entry->type & type_mask))
> -				entry = iter_entry;
> -			break;
> -		}
> -	}
> +	entry = rhashtable_lookup_fast(table, &index, index_entry_params);
> +
>  	if (likely(entry)) {
> +		if (unlikely(!(entry->type & type_mask))) {
> +			entry = NULL;
> +			goto out;
> +		}
> +
>  		entry->peer = wg_peer_get_maybe_zero(entry->peer);
>  		if (likely(entry->peer))
>  			*peer = entry->peer;
>  		else
>  			entry = NULL;
>  	}
> +
> +out:
>  	rcu_read_unlock_bh();
> +
>  	return entry;
>  }
> +
> +void wg_index_hashtable_destroy(struct rhashtable *table)
> +{
> +	WARN_ON(atomic_read(&table->nelems));
> +	rhashtable_destroy(table);
> +}
> diff --git a/drivers/net/wireguard/peerlookup.h 
> b/drivers/net/wireguard/peerlookup.h
> index ced811797680..a3cef26cb733 100644
> --- a/drivers/net/wireguard/peerlookup.h
> +++ b/drivers/net/wireguard/peerlookup.h
> @@ -8,17 +8,14 @@
> 
>  #include "messages.h"
> 
> -#include <linux/hashtable.h>
> -#include <linux/mutex.h>
> +#include <linux/rhashtable.h>
>  #include <linux/siphash.h>
> 
>  struct wg_peer;
> 
>  struct pubkey_hashtable {
> -	/* TODO: move to rhashtable */
> -	DECLARE_HASHTABLE(hashtable, 11);
> +	struct rhashtable hashtable;
>  	siphash_key_t key;
> -	struct mutex lock;
>  };
> 
>  struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void);
> @@ -29,12 +26,7 @@ void wg_pubkey_hashtable_remove(struct 
> pubkey_hashtable *table,
>  struct wg_peer *
>  wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
>  			   const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
> -
> -struct index_hashtable {
> -	/* TODO: move to rhashtable */
> -	DECLARE_HASHTABLE(hashtable, 13);
> -	spinlock_t lock;
> -};
> +void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table);
> 
>  enum index_hashtable_type {
>  	INDEX_HASHTABLE_HANDSHAKE = 1U << 0,
> @@ -43,22 +35,23 @@ enum index_hashtable_type {
> 
>  struct index_hashtable_entry {
>  	struct wg_peer *peer;
> -	struct hlist_node index_hash;
> +	struct rhash_head index_hash;
>  	enum index_hashtable_type type;
>  	__le32 index;
>  };
> 
> -struct index_hashtable *wg_index_hashtable_alloc(void);
> -__le32 wg_index_hashtable_insert(struct index_hashtable *table,
> +struct rhashtable *wg_index_hashtable_alloc(void);
> +__le32 wg_index_hashtable_insert(struct rhashtable *table,
>  				 struct index_hashtable_entry *entry);
> -bool wg_index_hashtable_replace(struct index_hashtable *table,
> +bool wg_index_hashtable_replace(struct rhashtable *table,
>  				struct index_hashtable_entry *old,
>  				struct index_hashtable_entry *new);
> -void wg_index_hashtable_remove(struct index_hashtable *table,
> +void wg_index_hashtable_remove(struct rhashtable *table,
>  			       struct index_hashtable_entry *entry);
>  struct index_hashtable_entry *
> -wg_index_hashtable_lookup(struct index_hashtable *table,
> +wg_index_hashtable_lookup(struct rhashtable *table,
>  			  const enum index_hashtable_type type_mask,
>  			  const __le32 index, struct wg_peer **peer);
> +void wg_index_hashtable_destroy(struct rhashtable *table);
> 
>  #endif /* _WG_PEERLOOKUP_H */
> --
> 2.32.0
> 




More information about the WireGuard mailing list