diff --git a/arch/arm64/include/asm/fpsimd.h b/arch/arm64/include/asm/fpsimd.h index 67f2fb781f59e8ad3fdc5dc6fe9e692a1fbd27d2..8df46f186c64b8e2665a99a5982ee1f5f65c7a4d 100644 --- a/arch/arm64/include/asm/fpsimd.h +++ b/arch/arm64/include/asm/fpsimd.h @@ -356,7 +356,7 @@ static inline int sme_max_virtualisable_vl(void) return vec_max_virtualisable_vl(ARM64_VEC_SME); } -extern void sme_alloc(struct task_struct *task); +extern void sme_alloc(struct task_struct *task, bool flush); extern unsigned int sme_get_vl(void); extern int sme_set_current_vl(unsigned long arg); extern int sme_get_current_vl(void); @@ -388,7 +388,7 @@ static inline void sme_smstart_sm(void) { } static inline void sme_smstop_sm(void) { } static inline void sme_smstop(void) { } -static inline void sme_alloc(struct task_struct *task) { } +static inline void sme_alloc(struct task_struct *task, bool flush) { } static inline void sme_setup(void) { } static inline unsigned int sme_get_vl(void) { return 0; } static inline int sme_max_vl(void) { return 0; } diff --git a/arch/arm64/kernel/fpsimd.c b/arch/arm64/kernel/fpsimd.c index 75c37b1c55aaf3911419c2cf6540b296bad80f49..087c05aa960ea310e3be7a4d5f3db06ecf11eeb9 100644 --- a/arch/arm64/kernel/fpsimd.c +++ b/arch/arm64/kernel/fpsimd.c @@ -1285,9 +1285,9 @@ void fpsimd_release_task(struct task_struct *dead_task) * the interest of testability and predictability, the architecture * guarantees that when ZA is enabled it will be zeroed. */ -void sme_alloc(struct task_struct *task) +void sme_alloc(struct task_struct *task, bool flush) { - if (task->thread.sme_state) { + if (task->thread.sme_state && flush) { memset(task->thread.sme_state, 0, sme_state_size(task)); return; } @@ -1515,7 +1515,7 @@ void do_sme_acc(unsigned long esr, struct pt_regs *regs) } sve_alloc(current, false); - sme_alloc(current); + sme_alloc(current, true); if (!current->thread.sve_state || !current->thread.sme_state) { force_sig(SIGKILL); return; diff --git a/arch/arm64/kernel/ptrace.c b/arch/arm64/kernel/ptrace.c index 5b9b4305248b88774cb86f36a5380499e5357f80..187aa2b175b4f68abd9a81ca84e653a975fa7743 100644 --- a/arch/arm64/kernel/ptrace.c +++ b/arch/arm64/kernel/ptrace.c @@ -881,6 +881,13 @@ static int sve_set_common(struct task_struct *target, break; case ARM64_VEC_SME: target->thread.svcr |= SVCR_SM_MASK; + + /* + * Disable traps and ensure there is SME storage but + * preserve any currently set values in ZA/ZT. + */ + sme_alloc(target, false); + set_tsk_thread_flag(target, TIF_SME); break; default: WARN_ON_ONCE(1); @@ -1100,7 +1107,7 @@ static int za_set(struct task_struct *target, } /* Allocate/reinit ZA storage */ - sme_alloc(target); + sme_alloc(target, true); if (!target->thread.sme_state) { ret = -ENOMEM; goto out; @@ -1170,8 +1177,13 @@ static int zt_set(struct task_struct *target, if (!system_supports_sme2()) return -EINVAL; + /* Ensure SVE storage in case this is first use of SME */ + sve_alloc(target, false); + if (!target->thread.sve_state) + return -ENOMEM; + if (!thread_za_enabled(&target->thread)) { - sme_alloc(target); + sme_alloc(target, true); if (!target->thread.sme_state) return -ENOMEM; } @@ -1179,8 +1191,10 @@ static int zt_set(struct task_struct *target, ret = user_regset_copyin(&pos, &count, &kbuf, &ubuf, thread_zt_state(&target->thread), 0, ZT_SIG_REG_BYTES); - if (ret == 0) + if (ret == 0) { target->thread.svcr |= SVCR_ZA_MASK; + set_tsk_thread_flag(target, TIF_SME); + } fpsimd_flush_task_state(target); diff --git a/arch/arm64/kernel/signal.c b/arch/arm64/kernel/signal.c index e304f7ebec2a91ff88c83afa20d8722a61f99cfa..c7ebe744c64e39fb93f6b26433beb460f6aed079 100644 --- a/arch/arm64/kernel/signal.c +++ b/arch/arm64/kernel/signal.c @@ -475,7 +475,7 @@ static int restore_za_context(struct user_ctxs *user) fpsimd_flush_task_state(current); /* From now, fpsimd_thread_switch() won't touch thread.sve_state */ - sme_alloc(current); + sme_alloc(current, true); if (!current->thread.sme_state) { current->thread.svcr &= ~SVCR_ZA_MASK; clear_thread_flag(TIF_SME);