简单学习netlink

最近稍微了解了下内核与应用程序之间通信相关的知识, 先简单熟悉了一下netlink. 有关它的相关介绍强烈推荐一下CU的这篇帖子. 感觉用这种方式有两个好处: 一是socket接口风格, 比较好上手; 二是内核态程序可以主动向应用层发起通信. 结合着这篇帖子, 自己尝试写了一段小代码熟悉一下.

因为我这边的内核版本比较新, 相关的api稍微有点改动. 另外真心觉得写内核态程序的时候最好放在虚拟机上搞, 因为很容易整出panic, 把我郁闷坏了. 下面是一个简单echo/reply的小例子.

首先是内核态的程序(在3.10以上内核版本上编译通过):

fun.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <linux/module.h>
#include <net/sock.h>
#include <linux/netlink.h>
#include <linux/skbuff.h>

#define NETLINK_FUNTRANSPORT 31

#define FUN_NL_GROUP 1

struct sock *fun_nl_sk = NULL;

static void
fun_nl_rcv_msg(struct sk_buff *skb)
{
    int              pid, res, msglen;
    char             buffer[256] = {0};
    struct nlmsghdr *onlh, *nnlh;
    struct sk_buff  *skb_out;

    printk(KERN_INFO "in %s\n", __func__);

    onlh    = nlmsg_hdr(skb);
    msglen  = nlmsg_len(onlh);       /* payload lenghth */
    pid     = onlh->nlmsg_pid;

    memcpy(buffer, nlmsg_data(onlh), msglen);
    printk(KERN_INFO "receive msg from pid(%d) size(%d): %s\n", pid, msglen, buffer);

    skb_out = nlmsg_new(msglen, 0);
    if (!skb_out) {
        printk(KERN_ERR "failed to allocate new skb\n");
        return;
    }

    nnlh = nlmsg_put(skb_out, 0, 0, NLMSG_DONE, msglen, 0);
    memcpy(nlmsg_data(nnlh), nlmsg_data(onlh), msglen);

    NETLINK_CB(skb_out).dst_group = 0;

    res = nlmsg_unicast(fun_nl_sk, skb_out, pid);
    if (res < 0) {
        printk(KERN_ERR "failed to send msg to user\n");
    }
}

static int __init fun_init(void)
{
    struct netlink_kernel_cfg cfg = {
        .input  = fun_nl_rcv_msg,
        .groups = FUN_NL_GROUP,
    };

    fun_nl_sk = netlink_kernel_create(&init_net, NETLINK_FUNTRANSPORT, &cfg);
    if (!fun_nl_sk) {
        printk(KERN_ERR "%s: register of receive handler failed\n", __func__);
        return -1;
    }

    printk(KERN_INFO "fun mod init\n");

    return 0;
}

static void __exit fun_exit(void)
{
    netlink_kernel_release(fun_nl_sk);
    printk(KERN_INFO "fun mod exit\n");
}

module_init(fun_init);
module_exit(fun_exit);

MODULE_LICENSE("FUN");

然后是一个互动的应用程序:

user.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <sys/socket.h>
#include <linux/netlink.h>
#include <errno.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

#define NETLINK_FUNTRANSPORT 31

#define MAX_LEN 256

/**
 *  * nlmsg_len - length of message payload
 *   * @nlh: netlink message header
 *    */
static inline int nlmsg_len(const struct nlmsghdr *nlh)
{
        return nlh->nlmsg_len - NLMSG_HDRLEN;
}


int main(int argc, char *argv[])
{
    int  sockfd, pid;
    size_t msglen;
    char buffer[MAX_LEN];
    struct sockaddr_nl  src_addr, dst_addr;
    struct nlmsghdr    *nlh;

    sockfd = socket(AF_NETLINK, SOCK_RAW, NETLINK_FUNTRANSPORT);
    if (sockfd < 0) {
        fprintf(stderr, "socket failed(%d)\n", errno);
        exit(-1);
    }

    pid = getpid();

    bzero(&src_addr, sizeof(src_addr));
    src_addr.nl_family = AF_NETLINK;
    src_addr.nl_pid    = pid;
    bind(sockfd, (struct sockaddr *)&src_addr, sizeof(src_addr));

    bzero(&dst_addr, sizeof(dst_addr));
    dst_addr.nl_family = AF_NETLINK;
    dst_addr.nl_pid    = 0;

    while (fgets(buffer, MAX_LEN, stdin) != NULL) {
        msglen = strlen(buffer);
        buffer[--msglen] = 0;    /* chomp */
        fprintf(stdout, "STDIN: %s(%lu)\n", buffer, msglen);

        nlh = (struct nlmsghdr *)malloc(NLMSG_SPACE(msglen)); /* NLMSG_SPACE return the total size */
        memset(nlh, 0, NLMSG_SPACE(msglen));
        nlh->nlmsg_len   = NLMSG_SPACE(msglen);
        nlh->nlmsg_pid   = pid;
        nlh->nlmsg_flags = 0;
        strncpy(NLMSG_DATA(nlh), buffer, msglen);

        sendto(sockfd, (void *)nlh, nlh->nlmsg_len, 0, (struct sockaddr *)&dst_addr, (socklen_t)sizeof(dst_addr));

        /* reuse the nlh */
        recvfrom(sockfd, (void *)nlh, nlh->nlmsg_len, 0, NULL, NULL);
        strncpy(buffer, NLMSG_DATA(nlh), nlmsg_len(nlh));
        buffer[nlmsg_len(nlh) + 1] = 0;
        fprintf(stdout, "KERNEL: %s\n", buffer);

        free(nlh);
    }
}

一个值得注意的地方是应用态的程序是无法使用诸如nlmsg_len的函数, 因此抄了一份出来.

Comments