From: Olivier Dion Date: Wed, 9 Aug 2023 21:35:40 +0000 (-0400) Subject: ustfork: Fix possible race conditions X-Git-Url: https://git.lttng.org./?a=commitdiff_plain;h=cf54f06a38ed1f3a2ac49236ffb1f1e7b3024d87;p=lttng-ust.git ustfork: Fix possible race conditions Assuming that `dlsym(RTLD_NEXT, "symbol")' is invariant for "symbol", then we could think that memory operations on the `plibc_func' pointers can be safely done without atomics. However, consider what would happen if a load to a`plibc_func' pointer is torn apart by the compiler. Then a thread could see: 1) NULL 2) The stored value as returned by a dlsym() call 3) A mix of 1) and 2) The same goes for other optimizations that a compiler is authorized to do (e.g. store tearing, load fusing). One could question whether such race condition is even possible for the clone(2) wrapper. Indeed, a thread must be cloned to get into existence. Therefore, the main thread would always store the value of `plibc_func' at least once before creating the first sibling thread, preventing any possible race condition for this wrapper. However, this assume that the main thread will not call the clone system call directly before calling the libc wrapper! Thus, to be on the safe side, we do the same for the clone wrapper. Fix the race conditions by using the uatomic_read/uatomic_set functions, on access to `plibc_func' pointers. Change-Id: Ic4be25983b8836d2b333f367af9c18d2f6b75879 Signed-off-by: Olivier Dion Signed-off-by: Mathieu Desnoyers --- diff --git a/src/lib/lttng-ust-fork/ustfork.c b/src/lib/lttng-ust-fork/ustfork.c index 321ffc30..9508cc75 100644 --- a/src/lib/lttng-ust-fork/ustfork.c +++ b/src/lib/lttng-ust-fork/ustfork.c @@ -18,25 +18,30 @@ #include +#include + pid_t fork(void) { static pid_t (*plibc_func)(void) = NULL; + pid_t (*func)(void); sigset_t sigset; pid_t retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "fork"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "fork"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"fork\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } lttng_ust_before_fork(&sigset); /* Do the real fork */ - retval = plibc_func(); + retval = func(); saved_errno = errno; if (retval == 0) { /* child */ @@ -51,22 +56,25 @@ pid_t fork(void) int daemon(int nochdir, int noclose) { static int (*plibc_func)(int nochdir, int noclose) = NULL; + int (*func)(int nochdir, int noclose); sigset_t sigset; int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "daemon"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "daemon"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"daemon\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } lttng_ust_before_fork(&sigset); /* Do the real daemon call */ - retval = plibc_func(nochdir, noclose); + retval = func(nochdir, noclose); saved_errno = errno; if (retval == 0) { /* child, parent called _exit() directly */ @@ -82,20 +90,23 @@ int daemon(int nochdir, int noclose) int setuid(uid_t uid) { static int (*plibc_func)(uid_t uid) = NULL; + int (*func)(uid_t uid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setuid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setuid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setuid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setuid */ - retval = plibc_func(uid); + retval = func(uid); saved_errno = errno; lttng_ust_after_setuid(); @@ -107,20 +118,23 @@ int setuid(uid_t uid) int setgid(gid_t gid) { static int (*plibc_func)(gid_t gid) = NULL; + int (*func)(gid_t gid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setgid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setgid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setgid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setgid */ - retval = plibc_func(gid); + retval = func(gid); saved_errno = errno; lttng_ust_after_setgid(); @@ -132,20 +146,23 @@ int setgid(gid_t gid) int seteuid(uid_t euid) { static int (*plibc_func)(uid_t euid) = NULL; + int (*func)(uid_t euid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "seteuid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "seteuid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"seteuid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real seteuid */ - retval = plibc_func(euid); + retval = func(euid); saved_errno = errno; lttng_ust_after_seteuid(); @@ -157,20 +174,23 @@ int seteuid(uid_t euid) int setegid(gid_t egid) { static int (*plibc_func)(gid_t egid) = NULL; + int (*func)(gid_t egid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setegid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setegid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setegid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setegid */ - retval = plibc_func(egid); + retval = func(egid); saved_errno = errno; lttng_ust_after_setegid(); @@ -182,20 +202,23 @@ int setegid(gid_t egid) int setreuid(uid_t ruid, uid_t euid) { static int (*plibc_func)(uid_t ruid, uid_t euid) = NULL; + int (*func)(uid_t ruid, uid_t euid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setreuid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setreuid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setreuid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setreuid */ - retval = plibc_func(ruid, euid); + retval = func(ruid, euid); saved_errno = errno; lttng_ust_after_setreuid(); @@ -207,20 +230,23 @@ int setreuid(uid_t ruid, uid_t euid) int setregid(gid_t rgid, gid_t egid) { static int (*plibc_func)(gid_t rgid, gid_t egid) = NULL; + int (*func)(gid_t rgid, gid_t egid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setregid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setregid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setregid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setregid */ - retval = plibc_func(rgid, egid); + retval = func(rgid, egid); saved_errno = errno; lttng_ust_after_setregid(); @@ -253,6 +279,9 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...) static int (*plibc_func)(int (*fn)(void *), void *child_stack, int flags, void *arg, pid_t *ptid, struct user_desc *tls, pid_t *ctid) = NULL; + int (*func)(int (*fn)(void *), void *child_stack, + int flags, void *arg, pid_t *ptid, + struct user_desc *tls, pid_t *ctid); /* var args */ pid_t *ptid; struct user_desc *tls; @@ -268,13 +297,15 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...) ctid = va_arg(ap, pid_t *); va_end(ap); - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "clone"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "clone"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"clone\" symbol.\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } if (flags & CLONE_VM) { @@ -282,16 +313,16 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...) * Creating a thread, no need to intervene, just pass on * the arguments. */ - retval = plibc_func(fn, child_stack, flags, arg, ptid, - tls, ctid); + retval = func(fn, child_stack, flags, arg, ptid, + tls, ctid); saved_errno = errno; } else { /* Creating a real process, we need to intervene. */ struct ustfork_clone_info info = { .fn = fn, .arg = arg }; lttng_ust_before_fork(&info.sigset); - retval = plibc_func(clone_fn, child_stack, flags, &info, - ptid, tls, ctid); + retval = func(clone_fn, child_stack, flags, &info, + ptid, tls, ctid); saved_errno = errno; /* The child doesn't get here. */ lttng_ust_after_fork_parent(&info.sigset); @@ -303,20 +334,23 @@ int clone(int (*fn)(void *), void *child_stack, int flags, void *arg, ...) int setns(int fd, int nstype) { static int (*plibc_func)(int fd, int nstype) = NULL; + int (*func)(int fd, int nstype); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setns"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setns"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setns\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setns */ - retval = plibc_func(fd, nstype); + retval = func(fd, nstype); saved_errno = errno; lttng_ust_after_setns(); @@ -328,20 +362,23 @@ int setns(int fd, int nstype) int unshare(int flags) { static int (*plibc_func)(int flags) = NULL; + int (*func)(int flags); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "unshare"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "unshare"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"unshare\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setns */ - retval = plibc_func(flags); + retval = func(flags); saved_errno = errno; lttng_ust_after_unshare(); @@ -353,20 +390,23 @@ int unshare(int flags) int setresuid(uid_t ruid, uid_t euid, uid_t suid) { static int (*plibc_func)(uid_t ruid, uid_t euid, uid_t suid) = NULL; + int (*func)(uid_t ruid, uid_t euid, uid_t suid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setresuid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setresuid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setresuid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setresuid */ - retval = plibc_func(ruid, euid, suid); + retval = func(ruid, euid, suid); saved_errno = errno; lttng_ust_after_setresuid(); @@ -378,20 +418,23 @@ int setresuid(uid_t ruid, uid_t euid, uid_t suid) int setresgid(gid_t rgid, gid_t egid, gid_t sgid) { static int (*plibc_func)(gid_t rgid, gid_t egid, gid_t sgid) = NULL; + int (*func)(gid_t rgid, gid_t egid, gid_t sgid); int retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "setresgid"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "setresgid"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"setresgid\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } /* Do the real setresgid */ - retval = plibc_func(rgid, egid, sgid); + retval = func(rgid, egid, sgid); saved_errno = errno; lttng_ust_after_setresgid(); @@ -405,22 +448,25 @@ int setresgid(gid_t rgid, gid_t egid, gid_t sgid) pid_t rfork(int flags) { static pid_t (*plibc_func)(int flags) = NULL; + pid_t (*func)(int flags); sigset_t sigset; pid_t retval; int saved_errno; - if (plibc_func == NULL) { - plibc_func = dlsym(RTLD_NEXT, "rfork"); - if (plibc_func == NULL) { + func = uatomic_read(plibc_func); + if (func == NULL) { + func = dlsym(RTLD_NEXT, "rfork"); + if (func == NULL) { fprintf(stderr, "libustfork: unable to find \"rfork\" symbol\n"); errno = ENOSYS; return -1; } + uatomic_set(&plibc_func, func); } lttng_ust_before_fork(&sigset); /* Do the real rfork */ - retval = plibc_func(flags); + retval = func(flags); saved_errno = errno; if (retval == 0) { /* child */