#define _GNU_SOURCE

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <fcntl.h>

// 3 arbitrary structs
typedef struct {
    int a;
    int b;
} struct_small;

typedef struct {
    long a[16];
    int b;
} struct_medium;

typedef struct {
    long a[256];
    double b[128];
    char c[64];
} struct_large;

typedef struct {
    const char *name;
    size_t size;
} type_map;

int create_memfd(const char *name, size_t size)
{
    int fd = memfd_create(name, 0);
    if (fd < 0) {
        perror("memfd_create");
        exit(1);
    }

    if (ftruncate(fd, size) < 0) {
        perror("ftruncate");
        exit(1);
    }

    return fd;
}

void check_and_link(int fd, size_t size, type_map *types, int n)
{
    struct stat st;

    if (fstat(fd, &st) < 0)
        return;

    size_t actual = st.st_size;

    for (int i = 0; i < n; i++)
    {
        if (actual == types[i].size)
        {
            char linkname[256];
            snprintf(linkname,
                     sizeof(linkname),
                     "/dev/shm/%s",
                     types[i].name);

            unlink(linkname);

            char fdpath[256];
            snprintf(fdpath,
                     sizeof(fdpath),
                     "/proc/self/fd/%d",
                     fd);

            if (symlink(fdpath, linkname) == 0)
            {
                printf("[+] matched %s -> %s\n",
                       types[i].name,
                       linkname);
            }
        }
    }
}

int main()
{
    type_map types[] = {
        {"struct_small", sizeof(struct_small)},
        {"struct_medium", sizeof(struct_medium)},
        {"struct_large", sizeof(struct_large)}
    };

    int fd1 = create_memfd("small", sizeof(struct_small));
    int fd2 = create_memfd("medium", sizeof(struct_medium));
    int fd3 = create_memfd("large", sizeof(struct_large));

    int fds[] = {fd1, fd2, fd3};

    for (int i = 0; i < 3; i++)
    {
        check_and_link(fds[i],
                       0,
                       types,
                       3);
    }

    printf("done. check /dev/shm/\n");

    pause();
}
