diff --git a/net/ipv4/gre_offload.c b/net/ipv4/gre_offload.c
index 2e6d1b7a7bc9a5c369c35e9be3a4ca40cee2e491..e0a24657588721ac7b7279603025d876188b385d 100644
--- a/net/ipv4/gre_offload.c
+++ b/net/ipv4/gre_offload.c
@@ -15,12 +15,12 @@ static struct sk_buff *gre_gso_segment(struct sk_buff *skb,
 				       netdev_features_t features)
 {
 	int tnl_hlen = skb_inner_mac_header(skb) - skb_transport_header(skb);
+	bool need_csum, need_recompute_csum, gso_partial;
 	struct sk_buff *segs = ERR_PTR(-EINVAL);
 	u16 mac_offset = skb->mac_header;
 	__be16 protocol = skb->protocol;
 	u16 mac_len = skb->mac_len;
 	int gre_offset, outer_hlen;
-	bool need_csum, gso_partial;
 
 	if (!skb->encapsulation)
 		goto out;
@@ -41,6 +41,7 @@ static struct sk_buff *gre_gso_segment(struct sk_buff *skb,
 	skb->protocol = skb->inner_protocol;
 
 	need_csum = !!(skb_shinfo(skb)->gso_type & SKB_GSO_GRE_CSUM);
+	need_recompute_csum = skb->csum_not_inet;
 	skb->encap_hdr_csum = need_csum;
 
 	features &= skb->dev->hw_enc_features;
@@ -98,7 +99,15 @@ static struct sk_buff *gre_gso_segment(struct sk_buff *skb,
 		}
 
 		*(pcsum + 1) = 0;
-		*pcsum = gso_make_checksum(skb, 0);
+		if (need_recompute_csum && !skb_is_gso(skb)) {
+			__wsum csum;
+
+			csum = skb_checksum(skb, gre_offset,
+					    skb->len - gre_offset, 0);
+			*pcsum = csum_fold(csum);
+		} else {
+			*pcsum = gso_make_checksum(skb, 0);
+		}
 	} while ((skb = skb->next));
 out:
 	return segs;