diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 65a937ed576285fbb0e84020be0b8271cfb54c9d..ac09b35a9567f3ecc07ad661455697b8ad924e45 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -72,6 +72,7 @@ struct bpf_htab {
 	u32 n_buckets;
 	u32 elem_size;
 	struct bpf_sock_progs progs;
+	struct rcu_head rcu;
 };
 
 struct htab_elem {
@@ -89,8 +90,8 @@ enum smap_psock_state {
 struct smap_psock_map_entry {
 	struct list_head list;
 	struct sock **entry;
-	struct htab_elem *hash_link;
-	struct bpf_htab *htab;
+	struct htab_elem __rcu *hash_link;
+	struct bpf_htab __rcu *htab;
 };
 
 struct smap_psock {
@@ -120,6 +121,7 @@ struct smap_psock {
 	struct bpf_prog *bpf_parse;
 	struct bpf_prog *bpf_verdict;
 	struct list_head maps;
+	spinlock_t maps_lock;
 
 	/* Back reference used when sock callback trigger sockmap operations */
 	struct sock *sock;
@@ -258,16 +260,54 @@ static void bpf_tcp_release(struct sock *sk)
 	rcu_read_unlock();
 }
 
+static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
+					 u32 hash, void *key, u32 key_size)
+{
+	struct htab_elem *l;
+
+	hlist_for_each_entry_rcu(l, head, hash_node) {
+		if (l->hash == hash && !memcmp(&l->key, key, key_size))
+			return l;
+	}
+
+	return NULL;
+}
+
+static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
+{
+	return &htab->buckets[hash & (htab->n_buckets - 1)];
+}
+
+static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
+{
+	return &__select_bucket(htab, hash)->head;
+}
+
 static void free_htab_elem(struct bpf_htab *htab, struct htab_elem *l)
 {
 	atomic_dec(&htab->count);
 	kfree_rcu(l, rcu);
 }
 
+static struct smap_psock_map_entry *psock_map_pop(struct sock *sk,
+						  struct smap_psock *psock)
+{
+	struct smap_psock_map_entry *e;
+
+	spin_lock_bh(&psock->maps_lock);
+	e = list_first_entry_or_null(&psock->maps,
+				     struct smap_psock_map_entry,
+				     list);
+	if (e)
+		list_del(&e->list);
+	spin_unlock_bh(&psock->maps_lock);
+	return e;
+}
+
 static void bpf_tcp_close(struct sock *sk, long timeout)
 {
 	void (*close_fun)(struct sock *sk, long timeout);
-	struct smap_psock_map_entry *e, *tmp;
+	struct smap_psock_map_entry *e;
 	struct sk_msg_buff *md, *mtmp;
 	struct smap_psock *psock;
 	struct sock *osk;
@@ -286,7 +326,6 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 	 */
 	close_fun = psock->save_close;
 
-	write_lock_bh(&sk->sk_callback_lock);
 	if (psock->cork) {
 		free_start_sg(psock->sock, psock->cork);
 		kfree(psock->cork);
@@ -299,20 +338,38 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 		kfree(md);
 	}
 
-	list_for_each_entry_safe(e, tmp, &psock->maps, list) {
+	e = psock_map_pop(sk, psock);
+	while (e) {
 		if (e->entry) {
 			osk = cmpxchg(e->entry, sk, NULL);
 			if (osk == sk) {
-				list_del(&e->list);
 				smap_release_sock(psock, sk);
 			}
 		} else {
-			hlist_del_rcu(&e->hash_link->hash_node);
-			smap_release_sock(psock, e->hash_link->sk);
-			free_htab_elem(e->htab, e->hash_link);
+			struct htab_elem *link = rcu_dereference(e->hash_link);
+			struct bpf_htab *htab = rcu_dereference(e->htab);
+			struct hlist_head *head;
+			struct htab_elem *l;
+			struct bucket *b;
+
+			b = __select_bucket(htab, link->hash);
+			head = &b->head;
+			raw_spin_lock_bh(&b->lock);
+			l = lookup_elem_raw(head,
+					    link->hash, link->key,
+					    htab->map.key_size);
+			/* If another thread deleted this object skip deletion.
+			 * The refcnt on psock may or may not be zero.
+			 */
+			if (l) {
+				hlist_del_rcu(&link->hash_node);
+				smap_release_sock(psock, link->sk);
+				free_htab_elem(htab, link);
+			}
+			raw_spin_unlock_bh(&b->lock);
 		}
+		e = psock_map_pop(sk, psock);
 	}
-	write_unlock_bh(&sk->sk_callback_lock);
 	rcu_read_unlock();
 	close_fun(sk, timeout);
 }
@@ -1395,7 +1452,9 @@ static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
 {
 	if (refcount_dec_and_test(&psock->refcnt)) {
 		tcp_cleanup_ulp(sock);
+		write_lock_bh(&sock->sk_callback_lock);
 		smap_stop_sock(psock, sock);
+		write_unlock_bh(&sock->sk_callback_lock);
 		clear_bit(SMAP_TX_RUNNING, &psock->state);
 		rcu_assign_sk_user_data(sock, NULL);
 		call_rcu_sched(&psock->rcu, smap_destroy_psock);
@@ -1546,6 +1605,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock, int node)
 	INIT_LIST_HEAD(&psock->maps);
 	INIT_LIST_HEAD(&psock->ingress);
 	refcount_set(&psock->refcnt, 1);
+	spin_lock_init(&psock->maps_lock);
 
 	rcu_assign_sk_user_data(sock, psock);
 	sock_hold(sock);
@@ -1607,10 +1667,12 @@ static void smap_list_map_remove(struct smap_psock *psock,
 {
 	struct smap_psock_map_entry *e, *tmp;
 
+	spin_lock_bh(&psock->maps_lock);
 	list_for_each_entry_safe(e, tmp, &psock->maps, list) {
 		if (e->entry == entry)
 			list_del(&e->list);
 	}
+	spin_unlock_bh(&psock->maps_lock);
 }
 
 static void smap_list_hash_remove(struct smap_psock *psock,
@@ -1618,12 +1680,14 @@ static void smap_list_hash_remove(struct smap_psock *psock,
 {
 	struct smap_psock_map_entry *e, *tmp;
 
+	spin_lock_bh(&psock->maps_lock);
 	list_for_each_entry_safe(e, tmp, &psock->maps, list) {
-		struct htab_elem *c = e->hash_link;
+		struct htab_elem *c = rcu_dereference(e->hash_link);
 
 		if (c == hash_link)
 			list_del(&e->list);
 	}
+	spin_unlock_bh(&psock->maps_lock);
 }
 
 static void sock_map_free(struct bpf_map *map)
@@ -1649,7 +1713,6 @@ static void sock_map_free(struct bpf_map *map)
 		if (!sock)
 			continue;
 
-		write_lock_bh(&sock->sk_callback_lock);
 		psock = smap_psock_sk(sock);
 		/* This check handles a racing sock event that can get the
 		 * sk_callback_lock before this case but after xchg happens
@@ -1660,7 +1723,6 @@ static void sock_map_free(struct bpf_map *map)
 			smap_list_map_remove(psock, &stab->sock_map[i]);
 			smap_release_sock(psock, sock);
 		}
-		write_unlock_bh(&sock->sk_callback_lock);
 	}
 	rcu_read_unlock();
 
@@ -1709,7 +1771,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
 	if (!sock)
 		return -EINVAL;
 
-	write_lock_bh(&sock->sk_callback_lock);
 	psock = smap_psock_sk(sock);
 	if (!psock)
 		goto out;
@@ -1719,7 +1780,6 @@ static int sock_map_delete_elem(struct bpf_map *map, void *key)
 	smap_list_map_remove(psock, &stab->sock_map[k]);
 	smap_release_sock(psock, sock);
 out:
-	write_unlock_bh(&sock->sk_callback_lock);
 	return 0;
 }
 
@@ -1800,7 +1860,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
 		}
 	}
 
-	write_lock_bh(&sock->sk_callback_lock);
 	psock = smap_psock_sk(sock);
 
 	/* 2. Do not allow inheriting programs if psock exists and has
@@ -1857,7 +1916,9 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
 		if (err)
 			goto out_free;
 		smap_init_progs(psock, verdict, parse);
+		write_lock_bh(&sock->sk_callback_lock);
 		smap_start_sock(psock, sock);
+		write_unlock_bh(&sock->sk_callback_lock);
 	}
 
 	/* 4. Place psock in sockmap for use and stop any programs on
@@ -1867,9 +1928,10 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
 	 */
 	if (map_link) {
 		e->entry = map_link;
+		spin_lock_bh(&psock->maps_lock);
 		list_add_tail(&e->list, &psock->maps);
+		spin_unlock_bh(&psock->maps_lock);
 	}
-	write_unlock_bh(&sock->sk_callback_lock);
 	return err;
 out_free:
 	smap_release_sock(psock, sock);
@@ -1880,7 +1942,6 @@ static int __sock_map_ctx_update_elem(struct bpf_map *map,
 	}
 	if (tx_msg)
 		bpf_prog_put(tx_msg);
-	write_unlock_bh(&sock->sk_callback_lock);
 	kfree(e);
 	return err;
 }
@@ -1917,10 +1978,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 	if (osock) {
 		struct smap_psock *opsock = smap_psock_sk(osock);
 
-		write_lock_bh(&osock->sk_callback_lock);
 		smap_list_map_remove(opsock, &stab->sock_map[i]);
 		smap_release_sock(opsock, osock);
-		write_unlock_bh(&osock->sk_callback_lock);
 	}
 out:
 	return err;
@@ -2109,14 +2168,13 @@ static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
 	return ERR_PTR(err);
 }
 
-static inline struct bucket *__select_bucket(struct bpf_htab *htab, u32 hash)
+static void __bpf_htab_free(struct rcu_head *rcu)
 {
-	return &htab->buckets[hash & (htab->n_buckets - 1)];
-}
+	struct bpf_htab *htab;
 
-static inline struct hlist_head *select_bucket(struct bpf_htab *htab, u32 hash)
-{
-	return &__select_bucket(htab, hash)->head;
+	htab = container_of(rcu, struct bpf_htab, rcu);
+	bpf_map_area_free(htab->buckets);
+	kfree(htab);
 }
 
 static void sock_hash_free(struct bpf_map *map)
@@ -2135,16 +2193,18 @@ static void sock_hash_free(struct bpf_map *map)
 	 */
 	rcu_read_lock();
 	for (i = 0; i < htab->n_buckets; i++) {
-		struct hlist_head *head = select_bucket(htab, i);
+		struct bucket *b = __select_bucket(htab, i);
+		struct hlist_head *head;
 		struct hlist_node *n;
 		struct htab_elem *l;
 
+		raw_spin_lock_bh(&b->lock);
+		head = &b->head;
 		hlist_for_each_entry_safe(l, n, head, hash_node) {
 			struct sock *sock = l->sk;
 			struct smap_psock *psock;
 
 			hlist_del_rcu(&l->hash_node);
-			write_lock_bh(&sock->sk_callback_lock);
 			psock = smap_psock_sk(sock);
 			/* This check handles a racing sock event that can get
 			 * the sk_callback_lock before this case but after xchg
@@ -2155,13 +2215,12 @@ static void sock_hash_free(struct bpf_map *map)
 				smap_list_hash_remove(psock, l);
 				smap_release_sock(psock, sock);
 			}
-			write_unlock_bh(&sock->sk_callback_lock);
-			kfree(l);
+			free_htab_elem(htab, l);
 		}
+		raw_spin_unlock_bh(&b->lock);
 	}
 	rcu_read_unlock();
-	bpf_map_area_free(htab->buckets);
-	kfree(htab);
+	call_rcu(&htab->rcu, __bpf_htab_free);
 }
 
 static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
@@ -2188,19 +2247,6 @@ static struct htab_elem *alloc_sock_hash_elem(struct bpf_htab *htab,
 	return l_new;
 }
 
-static struct htab_elem *lookup_elem_raw(struct hlist_head *head,
-					 u32 hash, void *key, u32 key_size)
-{
-	struct htab_elem *l;
-
-	hlist_for_each_entry_rcu(l, head, hash_node) {
-		if (l->hash == hash && !memcmp(&l->key, key, key_size))
-			return l;
-	}
-
-	return NULL;
-}
-
 static inline u32 htab_map_hash(const void *key, u32 key_len)
 {
 	return jhash(key, key_len, 0);
@@ -2320,9 +2366,12 @@ static int sock_hash_ctx_update_elem(struct bpf_sock_ops_kern *skops,
 		goto bucket_err;
 	}
 
-	e->hash_link = l_new;
-	e->htab = container_of(map, struct bpf_htab, map);
+	rcu_assign_pointer(e->hash_link, l_new);
+	rcu_assign_pointer(e->htab,
+			   container_of(map, struct bpf_htab, map));
+	spin_lock_bh(&psock->maps_lock);
 	list_add_tail(&e->list, &psock->maps);
+	spin_unlock_bh(&psock->maps_lock);
 
 	/* add new element to the head of the list, so that
 	 * concurrent search will find it before old elem
@@ -2392,7 +2441,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key)
 		struct smap_psock *psock;
 
 		hlist_del_rcu(&l->hash_node);
-		write_lock_bh(&sock->sk_callback_lock);
 		psock = smap_psock_sk(sock);
 		/* This check handles a racing sock event that can get the
 		 * sk_callback_lock before this case but after xchg happens
@@ -2403,7 +2451,6 @@ static int sock_hash_delete_elem(struct bpf_map *map, void *key)
 			smap_list_hash_remove(psock, l);
 			smap_release_sock(psock, sock);
 		}
-		write_unlock_bh(&sock->sk_callback_lock);
 		free_htab_elem(htab, l);
 		ret = 0;
 	}