[PATCH] wireguard: convert index_hashtable and pubkey_hashtable into rhashtables
Hamza Mahfooz
someguy at effective-light.com
Wed Sep 8 10:27:24 UTC 2021
ping
On Fri, Aug 6 2021 at 12:43:14 AM -0400, Hamza Mahfooz
<someguy at effective-light.com> wrote:
> It is made mention of in commit e7096c131e516 ("net: WireGuard secure
> network tunnel"), that it is desirable to move away from the
> statically
> sized hash-table implementation.
>
> Signed-off-by: Hamza Mahfooz <someguy at effective-light.com>
> ---
> drivers/net/wireguard/device.c | 4 +
> drivers/net/wireguard/device.h | 2 +-
> drivers/net/wireguard/noise.c | 1 +
> drivers/net/wireguard/noise.h | 1 +
> drivers/net/wireguard/peer.h | 2 +-
> drivers/net/wireguard/peerlookup.c | 190
> ++++++++++++++---------------
> drivers/net/wireguard/peerlookup.h | 27 ++--
> 7 files changed, 112 insertions(+), 115 deletions(-)
>
> diff --git a/drivers/net/wireguard/device.c
> b/drivers/net/wireguard/device.c
> index 551ddaaaf540..3bd43c9481ef 100644
> --- a/drivers/net/wireguard/device.c
> +++ b/drivers/net/wireguard/device.c
> @@ -243,7 +243,9 @@ static void wg_destruct(struct net_device *dev)
> skb_queue_purge(&wg->incoming_handshakes);
> free_percpu(dev->tstats);
> free_percpu(wg->incoming_handshakes_worker);
> + wg_index_hashtable_destroy(wg->index_hashtable);
> kvfree(wg->index_hashtable);
> + wg_pubkey_hashtable_destroy(wg->peer_hashtable);
> kvfree(wg->peer_hashtable);
> mutex_unlock(&wg->device_update_lock);
>
> @@ -382,8 +384,10 @@ static int wg_newlink(struct net *src_net,
> struct net_device *dev,
> err_free_tstats:
> free_percpu(dev->tstats);
> err_free_index_hashtable:
> + wg_index_hashtable_destroy(wg->index_hashtable);
> kvfree(wg->index_hashtable);
> err_free_peer_hashtable:
> + wg_pubkey_hashtable_destroy(wg->peer_hashtable);
> kvfree(wg->peer_hashtable);
> return ret;
> }
> diff --git a/drivers/net/wireguard/device.h
> b/drivers/net/wireguard/device.h
> index 854bc3d97150..24980eb766af 100644
> --- a/drivers/net/wireguard/device.h
> +++ b/drivers/net/wireguard/device.h
> @@ -50,7 +50,7 @@ struct wg_device {
> struct multicore_worker __percpu *incoming_handshakes_worker;
> struct cookie_checker cookie_checker;
> struct pubkey_hashtable *peer_hashtable;
> - struct index_hashtable *index_hashtable;
> + struct rhashtable *index_hashtable;
> struct allowedips peer_allowedips;
> struct mutex device_update_lock, socket_update_lock;
> struct list_head device_list, peer_list;
> diff --git a/drivers/net/wireguard/noise.c
> b/drivers/net/wireguard/noise.c
> index c0cfd9b36c0b..d42a0ff2be5d 100644
> --- a/drivers/net/wireguard/noise.c
> +++ b/drivers/net/wireguard/noise.c
> @@ -797,6 +797,7 @@ bool wg_noise_handshake_begin_session(struct
> noise_handshake *handshake,
> new_keypair->i_am_the_initiator = handshake->state ==
> HANDSHAKE_CONSUMED_RESPONSE;
> new_keypair->remote_index = handshake->remote_index;
> + new_keypair->entry.index = handshake->entry.index;
>
> if (new_keypair->i_am_the_initiator)
> derive_keys(&new_keypair->sending, &new_keypair->receiving,
> diff --git a/drivers/net/wireguard/noise.h
> b/drivers/net/wireguard/noise.h
> index c527253dba80..ea705747e4e4 100644
> --- a/drivers/net/wireguard/noise.h
> +++ b/drivers/net/wireguard/noise.h
> @@ -72,6 +72,7 @@ struct noise_handshake {
>
> u8 ephemeral_private[NOISE_PUBLIC_KEY_LEN];
> u8 remote_static[NOISE_PUBLIC_KEY_LEN];
> + siphash_key_t skey;
> u8 remote_ephemeral[NOISE_PUBLIC_KEY_LEN];
> u8 precomputed_static_static[NOISE_PUBLIC_KEY_LEN];
>
> diff --git a/drivers/net/wireguard/peer.h
> b/drivers/net/wireguard/peer.h
> index 76e4d3128ad4..d5403fb7a6a0 100644
> --- a/drivers/net/wireguard/peer.h
> +++ b/drivers/net/wireguard/peer.h
> @@ -48,7 +48,7 @@ struct wg_peer {
> atomic64_t last_sent_handshake;
> struct work_struct transmit_handshake_work, clear_peer_work,
> transmit_packet_work;
> struct cookie latest_cookie;
> - struct hlist_node pubkey_hash;
> + struct rhash_head pubkey_hash;
> u64 rx_bytes, tx_bytes;
> struct timer_list timer_retransmit_handshake, timer_send_keepalive;
> struct timer_list timer_new_handshake, timer_zero_key_material;
> diff --git a/drivers/net/wireguard/peerlookup.c
> b/drivers/net/wireguard/peerlookup.c
> index f2783aa7a88f..2ea2ba85a33d 100644
> --- a/drivers/net/wireguard/peerlookup.c
> +++ b/drivers/net/wireguard/peerlookup.c
> @@ -7,18 +7,29 @@
> #include "peer.h"
> #include "noise.h"
>
> -static struct hlist_head *pubkey_bucket(struct pubkey_hashtable
> *table,
> - const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
> +struct pubkey_pair {
> + u8 key[NOISE_PUBLIC_KEY_LEN];
> + siphash_key_t skey;
> +};
> +
> +static u32 pubkey_hash(const void *data, u32 len, u32 seed)
> {
> + const struct pubkey_pair *pair = data;
> +
> /* siphash gives us a secure 64bit number based on a random key.
> Since
> - * the bits are uniformly distributed, we can then mask off to get
> the
> - * bits we need.
> + * the bits are uniformly distributed.
> */
> - const u64 hash = siphash(pubkey, NOISE_PUBLIC_KEY_LEN, &table->key);
>
> - return &table->hashtable[hash & (HASH_SIZE(table->hashtable) - 1)];
> + return (u32)siphash(pair->key, len, &pair->skey);
> }
>
> +static const struct rhashtable_params wg_peer_params = {
> + .key_len = NOISE_PUBLIC_KEY_LEN,
> + .key_offset = offsetof(struct wg_peer, handshake.remote_static),
> + .head_offset = offsetof(struct wg_peer, pubkey_hash),
> + .hashfn = pubkey_hash
> +};
> +
> struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void)
> {
> struct pubkey_hashtable *table = kvmalloc(sizeof(*table),
> GFP_KERNEL);
> @@ -27,26 +38,25 @@ struct pubkey_hashtable
> *wg_pubkey_hashtable_alloc(void)
> return NULL;
>
> get_random_bytes(&table->key, sizeof(table->key));
> - hash_init(table->hashtable);
> - mutex_init(&table->lock);
> + rhashtable_init(&table->hashtable, &wg_peer_params);
> +
> return table;
> }
>
> void wg_pubkey_hashtable_add(struct pubkey_hashtable *table,
> struct wg_peer *peer)
> {
> - mutex_lock(&table->lock);
> - hlist_add_head_rcu(&peer->pubkey_hash,
> - pubkey_bucket(table, peer->handshake.remote_static));
> - mutex_unlock(&table->lock);
> + memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
> + WARN_ON(rhashtable_insert_fast(&table->hashtable,
> &peer->pubkey_hash,
> + wg_peer_params));
> }
>
> void wg_pubkey_hashtable_remove(struct pubkey_hashtable *table,
> struct wg_peer *peer)
> {
> - mutex_lock(&table->lock);
> - hlist_del_init_rcu(&peer->pubkey_hash);
> - mutex_unlock(&table->lock);
> + memcpy(&peer->handshake.skey, &table->key, sizeof(table->key));
> + rhashtable_remove_fast(&table->hashtable, &peer->pubkey_hash,
> + wg_peer_params);
> }
>
> /* Returns a strong reference to a peer */
> @@ -54,41 +64,54 @@ struct wg_peer *
> wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
> const u8 pubkey[NOISE_PUBLIC_KEY_LEN])
> {
> - struct wg_peer *iter_peer, *peer = NULL;
> + struct wg_peer *peer = NULL;
> + struct pubkey_pair pair;
> +
> + memcpy(pair.key, pubkey, NOISE_PUBLIC_KEY_LEN);
> + memcpy(&pair.skey, &table->key, sizeof(pair.skey));
>
> rcu_read_lock_bh();
> - hlist_for_each_entry_rcu_bh(iter_peer, pubkey_bucket(table, pubkey),
> - pubkey_hash) {
> - if (!memcmp(pubkey, iter_peer->handshake.remote_static,
> - NOISE_PUBLIC_KEY_LEN)) {
> - peer = iter_peer;
> - break;
> - }
> - }
> - peer = wg_peer_get_maybe_zero(peer);
> + peer =
> wg_peer_get_maybe_zero(rhashtable_lookup_fast(&table->hashtable,
> + &pair,
> + wg_peer_params));
> rcu_read_unlock_bh();
> +
> return peer;
> }
>
> -static struct hlist_head *index_bucket(struct index_hashtable *table,
> - const __le32 index)
> +void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table)
> +{
> + WARN_ON(atomic_read(&table->hashtable.nelems));
> + rhashtable_destroy(&table->hashtable);
> +}
> +
> +static u32 index_hash(const void *data, u32 len, u32 seed)
> {
> + const __le32 *index = data;
> +
> /* Since the indices are random and thus all bits are uniformly
> - * distributed, we can find its bucket simply by masking.
> + * distributed, we can use them as the hash value.
> */
> - return &table->hashtable[(__force u32)index &
> - (HASH_SIZE(table->hashtable) - 1)];
> +
> + return (__force u32)*index;
> }
>
> -struct index_hashtable *wg_index_hashtable_alloc(void)
> +static const struct rhashtable_params index_entry_params = {
> + .key_len = sizeof(__le32),
> + .key_offset = offsetof(struct index_hashtable_entry, index),
> + .head_offset = offsetof(struct index_hashtable_entry, index_hash),
> + .hashfn = index_hash
> +};
> +
> +struct rhashtable *wg_index_hashtable_alloc(void)
> {
> - struct index_hashtable *table = kvmalloc(sizeof(*table),
> GFP_KERNEL);
> + struct rhashtable *table = kvmalloc(sizeof(*table), GFP_KERNEL);
>
> if (!table)
> return NULL;
>
> - hash_init(table->hashtable);
> - spin_lock_init(&table->lock);
> + rhashtable_init(table, &index_entry_params);
> +
> return table;
> }
>
> @@ -116,111 +139,86 @@ struct index_hashtable
> *wg_index_hashtable_alloc(void)
> * is another thing to consider moving forward.
> */
>
> -__le32 wg_index_hashtable_insert(struct index_hashtable *table,
> +__le32 wg_index_hashtable_insert(struct rhashtable *table,
> struct index_hashtable_entry *entry)
> {
> struct index_hashtable_entry *existing_entry;
>
> - spin_lock_bh(&table->lock);
> - hlist_del_init_rcu(&entry->index_hash);
> - spin_unlock_bh(&table->lock);
> + wg_index_hashtable_remove(table, entry);
>
> rcu_read_lock_bh();
>
> search_unused_slot:
> /* First we try to find an unused slot, randomly, while unlocked. */
> entry->index = (__force __le32)get_random_u32();
> - hlist_for_each_entry_rcu_bh(existing_entry,
> - index_bucket(table, entry->index),
> - index_hash) {
> - if (existing_entry->index == entry->index)
> - /* If it's already in use, we continue searching. */
> - goto search_unused_slot;
> - }
>
> - /* Once we've found an unused slot, we lock it, and then
> double-check
> - * that nobody else stole it from us.
> - */
> - spin_lock_bh(&table->lock);
> - hlist_for_each_entry_rcu_bh(existing_entry,
> - index_bucket(table, entry->index),
> - index_hash) {
> - if (existing_entry->index == entry->index) {
> - spin_unlock_bh(&table->lock);
> - /* If it was stolen, we start over. */
> - goto search_unused_slot;
> - }
> + existing_entry = rhashtable_lookup_get_insert_fast(table,
> + &entry->index_hash,
> + index_entry_params);
> +
> + if (existing_entry) {
> + WARN_ON(IS_ERR(existing_entry));
> +
> + /* If it's already in use, we continue searching. */
> + goto search_unused_slot;
> }
> - /* Otherwise, we know we have it exclusively (since we're locked),
> - * so we insert.
> - */
> - hlist_add_head_rcu(&entry->index_hash,
> - index_bucket(table, entry->index));
> - spin_unlock_bh(&table->lock);
>
> rcu_read_unlock_bh();
>
> return entry->index;
> }
>
> -bool wg_index_hashtable_replace(struct index_hashtable *table,
> +bool wg_index_hashtable_replace(struct rhashtable *table,
> struct index_hashtable_entry *old,
> struct index_hashtable_entry *new)
> {
> - bool ret;
> + int ret = rhashtable_replace_fast(table, &old->index_hash,
> + &new->index_hash,
> + index_entry_params);
>
> - spin_lock_bh(&table->lock);
> - ret = !hlist_unhashed(&old->index_hash);
> - if (unlikely(!ret))
> - goto out;
> + WARN_ON(ret == -EINVAL);
>
> - new->index = old->index;
> - hlist_replace_rcu(&old->index_hash, &new->index_hash);
> -
> - /* Calling init here NULLs out index_hash, and in fact after this
> - * function returns, it's theoretically possible for this to get
> - * reinserted elsewhere. That means the RCU lookup below might
> either
> - * terminate early or jump between buckets, in which case the packet
> - * simply gets dropped, which isn't terrible.
> - */
> - INIT_HLIST_NODE(&old->index_hash);
> -out:
> - spin_unlock_bh(&table->lock);
> - return ret;
> + return ret != -ENOENT;
> }
>
> -void wg_index_hashtable_remove(struct index_hashtable *table,
> +void wg_index_hashtable_remove(struct rhashtable *table,
> struct index_hashtable_entry *entry)
> {
> - spin_lock_bh(&table->lock);
> - hlist_del_init_rcu(&entry->index_hash);
> - spin_unlock_bh(&table->lock);
> + rhashtable_remove_fast(table, &entry->index_hash,
> index_entry_params);
> }
>
> /* Returns a strong reference to a entry->peer */
> struct index_hashtable_entry *
> -wg_index_hashtable_lookup(struct index_hashtable *table,
> +wg_index_hashtable_lookup(struct rhashtable *table,
> const enum index_hashtable_type type_mask,
> const __le32 index, struct wg_peer **peer)
> {
> - struct index_hashtable_entry *iter_entry, *entry = NULL;
> + struct index_hashtable_entry *entry = NULL;
>
> rcu_read_lock_bh();
> - hlist_for_each_entry_rcu_bh(iter_entry, index_bucket(table, index),
> - index_hash) {
> - if (iter_entry->index == index) {
> - if (likely(iter_entry->type & type_mask))
> - entry = iter_entry;
> - break;
> - }
> - }
> + entry = rhashtable_lookup_fast(table, &index, index_entry_params);
> +
> if (likely(entry)) {
> + if (unlikely(!(entry->type & type_mask))) {
> + entry = NULL;
> + goto out;
> + }
> +
> entry->peer = wg_peer_get_maybe_zero(entry->peer);
> if (likely(entry->peer))
> *peer = entry->peer;
> else
> entry = NULL;
> }
> +
> +out:
> rcu_read_unlock_bh();
> +
> return entry;
> }
> +
> +void wg_index_hashtable_destroy(struct rhashtable *table)
> +{
> + WARN_ON(atomic_read(&table->nelems));
> + rhashtable_destroy(table);
> +}
> diff --git a/drivers/net/wireguard/peerlookup.h
> b/drivers/net/wireguard/peerlookup.h
> index ced811797680..a3cef26cb733 100644
> --- a/drivers/net/wireguard/peerlookup.h
> +++ b/drivers/net/wireguard/peerlookup.h
> @@ -8,17 +8,14 @@
>
> #include "messages.h"
>
> -#include <linux/hashtable.h>
> -#include <linux/mutex.h>
> +#include <linux/rhashtable.h>
> #include <linux/siphash.h>
>
> struct wg_peer;
>
> struct pubkey_hashtable {
> - /* TODO: move to rhashtable */
> - DECLARE_HASHTABLE(hashtable, 11);
> + struct rhashtable hashtable;
> siphash_key_t key;
> - struct mutex lock;
> };
>
> struct pubkey_hashtable *wg_pubkey_hashtable_alloc(void);
> @@ -29,12 +26,7 @@ void wg_pubkey_hashtable_remove(struct
> pubkey_hashtable *table,
> struct wg_peer *
> wg_pubkey_hashtable_lookup(struct pubkey_hashtable *table,
> const u8 pubkey[NOISE_PUBLIC_KEY_LEN]);
> -
> -struct index_hashtable {
> - /* TODO: move to rhashtable */
> - DECLARE_HASHTABLE(hashtable, 13);
> - spinlock_t lock;
> -};
> +void wg_pubkey_hashtable_destroy(struct pubkey_hashtable *table);
>
> enum index_hashtable_type {
> INDEX_HASHTABLE_HANDSHAKE = 1U << 0,
> @@ -43,22 +35,23 @@ enum index_hashtable_type {
>
> struct index_hashtable_entry {
> struct wg_peer *peer;
> - struct hlist_node index_hash;
> + struct rhash_head index_hash;
> enum index_hashtable_type type;
> __le32 index;
> };
>
> -struct index_hashtable *wg_index_hashtable_alloc(void);
> -__le32 wg_index_hashtable_insert(struct index_hashtable *table,
> +struct rhashtable *wg_index_hashtable_alloc(void);
> +__le32 wg_index_hashtable_insert(struct rhashtable *table,
> struct index_hashtable_entry *entry);
> -bool wg_index_hashtable_replace(struct index_hashtable *table,
> +bool wg_index_hashtable_replace(struct rhashtable *table,
> struct index_hashtable_entry *old,
> struct index_hashtable_entry *new);
> -void wg_index_hashtable_remove(struct index_hashtable *table,
> +void wg_index_hashtable_remove(struct rhashtable *table,
> struct index_hashtable_entry *entry);
> struct index_hashtable_entry *
> -wg_index_hashtable_lookup(struct index_hashtable *table,
> +wg_index_hashtable_lookup(struct rhashtable *table,
> const enum index_hashtable_type type_mask,
> const __le32 index, struct wg_peer **peer);
> +void wg_index_hashtable_destroy(struct rhashtable *table);
>
> #endif /* _WG_PEERLOOKUP_H */
> --
> 2.32.0
>
More information about the WireGuard
mailing list