+free_sockaddr:
+ free(sockaddr);
+
+del_sock:
+ ustcomm_del_sock(sock, keep_socket_file);
+}
+
+int ustcomm_recv_alloc(int sock,
+ struct ustcomm_header *header,
+ char **data) {
+ int result;
+ struct ustcomm_header peek_header;
+ struct iovec iov[2];
+ struct msghdr msg;
+
+ /* Just to make the caller fail hard */
+ *data = NULL;
+
+ result = recv(sock, &peek_header, sizeof(peek_header),
+ MSG_PEEK | MSG_WAITALL);
+ if (result <= 0) {
+ if(errno == ECONNRESET) {
+ return 0;
+ } else if (errno == EINTR) {
+ return -1;
+ } else if (result < 0) {
+ PERROR("recv");
+ return -1;
+ }
+ return 0;
+ }
+
+ memset(&msg, 0, sizeof(msg));
+
+ iov[0].iov_base = (char *)header;
+ iov[0].iov_len = sizeof(struct ustcomm_header);
+
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 1;
+
+ if (peek_header.size) {
+ *data = zmalloc(peek_header.size);
+ if (!*data) {
+ return -ENOMEM;
+ }
+
+ iov[1].iov_base = *data;
+ iov[1].iov_len = peek_header.size;
+
+ msg.msg_iovlen++;
+ }
+
+ result = recvmsg(sock, &msg, MSG_WAITALL);
+ if (result < 0) {
+ free(*data);
+ PERROR("recvmsg failed");
+ }
+
+ return result;
+}
+
+/* returns 1 to indicate a message was received
+ * returns 0 to indicate no message was received (end of stream)
+ * returns -1 to indicate an error
+ */
+int ustcomm_recv_fd(int sock,
+ struct ustcomm_header *header,
+ char *data, int *fd)
+{
+ int result;
+ struct ustcomm_header peek_header;
+ struct iovec iov[2];
+ struct msghdr msg;
+ struct cmsghdr *cmsg;
+ char buf[CMSG_SPACE(sizeof(int))];
+
+ result = recv(sock, &peek_header, sizeof(peek_header),
+ MSG_PEEK | MSG_WAITALL);
+ if (result <= 0) {
+ if(errno == ECONNRESET) {
+ return 0;
+ } else if (errno == EINTR) {
+ return -1;
+ } else if (result < 0) {
+ PERROR("recv");
+ return -1;
+ }
+ return 0;
+ }
+
+ memset(&msg, 0, sizeof(msg));
+
+ iov[0].iov_base = (char *)header;
+ iov[0].iov_len = sizeof(struct ustcomm_header);
+
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 1;
+
+ if (peek_header.size && data) {
+ if (peek_header.size < 0 ||
+ peek_header.size > USTCOMM_DATA_SIZE) {
+ ERR("big peek header! %d", peek_header.size);
+ return 0;
+ }
+
+ iov[1].iov_base = data;
+ iov[1].iov_len = peek_header.size;
+
+ msg.msg_iovlen++;
+ }
+
+ if (fd && peek_header.fd_included) {
+ msg.msg_control = buf;
+ msg.msg_controllen = sizeof(buf);
+ }
+
+ result = recvmsg(sock, &msg, MSG_WAITALL);
+ if (result <= 0) {
+ if (result < 0) {
+ PERROR("recvmsg failed");
+ }
+ return result;
+ }
+
+ if (fd && peek_header.fd_included) {
+ cmsg = CMSG_FIRSTHDR(&msg);
+ result = 0;
+ while (cmsg != NULL) {
+ if (cmsg->cmsg_level == SOL_SOCKET
+ && cmsg->cmsg_type == SCM_RIGHTS) {
+ *fd = *(int *) CMSG_DATA(cmsg);
+ result = 1;
+ break;
+ }
+ cmsg = CMSG_NXTHDR(&msg, cmsg);
+ }
+ if (!result) {
+ ERR("Failed to receive file descriptor\n");
+ }
+ }
+
+ return 1;
+}
+
+int ustcomm_recv(int sock,
+ struct ustcomm_header *header,
+ char *data)
+{
+ return ustcomm_recv_fd(sock, header, data, NULL);
+}
+
+
+int ustcomm_send_fd(int sock,
+ const struct ustcomm_header *header,
+ const char *data,
+ int *fd)
+{
+ struct iovec iov[2];
+ struct msghdr msg;
+ int result;
+ struct cmsghdr *cmsg;
+ char buf[CMSG_SPACE(sizeof(int))];
+
+ memset(&msg, 0, sizeof(msg));
+
+ iov[0].iov_base = (char *)header;
+ iov[0].iov_len = sizeof(struct ustcomm_header);
+
+ msg.msg_iov = iov;
+ msg.msg_iovlen = 1;
+
+ if (header->size && data) {
+ iov[1].iov_base = (char *)data;
+ iov[1].iov_len = header->size;
+
+ msg.msg_iovlen++;
+
+ }
+
+ if (fd && header->fd_included) {
+ msg.msg_control = buf;
+ msg.msg_controllen = sizeof(buf);
+ cmsg = CMSG_FIRSTHDR(&msg);
+ cmsg->cmsg_level = SOL_SOCKET;
+ cmsg->cmsg_type = SCM_RIGHTS;
+ cmsg->cmsg_len = CMSG_LEN(sizeof(int));
+ *(int *) CMSG_DATA(cmsg) = *fd;
+ msg.msg_controllen = cmsg->cmsg_len;
+ }
+
+ result = sendmsg(sock, &msg, MSG_NOSIGNAL);
+ if (result < 0 && errno != EPIPE) {
+ PERROR("sendmsg failed");
+ }
+ return result;
+}
+
+int ustcomm_send(int sock,
+ const struct ustcomm_header *header,
+ const char *data)
+{
+ return ustcomm_send_fd(sock, header, data, NULL);
+}
+
+int ustcomm_req(int sock,
+ const struct ustcomm_header *req_header,
+ const char *req_data,
+ struct ustcomm_header *res_header,
+ char *res_data)
+{
+ int result;