diff --git a/net/ipv6/ip6mr.c b/net/ipv6/ip6mr.c index 2ce4ae0d8dc3b443986d7a7b4177a057f5affaec..d5057401701c1a2a09284e3d61d0fc5d035c0d5b 100644 --- a/net/ipv6/ip6mr.c +++ b/net/ipv6/ip6mr.c @@ -125,7 +125,7 @@ static struct mr_table *ip6mr_mr_table_iter(struct net *net, return ret; } -static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +static struct mr_table *__ip6mr_get_table(struct net *net, u32 id) { struct mr_table *mrt; @@ -136,6 +136,16 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id) return NULL; } +static struct mr_table *ip6mr_get_table(struct net *net, u32 id) +{ + struct mr_table *mrt; + + rcu_read_lock(); + mrt = __ip6mr_get_table(net, id); + rcu_read_unlock(); + return mrt; +} + static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, struct mr_table **mrt) { @@ -177,7 +187,7 @@ static int ip6mr_rule_action(struct fib_rule *rule, struct flowi *flp, arg->table = fib_rule_get_table(rule, arg); - mrt = ip6mr_get_table(rule->fr_net, arg->table); + mrt = __ip6mr_get_table(rule->fr_net, arg->table); if (!mrt) return -EAGAIN; res->mrt = mrt; @@ -304,6 +314,8 @@ static struct mr_table *ip6mr_get_table(struct net *net, u32 id) return net->ipv6.mrt6; } +#define __ip6mr_get_table ip6mr_get_table + static int ip6mr_fib_lookup(struct net *net, struct flowi6 *flp6, struct mr_table **mrt) { @@ -382,7 +394,7 @@ static struct mr_table *ip6mr_new_table(struct net *net, u32 id) { struct mr_table *mrt; - mrt = ip6mr_get_table(net, id); + mrt = __ip6mr_get_table(net, id); if (mrt) return mrt; @@ -411,13 +423,15 @@ static void *ip6mr_vif_seq_start(struct seq_file *seq, loff_t *pos) struct net *net = seq_file_net(seq); struct mr_table *mrt; - mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); - if (!mrt) + rcu_read_lock(); + mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) { + rcu_read_unlock(); return ERR_PTR(-ENOENT); + } iter->mrt = mrt; - rcu_read_lock(); return mr_vif_seq_start(seq, pos); } @@ -2275,11 +2289,13 @@ int ip6mr_get_route(struct net *net, struct sk_buff *skb, struct rtmsg *rtm, struct mfc6_cache *cache; struct rt6_info *rt = dst_rt6_info(skb_dst(skb)); - mrt = ip6mr_get_table(net, RT6_TABLE_DFLT); - if (!mrt) + rcu_read_lock(); + mrt = __ip6mr_get_table(net, RT6_TABLE_DFLT); + if (!mrt) { + rcu_read_unlock(); return -ENOENT; + } - rcu_read_lock(); cache = ip6mr_cache_find(mrt, &rt->rt6i_src.addr, &rt->rt6i_dst.addr); if (!cache && skb->dev) { int vif = ip6mr_find_vif(mrt, skb->dev); @@ -2559,7 +2575,7 @@ static int ip6mr_rtm_getroute(struct sk_buff *in_skb, struct nlmsghdr *nlh, grp = nla_get_in6_addr(tb[RTA_DST]); tableid = tb[RTA_TABLE] ? nla_get_u32(tb[RTA_TABLE]) : 0; - mrt = ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT); + mrt = __ip6mr_get_table(net, tableid ?: RT_TABLE_DEFAULT); if (!mrt) { NL_SET_ERR_MSG_MOD(extack, "MR table does not exist"); return -ENOENT; @@ -2606,7 +2622,7 @@ static int ip6mr_rtm_dumproute(struct sk_buff *skb, struct netlink_callback *cb) if (filter.table_id) { struct mr_table *mrt; - mrt = ip6mr_get_table(sock_net(skb->sk), filter.table_id); + mrt = __ip6mr_get_table(sock_net(skb->sk), filter.table_id); if (!mrt) { if (rtnl_msg_family(cb->nlh) != RTNL_FAMILY_IP6MR) return skb->len;