Skip to content
This repository was archived by the owner on Dec 27, 2019. It is now read-only.

Commit 2086e70

Browse files
committed
peer: ensure destruction doesn't race
Completely rework peer removal to ensure peers don't jump between contexts and create races.
1 parent c9ba3b7 commit 2086e70

File tree

9 files changed

+125
-84
lines changed

9 files changed

+125
-84
lines changed

src/cookie.c

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,15 +165,9 @@ void cookie_message_consume(struct message_handshake_cookie *src, struct wiregua
165165
{
166166
u8 cookie[COOKIE_LEN];
167167
struct wireguard_peer *peer = NULL;
168-
struct index_hashtable_entry *entry;
169168
bool ret;
170169

171-
rcu_read_lock_bh();
172-
entry = index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index);
173-
if (likely(entry))
174-
peer = entry->peer;
175-
rcu_read_unlock_bh();
176-
if (unlikely(!peer))
170+
if (unlikely(!index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE | INDEX_HASHTABLE_KEYPAIR, src->receiver_index, &peer)))
177171
return;
178172

179173
down_read(&peer->latest_cookie.lock);

src/hashtables.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ void index_hashtable_remove(struct index_hashtable *table, struct index_hashtabl
152152
}
153153

154154
/* Returns a strong reference to a entry->peer */
155-
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index)
155+
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer)
156156
{
157157
struct index_hashtable_entry *iter_entry, *entry = NULL;
158158

@@ -166,7 +166,9 @@ struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *tab
166166
}
167167
if (likely(entry)) {
168168
entry->peer = peer_get_maybe_zero(entry->peer);
169-
if (unlikely(!entry->peer))
169+
if (likely(entry->peer))
170+
*peer = entry->peer;
171+
else
170172
entry = NULL;
171173
}
172174
rcu_read_unlock_bh();

src/hashtables.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,6 @@ void index_hashtable_init(struct index_hashtable *table);
4747
__le32 index_hashtable_insert(struct index_hashtable *table, struct index_hashtable_entry *entry);
4848
bool index_hashtable_replace(struct index_hashtable *table, struct index_hashtable_entry *old, struct index_hashtable_entry *new);
4949
void index_hashtable_remove(struct index_hashtable *table, struct index_hashtable_entry *entry);
50-
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index);
50+
struct index_hashtable_entry *index_hashtable_lookup(struct index_hashtable *table, const enum index_hashtable_type type_mask, const __le32 index, struct wireguard_peer **peer);
5151

5252
#endif /* _WG_HASHTABLES_H */

src/noise.c

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -103,16 +103,13 @@ static struct noise_keypair *keypair_create(struct wireguard_peer *peer)
103103

104104
static void keypair_free_rcu(struct rcu_head *rcu)
105105
{
106-
struct noise_keypair *keypair = container_of(rcu, struct noise_keypair, rcu);
107-
108-
net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
109-
kzfree(keypair);
106+
kzfree(container_of(rcu, struct noise_keypair, rcu));
110107
}
111108

112109
static void keypair_free_kref(struct kref *kref)
113110
{
114111
struct noise_keypair *keypair = container_of(kref, struct noise_keypair, refcount);
115-
112+
net_dbg_ratelimited("%s: Keypair %llu destroyed for peer %llu\n", keypair->entry.peer->device->dev->name, keypair->internal_id, keypair->entry.peer->internal_id);
116113
index_hashtable_remove(&keypair->entry.peer->device->index_hashtable, &keypair->entry);
117114
call_rcu_bh(&keypair->rcu, keypair_free_rcu);
118115
}
@@ -542,7 +539,7 @@ bool noise_handshake_create_response(struct message_handshake_response *dst, str
542539
struct wireguard_peer *noise_handshake_consume_response(struct message_handshake_response *src, struct wireguard_device *wg)
543540
{
544541
struct noise_handshake *handshake;
545-
struct wireguard_peer *ret_peer = NULL;
542+
struct wireguard_peer *peer = NULL, *ret_peer = NULL;
546543
u8 key[NOISE_SYMMETRIC_KEY_LEN];
547544
u8 hash[NOISE_HASH_LEN];
548545
u8 chaining_key[NOISE_HASH_LEN];
@@ -556,7 +553,7 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
556553
if (unlikely(!wg->static_identity.has_identity))
557554
goto out;
558555

559-
handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index);
556+
handshake = (struct noise_handshake *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_HANDSHAKE, src->receiver_index, &peer);
560557
if (unlikely(!handshake))
561558
goto out;
562559

@@ -601,11 +598,11 @@ struct wireguard_peer *noise_handshake_consume_response(struct message_handshake
601598
handshake->remote_index = src->sender_index;
602599
handshake->state = HANDSHAKE_CONSUMED_RESPONSE;
603600
up_write(&handshake->lock);
604-
ret_peer = handshake->entry.peer;
601+
ret_peer = peer;
605602
goto out;
606603

607604
fail:
608-
peer_put(handshake->entry.peer);
605+
peer_put(peer);
609606
out:
610607
memzero_explicit(key, NOISE_SYMMETRIC_KEY_LEN);
611608
memzero_explicit(hash, NOISE_HASH_LEN);

src/peer.c

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_
4848
spin_lock_init(&peer->keypairs.keypair_update_lock);
4949
INIT_WORK(&peer->transmit_handshake_work, packet_handshake_send_worker);
5050
rwlock_init(&peer->endpoint_lock);
51+
atomic_set(&peer->dead_count, 1);
5152
kref_init(&peer->refcount);
5253
skb_queue_head_init(&peer->staged_packet_queue);
5354
peer->last_sent_handshake = ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC;
@@ -86,27 +87,47 @@ void peer_remove(struct wireguard_peer *peer)
8687
if (unlikely(!peer))
8788
return;
8889
lockdep_assert_held(&peer->device->device_update_lock);
90+
91+
/* Remove from configuration-time lookup structures so new packets can't enter. */
92+
list_del_init(&peer->peer_list);
8993
allowedips_remove_by_peer(&peer->device->peer_allowedips, peer, &peer->device->device_update_lock);
9094
pubkey_hashtable_remove(&peer->device->peer_hashtable, peer);
91-
skb_queue_purge(&peer->staged_packet_queue);
95+
96+
/* Mark as dead, so that we don't allow jumping contexts after. */
97+
while (atomic_cmpxchg(&peer->dead_count, 1, 0) != 1) cpu_relax();
98+
99+
/* The transition between packet encryption/decryption queues isn't guarded
100+
* by the dead_count, but each reference's life is strictly bounded by
101+
* two generations: once for parallel crypto and once for serial ingestion,
102+
* so we can simply flush twice, and be sure that we no longer have references
103+
* inside these queues.
104+
*/
105+
106+
/* The first flush is for encrypt/decrypt. */
107+
flush_workqueue(peer->device->packet_crypt_wq);
108+
/* The second.1 flush is for send (but not receive, since that's napi). */
109+
flush_workqueue(peer->device->packet_crypt_wq);
110+
/* The second.2 flush is for receive (but not send, since that's wq). */
111+
napi_disable(&peer->napi);
112+
/* It's now safe to remove the napi struct, which must be done here from process context. */
113+
netif_napi_del(&peer->napi);
114+
/* Ensure any workstructs we own (like transmit_handshake_work or clear_peer_work) no longer are in use. */
115+
flush_workqueue(peer->device->handshake_send_wq);
116+
117+
/* Remove keys and handshakes from memory. Handshake removal must be done here from process context. */
92118
noise_handshake_clear(&peer->handshake);
93119
noise_keypairs_clear(&peer->keypairs);
94-
list_del_init(&peer->peer_list);
120+
121+
/* Destroy all ongoing timers that were in-flight at the beginning of this function. */
95122
timers_stop(peer);
96-
flush_workqueue(peer->device->packet_crypt_wq); /* The first flush is for encrypt/decrypt. */
97-
flush_workqueue(peer->device->packet_crypt_wq); /* The second.1 flush is for send (but not receive, since that's napi). */
98-
napi_disable(&peer->napi); /* The second.2 flush is for receive (but not send, since that's wq). */
99-
flush_workqueue(peer->device->handshake_send_wq);
100-
netif_napi_del(&peer->napi);
123+
101124
--peer->device->num_peers;
102125
peer_put(peer);
103126
}
104127

105128
static void rcu_release(struct rcu_head *rcu)
106129
{
107130
struct wireguard_peer *peer = container_of(rcu, struct wireguard_peer, rcu);
108-
109-
pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
110131
dst_cache_destroy(&peer->endpoint_cache);
111132
packet_queue_free(&peer->rx_queue, false);
112133
packet_queue_free(&peer->tx_queue, false);
@@ -116,9 +137,12 @@ static void rcu_release(struct rcu_head *rcu)
116137
static void kref_release(struct kref *refcount)
117138
{
118139
struct wireguard_peer *peer = container_of(refcount, struct wireguard_peer, refcount);
119-
140+
pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr);
141+
/* Remove ourself from dynamic runtime lookup structures, now that the last reference is gone. */
120142
index_hashtable_remove(&peer->device->index_hashtable, &peer->handshake.entry);
143+
/* Remove any lingering packets that didn't have a chance to be transmitted. */
121144
skb_queue_purge(&peer->staged_packet_queue);
145+
/* Free the memory used. */
122146
call_rcu_bh(&peer->rcu, rcu_release);
123147
}
124148

src/peer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct wireguard_peer {
5454
bool timers_enabled, timer_need_another_keepalive, sent_lastminute_handshake;
5555
struct timespec walltime_last_handshake;
5656
struct kref refcount;
57+
atomic_t dead_count;
5758
struct rcu_head rcu;
5859
struct list_head peer_list;
5960
u64 internal_id;

src/receive.c

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
120120
net_dbg_skb_ratelimited("%s: Invalid handshake initiation from %pISpfsc\n", wg->dev->name, skb);
121121
return;
122122
}
123+
if (unlikely(!atomic_inc_not_zero(&peer->dead_count)))
124+
goto err_dead;
123125
socket_set_peer_endpoint_from_skb(peer, skb);
124126
net_dbg_ratelimited("%s: Receiving handshake initiation from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr);
125127
packet_send_handshake_response(peer);
@@ -137,6 +139,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
137139
net_dbg_skb_ratelimited("%s: Invalid handshake response from %pISpfsc\n", wg->dev->name, skb);
138140
return;
139141
}
142+
if (unlikely(!atomic_inc_not_zero(&peer->dead_count)))
143+
goto err_dead;
140144
socket_set_peer_endpoint_from_skb(peer, skb);
141145
net_dbg_ratelimited("%s: Receiving handshake response from peer %llu (%pISpfsc)\n", wg->dev->name, peer->internal_id, &peer->endpoint.addr);
142146
if (noise_handshake_begin_session(&peer->handshake, &peer->keypairs)) {
@@ -164,6 +168,8 @@ static void receive_handshake_packet(struct wireguard_device *wg, struct sk_buff
164168

165169
timers_any_authenticated_packet_received(peer);
166170
timers_any_authenticated_packet_traversal(peer);
171+
atomic_dec(&peer->dead_count);
172+
err_dead:
167173
peer_put(peer);
168174
}
169175

@@ -282,9 +288,9 @@ static inline bool counter_validate(union noise_counter *counter, u64 their_coun
282288
}
283289
#include "selftest/counter.h"
284290

285-
static void packet_consume_data_done(struct sk_buff *skb, struct endpoint *endpoint)
291+
static void packet_consume_data_done(struct wireguard_peer *peer, struct sk_buff *skb, struct endpoint *endpoint)
286292
{
287-
struct wireguard_peer *peer = PACKET_PEER(skb), *routed_peer;
293+
struct wireguard_peer *routed_peer;
288294
struct net_device *dev = peer->device->dev;
289295
unsigned int len, len_before_trim;
290296

@@ -400,7 +406,7 @@ int packet_rx_poll(struct napi_struct *napi, int budget)
400406
goto next;
401407

402408
skb_reset(skb);
403-
packet_consume_data_done(skb, &endpoint);
409+
packet_consume_data_done(peer, skb, &endpoint);
404410
free = false;
405411

406412
next:
@@ -436,32 +442,34 @@ void packet_decrypt_worker(struct work_struct *work)
436442

437443
static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb)
438444
{
439-
struct wireguard_peer *peer;
445+
struct wireguard_peer *peer = NULL;
440446
__le32 idx = ((struct message_data *)skb->data)->key_idx;
441447
int ret;
442448

443449
rcu_read_lock_bh();
444-
PACKET_CB(skb)->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx));
445-
rcu_read_unlock_bh();
446-
if (unlikely(!PACKET_CB(skb)->keypair)) {
447-
dev_kfree_skb(skb);
448-
return;
450+
PACKET_CB(skb)->keypair = (struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx, &peer);
451+
if (unlikely(!noise_keypair_get(PACKET_CB(skb)->keypair))) {
452+
rcu_read_unlock_bh();
453+
goto err_keypair;
449454
}
455+
rcu_read_unlock_bh();
450456

451-
/* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_get(). */
452-
peer = PACKET_PEER(skb);
457+
if (unlikely(list_empty(&peer->peer_list) || !atomic_inc_not_zero(&peer->dead_count)))
458+
goto err;
453459

460+
peer_get(peer);
454461
ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu);
455-
if (likely(!ret))
456-
return; /* Successful. No need to drop references below. */
457-
458-
if (ret == -EPIPE)
462+
if (unlikely(ret == -EPIPE))
459463
queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD);
460-
else {
461-
peer_put(peer);
462-
noise_keypair_put(PACKET_CB(skb)->keypair);
463-
dev_kfree_skb(skb);
464-
}
464+
atomic_dec(&peer->dead_count);
465+
peer_put(peer);
466+
if (likely(!ret || ret == -EPIPE))
467+
return;
468+
err:
469+
noise_keypair_put(PACKET_CB(skb)->keypair);
470+
err_keypair:
471+
peer_put(peer);
472+
dev_kfree_skb(skb);
465473
}
466474

467475
void packet_receive(struct wireguard_device *wg, struct sk_buff *skb)

src/send.c

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,14 @@ void packet_send_queued_handshake_initiation(struct wireguard_peer *peer, bool i
5858
/* First checking the timestamp here is just an optimization; it will
5959
* be caught while properly locked inside the actual work queue.
6060
*/
61-
if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT))
61+
if (!has_expired(peer->last_sent_handshake, REKEY_TIMEOUT) || unlikely(!atomic_inc_not_zero(&peer->dead_count)))
6262
return;
6363

6464
peer_get(peer);
6565
/* Queues up calling packet_send_queued_handshakes(peer), where we do a peer_put(peer) after: */
6666
if (!queue_work(peer->device->handshake_send_wq, &peer->transmit_handshake_work))
6767
peer_put(peer); /* If the work was already queued, we want to drop the extra reference */
68+
atomic_dec(&peer->dead_count);
6869
}
6970

7071
void packet_send_handshake_response(struct wireguard_peer *peer)
@@ -268,17 +269,21 @@ static void packet_create_data(struct sk_buff *first)
268269
struct wireguard_device *wg = peer->device;
269270
int ret;
270271

271-
ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
272-
if (likely(!ret))
273-
return; /* Successful. No need to fall through to drop references below. */
272+
if (unlikely(!atomic_inc_not_zero(&peer->dead_count)))
273+
goto err;
274274

275-
if (ret == -EPIPE)
275+
peer_get(peer);
276+
ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu);
277+
if (unlikely(ret == -EPIPE))
276278
queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD);
277-
else {
278-
peer_put(peer);
279-
noise_keypair_put(PACKET_CB(first)->keypair);
280-
skb_free_null_queue(first);
281-
}
279+
atomic_dec(&peer->dead_count);
280+
peer_put(peer);
281+
if (likely(!ret || ret == -EPIPE))
282+
return;
283+
err:
284+
noise_keypair_put(PACKET_CB(first)->keypair);
285+
peer_put(peer);
286+
skb_free_null_queue(first);
282287
}
283288

284289
void packet_send_staged_packets(struct wireguard_peer *peer)

0 commit comments

Comments
 (0)