2023年8月6日 星期日

Linux Kernel(22.1)- My Socket Domain and Protocol


本章主要參考Add a new protocol to Linux Kernel寫一個自創新的socket protocol family小範例, 主要要填寫“struct proto” (/include/net/sock.h) 與“struct net_proto_family” (/include/linux/net.h)相關的operation,再分別用proto_register(struct proto *)與sock_register(struct net_proto_famil*)去跟系統註冊, 並將struct proto_ops分配給socket, 讓對應的system call都能找到對應的operation去執行

首先要先呼叫“proto_register()”跟系統註冊protocol handler.
struct my_sock {
  /* struct sock must be the first member of my_sock */
  struct sock sk;
  int channel;
};

static struct proto my_proto = {
  .name = "MYSOCK",
  .owner = THIS_MODULE,
  .obj_size = sizeof(struct my_sock),
};

static int __init myproto_init(void)
{
  int ret = -1;

  ret = proto_register(&my_proto, 0);
  if (ret) {
    mypr_err("Failed to register myprotocol\n");
    return ret;
  }
  ...
}

這個註冊動作只是把自訂的proto加入proto_list中, 我跳過這個註冊也不影響該範例, 有空再來研究細節吧, 註冊成功後可以在/proc/net/protocols中看見.
/ # cat /proc/net/protocols | grep MY
/ # insmod /lib/modules/5.15.0/extra/socket_demo.ko
socket_demo: loading out-of-tree module taints kernel.
NET: Registered PF_MCTP protocol family
myproto_init(#182)myprotocol module loaded
/ # cat /proc/net/protocols | grep MY
MYSOCK     504      0      -1   NI       0   no   socket_demo  n  n  n  n  n  n  n  n  n  n  n  n  n  n  n  n  n  n  n

接著要註冊socket layer的handler, 是透過sock_register()註冊到net_families[NPROTO=AF_MAX]中, 當user space呼叫socket()時, 就會透過sock_rgister()所掛載的create()創建對應的socket.
socket() /* userspace */
|-> SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol) /* kernel */
  |-> __sys_socket(family, type, protocol);
    |-> __sock_create(family, type, protocol, &sock);
      |-> __sock_create(current->nsproxy->net_ns, family, type, protocol, res, 0);
        |-> pf = rcu_dereference(net_families[family]);
        |-> err = pf->create(net, sock, protocol, kern);
  |->sock_map_fd(sock, flags & (O_CLOEXEC | O_NONBLOCK));

相對應的"sock_register()"代碼
#define PF_MYPROTO 45         // (AF_MAX - 1), 隨意給個我沒用的PROTO
#define AF_MYPROTO PF_MYPROTO

#define mypr_info(fmt, ...)  pr_info("%s(#%d)"fmt, __func__, __LINE__, ##__VA_ARGS__);
#define mypr_err(fmt, ...)  pr_err("%s(#%d)"fmt, __func__, __LINE__, ##__VA_ARGS__);

/* for user space */
struct sockaddr_my {
  int channel;
};

static const struct proto_ops my_proto_ops = {
  .family = PF_MYPROTO,
  .owner = THIS_MODULE,
  .bind = my_bind,
  .listen = my_listen,
  .accept = my_accept,
  .connect = my_connect,
  .release = my_release,
  .sendmsg = my_sendmsg,
  .recvmsg = my_recvmsg,
};

static int myproto_create(struct net *net, struct socket *sock, int protocol, int kern)
{
  struct sock *sk;
  struct my_sock *my_sock;
  // 這裡的alloc會把my_proto帶入, 這樣在alloc時, 就可以alloc "struct my_sock"大小的記憶體
  // struct my_sock的struct sock sk;可以用kernel的sk相關函數操作, 自定義部分再轉型成"my_sock"去操作
  sk = sk_alloc(net, PF_MYPROTO, GFP_KERNEL, &my_proto, kern);
  if (!sk) {
    mypr_err("sk_alloc failed\n");
    return -ENOMEM;
  }
  // 將socket operation掛上來, 屆時對應的system call就會呼叫到對應的socket operation
  sock->ops = &my_proto_ops;
  // struct sock *sk 剛alloc, 透過sock_init_data()做一下init, 並將sock與sk做關聯
  // sk->sk_socket = sock;
  sock_init_data(sock, sk);
  // sk已經透過sock_init_data()處理好後, 再轉型成my_sock做自定義操作
  my_sock = (struct my_sock *) sk;
  my_sock->channel = 999; // 範例而已, 沒特別意思
  mypr_info("default channel:%d\n", my_sock->channel);

  return 0;
}

static struct net_proto_family myproto_family = {
  .family = PF_MYPROTO,
  .create = myproto_create,
  .owner = THIS_MODULE,
};

static int __init myproto_init(void)
{
  ret = sock_register(&myproto_family);
  if (ret) {
    mypr_err("Failed to register myprotocol family\n");
    proto_unregister(&my_proto);
    return ret;
  }

  mypr_err("myprotocol module loaded\n");
  return 0;
}

下面舉幾個socket operation從user到kernel的socket operation的路徑
bind() /* userspace */
|-> SYSCALL_DEFINE3(bind, int, fd, struct sockaddr __user *, umyaddr, int, addrlen) // kernel space
  |-> _sys_bind(fd, umyaddr, addrlen);
    |-> sock = sockfd_lookup_light(fd, &err, &fput_needed);
    |-> sock->ops->bind(sock,(struct sockaddr *)&address, addrlen);
    
listen() // userspace
|-> SYSCALL_DEFINE2(listen, int, fd, int, backlog) // kernel space
  |-> __sys_listen(fd, backlog);
    |-> sock = sockfd_lookup_light(fd, &err, &fput_needed);
    |-> sock->ops->listen(sock,(struct sockaddr *)&address, addrlen);  
從上面的範例不難理解, 大概就是在system call(__sys_xx())時直接呼叫對應的socket operation, 但是, 用過user space的都知道, 也可以透過read()/write()呼叫對應的sendmsg()與recvmsg(), 主要是在__sys_socket()時, 透過sock_map_fd()將file operation掛上去, 其中的read()/write()就是對應到sendmsg()/recvmsg().
int sock_map_fd(struct socket *sock, int flags)
|-> sock_alloc_file(sock, flags, NULL);
  |-> alloc_file_pseudo(&socket_file_ops);
    |-> file = alloc_file(&path, flags, fops);
      |-> file->f_op = fop;
      
static const struct file_operations socket_file_ops = {
  .read_iter =    sock_read_iter,
  .write_iter =   sock_write_iter,
};

sock_read_iter(struct kiocb *iocb, struct iov_iter *to)
|-> sock_recvmsg(sock, &msg, msg.msg_flags);

sock_write_iter(struct kiocb *iocb, struct iov_iter *from)
|-> res = sock_sendmsg(sock, &msg);

這篇只有簡單的介紹一下相關的API, 所以底下的socket operation都只是簡單的印出訊息, sendmsg()則是將user資料印出, 而recvmsg()則是固定回傳"My test", 如果不支援的socket operation可以使用sock_no_xxx即可.
/* Bind socket to specified sockaddr. */
static int my_bind(struct socket *sock, struct sockaddr *saddr, int len)
{
  DECLARE_SOCKADDR(struct sockaddr_my *, addr, saddr);
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  mypr_info("sock->channel %d\n", my_sock->channel);
  if (len < sizeof(*addr)) {
    mypr_err("len of addr is small\n");
    return -EINVAL;
  }
  my_sock->channel = addr->channel;
  return 0;
}

static int my_listen(struct socket *sock, int len)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return sock_no_listen(sock, len);
}

static int my_accept(struct socket *sock, struct socket *newsock, int flags, bool kern)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return sock_no_accept(sock, newsock, flags, kern);
}

static int my_release(struct socket *sock)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return 0;
}

static int my_connect(struct socket *sock, struct sockaddr *saddr, int len, int flags)
{
  DECLARE_SOCKADDR(struct sockaddr_my *, addr, saddr);
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;

  if (len < sizeof(*addr)) {
    return -EINVAL;
  }
  return 0;
}

static int my_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  struct sk_buff *skb;
  int err;
  size_t copied;
  unsigned char buf[] = "My test";
  memcpy_to_msg(msg, buf, sizeof(buf));

  return sizeof(buf);
}

static int my_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  int err;
  unsigned *buf;
  mypr_info("len:%d, channel:%d\n", len, my_sock->channel);

  buf = kmalloc(len + 1, GFP_KERNEL);
  if (!buf) {
    return -ENOMEM;
  }
  // Safely copy data from user space to kernel space
  memset(buf, 0, len + 1);
  err = memcpy_from_msg(buf, msg, len);
  mypr_info("data: err:%d, msg:%s\n", err, (char *) buf);
  kfree(buf);

  return len;
}

完整的Module code
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/init.h>
#include <linux/socket.h>
#include <linux/net.h>
#include <linux/sockios.h>
#include <linux/netdevice.h>
#include <linux/errno.h>
#include <linux/proc_fs.h>
#include <linux/file.h>
#include <linux/fs.h>
#include <net/protocol.h>

#define PF_MYPROTO 45		// (AF_MAX - 1)
#define AF_MYPROTO PF_MYPROTO

#define mypr_info(fmt, ...)  pr_info("%s(#%d)"fmt, __func__, __LINE__, ##__VA_ARGS__);
#define mypr_err(fmt, ...)  pr_err("%s(#%d)"fmt, __func__, __LINE__, ##__VA_ARGS__);

#include <net/sock.h>
struct my_sock {
  /* struct sock must be the first member of my_sock */
  struct sock sk;
  int channel;
};

static inline struct my_sock *my_sock_sk(struct sock *sk)
{
  return container_of(sk, struct my_sock, sk);
}

/* for user space */
struct sockaddr_my {
  int channel;
};

static struct proto my_proto = {
  .name = "MYSOCK",
  .owner = THIS_MODULE,
  .obj_size = sizeof(struct my_sock),
};

/* Bind socket to specified sockaddr. */
static int my_bind(struct socket *sock, struct sockaddr *saddr, int len)
{
  DECLARE_SOCKADDR(struct sockaddr_my *, addr, saddr);
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  mypr_info("sock->channel %d\n", my_sock->channel);
  if (len < sizeof(*addr)) {
    mypr_err("len of addr is small\n");
    return -EINVAL;
  }
  my_sock->channel = addr->channel;
  return 0;
}

static int my_listen(struct socket *sock, int len)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return sock_no_listen(sock, len);
}

static int my_accept(struct socket *sock, struct socket *newsock, int flags, bool kern)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return sock_no_accept(sock, newsock, flags, kern);
}

static int my_release(struct socket *sock)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  mypr_info("sock->channel %d\n", my_sock->channel);
  return 0;
}

static int my_connect(struct socket *sock, struct sockaddr *saddr, int len, int flags)
{
  DECLARE_SOCKADDR(struct sockaddr_my *, addr, saddr);
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;

  if (len < sizeof(*addr)) {
    return -EINVAL;
  }
  return 0;
}

static int my_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, int flags)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  struct sk_buff *skb;
  int err;
  size_t copied;
  unsigned char buf[] = "My test";
  memcpy_to_msg(msg, buf, sizeof(buf));

  return sizeof(buf);
}

static int my_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
  struct my_sock *my_sock = my_sock_sk(sock->sk);
  struct sock *sk = sock->sk;
  int err;
  unsigned *buf;
  mypr_info("len:%d, channel:%d\n", len, my_sock->channel);

  buf = kmalloc(len + 1, GFP_KERNEL);
  if (!buf) {
    return -ENOMEM;
  }
  // Safely copy data from user space to kernel space
  memset(buf, 0, len + 1);
  err = memcpy_from_msg(buf, msg, len);
  mypr_info("data: err:%d, msg:%s\n", err, (char *) buf);
  kfree(buf);

  return len;
}

static const struct proto_ops my_proto_ops = {
  .family = PF_MYPROTO,
  .owner = THIS_MODULE,
  .bind = my_bind,
  .listen = my_listen,
  .accept = my_accept,
  .connect = my_connect,
  .release = my_release,
  .sendmsg = my_sendmsg,
  .recvmsg = my_recvmsg,
};

static int myproto_create(struct net *net, struct socket *sock, int protocol, int kern)
{
  struct sock *sk;
  struct my_sock *my_sock;
  sk = sk_alloc(net, PF_MYPROTO, GFP_KERNEL, &my_proto, kern);
  if (!sk) {
    mypr_err("sk_alloc failed\n");
    return -ENOMEM;
  }
  sock->ops = &my_proto_ops;
  sock_init_data(sock, sk);
  my_sock = (struct my_sock *) sk;
  my_sock->channel = 999;
  mypr_info("default channel:%d\n", my_sock->channel);

  return 0;
}

static struct net_proto_family myproto_family = {
  .family = PF_MYPROTO,
  .create = myproto_create,
  .owner = THIS_MODULE,
};

static int __init myproto_init(void)
{
  int ret = -1;

  ret = proto_register(&my_proto, 0);
  if (ret) {
    mypr_err("Failed to register myprotocol\n");
    return ret;
  }

  ret = sock_register(&myproto_family);
  if (ret) {
    mypr_err("Failed to register myprotocol family\n");
    proto_unregister(&my_proto);
    return ret;
  }

  mypr_err("myprotocol module loaded\n");
  return 0;
}

static void __exit myproto_exit(void)
{
  sock_unregister(PF_MYPROTO);
  proto_unregister(&my_proto);
  mypr_info("myprotocol module unloaded\n");
}

module_init(myproto_init);
module_exit(myproto_exit);

MODULE_LICENSE("GPL");

完整的User code
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <arpa/inet.h>

#define AF_MYPROTO 45
#define PF_MYPROTO AF_MYPROTO

struct sockaddr_my {
  int channel;
};

int main(int argc, char *argv[]) {
    int sfd, new_socket, ret;
    struct sockaddr_my saddr;
    char buf[128];

    // Create a socket
    printf("%s(#%d): socket\n", __FUNCTION__, __LINE__);
    sfd = socket(AF_MYPROTO, SOCK_STREAM, 0);
    if (sfd == -1) {
        perror("Socket creation failed");
        exit(EXIT_FAILURE);
    }

    // Set up the server address structure
    saddr.channel = 123;

    printf("%s(#%d): bind\n", __FUNCTION__, __LINE__);
    // Bind the socket to the specified port
    if (bind(sfd, (struct sockaddr *)&saddr, sizeof(saddr)) == -1) {
        perror("Bind failed");
    }

    printf("%s(#%d): listen\n", __FUNCTION__, __LINE__);
    // Listen for incoming connections
    if (listen(sfd, 1) == -1) {
        perror("Listen failed");
    }

    ret = write(sfd, argv[1], strlen(argv[1]));
    if (ret < 0) {
        perror("write");
	exit(0);
    }
    printf("write: %d\n", ret);

    memset(buf, 0, sizeof(buf));
    ret = read(sfd, buf, sizeof(buf));
    printf("read: %d/%s\n", ret, buf);

    // Close the server socket
    close(sfd);

    return 0;
}

執行結果
/ # insmod /lib/modules/5.15.0/extra/socket_demo.ko
socket_demo: loading out-of-tree module taints kernel.
NET: Registered PF_MCTP protocol family
myproto_init(#178)myprotocol module loaded
/ # /my_sock abc
main(#23): socket
myproto_create(#150)default channel:999
main(#33): bind
my_bind(#49)sock->channel 999
main(#39): listen
my_listen(#61)sock->channel 123
Listen failed: Operation not supported
my_sendmsg(#110)len:3, channel:123
my_sendmsg(#119)data: err:0, msg:abc
write: 3
read: 8/My test
my_release(#75)sock->channel 123


    參考資料:
  • Add a new protocol to Linux Kernel, https://linuxwarrior.wordpress.com/2008/12/02/add-a-new-protocol-to-linux-kernel/
  • https://lishiwen4.github.io/network/socket-interface-and-network-protocol
  • https://www.cnblogs.com/hellokitty2/p/10188376.html
  • https://liuhangbin.netlify.app/post/linux-socket/
  • https://hackmd.io/@rickywu0421/linux_networking_1




沒有留言:

張貼留言

熱門文章