diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 2a7285c14013c1c75ab52465bc381ba3729fc95a..032d2fa301f576b33d1f8887e6742ca0722f14e0 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -3480,6 +3480,11 @@ static inline void update_halt_poll_stats(struct kvm_vcpu *vcpu, ktime_t start,
 	}
 }
 
+static unsigned int kvm_vcpu_max_halt_poll_ns(struct kvm_vcpu *vcpu)
+{
+	return READ_ONCE(vcpu->kvm->max_halt_poll_ns);
+}
+
 /*
  * Emulate a vCPU halt condition, e.g. HLT on x86, WFI on arm, etc...  If halt
  * polling is enabled, busy wait for a short time before blocking to avoid the
@@ -3488,14 +3493,15 @@ static inline void update_halt_poll_stats(struct kvm_vcpu *vcpu, ktime_t start,
  */
 void kvm_vcpu_halt(struct kvm_vcpu *vcpu)
 {
+	unsigned int max_halt_poll_ns = kvm_vcpu_max_halt_poll_ns(vcpu);
 	bool halt_poll_allowed = !kvm_arch_no_poll(vcpu);
 	ktime_t start, cur, poll_end;
 	bool waited = false;
 	bool do_halt_poll;
 	u64 halt_ns;
 
-	if (vcpu->halt_poll_ns > vcpu->kvm->max_halt_poll_ns)
-		vcpu->halt_poll_ns = vcpu->kvm->max_halt_poll_ns;
+	if (vcpu->halt_poll_ns > max_halt_poll_ns)
+		vcpu->halt_poll_ns = max_halt_poll_ns;
 
 	do_halt_poll = halt_poll_allowed && vcpu->halt_poll_ns;
 
@@ -3537,18 +3543,21 @@ void kvm_vcpu_halt(struct kvm_vcpu *vcpu)
 		update_halt_poll_stats(vcpu, start, poll_end, !waited);
 
 	if (halt_poll_allowed) {
+		/* Recompute the max halt poll time in case it changed. */
+		max_halt_poll_ns = kvm_vcpu_max_halt_poll_ns(vcpu);
+
 		if (!vcpu_valid_wakeup(vcpu)) {
 			shrink_halt_poll_ns(vcpu);
-		} else if (vcpu->kvm->max_halt_poll_ns) {
+		} else if (max_halt_poll_ns) {
 			if (halt_ns <= vcpu->halt_poll_ns)
 				;
 			/* we had a long block, shrink polling */
 			else if (vcpu->halt_poll_ns &&
-				 halt_ns > vcpu->kvm->max_halt_poll_ns)
+				 halt_ns > max_halt_poll_ns)
 				shrink_halt_poll_ns(vcpu);
 			/* we had a short halt and our poll time is too small */
-			else if (vcpu->halt_poll_ns < vcpu->kvm->max_halt_poll_ns &&
-				 halt_ns < vcpu->kvm->max_halt_poll_ns)
+			else if (vcpu->halt_poll_ns < max_halt_poll_ns &&
+				 halt_ns < max_halt_poll_ns)
 				grow_halt_poll_ns(vcpu);
 		} else {
 			vcpu->halt_poll_ns = 0;