diff --git a/components/smp/smp.c b/components/smp/smp.c index 1c854946c4..134d17e3d0 100644 --- a/components/smp/smp.c +++ b/components/smp/smp.c @@ -1,3 +1,13 @@ +/* + * Copyright (c) 2006-2024 RT-Thread Development Team + * + * SPDX-License-Identifier: Apache-2.0 + * + * Change Logs: + * Date Author Notes + * 2024/9/12 zhujiale the first version + */ + #include "smp.h" #define DBG_TAG "SMP" @@ -5,18 +15,18 @@ #include struct smp_call global_work[RT_CPUS_NR]; - +rt_atomic_t wait; rt_err_t smp_call_handler(struct smp_event *event) { switch(event->event_id) { case SMP_CALL_EVENT_FUNC: event->func(event->data); + rt_atomic_add(&wait,1); break; default: - rt_kprintf("error event id\n"); - return -RT_ERROR; - break; + LOG_E("error event id\n"); + return RT_ERROR; } return RT_EOK; } @@ -37,20 +47,36 @@ void rt_smp_call_ipi_handler(int vector, void *param) rt_memset(&global_work[cur_cpu].event,0,sizeof(struct smp_event)); } rt_spin_unlock(&global_work[cur_cpu].lock); - + } -void rt_smp_call_func_cond(int cpu_mask, smp_func func, void *data) +/** + * @brief call function on specified CPU , + * + * @param cpu_mask cpu mask for call + * @param func the function pointer + * @param data the data pointer + * @param flag call flag if you set SMP_CALL_WAIT_ALL + * then it will wait all cpu call finish and return + * else it will call function on specified CPU and return immediately + * @param cond the condition function pointer,if you set it then it will call function only when cond return true + */ +void rt_smp_call_func_cond(int cpu_mask, smp_call_func_back func, void *data,rt_uint8_t flag,smp_cond cond) { RT_DEBUG_NOT_IN_INTERRUPT; - struct smp_call work; struct smp_event event; - rt_bool_t need_call = RT_TRUE; + rt_bool_t need_call = RT_TRUE,need_wait = RT_FALSE; int cur_cpu = rt_hw_cpu_id(); int cpuid = 1 << cur_cpu; - int tmp_id = 0; + int tmp_id = 0,cpu_nr = 0; int tmp_mask; + if(flag == SMP_CALL_WAIT_ALL) + { + need_wait = RT_TRUE; + rt_atomic_store(&wait,0); + } + if(cpuid & cpu_mask) { func(data); @@ -58,7 +84,7 @@ void rt_smp_call_func_cond(int cpu_mask, smp_func func, void *data) } if(!cpu_mask) - need_call = RT_FALSE; + need_call = RT_FALSE; tmp_mask = cpu_mask; if(need_call) @@ -67,6 +93,9 @@ void rt_smp_call_func_cond(int cpu_mask, smp_func func, void *data) { if((tmp_mask & 1) && (tmp_id < RT_CPUS_NR)) { + if(cond && !cond(tmp_id,data)) + continue; + cpu_nr++; event.event_id = SMP_CALL_EVENT_FUNC; event.func = func; event.data = data; @@ -80,15 +109,38 @@ void rt_smp_call_func_cond(int cpu_mask, smp_func func, void *data) } rt_hw_ipi_send(RT_IPI_FUNC, cpu_mask); } -} -struct rt_spinlock lock_1; -void smp_init(void) -{ - rt_spin_lock_init(&lock_1); - for(int i = 0; i < RT_CPUS_NR; i++) - { - rt_memset(&global_work[i],0,sizeof(struct smp_call)); - rt_spin_lock_init(&global_work[i].lock); - } + + if(need_wait) + { + while(rt_atomic_load(&wait) != cpu_nr); + } +} + +void rt_call_each_cpu(smp_call_func_back func, void *data,rt_uint8_t flag) +{ + rt_smp_call_func_cond(RT_ALL_CPU,func,data,flag,RT_NULL); +} + +void rt_call_each_cpu_cond(smp_call_func_back func, void *data,rt_uint8_t flag,smp_cond cond_func) +{ + rt_smp_call_func_cond(RT_ALL_CPU,func,data,flag,cond_func); +} +void rt_call_any_cpu(int cpu_mask,smp_call_func_back func, void *data,rt_uint8_t flag) +{ + rt_smp_call_func_cond(cpu_mask,func,data,flag,RT_NULL); +} + +void rt_call_any_cpu_cond(int cpu_mask,smp_call_func_back func, void *data,rt_uint8_t flag,smp_cond cond_func) +{ + rt_smp_call_func_cond(cpu_mask,func,data,flag,cond_func); +} + +void smp_init(void) +{ + for(int i = 0; i < RT_CPUS_NR; i++) + { + rt_memset(&global_work[i],0,sizeof(struct smp_call)); + rt_spin_lock_init(&global_work[i].lock); + } } diff --git a/components/smp/smp.h b/components/smp/smp.h index ea0a2de124..821b95ff75 100644 --- a/components/smp/smp.h +++ b/components/smp/smp.h @@ -1,28 +1,34 @@ #ifndef __SMP_IPI_H__ #define __SMP_IPI_H__ #include -typedef void (*smp_func)(void *data); +typedef void (*smp_call_func_back)(void *data); +typedef rt_bool_t (*smp_cond)(int cpu, void *info); #define SMP_CALL_EVENT_FUNC 0x1 +#define SMP_CALL_WAIT_ALL (1 << 0) +#define SMP_CALL_NO_WAIT (1 << 1) + +#define RT_ALL_CPU ((1 << RT_CPUS_NR) - 1) struct smp_event { int cpu_mask; int event_id; void *data; - smp_func func; - + smp_call_func_back func; }; struct smp_call { struct rt_spinlock lock; struct smp_event event; - }; -void rt_smp_call_func_cond(int cpu_mask,smp_func func, void *data); void rt_smp_call_ipi_handler(int vector, void *param); +void rt_call_each_cpu(smp_call_func_back func, void *data,rt_uint8_t flag); +void rt_call_each_cpu_cond(smp_call_func_back func, void *data,rt_uint8_t flag,smp_cond cond_func); +void rt_call_any_cpu(int cpu_mask,smp_call_func_back func, void *data,rt_uint8_t flag); +void rt_call_any_cpu_cond(int cpu_mask,smp_call_func_back func, void *data,rt_uint8_t flag,smp_cond cond_func); void smp_init(void); #endif diff --git a/examples/utest/testcases/smp/smp.c b/examples/utest/testcases/smp/smp.c index f35e243912..1af161b72f 100644 --- a/examples/utest/testcases/smp/smp.c +++ b/examples/utest/testcases/smp/smp.c @@ -2,39 +2,58 @@ #include "utest.h" #include "utest_assert.h" #include "smp.h" -int pass_count = 0; -int pass = 1000; +int pass_count = 0; +int pass = 1000; struct rt_spinlock lock; void test_call(void *data) { rt_spin_lock(&lock); - int *i = (int *)data; - int id = rt_hw_cpu_id(); - *i &= ~(1 << id); - if(*i == 0) + int *i = (int *)data; + int id = rt_hw_cpu_id(); + *i &= ~(1 << id); + if (*i == 0) pass_count++; rt_spin_unlock(&lock); } -void test() +void test1() { int cpu_mask = 0xf; - for(int i =0 ;i < 1000 ;i++) + for (int i = 0; i < 1000; i++) { - cpu_mask = rand()% 0xf; - if (cpu_mask == 0) - pass--; - rt_smp_call_func_cond(cpu_mask,test_call, &cpu_mask); - if(i % 20 == 0) + cpu_mask = rand() % 0xf; + if (cpu_mask == 0) + pass--; + rt_call_each_cpu(test_call, &cpu_mask, SMP_CALL_NO_WAIT); + if (i % 20 == 0) rt_kprintf("#"); rt_thread_mdelay(1); - } + } rt_kprintf("\n"); + uassert_true(pass_count == pass); } - +void test_call2(void *data) +{ + rt_spin_lock(&lock); + int a = 100000; + while (a--); + int *i = (int *)data; + (*i)++; + rt_spin_unlock(&lock); +} +void test2(void) +{ + int data = 0; + rt_call_each_cpu(test_call2, &data, SMP_CALL_WAIT_ALL); + uassert_true(data == RT_CPUS_NR); + rt_thread_mdelay(10); + data = 0; + rt_call_each_cpu(test_call2, &data, SMP_CALL_NO_WAIT); + uassert_true(data != RT_CPUS_NR); +} static rt_err_t utest_tc_init(void) { @@ -44,12 +63,12 @@ static rt_err_t utest_tc_init(void) static rt_err_t utest_tc_cleanup(void) { - uassert_true(pass_count == pass); return RT_EOK; } static void testcase(void) { - UTEST_UNIT_RUN(test); + UTEST_UNIT_RUN(test1); + UTEST_UNIT_RUN(test2); } UTEST_TC_EXPORT(testcase, "testcase.smp.smp", utest_tc_init, utest_tc_cleanup, 10);