diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 44f6a20fa29df830c1208e825816eaa11785f1ab..2e2aecbe22c4499d80dd715ce0f369bd20dc8876 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -560,15 +560,11 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
 						 __be16 sport, __be16 dport,
 						 struct udp_table *udptable)
 {
-	struct sock *sk;
 	const struct iphdr *iph = ip_hdr(skb);
 
-	if (unlikely(sk = skb_steal_sock(skb)))
-		return sk;
-	else
-		return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
-					 iph->daddr, dport, inet_iif(skb),
-					 udptable);
+	return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
+				 iph->daddr, dport, inet_iif(skb),
+				 udptable);
 }
 
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
@@ -1739,15 +1735,15 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	if (udp4_csum_init(skb, uh, proto))
 		goto csum_error;
 
-	if (skb->sk) {
+	sk = skb_steal_sock(skb);
+	if (sk) {
 		int ret;
-		sk = skb->sk;
 
 		if (unlikely(sk->sk_rx_dst == NULL))
 			udp_sk_rx_dst_set(sk, skb);
 
 		ret = udp_queue_rcv_skb(sk, skb);
-
+		sock_put(sk);
 		/* a return value > 0 means to resubmit the input, but
 		 * it wants the return to be -protocol, or 0
 		 */