Skip to content

Commit

Permalink
Switch from polling on FIFOs to signal (#12)
Browse files Browse the repository at this point in the history
* uids are pids

* remove fifo, go for signal

* that's more what i'm saying

* oops, that's wrong

* add some todos

* make signal poll work

* bring randomness back

* fix case in msgq when subs get evicted while polling

* check for ready messages before poll starts

* No pr builds

* use nanosleep with remainder

* this should pass the test
  • Loading branch information
geohot authored and pd0wm committed Nov 22, 2019
1 parent e25bba7 commit 347a866
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 114 deletions.
2 changes: 2 additions & 0 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pr: none

pool:
vmImage: 'ubuntu-16.04'

Expand Down
155 changes: 50 additions & 105 deletions messaging/msgq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,35 @@
#include <cstring>
#include <cstdint>
#include <chrono>
#include <random>
#include <algorithm>
#include <cstdlib>

#include <csignal>
#include <random>

#include <poll.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/syscall.h>
#include <fcntl.h>
#include <unistd.h>


#include <stdio.h>

#include "msgq.hpp"

void sigusr1_handler(int signal) {
assert(signal == SIGUSR1);
}

uint64_t msgq_get_uid(void){
std::random_device rd("/dev/urandom");
std::uniform_int_distribution<uint64_t> distribution(0,std::numeric_limits<uint32_t>::max());

uint64_t uid = distribution(rd) << 32 | syscall(SYS_gettid);
return uid;
}

int msgq_msg_init_size(msgq_msg_t * msg, size_t size){
msg->size = size;
Expand Down Expand Up @@ -69,6 +80,8 @@ void msgq_wait_for_subscriber(msgq_queue_t *q){
int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){
assert(size < 0xFFFFFFFF); // Buffer must be smaller than 2^32 bytes

std::signal(SIGUSR1, sigusr1_handler);

const char * prefix = "/dev/shm/";
char * full_path = new char[strlen(path) + strlen(prefix) + 1];
strcpy(full_path, prefix);
Expand Down Expand Up @@ -114,23 +127,11 @@ int msgq_new_queue(msgq_queue_t * q, const char * path, size_t size){

q->endpoint = path;
q->read_conflate = false;
q->read_fifo = -1;

return 0;
}

void msgq_close_queue(msgq_queue_t *q){
if (q->read_fifo >= 0){
close(q->read_fifo);
remove(q->read_fifo_path.c_str());
}

for (uint64_t i = 0; i < NUM_READERS; i++){
if (q->read_fifos[i] >= 0){
close(q->read_fifos[i]);
}
}

if (q->mmap_p != NULL){
munmap(q->mmap_p, q->size + sizeof(msgq_header_t));
}
Expand All @@ -139,18 +140,13 @@ void msgq_close_queue(msgq_queue_t *q){

void msgq_init_publisher(msgq_queue_t * q) {
std::cout << "Starting publisher" << std::endl;

std::random_device rd("/dev/urandom");
std::uniform_int_distribution<uint64_t> distribution(0,std::numeric_limits<uint64_t>::max());
uint64_t uid = distribution(rd);
uint64_t uid = msgq_get_uid();

*q->write_uid = uid;
*q->num_readers = 0;

for (size_t i = 0; i < NUM_READERS; i++){
*q->read_valids[i] = false;
q->read_fifos[i] = -1;
q->read_fifos_uid[i] = 0;
*q->read_uids[i] = 0;
}

Expand All @@ -161,9 +157,7 @@ void msgq_init_subscriber(msgq_queue_t * q) {
assert(q != NULL);
assert(q->num_readers != NULL);

std::random_device rd("/dev/urandom");
std::uniform_int_distribution<uint64_t> distribution(0,std::numeric_limits<uint64_t>::max());
uint64_t uid = distribution(rd);
uint64_t uid = msgq_get_uid();

// Get reader id
while (true){
Expand All @@ -177,7 +171,12 @@ void msgq_init_subscriber(msgq_queue_t * q) {

for (size_t i = 0; i < NUM_READERS; i++){
*q->read_valids[i] = false;

uint64_t old_uid = *q->read_uids[i];
*q->read_uids[i] = 0;

// Wake up reader in case they are in a poll
syscall(SYS_tkill, old_uid & 0xFFFFFFFF, SIGUSR1);
}

continue;
Expand All @@ -200,26 +199,6 @@ void msgq_init_subscriber(msgq_queue_t * q) {
}
}

for (size_t i = 0; i < NUM_READERS; i++){
q->read_fifos[i] = -1;
}

q->read_fifo_path = "/dev/shm/fifo-";
q->read_fifo_path += std::to_string(q->read_uid_local);

std::cout << q->read_fifo_path << std::endl;
int r = mkfifo(q->read_fifo_path.c_str(), 0777);
if (r != 0)
perror("Fifo: ");
assert(r == 0);

q->read_fifo = open(q->read_fifo_path.c_str(), O_RDWR | O_NONBLOCK);

// Fysnc so the fifo shows up in the directory
auto shm_fd = open("/dev/shm", O_RDONLY);
fsync(shm_fd);
close(shm_fd);

std::cout << "New subscriber id: " << q->reader_id << " uid: " << q->read_uid_local << " " << q->endpoint << std::endl;
msgq_reset_reader(q);
}
Expand All @@ -231,7 +210,6 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
assert(q->write_uid_local == *q->write_uid);
}


uint64_t total_msg_size = ALIGN(msg->size + sizeof(int64_t));

// We need to fit at least three messages in the queue,
Expand Down Expand Up @@ -303,44 +281,12 @@ int msgq_msg_send(msgq_msg_t * msg, msgq_queue_t *q){
for (uint64_t i = 0; i < num_readers; i++){
uint64_t reader_uid = *q->read_uids[i];

// Open fifo when not set, or when reader changes
if (q->read_fifos[i] == -1 || q->read_fifos_uid[i] != reader_uid){
// Close old reader fifo
if (q->read_fifos[i] >= 0){
close(q->read_fifos[i]);
}

q->read_fifos_uid[i] = reader_uid;

std::string path = "/dev/shm/fifo-";
path += std::to_string(reader_uid);

q->read_fifos[i] = open(path.c_str(), O_RDWR | O_NONBLOCK);
if(q->read_fifos[i] < 0){
std::cout << "Fifo: " << path << std::endl;
perror("Error opening fifo");
}
}

uint8_t m = 1;
write(q->read_fifos[i], &m, 1);
syscall(SYS_tkill, reader_uid & 0xFFFFFFFF, SIGUSR1);
}

return msg->size;
}

int msgq_get_fd(msgq_queue_t * q){
int id = q->reader_id;
assert(id >= 0); // Make sure subscriber is initialized

if (q->read_uid_local != *q->read_uids[id]){
std::cout << q->endpoint << ": Reader was evicted, reconnecting" << std::endl;
msgq_init_subscriber(q);
}

return q->read_fifo;
}


int msgq_msg_ready(msgq_queue_t * q){
start:
Expand Down Expand Up @@ -380,10 +326,6 @@ int msgq_msg_recv(msgq_msg_t * msg, msgq_queue_t * q){
goto start;
}

// Read one byte from fifo
char buf[1];
read(q->read_fifo, buf, 1);

// Check valid
if (!*q->read_valids[id]){
msgq_reset_reader(q);
Expand Down Expand Up @@ -465,34 +407,37 @@ int msgq_poll(msgq_pollitem_t * items, size_t nitems, int timeout){
assert(timeout >= 0);

int num = 0;
struct pollfd * fds = (struct pollfd *)calloc(nitems, sizeof(struct pollfd));

// Build poll structure
for (size_t i = 0; i < nitems; i++){
fds[i].fd = msgq_get_fd(items[i].q);
fds[i].events = POLLIN;

// Check if message is ready in case we get out of sync with the pipe
if (msgq_msg_ready(items[i].q)){
items[i].revents = 1;
timeout = 0; // No timeout if a message is ready
num++;
} else {
items[i].revents = 0;
}

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
items[i].revents = msgq_msg_ready(items[i].q);
if (items[i].revents) num++;
}

poll(fds, nitems, timeout);
int ms = (timeout == -1) ? 100 : timeout;
struct timespec ts;
ts.tv_sec = ms / 1000;
ts.tv_nsec = (ms % 1000) * 1000 * 1000;


// Read poll results
for (size_t i = 0; i < nitems; i++){
if (fds[i].revents && !items[i].revents){
// Don't add it if it was already added
num++;
items[i].revents = 1;
while (num == 0) {
int ret;

ret = nanosleep(&ts, &ts);

// Check if messages ready
for (size_t i = 0; i < nitems; i++) {
if (items[i].revents == 0 && msgq_msg_ready(items[i].q)){
num += 1;
items[i].revents = 1;
}
}

// exit if we had a timeout and the sleep finished
if (timeout != -1 && ret == 0){
break;
}
}

free(fds);
return num;
}
7 changes: 0 additions & 7 deletions messaging/msgq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,7 @@ struct msgq_queue_t {
uint64_t write_uid_local;

bool read_conflate;
int read_fifo;

// Fifo fds and corresponding reader uid
int read_fifos[NUM_READERS];
uint64_t read_fifos_uid[NUM_READERS];

std::string endpoint;
std::string read_fifo_path;
};

struct msgq_msg_t {
Expand Down
9 changes: 7 additions & 2 deletions messaging/tests/test_poller.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def poller():
sub.connect(context, 'controlsState')
p.registerSocket(sub)

socks = p.poll(1000)
socks = p.poll(10000)
r = [s.receive(non_blocking=True) for s in socks]

return r
Expand Down Expand Up @@ -44,7 +44,6 @@ def test_poll_once(self):

self.assertEqual(result, [b"a"])

@unittest.skipIf(os.environ.get('MSGQ'), "fails under msgq")
def test_poll_and_create_many_subscribers(self):
context = messaging.Context()

Expand All @@ -59,6 +58,8 @@ def test_poll_and_create_many_subscribers(self):
for _ in range(10):
messaging.SubSocket().connect(c, 'controlsState')

time.sleep(0.1)

# Send message
pub.send("a")

Expand All @@ -69,3 +70,7 @@ def test_poll_and_create_many_subscribers(self):
context.term()

self.assertEqual(result, [b"a"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 347a866

Please sign in to comment.