diff --git a/src/threading/ipc_channel.cpp b/src/threading/ipc_channel.cpp index 113a7f701..1ea61e7ff 100644 --- a/src/threading/ipc_channel.cpp +++ b/src/threading/ipc_channel.cpp @@ -202,6 +202,18 @@ static struct timespec *set_timespec(struct timespec *ts, int ms) } #endif // !defined(_WIN32) +template +static inline void write_once(volatile T *var, const T val) +{ + *var = val; +} + +template +static inline T read_once(const volatile T *var) +{ + return *var; +} + IPCChannelEnd IPCChannelEnd::makeA(std::unique_ptr stuff) { IPCChannelShared *shared = stuff->getShared(); @@ -228,7 +240,7 @@ IPCChannelEnd IPCChannelEnd::makeB(std::unique_ptr stuff) void IPCChannelEnd::sendSmall(const void *data, size_t size) noexcept { - m_out->size = size; + write_once(&m_out->size, size); memcpy(m_out->data, data, size); #if defined(_WIN32) post(m_sem_out); @@ -245,7 +257,7 @@ bool IPCChannelEnd::sendLarge(const void *data, size_t size, int timeout_ms) noe struct timespec timeout; struct timespec *timeoutp = set_timespec(&timeout, timeout_ms); #endif - m_out->size = size; + write_once(&m_out->size, size); do { memcpy(m_out->data, data, IPC_CHANNEL_MSG_SIZE); #if defined(_WIN32) @@ -285,10 +297,11 @@ bool IPCChannelEnd::recv(int timeout_ms) noexcept if (!wait(m_in, timeoutp)) #endif return false; - size_t size = m_in->size; + size_t size = read_once(&m_in->size); + m_recv_size = size; if (size <= IPC_CHANNEL_MSG_SIZE) { - m_recv_size = size; - m_recv_data = m_in->data; + // m_large_recv.size() is always >= IPC_CHANNEL_MSG_SIZE + memcpy(m_large_recv.data(), m_in->data, size); } else { try { m_large_recv.resize(size); @@ -299,8 +312,6 @@ bool IPCChannelEnd::recv(int timeout_ms) noexcept FATAL_ERROR(errmsg.c_str()); } u8 *recv_data = m_large_recv.data(); - m_recv_size = size; - m_recv_data = recv_data; do { memcpy(recv_data, m_in->data, IPC_CHANNEL_MSG_SIZE); size -= IPC_CHANNEL_MSG_SIZE; diff --git a/src/threading/ipc_channel.h b/src/threading/ipc_channel.h index b085d286a..8247eddb5 100644 --- a/src/threading/ipc_channel.h +++ b/src/threading/ipc_channel.h @@ -48,7 +48,7 @@ with this program; if not, write to the Free Software Foundation, Inc., * other posix: uses posix mutex and condition variable */ -#define IPC_CHANNEL_MSG_SIZE 8192U +#define IPC_CHANNEL_MSG_SIZE 0x2000U struct IPCChannelBuffer { @@ -67,8 +67,11 @@ struct IPCChannelBuffer bool posted = false; // protected by mutex #endif #endif // !defined(_WIN32) - size_t size; - u8 data[IPC_CHANNEL_MSG_SIZE]; + // Note: If the other side isn't acting cooperatively, they might write to + // this at any times. So we must make sure to copy out the data once, and + // only access that copy. + size_t size = 0; + u8 data[IPC_CHANNEL_MSG_SIZE] = {}; IPCChannelBuffer(); ~IPCChannelBuffer(); @@ -121,7 +124,7 @@ public: } // Get the content of the last received message - inline const void *getRecvData() const noexcept { return m_recv_data; } + inline const void *getRecvData() const noexcept { return m_large_recv.data(); } inline size_t getRecvSize() const noexcept { return m_recv_size; } private: @@ -155,7 +158,8 @@ private: HANDLE m_sem_in; HANDLE m_sem_out; #endif - const void *m_recv_data = nullptr; size_t m_recv_size = 0; - std::vector m_large_recv; + // we always copy from the shared buffer into this + // (this buffer only grows) + std::vector m_large_recv = std::vector(IPC_CHANNEL_MSG_SIZE); }; diff --git a/src/unittest/test_threading.cpp b/src/unittest/test_threading.cpp index ed56abe61..28bde44ee 100644 --- a/src/unittest/test_threading.cpp +++ b/src/unittest/test_threading.cpp @@ -285,8 +285,8 @@ void TestThreading::testIPCChannel() IPCChannelEnd end_b = IPCChannelEnd::makeB(std::make_unique(stuff)); for (;;) { - end_b.recv(); - end_b.send(end_b.getRecvData(), end_b.getRecvSize()); + UASSERT(end_b.recv()); + UASSERT(end_b.send(end_b.getRecvData(), end_b.getRecvSize())); if (end_b.getRecvSize() == 0) break; } @@ -295,12 +295,12 @@ void TestThreading::testIPCChannel() char buf[20000] = {}; for (int i = sizeof(buf); i > 0; i -= 1000) { buf[i - 1] = 123; - end_a.exchange(buf, i); + UASSERT(end_a.exchange(buf, i)); UASSERTEQ(int, end_a.getRecvSize(), i); UASSERTEQ(int, ((const char *)end_a.getRecvData())[i - 1], 123); } - end_a.exchange(buf, 0); + UASSERT(end_a.exchange(buf, 0)); UASSERTEQ(int, end_a.getRecvSize(), 0); thread_b.join();