diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index f84aad420d446454cf86619d494153e0bbc8f1ac..775d707ec708a7ad6879f81f578ca3a1bd282879 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -2176,9 +2176,14 @@ netlink_ack_tlv_len(struct netlink_sock *nlk, int err,
 	return tlvlen;
 }
 
+static bool nlmsg_check_in_payload(const struct nlmsghdr *nlh, const void *addr)
+{
+	return !WARN_ON(addr < nlmsg_data(nlh) ||
+			addr - (const void *) nlh >= nlh->nlmsg_len);
+}
+
 static void
-netlink_ack_tlv_fill(struct sk_buff *in_skb, struct sk_buff *skb,
-		     const struct nlmsghdr *nlh, int err,
+netlink_ack_tlv_fill(struct sk_buff *skb, const struct nlmsghdr *nlh, int err,
 		     const struct netlink_ext_ack *extack)
 {
 	if (extack->_msg)
@@ -2190,9 +2195,7 @@ netlink_ack_tlv_fill(struct sk_buff *in_skb, struct sk_buff *skb,
 	if (!err)
 		return;
 
-	if (extack->bad_attr &&
-	    !WARN_ON((u8 *)extack->bad_attr < in_skb->data ||
-		     (u8 *)extack->bad_attr >= in_skb->data + in_skb->len))
+	if (extack->bad_attr && nlmsg_check_in_payload(nlh, extack->bad_attr))
 		WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_OFFS,
 				    (u8 *)extack->bad_attr - (const u8 *)nlh));
 	if (extack->policy)
@@ -2201,9 +2204,7 @@ netlink_ack_tlv_fill(struct sk_buff *in_skb, struct sk_buff *skb,
 	if (extack->miss_type)
 		WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_MISS_TYPE,
 				    extack->miss_type));
-	if (extack->miss_nest &&
-	    !WARN_ON((u8 *)extack->miss_nest < in_skb->data ||
-		     (u8 *)extack->miss_nest > in_skb->data + in_skb->len))
+	if (extack->miss_nest && nlmsg_check_in_payload(nlh, extack->miss_nest))
 		WARN_ON(nla_put_u32(skb, NLMSGERR_ATTR_MISS_NEST,
 				    (u8 *)extack->miss_nest - (const u8 *)nlh));
 }
@@ -2232,7 +2233,7 @@ static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb,
 	if (extack_len) {
 		nlh->nlmsg_flags |= NLM_F_ACK_TLVS;
 		if (skb_tailroom(skb) >= extack_len) {
-			netlink_ack_tlv_fill(cb->skb, skb, cb->nlh,
+			netlink_ack_tlv_fill(skb, cb->nlh,
 					     nlk->dump_done_errno, extack);
 			nlmsg_end(skb, nlh);
 		}
@@ -2491,7 +2492,7 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
 	}
 
 	if (tlvlen)
-		netlink_ack_tlv_fill(in_skb, skb, nlh, err, extack);
+		netlink_ack_tlv_fill(skb, nlh, err, extack);
 
 	nlmsg_end(skb, rep);