לדלג לתוכן

11.2 שרת TCP מתקדם פתרון

פתרון - שרת TCP מתקדם - advanced TCP server

פתרון 1 - שרת הד עם fork

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <signal.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT 8080
#define BUFFER_SIZE 1024

void handle_client(int client_fd)
{
    char buffer[BUFFER_SIZE];
    ssize_t n;
    while ((n = read(client_fd, buffer, sizeof(buffer))) > 0) {
        write(client_fd, buffer, n);
    }
    close(client_fd);
    exit(0);
}

int main(void)
{
    signal(SIGCHLD, SIG_IGN);

    int server_fd = socket(AF_INET, SOCK_STREAM, 0);
    if (server_fd == -1) {
        perror("socket");
        exit(1);
    }

    int opt = 1;
    setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(PORT);
    server_addr.sin_addr.s_addr = INADDR_ANY;

    if (bind(server_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
        perror("bind");
        exit(1);
    }

    if (listen(server_fd, 10) == -1) {
        perror("listen");
        exit(1);
    }
    printf("fork echo server on port %d\n", PORT);

    while (1) {
        struct sockaddr_in client_addr;
        socklen_t client_len = sizeof(client_addr);
        int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len);
        if (client_fd == -1) {
            perror("accept");
            continue;
        }

        char client_ip[INET_ADDRSTRLEN];
        inet_ntop(AF_INET, &client_addr.sin_addr, client_ip, sizeof(client_ip));

        pid_t pid = fork();
        if (pid == -1) {
            perror("fork");
            close(client_fd);
            continue;
        }

        if (pid == 0) {
            close(server_fd);
            handle_client(client_fd);
        }

        printf("child pid %d handling client %s:%d\n",
               pid, client_ip, ntohs(client_addr.sin_port));
        close(client_fd);
    }

    close(server_fd);
    return 0;
}

signal(SIGCHLD, SIG_IGN) אומר למערכת ההפעלה לנקות תהליכי ילד אוטומטית. בלי זה, כל ילד שמסיים הופך לzombie (תהליך שנגמר אבל עדיין תופס כניסה בטבלת התהליכים) עד שהאב קורא ל-wait.


פתרון 2 - שרת צ'אט עם threads

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <pthread.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT 8080
#define BUFFER_SIZE 1024
#define MAX_CLIENTS 100

int client_fds[MAX_CLIENTS];
int client_count = 0;
pthread_mutex_t clients_mutex = PTHREAD_MUTEX_INITIALIZER;

void add_client(int fd)
{
    pthread_mutex_lock(&clients_mutex);
    if (client_count < MAX_CLIENTS) {
        client_fds[client_count] = fd;
        client_count++;
    }
    pthread_mutex_unlock(&clients_mutex);
}

void remove_client(int fd)
{
    pthread_mutex_lock(&clients_mutex);
    for (int i = 0; i < client_count; i++) {
        if (client_fds[i] == fd) {
            client_fds[i] = client_fds[client_count - 1];
            client_count--;
            break;
        }
    }
    pthread_mutex_unlock(&clients_mutex);
}

void broadcast(int sender_fd, const char *msg, int len)
{
    pthread_mutex_lock(&clients_mutex);
    for (int i = 0; i < client_count; i++) {
        if (client_fds[i] != sender_fd) {
            write(client_fds[i], msg, len);
        }
    }
    pthread_mutex_unlock(&clients_mutex);
}

void *handle_client(void *arg)
{
    int client_fd = *(int *)arg;
    free(arg);

    add_client(client_fd);

    char buffer[BUFFER_SIZE];
    ssize_t n;
    while ((n = read(client_fd, buffer, sizeof(buffer))) > 0) {
        buffer[n] = '\0';
        printf("[fd %d] says: %s", client_fd, buffer);
        broadcast(client_fd, buffer, n);
    }

    printf("[fd %d] disconnected\n", client_fd);
    remove_client(client_fd);
    close(client_fd);
    return NULL;
}

int main(void)
{
    int server_fd = socket(AF_INET, SOCK_STREAM, 0);
    if (server_fd == -1) {
        perror("socket");
        exit(1);
    }

    int opt = 1;
    setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(PORT);
    server_addr.sin_addr.s_addr = INADDR_ANY;

    if (bind(server_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
        perror("bind");
        exit(1);
    }

    if (listen(server_fd, 10) == -1) {
        perror("listen");
        exit(1);
    }
    printf("chat server on port %d\n", PORT);

    while (1) {
        struct sockaddr_in client_addr;
        socklen_t client_len = sizeof(client_addr);
        int client_fd = accept(server_fd, (struct sockaddr *)&client_addr, &client_len);
        if (client_fd == -1) {
            perror("accept");
            continue;
        }

        char client_ip[INET_ADDRSTRLEN];
        inet_ntop(AF_INET, &client_addr.sin_addr, client_ip, sizeof(client_ip));
        printf("client connected: %s:%d (fd %d)\n",
               client_ip, ntohs(client_addr.sin_port), client_fd);

        int *fd_ptr = malloc(sizeof(int));
        *fd_ptr = client_fd;

        pthread_t thread;
        if (pthread_create(&thread, NULL, handle_client, fd_ptr) != 0) {
            perror("pthread_create");
            close(client_fd);
            free(fd_ptr);
            continue;
        }
        pthread_detach(thread);
    }

    close(server_fd);
    return 0;
}

נקודות חשובות:

  • המוטקס מגן על מערך client_fds ועל client_count. בלעדיו, שני threads יכולים לנסות להוסיף/להסיר לקוח בו-זמנית ולהשחית את המערך.
  • פונקציית broadcast נועלת את המוטקס לכל משך השליחה. אם יש הרבה לקוחות, זה יכול ליצור צוואר בקבוק. פתרון מתקדם יותר: לשמור עותק מקומי של הרשימה ולשחרר את המוטקס לפני השליחה.
  • מעבירים את ה-fd דרך malloc ולא כמצביע לstack - כדי למנוע race condition כמו שהסברנו בהרצאה.

פתרון 3 - שרת הד עם select

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/select.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT 8080
#define MAX_CLIENTS 100
#define BUFFER_SIZE 1024

int main(void)
{
    int server_fd = socket(AF_INET, SOCK_STREAM, 0);
    if (server_fd == -1) {
        perror("socket");
        exit(1);
    }

    int opt = 1;
    setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(PORT);
    server_addr.sin_addr.s_addr = INADDR_ANY;

    if (bind(server_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
        perror("bind");
        exit(1);
    }

    if (listen(server_fd, 10) == -1) {
        perror("listen");
        exit(1);
    }
    printf("select echo server on port %d\n", PORT);

    int clients[MAX_CLIENTS];
    for (int i = 0; i < MAX_CLIENTS; i++)
        clients[i] = -1;

    while (1) {
        fd_set read_fds;
        FD_ZERO(&read_fds);
        FD_SET(server_fd, &read_fds);
        int max_fd = server_fd;

        for (int i = 0; i < MAX_CLIENTS; i++) {
            if (clients[i] != -1) {
                FD_SET(clients[i], &read_fds);
                if (clients[i] > max_fd)
                    max_fd = clients[i];
            }
        }

        int ready = select(max_fd + 1, &read_fds, NULL, NULL, NULL);
        if (ready == -1) {
            perror("select");
            break;
        }

        if (FD_ISSET(server_fd, &read_fds)) {
            struct sockaddr_in client_addr;
            socklen_t client_len = sizeof(client_addr);
            int client_fd = accept(server_fd,
                (struct sockaddr *)&client_addr, &client_len);
            if (client_fd != -1) {
                char ip[INET_ADDRSTRLEN];
                inet_ntop(AF_INET, &client_addr.sin_addr, ip, sizeof(ip));
                printf("new client: %s:%d (fd %d)\n",
                       ip, ntohs(client_addr.sin_port), client_fd);

                int added = 0;
                for (int i = 0; i < MAX_CLIENTS; i++) {
                    if (clients[i] == -1) {
                        clients[i] = client_fd;
                        added = 1;
                        break;
                    }
                }
                if (!added) {
                    printf("max clients reached\n");
                    close(client_fd);
                }
            }
        }

        for (int i = 0; i < MAX_CLIENTS; i++) {
            if (clients[i] != -1 && FD_ISSET(clients[i], &read_fds)) {
                char buffer[BUFFER_SIZE];
                ssize_t n = read(clients[i], buffer, sizeof(buffer));
                if (n <= 0) {
                    printf("client fd %d disconnected\n", clients[i]);
                    close(clients[i]);
                    clients[i] = -1;
                } else {
                    /* הוספת prefix עם מספר הfd */
                    char response[BUFFER_SIZE + 20];
                    snprintf(response, sizeof(response), "[fd %d] ", clients[i]);
                    int prefix_len = strlen(response);
                    memcpy(response + prefix_len, buffer, n);
                    write(clients[i], response, prefix_len + n);
                }
            }
        }
    }

    close(server_fd);
    return 0;
}

שימו לב שהלולאה הראשית בונה מחדש את read_fds בכל סיבוב. זה הכרחי כי select משנה את ה-fd_set - אחרי הקריאה, רק הfd-ים שמוכנים נשארים.


פתרון 4 - שרת הד עם epoll

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define PORT 8080
#define MAX_EVENTS 64
#define BUFFER_SIZE 1024

struct client_data {
    int fd;
    long bytes_received;
    char ip[INET_ADDRSTRLEN];
    int port;
};

int main(void)
{
    int server_fd = socket(AF_INET, SOCK_STREAM, 0);
    if (server_fd == -1) {
        perror("socket");
        exit(1);
    }

    int opt = 1;
    setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt));

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(PORT);
    server_addr.sin_addr.s_addr = INADDR_ANY;

    if (bind(server_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
        perror("bind");
        exit(1);
    }

    if (listen(server_fd, 10) == -1) {
        perror("listen");
        exit(1);
    }
    printf("epoll echo server on port %d\n", PORT);

    int epoll_fd = epoll_create1(0);
    if (epoll_fd == -1) {
        perror("epoll_create1");
        exit(1);
    }

    /* הוספת סוקט השרת - משתמשים ב-data.fd כי זה השרת, לא צריך struct */
    struct epoll_event ev;
    ev.events = EPOLLIN;
    ev.data.fd = server_fd;
    epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &ev);

    struct epoll_event events[MAX_EVENTS];

    while (1) {
        int n = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
        if (n == -1) {
            perror("epoll_wait");
            break;
        }

        for (int i = 0; i < n; i++) {
            if (events[i].data.fd == server_fd) {
                /* לקוח חדש */
                struct sockaddr_in client_addr;
                socklen_t client_len = sizeof(client_addr);
                int client_fd = accept(server_fd,
                    (struct sockaddr *)&client_addr, &client_len);
                if (client_fd == -1) {
                    perror("accept");
                    continue;
                }

                /* יצירת מבנה נתונים ללקוח */
                struct client_data *cd = malloc(sizeof(struct client_data));
                cd->fd = client_fd;
                cd->bytes_received = 0;
                inet_ntop(AF_INET, &client_addr.sin_addr, cd->ip, sizeof(cd->ip));
                cd->port = ntohs(client_addr.sin_port);

                printf("new client: %s:%d (fd %d)\n", cd->ip, cd->port, client_fd);

                /* הוספה ל-epoll עם מצביע ל-struct */
                ev.events = EPOLLIN;
                ev.data.ptr = cd;
                if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &ev) == -1) {
                    perror("epoll_ctl add");
                    close(client_fd);
                    free(cd);
                }

            } else {
                /* נתונים מלקוח */
                struct client_data *cd = (struct client_data *)events[i].data.ptr;
                char buffer[BUFFER_SIZE];
                ssize_t bytes = read(cd->fd, buffer, sizeof(buffer));

                if (bytes <= 0) {
                    printf("client %s:%d (fd %d) disconnected. total bytes: %ld\n",
                           cd->ip, cd->port, cd->fd, cd->bytes_received);
                    epoll_ctl(epoll_fd, EPOLL_CTL_DEL, cd->fd, NULL);
                    close(cd->fd);
                    free(cd);
                } else {
                    cd->bytes_received += bytes;
                    write(cd->fd, buffer, bytes);
                }
            }
        }
    }

    close(epoll_fd);
    close(server_fd);
    return 0;
}

נקודה חשובה: כשמשתמשים ב-event.data.ptr (מצביע ל-struct) במקום event.data.fd, צריך לשים לב שזו union - אפשר להשתמש רק באחד מהם. לסוקט השרת השתמשנו ב-data.fd ולסוקטי הלקוחות השתמשנו ב-data.ptr. כדי להבחין ביניהם, בדקנו אם events[i].data.fd == server_fd. זה עובד כי כשהאירוע הוא מסוקט השרת, data.fd מכיל את server_fd.


פתרון 5 - השוואת ביצועים

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <time.h>
#include <pthread.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define SERVER_IP "127.0.0.1"
#define SERVER_PORT 8080
#define MSG_SIZE 64
#define MESSAGES_PER_CLIENT 100

struct thread_args {
    int id;
    int messages;
};

void *benchmark_thread(void *arg)
{
    struct thread_args *ta = (struct thread_args *)arg;

    int sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if (sockfd == -1) {
        perror("socket");
        return NULL;
    }

    struct sockaddr_in server_addr;
    memset(&server_addr, 0, sizeof(server_addr));
    server_addr.sin_family = AF_INET;
    server_addr.sin_port = htons(SERVER_PORT);
    inet_pton(AF_INET, SERVER_IP, &server_addr.sin_addr);

    if (connect(sockfd, (struct sockaddr *)&server_addr, sizeof(server_addr)) == -1) {
        perror("connect");
        close(sockfd);
        return NULL;
    }

    char msg[MSG_SIZE];
    char buffer[MSG_SIZE * 2];
    memset(msg, 'A', MSG_SIZE);

    for (int i = 0; i < ta->messages; i++) {
        write(sockfd, msg, MSG_SIZE);
        int total = 0;
        while (total < MSG_SIZE) {
            ssize_t n = read(sockfd, buffer + total, MSG_SIZE - total);
            if (n <= 0) break;
            total += n;
        }
    }

    close(sockfd);
    return NULL;
}

int main(int argc, char *argv[])
{
    if (argc != 2) {
        fprintf(stderr, "usage: %s <num_clients>\n", argv[0]);
        exit(1);
    }

    int num_clients = atoi(argv[1]);
    int messages = MESSAGES_PER_CLIENT;

    printf("benchmark: %d clients, %d messages each, %d bytes per message\n",
           num_clients, messages, MSG_SIZE);

    pthread_t *threads = malloc(num_clients * sizeof(pthread_t));
    struct thread_args *args = malloc(num_clients * sizeof(struct thread_args));

    struct timespec start, end;
    clock_gettime(CLOCK_MONOTONIC, &start);

    /* יצירת כל ה-threads */
    for (int i = 0; i < num_clients; i++) {
        args[i].id = i;
        args[i].messages = messages;
        if (pthread_create(&threads[i], NULL, benchmark_thread, &args[i]) != 0) {
            perror("pthread_create");
            num_clients = i;
            break;
        }
    }

    /* המתנה לסיום */
    for (int i = 0; i < num_clients; i++) {
        pthread_join(threads[i], NULL);
    }

    clock_gettime(CLOCK_MONOTONIC, &end);

    double elapsed = (end.tv_sec - start.tv_sec) +
                     (end.tv_nsec - start.tv_nsec) / 1e9;

    long total_messages = (long)num_clients * messages;
    long total_bytes = total_messages * MSG_SIZE;

    printf("results:\n");
    printf("  time: %.3f seconds\n", elapsed);
    printf("  total messages: %ld\n", total_messages);
    printf("  total bytes: %ld\n", total_bytes);
    printf("  messages/sec: %.0f\n", total_messages / elapsed);
    printf("  throughput: %.2f MB/sec\n", total_bytes / elapsed / 1e6);

    free(threads);
    free(args);
    return 0;
}

קומפילציה:

gcc -o benchmark benchmark.c -pthread

הרצה:

# בטרמינל אחד - הפעלת השרת:
./epoll_server

# בטרמינל אחר - הרצת הבדיקה:
./benchmark 10
./benchmark 100
./benchmark 500

בתוצאות תראו הבדלים ברורים: שרת fork יהיה האיטי ביותר (בגלל העלות של יצירת תהליכים), select יהיה בינוני, ו-epoll יהיה המהיר ביותר - במיוחד עם מספר גבוה של לקוחות.