diff --git a/net/core/rtnetlink.c b/net/core/rtnetlink.c
index dabba2a91fc8ff5852681dcee53149d99798ed79..ff292d3f2c417b885fd7d9ddc7301f44ca37a20a 100644
--- a/net/core/rtnetlink.c
+++ b/net/core/rtnetlink.c
@@ -63,6 +63,7 @@ struct rtnl_link {
 	rtnl_doit_func		doit;
 	rtnl_dumpit_func	dumpit;
 	unsigned int		flags;
+	struct rcu_head		rcu;
 };
 
 static DEFINE_MUTEX(rtnl_mutex);
@@ -127,7 +128,7 @@ bool lockdep_rtnl_is_held(void)
 EXPORT_SYMBOL(lockdep_rtnl_is_held);
 #endif /* #ifdef CONFIG_PROVE_LOCKING */
 
-static struct rtnl_link __rcu *rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
+static struct rtnl_link __rcu **rtnl_msg_handlers[RTNL_FAMILY_MAX + 1];
 static refcount_t rtnl_msg_handlers_ref[RTNL_FAMILY_MAX + 1];
 
 static inline int rtm_msgindex(int msgtype)
@@ -144,6 +145,20 @@ static inline int rtm_msgindex(int msgtype)
 	return msgindex;
 }
 
+static struct rtnl_link *rtnl_get_link(int protocol, int msgtype)
+{
+	struct rtnl_link **tab;
+
+	if (protocol >= ARRAY_SIZE(rtnl_msg_handlers))
+		protocol = PF_UNSPEC;
+
+	tab = rcu_dereference_rtnl(rtnl_msg_handlers[protocol]);
+	if (!tab)
+		tab = rcu_dereference_rtnl(rtnl_msg_handlers[PF_UNSPEC]);
+
+	return tab[msgtype];
+}
+
 /**
  * __rtnl_register - Register a rtnetlink message type
  * @protocol: Protocol family or PF_UNSPEC
@@ -166,28 +181,52 @@ int __rtnl_register(int protocol, int msgtype,
 		    rtnl_doit_func doit, rtnl_dumpit_func dumpit,
 		    unsigned int flags)
 {
-	struct rtnl_link *tab;
+	struct rtnl_link **tab, *link, *old;
 	int msgindex;
+	int ret = -ENOBUFS;
 
 	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
 	msgindex = rtm_msgindex(msgtype);
 
-	tab = rcu_dereference_raw(rtnl_msg_handlers[protocol]);
+	rtnl_lock();
+	tab = rtnl_msg_handlers[protocol];
 	if (tab == NULL) {
-		tab = kcalloc(RTM_NR_MSGTYPES, sizeof(*tab), GFP_KERNEL);
-		if (tab == NULL)
-			return -ENOBUFS;
+		tab = kcalloc(RTM_NR_MSGTYPES, sizeof(void *), GFP_KERNEL);
+		if (!tab)
+			goto unlock;
 
+		/* ensures we see the 0 stores */
 		rcu_assign_pointer(rtnl_msg_handlers[protocol], tab);
 	}
 
+	old = rtnl_dereference(tab[msgindex]);
+	if (old) {
+		link = kmemdup(old, sizeof(*old), GFP_KERNEL);
+		if (!link)
+			goto unlock;
+	} else {
+		link = kzalloc(sizeof(*link), GFP_KERNEL);
+		if (!link)
+			goto unlock;
+	}
+
+	WARN_ON(doit && link->doit && link->doit != doit);
 	if (doit)
-		tab[msgindex].doit = doit;
+		link->doit = doit;
+	WARN_ON(dumpit && link->dumpit && link->dumpit != dumpit);
 	if (dumpit)
-		tab[msgindex].dumpit = dumpit;
-	tab[msgindex].flags |= flags;
+		link->dumpit = dumpit;
 
-	return 0;
+	link->flags |= flags;
+
+	/* publish protocol:msgtype */
+	rcu_assign_pointer(tab[msgindex], link);
+	ret = 0;
+	if (old)
+		kfree_rcu(old, rcu);
+unlock:
+	rtnl_unlock();
+	return ret;
 }
 EXPORT_SYMBOL_GPL(__rtnl_register);
 
@@ -220,24 +259,25 @@ EXPORT_SYMBOL_GPL(rtnl_register);
  */
 int rtnl_unregister(int protocol, int msgtype)
 {
-	struct rtnl_link *handlers;
+	struct rtnl_link **tab, *link;
 	int msgindex;
 
 	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
 	msgindex = rtm_msgindex(msgtype);
 
 	rtnl_lock();
-	handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
-	if (!handlers) {
+	tab = rtnl_dereference(rtnl_msg_handlers[protocol]);
+	if (!tab) {
 		rtnl_unlock();
 		return -ENOENT;
 	}
 
-	handlers[msgindex].doit = NULL;
-	handlers[msgindex].dumpit = NULL;
-	handlers[msgindex].flags = 0;
+	link = tab[msgindex];
+	rcu_assign_pointer(tab[msgindex], NULL);
 	rtnl_unlock();
 
+	kfree_rcu(link, rcu);
+
 	return 0;
 }
 EXPORT_SYMBOL_GPL(rtnl_unregister);
@@ -251,20 +291,29 @@ EXPORT_SYMBOL_GPL(rtnl_unregister);
  */
 void rtnl_unregister_all(int protocol)
 {
-	struct rtnl_link *handlers;
+	struct rtnl_link **tab, *link;
+	int msgindex;
 
 	BUG_ON(protocol < 0 || protocol > RTNL_FAMILY_MAX);
 
 	rtnl_lock();
-	handlers = rtnl_dereference(rtnl_msg_handlers[protocol]);
+	tab = rtnl_msg_handlers[protocol];
 	RCU_INIT_POINTER(rtnl_msg_handlers[protocol], NULL);
+	for (msgindex = 0; msgindex < RTM_NR_MSGTYPES; msgindex++) {
+		link = tab[msgindex];
+		if (!link)
+			continue;
+
+		rcu_assign_pointer(tab[msgindex], NULL);
+		kfree_rcu(link, rcu);
+	}
 	rtnl_unlock();
 
 	synchronize_net();
 
 	while (refcount_read(&rtnl_msg_handlers_ref[protocol]) > 1)
 		schedule();
-	kfree(handlers);
+	kfree(tab);
 }
 EXPORT_SYMBOL_GPL(rtnl_unregister_all);
 
@@ -2973,18 +3022,26 @@ static int rtnl_dump_all(struct sk_buff *skb, struct netlink_callback *cb)
 		s_idx = 1;
 
 	for (idx = 1; idx <= RTNL_FAMILY_MAX; idx++) {
+		struct rtnl_link **tab;
 		int type = cb->nlh->nlmsg_type-RTM_BASE;
-		struct rtnl_link *handlers;
+		struct rtnl_link *link;
 		rtnl_dumpit_func dumpit;
 
 		if (idx < s_idx || idx == PF_PACKET)
 			continue;
 
-		handlers = rtnl_dereference(rtnl_msg_handlers[idx]);
-		if (!handlers)
+		if (type < 0 || type >= RTM_NR_MSGTYPES)
 			continue;
 
-		dumpit = READ_ONCE(handlers[type].dumpit);
+		tab = rcu_dereference_rtnl(rtnl_msg_handlers[idx]);
+		if (!tab)
+			continue;
+
+		link = tab[type];
+		if (!link)
+			continue;
+
+		dumpit = link->dumpit;
 		if (!dumpit)
 			continue;
 
@@ -4314,7 +4371,7 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 			     struct netlink_ext_ack *extack)
 {
 	struct net *net = sock_net(skb->sk);
-	struct rtnl_link *handlers;
+	struct rtnl_link *link;
 	int err = -EOPNOTSUPP;
 	rtnl_doit_func doit;
 	unsigned int flags;
@@ -4338,32 +4395,20 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 	if (kind != 2 && !netlink_net_capable(skb, CAP_NET_ADMIN))
 		return -EPERM;
 
-	if (family >= ARRAY_SIZE(rtnl_msg_handlers))
-		family = PF_UNSPEC;
-
 	rcu_read_lock();
-	handlers = rcu_dereference(rtnl_msg_handlers[family]);
-	if (!handlers) {
-		family = PF_UNSPEC;
-		handlers = rcu_dereference(rtnl_msg_handlers[family]);
-	}
-
 	if (kind == 2 && nlh->nlmsg_flags&NLM_F_DUMP) {
 		struct sock *rtnl;
 		rtnl_dumpit_func dumpit;
 		u16 min_dump_alloc = 0;
 
-		dumpit = READ_ONCE(handlers[type].dumpit);
-		if (!dumpit) {
+		link = rtnl_get_link(family, type);
+		if (!link || !link->dumpit) {
 			family = PF_UNSPEC;
-			handlers = rcu_dereference(rtnl_msg_handlers[PF_UNSPEC]);
-			if (!handlers)
-				goto err_unlock;
-
-			dumpit = READ_ONCE(handlers[type].dumpit);
-			if (!dumpit)
+			link = rtnl_get_link(family, type);
+			if (!link || !link->dumpit)
 				goto err_unlock;
 		}
+		dumpit = link->dumpit;
 
 		refcount_inc(&rtnl_msg_handlers_ref[family]);
 
@@ -4384,33 +4429,36 @@ static int rtnetlink_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh,
 		return err;
 	}
 
-	doit = READ_ONCE(handlers[type].doit);
-	if (!doit) {
+	link = rtnl_get_link(family, type);
+	if (!link || !link->doit) {
 		family = PF_UNSPEC;
-		handlers = rcu_dereference(rtnl_msg_handlers[family]);
+		link = rtnl_get_link(PF_UNSPEC, type);
+		if (!link || !link->doit)
+			goto out_unlock;
 	}
 
-	flags = READ_ONCE(handlers[type].flags);
+	flags = link->flags;
 	if (flags & RTNL_FLAG_DOIT_UNLOCKED) {
 		refcount_inc(&rtnl_msg_handlers_ref[family]);
-		doit = READ_ONCE(handlers[type].doit);
+		doit = link->doit;
 		rcu_read_unlock();
 		if (doit)
 			err = doit(skb, nlh, extack);
 		refcount_dec(&rtnl_msg_handlers_ref[family]);
 		return err;
 	}
-
 	rcu_read_unlock();
 
 	rtnl_lock();
-	handlers = rtnl_dereference(rtnl_msg_handlers[family]);
-	if (handlers) {
-		doit = READ_ONCE(handlers[type].doit);
-		if (doit)
-			err = doit(skb, nlh, extack);
-	}
+	link = rtnl_get_link(family, type);
+	if (link && link->doit)
+		err = link->doit(skb, nlh, extack);
 	rtnl_unlock();
+
+	return err;
+
+out_unlock:
+	rcu_read_unlock();
 	return err;
 
 err_unlock: