diff --git a/components/net/sal_socket/src/sal_socket.c b/components/net/sal_socket/src/sal_socket.c index 44442d306f..37bc2eb440 100644 --- a/components/net/sal_socket/src/sal_socket.c +++ b/components/net/sal_socket/src/sal_socket.c @@ -94,12 +94,12 @@ do { * SAL (Socket Abstraction Layer) initialize. * * @return result 0: initialize success - * -1: initialize failed + * -1: initialize failed */ int sal_init(void) { int cn; - + if (init_ok) { LOG_D("Socket Abstraction Layer is already initialized."); @@ -115,7 +115,7 @@ int sal_init(void) LOG_E("No memory for socket table.\n"); return -1; } - + /* create sal socket lock */ rt_mutex_init(&sal_core_lock, "sal_lock", RT_IPC_FLAG_FIFO); @@ -148,7 +148,7 @@ static void check_netdev_internet_up_work(struct rt_work *work, void *work_data) char send_data[SAL_INTERNET_BUFF_LEN], recv_data = 0; struct rt_delayed_work *delay_work = (struct rt_delayed_work *)work; - const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan","Feb","Mar","Apr","May","Jun","Jul","Aug","Sep","Oct","Nov","Dec"}; + const char month[][SAL_INTERNET_MONTH_LEN] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; char date[SAL_INTERNET_DATE_LEN]; int moth_num = 0; @@ -175,7 +175,7 @@ static void check_netdev_internet_up_work(struct rt_work *work, void *work_data) } skt_ops = pf->skt_ops; - if((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0) + if ((sockfd = skt_ops->socket(AF_INET, SOCK_DGRAM, 0)) < 0) { result = -RT_ERROR; goto __exit; @@ -224,7 +224,7 @@ static void check_netdev_internet_up_work(struct rt_work *work, void *work_data) send_data[11] = RT_REVISION; skt_ops->sendto(sockfd, send_data, SAL_INTERNET_BUFF_LEN, 0, - (struct sockaddr *)&server_addr, sizeof(struct sockaddr)); + (struct sockaddr *)&server_addr, sizeof(struct sockaddr)); result = skt_ops->recvfrom(sockfd, &recv_data, sizeof(recv_data), 0, (struct sockaddr *)&server_addr, &addr_len); if (result < 0) @@ -242,7 +242,7 @@ __exit: if (result > 0) { LOG_D("Set network interface device(%s) internet status up.", netdev->name); - netdev->flags |= NETDEV_FLAG_INTERNET_UP; + netdev->flags |= NETDEV_FLAG_INTERNET_UP; } else { @@ -277,7 +277,7 @@ int sal_check_netdev_internet_up(struct netdev *netdev) rt_delayed_work_init(net_work, check_netdev_internet_up_work, (void *)netdev); rt_work_submit(&(net_work->work), RT_TICK_PER_SECOND); - + return 0; } @@ -316,10 +316,7 @@ struct sal_socket *sal_get_socket(int socket) socket = socket - SAL_SOCKET_OFFSET; /* check socket structure valid or not */ - if (st->sockets[socket]->magic != SAL_SOCKET_MAGIC) - { - return RT_NULL; - } + RT_ASSERT(st->sockets[socket]->magic == SAL_SOCKET_MAGIC); return st->sockets[socket]; } @@ -376,7 +373,8 @@ int sal_netdev_cleanup(struct netdev *netdev) { rt_thread_mdelay(rt_tick_from_millisecond(100)); } - } while (find_dev); + } + while (find_dev); return 0; } @@ -428,7 +426,7 @@ static int socket_init(int family, int type, int protocol, struct sal_socket **r flag = RT_TRUE; } } - + if (flag == RT_FALSE) { /* get network interface device by protocol family */ @@ -452,8 +450,7 @@ static int socket_alloc(struct sal_socket_table *st, int f_socket) /* find an empty socket entry */ for (idx = f_socket; idx < (int) st->max_socket; idx++) { - if (st->sockets[idx] == RT_NULL || - st->sockets[idx]->netdev == RT_NULL) + if (st->sockets[idx] == RT_NULL) { break; } @@ -497,6 +494,15 @@ __result: return idx; } +static void socket_free(struct sal_socket_table *st, int idx) +{ + struct sal_socket *sock; + + sock = st->sockets[idx]; + st->sockets[idx] = RT_NULL; + rt_free(sock); +} + static int socket_new(void) { struct sal_socket *sock; @@ -529,18 +535,38 @@ __result: return idx + SAL_SOCKET_OFFSET; } +static void socket_delete(int socket) +{ + struct sal_socket *sock; + struct sal_socket_table *st = &socket_table; + int idx; + + idx = socket - SAL_SOCKET_OFFSET; + if (idx < 0 || idx >= (int) st->max_socket) + { + return; + } + sal_lock(); + sock = sal_get_socket(socket); + RT_ASSERT(sock != RT_NULL); + sock->magic = 0; + sock->netdev = RT_NULL; + socket_free(st, idx); + sal_unlock(); +} + int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen) { int new_socket; struct sal_socket *sock; - struct sal_proto_family *pf; + struct sal_proto_family *pf; /* get the socket object by socket descriptor */ SAL_SOCKET_OBJ_GET(sock, socket); /* check the network interface socket operations */ SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, accept); - + new_socket = pf->skt_ops->accept((int) sock->user_data, addr, addrlen); if (new_socket != -1) { @@ -550,18 +576,20 @@ int sal_accept(int socket, struct sockaddr *addr, socklen_t *addrlen) /* allocate a new socket structure and registered socket options */ new_sal_socket = socket_new(); - if (new_sal_socket < 0) + new_sock = sal_get_socket(new_sal_socket); + if (new_sock == RT_NULL) { pf->skt_ops->closesocket(new_socket); return -1; } - new_sock = sal_get_socket(new_sal_socket); retval = socket_init(sock->domain, sock->type, sock->protocol, &new_sock); if (retval < 0) { pf->skt_ops->closesocket(new_socket); rt_memset(new_sock, 0x00, sizeof(struct sal_socket)); + /* socket init failed, delete socket */ + socket_delete(new_sal_socket); LOG_E("New socket registered failed, return error %d.", retval); return -1; } @@ -587,7 +615,7 @@ static void sal_sockaddr_to_ipaddr(const struct sockaddr *name, ip_addr_t *local #elif NETDEV_IPV6 #error "not only support IPV6" #endif /* NETDEV_IPV4 && NETDEV_IPV6*/ -} +} int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen) { @@ -636,7 +664,7 @@ int sal_bind(int socket, const struct sockaddr *name, socklen_t namelen) sock->user_data = (void *) new_socket; } } - + /* check and get protocol families by the network interface device */ SAL_NETDEV_SOCKETOPS_VALID(sock->netdev, pf, bind); return pf->skt_ops->bind((int) sock->user_data, name, namelen); @@ -673,9 +701,8 @@ int sal_shutdown(int socket, int how) error = -1; } - /* free socket */ - rt_free(sock); - socket_table.sockets[socket] = RT_NULL; + /* delete socket */ + socket_delete(socket); return error; } @@ -791,7 +818,7 @@ int sal_connect(int socket, const struct sockaddr *name, socklen_t namelen) { return -1; } - + return ret; } #endif @@ -814,7 +841,7 @@ int sal_listen(int socket, int backlog) } int sal_recvfrom(int socket, void *mem, size_t len, int flags, - struct sockaddr *from, socklen_t *fromlen) + struct sockaddr *from, socklen_t *fromlen) { struct sal_socket *sock; struct sal_proto_family *pf; @@ -831,11 +858,11 @@ int sal_recvfrom(int socket, void *mem, size_t len, int flags, if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, recv)) { int ret; - + if ((ret = proto_tls->ops->recv(sock->user_data_tls, mem, len)) < 0) { return -1; - } + } return ret; } else @@ -848,7 +875,7 @@ int sal_recvfrom(int socket, void *mem, size_t len, int flags, } int sal_sendto(int socket, const void *dataptr, size_t size, int flags, - const struct sockaddr *to, socklen_t tolen) + const struct sockaddr *to, socklen_t tolen) { struct sal_socket *sock; struct sal_proto_family *pf; @@ -865,11 +892,11 @@ int sal_sendto(int socket, const void *dataptr, size_t size, int flags, if (SAL_SOCKOPS_PROTO_TLS_VALID(sock, send)) { int ret; - + if ((ret = proto_tls->ops->send(sock->user_data_tls, dataptr, size)) < 0) { return -1; - } + } return ret; } else @@ -907,6 +934,7 @@ int sal_socket(int domain, int type, int protocol) if (retval < 0) { LOG_E("SAL socket protocol family input failed, return error %d.", retval); + socket_delete(socket); return -1; } @@ -922,6 +950,7 @@ int sal_socket(int domain, int type, int protocol) sock->user_data_tls = proto_tls->ops->socket(socket); if (sock->user_data_tls == RT_NULL) { + socket_delete(socket); return -1; } } @@ -964,9 +993,8 @@ int sal_closesocket(int socket) error = -1; } - /* free socket */ - rt_free(sock); - socket_table.sockets[socket] = RT_NULL; + /* delete socket */ + socket_delete(socket); return error; } @@ -1027,7 +1055,7 @@ struct hostent *sal_gethostbyname(const char *name) } int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf, - size_t buflen, struct hostent **result, int *h_errnop) + size_t buflen, struct hostent **result, int *h_errnop) { struct netdev *netdev = netdev_default; struct sal_proto_family *pf; @@ -1050,9 +1078,9 @@ int sal_gethostbyname_r(const char *name, struct hostent *ret, char *buf, } int sal_getaddrinfo(const char *nodename, - const char *servname, - const struct addrinfo *hints, - struct addrinfo **res) + const char *servname, + const struct addrinfo *hints, + struct addrinfo **res) { struct netdev *netdev = netdev_default; struct sal_proto_family *pf;