diff --git a/drivers/infiniband/sw/rxe/rxe_pool.c b/drivers/infiniband/sw/rxe/rxe_pool.c
index d26730eec720c6f090f8c34512091815d814cc42..cfcd5517557287718e9124510d44e48dd915bc86 100644
--- a/drivers/infiniband/sw/rxe/rxe_pool.c
+++ b/drivers/infiniband/sw/rxe/rxe_pool.c
@@ -343,8 +343,6 @@ void *rxe_alloc_nl(struct rxe_pool *pool)
 	struct rxe_pool_entry *elem;
 	u8 *obj;
 
-	might_sleep_if(!(pool->flags & RXE_POOL_ATOMIC));
-
 	if (pool->state != RXE_POOL_STATE_VALID)
 		return NULL;
 
@@ -356,8 +354,7 @@ void *rxe_alloc_nl(struct rxe_pool *pool)
 	if (atomic_inc_return(&pool->num_elem) > pool->max_elem)
 		goto out_cnt;
 
-	obj = kzalloc(info->size, (pool->flags & RXE_POOL_ATOMIC) ?
-		      GFP_ATOMIC : GFP_KERNEL);
+	obj = kzalloc(info->size, GFP_ATOMIC);
 	if (!obj)
 		goto out_cnt;
 
@@ -378,14 +375,46 @@ void *rxe_alloc_nl(struct rxe_pool *pool)
 
 void *rxe_alloc(struct rxe_pool *pool)
 {
-	u8 *obj;
 	unsigned long flags;
+	struct rxe_type_info *info = &rxe_type_info[pool->type];
+	struct rxe_pool_entry *elem;
+	u8 *obj;
+
+	might_sleep_if(!(pool->flags & RXE_POOL_ATOMIC));
 
 	read_lock_irqsave(&pool->pool_lock, flags);
-	obj = rxe_alloc_nl(pool);
+	if (pool->state != RXE_POOL_STATE_VALID) {
+		read_unlock_irqrestore(&pool->pool_lock, flags);
+		return NULL;
+	}
+
+	kref_get(&pool->ref_cnt);
 	read_unlock_irqrestore(&pool->pool_lock, flags);
 
+	if (!ib_device_try_get(&pool->rxe->ib_dev))
+		goto out_put_pool;
+
+	if (atomic_inc_return(&pool->num_elem) > pool->max_elem)
+		goto out_cnt;
+
+	obj = kzalloc(info->size, (pool->flags & RXE_POOL_ATOMIC) ?
+		      GFP_ATOMIC : GFP_KERNEL);
+	if (!obj)
+		goto out_cnt;
+
+	elem = (struct rxe_pool_entry *)(obj + info->elem_offset);
+
+	elem->pool = pool;
+	kref_init(&elem->ref_cnt);
+
 	return obj;
+
+out_cnt:
+	atomic_dec(&pool->num_elem);
+	ib_device_put(&pool->rxe->ib_dev);
+out_put_pool:
+	rxe_pool_put(pool);
+	return NULL;
 }
 
 int __rxe_add_to_pool(struct rxe_pool *pool, struct rxe_pool_entry *elem)