【Jabberd2源码剖析系列 util (不包括nad, xhash, xdata, jid, config, stanza)】

xmpp使用util/pool作为内存池, 相比于常见的内存池模式, xmpp的pool使用了一种个性的设计: 一个pool为一个对象而生, 而对象随着pool的消亡而释放, 一个程序里可能有成百数千的pool, 每个pool管理一个小小的对象, 并且分配过的内存是不能重用的, 也没必要一点点的释放, 把整个pool销毁是唯一的释放方法.

 

1, 数据结构:

pool_cleanup_t是内存的释放回调函数, 负责真正的free内存.

pheap是存储内存块的结构体, 每个malloc返回的内存对应一个pheap, size表示malloc(size)的size, used表示使用了多少.

pfree是内存释放链表, 每个结点里存储了一个pheap以及对应的pool_cleanup_t, 后者负责前者的释放.

pool_t就是内存池了, 其中heap表示当前可用的内存块, cleanup与cleanup_tail表示内存释放链表的头与尾, size表示内存池至今总共分配过的内存.

到底pool是怎么工作的呢?  简单的说, 按heap分配内存并压入pfree链表, 内存分配从heap中切割得到, 最终遍历pfree链表释放所有分配过的heap.

/** 
 * pool_cleanup_t - callback type which is associated
 * with a pool entry; invoked when the pool entry is 
 * free'd 
 **/
typedef void (*pool_cleanup_t)(void *arg);

/** 
 * pheap - singular allocation of memory 
 **/
struct pheap
{
    void *block;
    int size, used;
};

/** 
 * pfree - a linked list node which stores an
 * allocation chunk, plus a callback 
 **/
struct pfree
{
    pool_cleanup_t f;
    void *arg;
    struct pheap *heap;
    struct pfree *next;
};

/**
 * pool - base node for a pool. Maintains a linked list
 * of pool entries (pfree)
 **/
typedef struct pool_struct
{
    int size;
    struct pfree *cleanup;
    struct pfree *cleanup_tail;
    struct pheap *heap;
#ifdef POOL_DEBUG
    char name[8], zone[32];
    int lsize;
#endif
} _pool, *pool_t;

2, API实现分析:

先贴出所有暴露给用户的API:

JABBERD2_API pool_t _pool_new(char *file, int line); /* new pool :) */
JABBERD2_API pool_t _pool_new_heap(int size, char *file, int line); /* creates a new memory pool with an initial heap size */
JABBERD2_API void *pmalloc(pool_t, int size); /* wrapper around malloc, takes from the pool, cleaned up automatically */
JABBERD2_API void *pmalloc_x(pool_t p, int size, char c); /* Wrapper around pmalloc which prefils buffer with c */
JABBERD2_API void *pmalloco(pool_t p, int size); /* YAPW for zeroing the block */
JABBERD2_API char *pstrdup(pool_t p, const char *src); /* wrapper around strdup, gains mem from pool */
JABBERD2_API char *pstrdupx(pool_t p, const char *src, int len); /* use given len */
JABBERD2_API void pool_stat(int full); /* print to stderr the changed pools and reset */
JABBERD2_API void pool_cleanup(pool_t p, pool_cleanup_t fn, void *arg); /* calls f(arg) before the pool is freed during cleanup */
JABBERD2_API void pool_free(pool_t p); /* calls the cleanup functions, frees all the data on the pool, and deletes the pool itself */
JABBERD2_API int pool_size(pool_t p); /* returns total bytes allocated in this pool */

 

1, _pool_new与_pool_new_heap: 这俩是一组的, 前者不预分配内存, 后者预分配内存, 但显然命名容易令人误入歧途.

前者创建pool_t结构体, pfree(回收内存链表), heap(切割内存的堆), size(总分配尺寸)设置为0, 然后就返回了.

后者调用前者创建pool_t结构体, 但紧随其后预分配了一个heap存储到pool_t->heap中, 显然有了内存池的效果, 但并不是说前者创建的pool就没法用, 后面继续看.

/** make an empty pool */
pool_t _pool_new(char *zone, int line)
{
pool_t p;
while((p = _pool__malloc(sizeof(_pool))) == NULL) sleep(1);
p->cleanup = NULL;
p->heap = NULL;
p->size = 0;

#ifdef POOL_DEBUG
p->lsize = -1;
p->zone[0] = '\0';
snprintf(p->zone, sizeof(p->zone), "%s:%i", zone, line);
sprintf(p->name,"%X",(int)p);

if(pool__disturbed == NULL)
{
pool__disturbed = (xht)1; /* reentrancy flag! */
pool__disturbed = xhash_new(POOL_NUM);
}
if(pool__disturbed != (xht)1)
xhash_put(pool__disturbed,p->name,p);
#endif

return p;
}

pool_t _pool_new_heap(int size, char *zone, int line)
{
pool_t p;
p = _pool_new(zone, line);
p->heap = _pool_heap(p,size);
return p;
}

既然用到了_pool_heap, 接着看它实现:先malloc一个pheap结构体, 再malloc一块预分配的内存到phead->block, 并调用_pool_free函数把一个pfree结点插入到pool->cleanup链表中, 其中内存释放回调函数是_pool_heap_free, pheap作为_pool_free第三个参数传入, 并在将来释放时作为typedef void (*pool_cleanup_t)(void *arg);

的参数传入以便释放使用.

/** create a heap and make sure it get's cleaned up */
static struct pheap *_pool_heap(pool_t p, int size)
{
    struct pheap *ret;
    struct pfree *clean;

    /* make the return heap */
    while((ret = _pool__malloc(sizeof(struct pheap))) == NULL) sleep(1);
    while((ret->block = _pool__malloc(size)) == NULL) sleep(1);
    ret->size = size;
    p->size += size;
    ret->used = 0;

    /* append to the cleanup list */
    clean = _pool_free(p, _pool_heap_free, (void *)ret);
    clean->heap = ret; /* for future use in finding used mem for pstrdup */
    _pool_cleanup_append(p, clean);

    return ret;
}

2, 接着是最重要的分配内存函数: 其余接口均通过此接口获取内存, 然后完成进一步操作. 首先, 如果这不是一个预分配内存的pool或者 尺寸>预分配的内存/2, 那么pool直接

malloc一块全新的内存返回, 当然要在pool的cleanup链表中记录下来内存地址以便后续释放(_pool_cleanup_append). 如果pool->heap足够分配, 那么首先将当前heap指针调整到8字节对齐, 否则将地址起始的内存解释为大于4字节的变量时(包括int, long, >4字节的结构体等), 将可能引起bus error(总线错误), 这在内存池的设计里是必须考虑的, 已经见惯不惯了.

注意这里是调整heap的起始地址到8倍数, 而不是调整请求的size到8倍数, 我习惯后者, 但效果是一样的. 最后, 判断当前pool->heap剩余内存是否足够, 如果不够则丢弃heap剩余部分, 直接_pool_heap分配一块新的heap, 与原先的那块heap一样长, 然后从heap里割一块内存返回.

void *pmalloc(pool_t p, int size)
{
    void *block;

    if(p == NULL)
    {
        fprintf(stderr,"Memory Leak! [pmalloc received NULL pool, unable to track allocation, exiting]\n");
        abort();
    }

    /* if there is no heap for this pool or it's a big request, just raw, I like how we clean this :) */
    if(p->heap == NULL || size > (p->heap->size / 2))
    {
        while((block = _pool__malloc(size)) == NULL) sleep(1);
        p->size += size;
        _pool_cleanup_append(p, _pool_free(p, _pool__free, block));
        return block;
    }

    /* we have to preserve boundaries, long story :) */
    if(size >= 4)
        while(p->heap->used&7) p->heap->used++;

    /* if we don't fit in the old heap, replace it */
    if(size > (p->heap->size - p->heap->used))
        p->heap = _pool_heap(p, p->heap->size);

    /* the current heap has room */
    block = (char *)p->heap->block + p->heap->used;
    p->heap->used += size;
    return block;
}

3, 内存池的释放:

很简单, 遍历pool->cleanup链表, 回调每个pfree结点的回调函数, 传入分配的内存地址, 之后释放pfree结点自身的内存, 并在遍历完成后释放pool自身内存.

(多提一点, pool_cleanup_append和pool_cleanup是一套函数, 都是向pool添加一个pfree, 前者是追加, 后者是向前插入)

注意, _pool__free, _pool__new相当于free, malloc, 定义如下:

#define _pool__malloc malloc
#define _pool__free free
void pool_free(pool_t p)
{
    struct pfree *cur, *stub;

    if(p == NULL) return;

    cur = p->cleanup;
    while(cur != NULL)
    {
        (*cur->f)(cur->arg);
        stub = cur->next;
        _pool__free(cur);
        cur = stub;
    }

#ifdef POOL_DEBUG
    if (pool__disturbed != NULL && pool__disturbed != (xht)1)
    xhash_zap(pool__disturbed,p->name);
#endif

    _pool__free(p);

}

4, 其他一些函数: 列举俩, 简单一看就可以了.

/** XXX efficient: move this to const char * and then loop throug the existing heaps to see if src is within a block in this pool */
char *pstrdup(pool_t p, const char *src)
{
    char *ret;

    if(src == NULL)
        return NULL;

    ret = pmalloc(p,strlen(src) + 1);
    strcpy(ret,src);

    return ret;
}

/** use given size */
char *pstrdupx(pool_t p, const char *src, int len)
{
    char *ret;

    if(src == NULL || len <= 0)
        return NULL;

    ret = pmalloc(p,len + 1);
    memcpy(ret,src,len);
    ret[len] = '\0';

    return ret;
}

 

util/base64是专用于base64编码与解码的接口, jabberd2基于apache apr做了二次封装.

/* base64 functions */
JABBERD2_API int apr_base64_decode_len(const char *bufcoded);
JABBERD2_API int apr_base64_decode(char *bufplain, const char *bufcoded);
JABBERD2_API int apr_base64_encode_len(int len);
JABBERD2_API int apr_base64_encode(char *encoded, const unsigned char *string, int len);
                                                                                                                                                                       
/* convenience, result string must be free()'d by caller */
JABBERD2_API char *b64_encode(char *buf, int len);
JABBERD2_API char *b64_decode(char *buf);

apache apr的base64接口的确很推荐使用, 比直接openssl要简单可靠的多, apr_base64_decode_len和 apr_base64_encode_len 用于计算base64需要的outbuf的大小, 

jabberd2会根据返回值使用Malloc分配内存, 将结果存入其中返回, 需要由调用者负责释放内存, 看一下代码就懂了:

/* convenience functions for j2 */
char *b64_encode(char *buf, int len) {
    int elen;
    char *out;

    if(len == 0)
        len = strlen(buf);

    elen = apr_base64_encode_len(len);
    out = (char *) malloc(sizeof(char) * (elen + 1));
                                                                                                                                                                       
    apr_base64_encode(out, buf, len);

    return out;
}

char *b64_decode(char *buf) {
    int elen;
    char *out;

    elen = apr_base64_decode_len(buf, -1);
    out = (char *) malloc(sizeof(char) * (elen + 1));

    apr_base64_decode(out, buf, -1);

    return out;
}

 

util/hex用于hex编码与解码, 其声明在util/util.h中, 定义于util/hex.c.

/* hex conversion utils */                                                                                                                                             
JABBERD2_API void hex_from_raw(char *in, int inlen, char *out);
JABBERD2_API int hex_to_raw(char *in, int inlen, char *out);

 

util/md5用于md5计算, 是原生的openssl md5实现, 但作者通过宏替换的方式, 将代码变得更加jabberd2了, 可以从下面看出来:

作者用4个宏将原生的md5结构体与接口进行了替换, 显得更加优雅了, 自己挂上了JABBERD2_API... 其实他什么事情都没有做.

#include <openssl/md5.h>

#define md5_state_t MD5_CTX
#define md5_init(c) MD5_Init(c)
#define md5_append(c, data, len) MD5_Update(c, data, len);
#define md5_finish(c, md) MD5_Final(md, c)

typedef uint8_t md5_byte_t; /* 8-bit byte */
typedef uint32_t md5_word_t; /* 32-bit word */

/* Define the state of the MD5 Algorithm. */
typedef struct md5_state_s {
md5_word_t count[2]; /* message length in bits, lsw first */
md5_word_t abcd[4]; /* digest buffer */
md5_byte_t buf[64]; /* accumulate block */
} md5_state_t;

#ifdef __cplusplus


extern "C" 
{
#endif

/* Initialize the algorithm. */
JABBERD2_API void md5_init(md5_state_t *pms);

/* Append a string to the message. */
JABBERD2_API void md5_append(md5_state_t *pms, const md5_byte_t *data, int nbytes);

/* Finish the message and return the digest. */
JABBERD2_API void md5_finish(md5_state_t *pms, md5_byte_t digest[16]);

util/str是一系列操作字符串的函数: 与标准库的区别就是保证NULL形参安全, 大部分的实现都是间接调用了标准库函数, 其中j_strcat行为最特殊, 它与strcat功能不同, 它实际上是strcpy后返回字符串末尾的指针, 这是因为该函数用于spool的特殊用途, 下面将会讲到.

j_atoi的第二个参数是default的意思, 如果第一个参数为NULL的话, 则返回def.

/* --------------------------------------------------------- */
/*                                                           */
/* String management routines                                */
/*                                                           */
/** --------------------------------------------------------- */
JABBERD2_API char *j_strdup(const char *str); /* provides NULL safe strdup wrapper */
JABBERD2_API char *j_strcat(char *dest, char *txt); /* strcpy() clone */
JABBERD2_API int j_strcmp(const char *a, const char *b); /* provides NULL safe strcmp wrapper */
JABBERD2_API int j_strcasecmp(const char *a, const char *b); /* provides NULL safe strcasecmp wrapper */
JABBERD2_API int j_strncmp(const char *a, const char *b, int i); /* provides NULL safe strncmp wrapper */
JABBERD2_API int j_strncasecmp(const char *a, const char *b, int i); /* provides NULL safe strncasecmp wrapper */
JABBERD2_API int j_strlen(const char *a); /* provides NULL safe strlen wrapper */
JABBERD2_API int j_atoi(const char *a, int def); /* checks for NULL and uses default instead, convienence */
JABBERD2_API char *j_attr(const char** atts, const char *attr); /* decode attr's (from expat) */
JABBERD2_API char *j_strnchr(const char *s, int c, int n); /* like strchr, but only searches n chars */

在util/str中还实现了spool, 即string pool, 其实就是字符串链表:

spool_node是结点, spool_struct使用Pool分配结点与自身所需内存.

/* --------------------------------------------------------- */
/*                                                           */
/* String pools (spool) functions                            */
/*                                                           */
/* --------------------------------------------------------- */
struct spool_node
{
    char *c; 
    struct spool_node *next;
};

typedef struct spool_struct
{
    pool_t p;
    int len;
    struct spool_node *last;
    struct spool_node *first;
} *spool;

JABBERD2_API spool spool_new(pool_t p); /* create a string pool */
JABBERD2_API void spooler(spool s, ...); /* append all the char * args to the pool, terminate args with s again */
JABBERD2_API char *spool_print(spool s); /* return a big string */
JABBERD2_API void spool_add(spool s, char *str); /* add a single string to the pool */
JABBERD2_API void spool_escape(spool s, char *raw, int len); /* add and xml escape a single string to the pool */
JABBERD2_API char *spools(pool_t p, ...); /* wrap all the spooler stuff in one function, the happy fun ball! */

关键看一下spool的创建:

spool spool_new(pool_t p)
{
    spool s;

    s = pmalloc(p, sizeof(struct spool_struct));
    s->p = p;
    s->len = 0;
    s->last = NULL;
    s->first = NULL;
    return s;
}

以及添加一个字符串到spool:

static void _spool_add(spool s, char *goodstr)
{
    struct spool_node *sn;

    sn = pmalloc(s->p, sizeof(struct spool_node));
    sn->c = goodstr;
    sn->next = NULL;

    s->len += strlen(goodstr);
    if(s->last != NULL)
        s->last->next = sn; 
    s->last = sn; 
    if(s->first == NULL)
        s->first = sn; 
}

void spool_add(spool s, char *str)
{
    if(str == NULL || strlen(str) == 0)
        return;

    _spool_add(s, pstrdup(s->p, str));
}

可以看出, 整个spool是基于pool创建的, 包括spool_node以及pstrdup生成的字符串副本.

另外, spool的使用接口必须看一下:

第一个接口遍历所有不定参数, spool s本身是一个指向struct spool_struct的指针, ...则是一系列的char *字符串, 在while(1)中, 不断的获取函数函数, 直到(spool)arg == s, 即spooler的使用方法一定是:spooler(s, str1, str2, str3, s)的形式, 循环终止于此, 对于每个str都被spool_add到spool中.

第二个接口依旧使用spool的pool开辟所有spool中字符串总长度+1的buffer, 然后把每个spool_node中的字符串j_strcat追加到buffer中, 最后返回这个buffer.

void spooler(spool s, ...)
{
    va_list ap;
    char *arg = NULL;

    if(s == NULL)
        return;

    va_start(ap, s);

    /* loop till we hit our end flag, the first arg */
    while(1)
    {
        arg = va_arg(ap,char *);
        if((spool)arg == s)
            break;
        else
            spool_add(s, arg);
    }

    va_end(ap);
}

char *spool_print(spool s)
{
    char *ret,*tmp;
    struct spool_node *next;

    if(s == NULL || s->len == 0 || s->first == NULL)
        return NULL;

    ret = pmalloc(s->p, s->len + 1);
    *ret = '\0';

    next = s->first;
    tmp = ret;
    while(next != NULL)
    {
        tmp = j_strcat(tmp,next->c);
        next = next->next;
    }

    return ret;
}                     

该接口综合上述几个接口, 一次性完成若干字符串的串联并返回结果:

/** convenience :) */
char *spools(pool_t p, ...)
{
    va_list ap;
    spool s;
    char *arg = NULL;

    if(p == NULL)
        return NULL;

    s = spool_new(p);

    va_start(ap, p);

    /* loop till we hit our end flag, the first arg */
    while(1)
    {
        arg = va_arg(ap,char *);
        if((pool_t)arg == p)
            break;
        else
            spool_add(s, arg);
    }

    va_end(ap);

    return spool_print(s);
}

另外, spool对XML的esacape/unescape做了良好的支持, 实现如下:

同spool_add一样, spool_escape在_spool_add前进行了xml转义, 实现方法很朴素, 看代码即可.

void spool_escape(spool s, char *raw, int len)
{
    if(raw == NULL || len <= 0)
        return;

    _spool_add(s, strescape(s->p, raw, len));
}

char *strunescape(pool_t p, char *buf)                                                                                                                                 
{
    int i,j=0;
    char *temp;

    if (buf == NULL) return(NULL);

    if (strchr(buf,'&') == NULL) return(buf);

    if(p != NULL)
        temp = pmalloc(p,strlen(buf)+1);
    else
        temp = malloc(strlen(buf)+1);

    if (temp == NULL) return(NULL);

    for(i=0;i<strlen(buf);i++)
    {
        if (buf[i]=='&')
        {
            if (strncmp(&buf[i],"&amp;",5)==0)
            {
                temp[j] = '&';
                i += 4;
            } else if (strncmp(&buf[i],"&quot;",6)==0) {
                temp[j] = '\"';
                i += 5;
            } else if (strncmp(&buf[i],"&apos;",6)==0) {
                temp[j] = '\'';
                i += 5;
            } else if (strncmp(&buf[i],"&lt;",4)==0) {
                temp[j] = '<';
                i += 3;
            } else if (strncmp(&buf[i],"&gt;",4)==0) {
                temp[j] = '>';
                i += 3;
            }
        } else {
            temp[j]=buf[i];
        }
        j++;
    }
    temp[j]='\0';
    return(temp);
}

char *strescape(pool_t p, char *buf, int len)
{
    int i,j,newlen = len;
    char *temp;

    if (buf == NULL || len < 0) return NULL;

    for(i=0;i<len;i++)
    {
        switch(buf[i])
        {
        case '&':
            newlen+=5;
            break;
        case '\'':
            newlen+=6;
            break;
        case '\"':
            newlen+=6;
            break;
        case '<':
            newlen+=4;
            break;
        case '>':
            newlen+=4;
            break;
        }
    }

    if(p != NULL)
        temp = pmalloc(p,newlen+1);
    else
        temp = malloc(newlen+1);
    if(newlen == len)
    {
        memcpy(temp,buf,len);
        temp[len] = '\0';
        return temp;
    }

    for(i=j=0;i<len;i++)
    {
        switch(buf[i])
        {
        case '&':
            memcpy(&temp[j],"&amp;",5);
            j += 5;
            break;
        case '\'':
            memcpy(&temp[j],"&apos;",6);
            j += 6;
            break;
        case '\"':
            memcpy(&temp[j],"&quot;",6);
            j += 6;
            break;
        case '<':
            memcpy(&temp[j],"&lt;",4);
            j += 4;
            break;
        case '>':
            memcpy(&temp[j],"&gt;",4);
            j += 4;
            break;
        default:
            temp[j++] = buf[i];
        }
    }
    temp[j] = '\0';
    return temp;
}

在str.c中还有两个函数, 一个是求raw的sha1, 一个是求hex过的sha1, 后者调用前者, 依赖于openssl的SHA1, 代码如下:其中hex_from_raw在hex.c中提到过了.


/** convenience (originally by Thomas Muldowney) */
void shahash_r(const char* str, char hashbuf[41]) {
    unsigned char hashval[20];
    
    shahash_raw(str, hashval);
    hex_from_raw(hashval, 20, hashbuf);
}
void shahash_raw(const char* str, unsigned char hashval[20]) {
#ifdef HAVE_SSL
    /* use OpenSSL functions when available */
#   include <openssl/sha.h>
    SHA1((unsigned char *)str, strlen(str), hashval);
#else
    sha1_hash((unsigned char *)str, strlen(str), hashval);
#endif

 

util/jsignal是信号处理函数的注册接口: 很简单, 但因为jabberd2跨平台支持win32与linux, 所以显得代码量略多, 实际linux代码只有这么一丁点, 使用sigaction注册信号, 如果不是SIGALRM, 还注册了SA_RESTART的flag, 返回值是之前的信号处理函数.

/* Portable signal function */
typedef void jsighandler_t(int);                                                                                                                                       
JABBERD2_API jsighandler_t* jabber_signal(int signo,  jsighandler_t *func);

jsighandler_t* jabber_signal(int signo, jsighandler_t *func)
{
#ifdef _WIN32
    if(signo == SIGTERM) jabber_term_handler = func;
    return NULL;
#else
    struct sigaction act, oact;

    act.sa_handler = func;
    sigemptyset(&act.sa_mask);
    act.sa_flags = 0;
#ifdef SA_RESTART
    if (signo != SIGALRM)
        act.sa_flags |= SA_RESTART;
#endif
    if (sigaction(signo, &act, &oact) < 0)
        return (SIG_ERR);
    return (oact.sa_handler);
#endif
}

 

util/sha1定义了计算sha1的接口: 如果没有openssl支持, 那么jabberd2实现了自己的sha1, 如果有openssl支持, 则采用了util/md5一样的方法用宏替换掉了openssl的函数名与结构名, 通常我们用最后一个接口, 上面在util/str中已经见过, 直接即可计算得到raw的sha1结果.

#ifdef HAVE_SSL
#include <openssl/sha.h>

#define sha1_state_t SHA_CTX
#define sha1_init(c) SHA1_Init(c)                                                                                                                                      
#define sha1_append(c, data, len) SHA1_Update(c, data, len);
#define sha1_finish(c, md) SHA1_Final(md, c)
#define sha1_hash(data, len, md) SHA1(data, len, md);

#else

#include <inttypes.h>

typedef struct sha1_state_s {
  uint32_t H[5];
  uint32_t W[80];
  int lenW;
  uint32_t sizeHi,sizeLo;
} sha1_state_t;

JABBERD2_API void sha1_init(sha1_state_t *ctx);
JABBERD2_API void sha1_append(sha1_state_t *ctx, const unsigned char *dataIn, int len);
JABBERD2_API void sha1_finish(sha1_state_t *ctx, unsigned char hashout[20]);
JABBERD2_API void sha1_hash(const unsigned char *dataIn, int len, unsigned char hashout[20]);

#endif

 

util/jqueue实现了优先级队列: 其原理非常简单, 即插入时刻保证优先级有序即可, 数据结构采用了双向非循环链表, 而不是使用复杂的平衡树或者堆结构.

其中, _jqueue_node_t做为node, 其中存有data数据和priority优先级.

而jqueue_t中有pool内存池, front,back即链表的头尾结点, size表示结点个数, init_time表示jqueue的创建时间,最重要的是cache, 其实就是被将被jqueue_pull取走data后留下的_jqueue_node_t插入到cache链表中, 下次就不用从pool重新开辟了, 俗称"对象池".

操作接口只有Push和pull, 直接看一下实现即可.

/*
 * priority queues
 */

typedef struct _jqueue_node_st  *_jqueue_node_t;
struct _jqueue_node_st {
    void            *data;

    int             priority;

    _jqueue_node_t  next;
    _jqueue_node_t  prev;
};

typedef struct _jqueue_st {
    pool_t          p;  
    _jqueue_node_t  cache;

    _jqueue_node_t  front;
    _jqueue_node_t  back;

    int             size;
    char            *key;
    time_t          init_time;
} *jqueue_t;

JABBERD2_API jqueue_t    jqueue_new(void);
JABBERD2_API void        jqueue_free(jqueue_t q); 
JABBERD2_API void        jqueue_push(jqueue_t q, void *data, int pri);
JABBERD2_API void        *jqueue_pull(jqueue_t q); 
JABBERD2_API int         jqueue_size(jqueue_t q); 
JABBERD2_API time_t      jqueue_age(jqueue_t q); 

jqueue自己创建了专用的Pool用于开辟内存, 并且jqueue自身就是使用pool分配的. 可以从Jqueue的释放中了解到pool的设计多么有意思, 只要把pool自己释放了, jqueue就不见了, 根本没有内存泄漏的机会.

jqueue_t jqueue_new(void) {
    pool_t p;
    jqueue_t q;

    p = pool_new();
    q = (jqueue_t) pmalloco(p, sizeof(struct _jqueue_st));

    q->p = p;
    q->init_time = time(NULL);

    return q;
}

void jqueue_free(jqueue_t q) {
    assert((int) (q != NULL));

    pool_free(q->p);
}

接下来是插入一个元素到优先级队列: 在jqueue_push中, 先检查cache是否有空闲的node, 有就取出来使用, 否则从Pool重新开辟. 如果队列当前非空, 那么需要根据优先级将Node插入到合适的位置, for循环就是做这事的. 在jqueue里, 优先级最大的在头部, 所以遍历queue时从back向front走. 说实话, 作者这个next和prev正好搞反了, 读起来很别扭.

void jqueue_push(jqueue_t q, void *data, int priority) {
    _jqueue_node_t qn, scan;

    assert((int) (q != NULL));

    q->size++;

    /* node from the cache, or make a new one */
    qn = q->cache;
    if(qn != NULL)
        q->cache = qn->next;
    else
        qn = (_jqueue_node_t) pmalloc(q->p, sizeof(struct _jqueue_node_st));

    qn->data = data;
    qn->priority = priority;

    qn->next = NULL;
    qn->prev = NULL;

    /* first one */
    if(q->back == NULL && q->front == NULL) {
        q->back = qn;
        q->front = qn;

        return;
    }

    /* find the first node with priority <= to us */
    for(scan = q->back; scan != NULL && scan->priority > priority; scan = scan->next);

    /* didn't find one, so we have top priority - push us on the front */
    if(scan == NULL) {
        qn->prev = q->front;
        qn->prev->next = qn;
        q->front = qn;

        return;
    }

    /* push us in front of scan */
    qn->next = scan;
    qn->prev = scan->prev;

    if(scan->prev != NULL)           
        scan->prev->next = qn;
    else
        q->back = qn;

    scan->prev = qn;
}

jqueue_pull和剩下的接口就很简单了: 取出q->front中的data, 如果还有剩余结点, 令下一个结点的next=NULL并成为新的front, 把取出的front插入到cache头部, 返回data.

void *jqueue_pull(jqueue_t q) {
    void *data;
    _jqueue_node_t qn;

    assert((int) (q != NULL));

    if(q->front == NULL)
        return NULL;

    data = q->front->data;

    qn = q->front;

    if(qn->prev != NULL)
        qn->prev->next = NULL;
    
    q->front = qn->prev;

    /* node to cache for later reuse */
    qn->next = q->cache;
    q->cache = qn;

    if(q->front == NULL)
        q->back = NULL;

    q->size--;

    return data;
}

int jqueue_size(jqueue_t q) {
    return q->size;
}

time_t jqueue_age(jqueue_t q) {
    return time(NULL) - q->init_time;
}

 

util/datetime定义了两个接口, 一个从字符串转time_t, 一个从time_t转字符串, 预定义了几种时间格式使用sscanf尝试解析, 按照我个人理解:

datetime_in传入的是各地的时间, 各不相同, 有若干种格式, 有的采用gmt+/-的格式表示各地的当地时间, 所以fix就是存储各地时间与gmt的差值, 函数最终希望生成的是gmt时间, 但mktime函数会将struct tm中的时间当作本地时间转换为time_t(即转换为time_t后加上了了本地时间与gmt的差值), 所以mktime()需要先-(tz.tz_minuteswest * 60)以便保证mktime是向gmt转换, 之后再将各地的时间与gmt的差值+fix补回去, 这样就得到了标准GMT时间.

 

typedef enum {
    dt_DATE     = 1,
    dt_TIME     = 2,
    dt_DATETIME = 3,
    dt_LEGACY   = 4 
} datetime_t;

JABBERD2_API time_t  datetime_in(char *date);
JABBERD2_API void    datetime_out(time_t t, datetime_t type, char *date, int datelen);

 

/* formats */
#define DT_DATETIME_P       "%04d-%02d-%02dT%02d:%02d:%lf+%02d:%02d"
#define DT_DATETIME_M       "%04d-%02d-%02dT%02d:%02d:%lf-%02d:%02d"
#define DT_DATETIME_Z       "%04d-%02d-%02dT%02d:%02d:%lfZ"
#define DT_TIME_P           "%02d:%02d:%lf+%02d:%02d"
#define DT_TIME_M           "%02d:%02d:%lf-%02d:%02d"
#define DT_TIME_Z           "%02d:%02d:%lfZ"
#define DT_LEGACY           "%04d%02d%02dT%02d:%02d:%lf"

time_t datetime_in(char *date) {
    struct tm gmt, off;
    double sec;
    off_t fix = 0;
    struct timeval tv;
    struct timezone tz;

    assert((int) (date != NULL));

    /* !!! sucks having to call this each time */
    tzset();

    memset(&gmt, 0, sizeof(struct tm));
    memset(&off, 0, sizeof(struct tm));

    if(sscanf(date, DT_DATETIME_P,
                   &gmt.tm_year, &gmt.tm_mon, &gmt.tm_mday,
                   &gmt.tm_hour, &gmt.tm_min, &sec,
                   &off.tm_hour, &off.tm_min) == 8) {
        gmt.tm_sec = (int) sec;
        gmt.tm_year -= 1900;
        gmt.tm_mon--;
        fix = off.tm_hour * 3600 + off.tm_min * 60;
    }

    else if(sscanf(date, DT_DATETIME_M,
                   &gmt.tm_year, &gmt.tm_mon, &gmt.tm_mday,
                   &gmt.tm_hour, &gmt.tm_min, &sec,
                   &off.tm_hour, &off.tm_min) == 8) {
        gmt.tm_sec = (int) sec;
        gmt.tm_year -= 1900;
        gmt.tm_mon--;
        fix = - off.tm_hour * 3600 - off.tm_min * 60;
    }
    else if(sscanf(date, DT_DATETIME_Z,
                   &gmt.tm_year, &gmt.tm_mon, &gmt.tm_mday,
                   &gmt.tm_hour, &gmt.tm_min, &sec) == 6) {
        gmt.tm_sec = (int) sec;
        gmt.tm_year -= 1900;
        gmt.tm_mon--;
        fix = 0;
    }

    else if(sscanf(date, DT_TIME_P,
                   &gmt.tm_hour, &gmt.tm_min, &sec,
                   &off.tm_hour, &off.tm_min) == 5) {
        gmt.tm_sec = (int) sec;
        fix = off.tm_hour * 3600 + off.tm_min * 60;
    }

    else if(sscanf(date, DT_TIME_M,
                   &gmt.tm_hour, &gmt.tm_min, &sec,
                   &off.tm_hour, &off.tm_min) == 5) {
        gmt.tm_sec = (int) sec;
        fix = - off.tm_hour * 3600 - off.tm_min * 60;
    }

    else if(sscanf(date, DT_TIME_Z,
                   &gmt.tm_hour, &gmt.tm_min, &sec) == 3) {
        gmt.tm_sec = (int) sec;
        fix = - off.tm_hour * 3600 - off.tm_min * 60;
    }

    else if(sscanf(date, DT_LEGACY,
                   &gmt.tm_year, &gmt.tm_mon, &gmt.tm_mday,
                   &gmt.tm_hour, &gmt.tm_min, &sec) == 6) {
        gmt.tm_sec = (int) sec;
        gmt.tm_year -= 1900;
        gmt.tm_mon--;
        fix = 0;
    }

    gmt.tm_isdst = -1;

    gettimeofday(&tv, &tz);

    return mktime(&gmt) + fix - (tz.tz_minuteswest * 60);
}

datetime_out函数, 就直接调用了gmtime获取了GTM时间, 之后格式化成了某一种格式, 这个函数比较简单, 没有那么复杂的时间概念.

void datetime_out(time_t t, datetime_t type, char *date, int datelen) {
    struct tm *gmt;

    assert((int) type);
    assert((int) (date != NULL));
    assert((int) datelen);

    gmt = gmtime(&t);

    switch(type) {
        case dt_DATE:
            snprintf(date, datelen, "%04d-%02d-%02d", gmt->tm_year + 1900, gmt->tm_mon + 1, gmt->tm_mday);
            break;

        case dt_TIME:
            snprintf(date, datelen, "%02d:%02d:%02dZ", gmt->tm_hour, gmt->tm_min, gmt->tm_sec);
            break;

        case dt_DATETIME:
            snprintf(date, datelen, "%04d-%02d-%02dT%02d:%02d:%02dZ", gmt->tm_year + 1900, gmt->tm_mon + 1, gmt->tm_mday, gmt->tm_hour, gmt->tm_min, gmt->tm_sec);
            break;

        case dt_LEGACY:
            snprintf(date, datelen, "%04d%02d%02dT%02d:%02d:%02d", gmt->tm_year + 1900, gmt->tm_mon + 1, gmt->tm_mday, gmt->tm_hour, gmt->tm_min, gmt->tm_sec);
            break;
    }   
}    

 

util/rate用于频率限制, 在实现上并没有限制使用场景, 是非常通用的:

注释本身很清晰了, 既如果我们seconds秒内做了total次, 那么我们休止wait秒. 其中time和count用于记录time秒内操作了count次, bad记录了进入休止状态的时刻 当休止超过wait秒后, 限制解除.

/*
 * rate limiting
 */

typedef struct rate_st
{
    int             total;      /* if we exceed this many events */
    int             seconds;    /* in this many seconds */
    int             wait;       /* then go bad for this many seconds */

    time_t          time;       /* time we started counting events */
    int             count;      /* event count */

    time_t          bad;        /* time we went bad, or 0 if we're not */
} *rate_t;

JABBERD2_API rate_t      rate_new(int total, int seconds, int wait);
JABBERD2_API void        rate_free(rate_t rt);
JABBERD2_API void        rate_reset(rate_t rt);

/**
 * Add a number of events to the counter.  This takes care of moving
 * the sliding window, if we've moved outside the previous window.
 */
JABBERD2_API void        rate_add(rate_t rt, int count);

/**
 * @return The amount of events we have left before we hit the rate
 *         limit.  This could be number of bytes, or number of
 *         connection attempts, etc.
 */
JABBERD2_API int         rate_left(rate_t rt);

/**
 * @return 1 if we're under the rate limit and everything is fine or
 *         0 if the rate limit has been exceeded and we should throttle
 *         something.
 */
JABBERD2_API int         rate_check(rate_t rt);

实现很简单, 总体来说并不是准确的, 只是大致那么一算: 

rate_new 创建并配置了rate_t结构体.

rate_free 释放.

rate_reset 重置.

rate_left 如果当前已休止, 那么没有剩余计数, 否则返回total-count.

rate_add 这个函数只关心在seconds秒内的情况下, 是否count > total, 否则会重置计数.

rate_check 如果还没计数过, 或者没超过限制, 立即返回. 如果处于bad阶段, 那么判断是否超过wait, 超过则恢复并重置, 没超过则返回0.

rate_t rate_new(int total, int seconds, int wait)
{
    rate_t rt = (rate_t) calloc(1, sizeof(struct rate_st));

    rt->total = total;
    rt->seconds = seconds;
    rt->wait = wait;

    return rt;
}

void rate_free(rate_t rt)
{
    free(rt);
}

void rate_reset(rate_t rt)
{
    rt->time = 0;
    rt->count = 0;
    rt->bad = 0;
}

void rate_add(rate_t rt, int count)
{
    time_t now;

    now = time(NULL);

    /* rate expired */
    if(now - rt->time >= rt->seconds)
        rate_reset(rt);

    rt->count += count;

    /* first event, so set the time */
    if(rt->time == 0)
        rt->time = now;

    /* uhoh, they stuffed up */
    if(rt->count >= rt->total)
        rt->bad = now;
}
int rate_left(rate_t rt)
{
    /* if we're bad, then there's none left */
    if(rt->bad != 0)
        return 0;

    return rt->total - rt->count;
}

int rate_check(rate_t rt)
{
    /* not tracking */
    if(rt->time == 0)
        return 1;

    /* under the limit */
    if(rt->count < rt->total)
        return 1;

    /* currently bad */
    if(rt->bad != 0)
    {
        /* wait over, they're good again */
        if(time(NULL) - rt->bad >= rt->wait)
        {
            rate_reset(rt);
            return 1;
        }

        /* keep them waiting */
        return 0;
    }

    /* they're inside the time, and not bad yet */
    return 1;
}

 

util/serial是一个内存中的序列化操作, 被其他地方使用到, 接口如下:

分成两个系列, 一个是(反)序列化string的, 一个是(反)序列化int的.

序列化在这里是指: 把string/int 存储到一个buffer里, 或者从buffer里反序列化得到一个string/int.

/*
 * serialisation helper functions                                                                                                                                      
 */

JABBERD2_API int         ser_string_get(char **dest, int *source, const char *buf, int len);
JABBERD2_API int         ser_int_get(int *dest, int *source, const char *buf, int len);
JABBERD2_API void        ser_string_set(char *source, int *dest, char **buf, int *len);
JABBERD2_API void        ser_int_set(int source, int *dest, char **buf, int *len);

buf是指针的地址, 因为用户的内存可能不够需要realloc, 所以要传char **buf. 

ser_string_set序列化char *source字符串到buf中, SER_SAFE检测buf的尺寸是否足够, 其中int *dest表示buf已经使用的尺寸, int *len表示buf总长度, 函数就修改int *dest和int *len(因为调用SER_SAFE在内存不足情况下会调用realloc).

ser_int_set原来类似, 使用了一个union转换int到字节数组.

void ser_string_set(char *source, int *dest, char **buf, int *len)
{
    int need = sizeof(char) * (strlen(source) + 1); 

    /* make more space if necessary */
    SER_SAFE(*buf, *dest + need, *len);

    /* copy it in */
    strcpy(*buf + *dest, source);

    /* and shift the pointer */
    *dest += need;
}

void ser_int_set(int source, int *dest, char **buf, int *len)
{
    union
    {   
        char c[sizeof(int)];
        int i;
    } u;
    int i;

    /* make more space if necessary */
    SER_SAFE(*buf, *dest + sizeof(int), *len)

    /* copy it in */
    u.i = source;
    for(i = 0; i < sizeof(int); i++)
        (*buf)[*dest + i] = u.c[i];

    /* and shift the pointer */
    *dest += sizeof(int);
}                        

其中SER_SAFE宏与实现如下:_ser_realloc重新分配一块内存, 存入*oblocks中, 新分配的内存大小长度为len, 但_ser_realloc会将其尺寸对齐到BLOCKSIZE.

SER_SAFE宏则判断buf剩余内存是否足够, 不够则调用_ser_realloc.

/* shamelessy stolen from nad.c */

#define BLOCKSIZE 1024

/** internal: do and return the math and ensure it gets realloc'd */
static int _ser_realloc(void **oblocks, int len)
{
    void *nblocks;
    int nlen;

    /* round up to standard block sizes */
    nlen = (((len-1)/BLOCKSIZE)+1)*BLOCKSIZE;

    /* keep trying till we get it */
    while((nblocks = realloc(*oblocks, nlen)) == NULL) sleep(1);
    *oblocks = nblocks;
    return nlen;
}

/** this is the safety check used to make sure there's always enough mem */
#define SER_SAFE(blocks, size, len) if((size) > len) len = _ser_realloc((void**)&(blocks),(size));

反序列化两个接口如下, 一样会修改int *source, 对于string会strdup一份字符串的副本返回到char **dest; 对于int则通过int *dest返回.

int ser_string_get(char **dest, int *source, const char *buf, int len)
{
    const char *end, *c;

    /* end of the buffer */
    end = buf + ((sizeof(char) * (len - 1)));

    /* make sure we have a \0 before the end of the buffer */
    c = &(buf[*source]);
    while(c <= end && *c != '\0') c++;
    if(c > end)
        /* we ran past the end, fail */
        return 1;

    /* copy the string */
    *dest = strdup(&(buf[*source]));

    /* and move the pointer */
    *source += strlen(*dest) + 1;

    return 0;
}

int ser_int_get(int *dest, int *source, const char *buf, int len)
{
    union
    {
        char c[sizeof(int)];
        int i;
    } u;
    int i;

    /* we need sizeof(int) bytes */
    if(&(buf[*source]) + sizeof(int) > buf + (sizeof(char) * len))
        return 1;

    /* copy the bytes into the union. we do it this way to avoid alignment problems */
    for(i = 0; i < sizeof(int); i++)
    {
        u.c[i] = buf[*source];
        (*source)++;
    }
    *dest = u.i;

    return 0;          
}

 

util/inaddr是一套自封装的网络地址操作库, 实现了常见的inet_pton, inet_ntop等功能.

JABBERD2_API int         j_inet_pton(char *src, struct sockaddr_storage *dst);
JABBERD2_API const char  *j_inet_ntop(struct sockaddr_storage *src, char *dst, size_t size);
JABBERD2_API int         j_inet_getport(struct sockaddr_storage *sa);
JABBERD2_API int         j_inet_setport(struct sockaddr_storage *sa, in_port_t port);
JABBERD2_API socklen_t   j_inet_addrlen(struct sockaddr_storage *sa);

这些接口的实现其实需要特别关注一下, 因为是兼容IPV4与IPV6的, 其中的原理与技巧需要掌握.

j_inet_pton将src字符串表达的地址转换存储到struct sockaddr_storage中, 该结构体保证兼容IPV4与IPV6的尺寸. 如果不支持inet_pton, 那么只使用inet_aton尝试转换ipv4, 对于ipv6则以失败告终. 如果支持inet_pton, 那么分别尝试inet_pton的AF_INET和AF_INET6, 只要成功则返回, 这算是个小技巧.

int j_inet_pton(char *src, struct sockaddr_storage *dst)
{
#ifndef HAVE_INET_PTON
    struct sockaddr_in *sin;

    memset(dst, 0, sizeof(struct sockaddr_storage));
    sin = (struct sockaddr_in *)dst;
    
    if(inet_aton(src, &sin->sin_addr))
    {   
        dst->ss_family = AF_INET;
        return 1;
    }   

    return 0;
#else
    struct sockaddr_in *sin;
    struct sockaddr_in6 *sin6;

    memset(dst, 0, sizeof(struct sockaddr_storage));
    sin = (struct sockaddr_in *)dst;
    sin6 = (struct sockaddr_in6 *)dst;
        
    if(inet_pton(AF_INET, src, &sin->sin_addr) > 0)
    {   
        dst->ss_family = AF_INET;
        return 1;
    }   

    if(inet_pton(AF_INET6, src, &sin6->sin6_addr) > 0)
    {   
        dst->ss_family = AF_INET6;
#ifdef SIN6_LEN
        sin6->sin6_len = sizeof(struct sockaddr_in6);
#endif
        return 1;
    }   

    return 0;
#endif
}

j_inet_ntop是将地址转换成可读字符串, 基本和上面类似, 看一下即可:

const char *j_inet_ntop(struct sockaddr_storage *src, char *dst, size_t size)
{
#ifndef HAVE_INET_NTOP
    char *tmp;
    struct sockaddr_in *sin;

    sin = (struct sockaddr_in *)src;

    /* if we don't have inet_ntop we only accept AF_INET
     * it's unlikely that we would have use for AF_INET6
     */
    if(src->ss_family != AF_INET)
    {
        return NULL;
    }

    tmp = inet_ntoa(sin->sin_addr);

    if(!tmp || strlen(tmp)>=size)
    {
        return NULL;
    }

    strncpy(dst, tmp, size);
    return dst;
#else
    struct sockaddr_in *sin;
    struct sockaddr_in6 *sin6;

    sin = (struct sockaddr_in *)src;
    sin6 = (struct sockaddr_in6 *)src;

    switch(src->ss_family)
    {
    case AF_UNSPEC:                                                                                                                                                    
    case AF_INET:
        return inet_ntop(AF_INET, &sin->sin_addr, dst, size);
    case AF_INET6:
        return inet_ntop(AF_INET6, &sin6->sin6_addr, dst, size);
    default:
        return NULL;
    }
#endif
}

j_inet_getport就是获取地址中的port, 兼容ipv4/ipv6.

int j_inet_getport(struct sockaddr_storage *sa)
{
    struct sockaddr_in *sin;
    struct sockaddr_in6 *sin6;
    
    switch(sa->ss_family)
    {
    case AF_INET:
        sin = (struct sockaddr_in *)sa;
        return ntohs(sin->sin_port);
    case AF_INET6:
        sin6 = (struct sockaddr_in6 *)sa;
        return ntohs(sin6->sin6_port);
    default:
        return 0;
    }
}

j_inet_setport类似:

int j_inet_setport(struct sockaddr_storage *sa, in_port_t port)
{
    struct sockaddr_in *sin;
    struct sockaddr_in6 *sin6;

    sin = (struct sockaddr_in *)sa;
    sin6 = (struct sockaddr_in6 *)sa;

    switch(sa->ss_family)
    {
    case AF_INET:
        sin->sin_port = htons(port);
        return 1;
    case AF_INET6:
        sin6->sin6_port = htons(port);
        return 1;
    default:
        return 0;
    }
}

最后:

socklen_t j_inet_addrlen(struct sockaddr_storage *sa)
{
#ifdef SIN6_LEN
    if(sa->ss_len != 0)
        return sa->ss_len;
#endif
    switch(sa->ss_family)
    {
    case AF_INET:
        return sizeof(struct sockaddr_in);
    case AF_INET6:
        return sizeof(struct sockaddr_in6);
    default:
        return sizeof(struct sockaddr_storage);
    }
}

 

util/access是基于IP的黑白名单, 支持带掩码的IP地址, 即可以按照网段过滤, 说白了就和路由表一样, 对于一个IP, 先用IP&mask, 然后判断是否符合access_rule_t中的ip, 其实现技巧也体现了很多网络IPV4,IPV6方面的处理, 需要特别关注一下.

/*
 * IP-based access controls
 */

typedef struct access_rule_st
{
    struct sockaddr_storage ip; 
    int            mask;
} *access_rule_t;

typedef struct access_st
{
    int             order;      /* 0 = allow,deny  1 = deny,allow */

    access_rule_t   allow;
    int             nallow;

    access_rule_t   deny;
    int             ndeny;
} *access_t;

JABBERD2_API access_t    access_new(int order);
JABBERD2_API void        access_free(access_t access);
JABBERD2_API int         access_allow(access_t access, char *ip, char *mask);
JABBERD2_API int         access_deny(access_t access, char *ip, char *mask);
JABBERD2_API int         access_check(access_t access, char *ip);

首先, 创建与释放access结构体, 其中access->allow和access->deny是两个struct access_rule_st的数组.

access_t access_new(int order)
{
    access_t access = (access_t) calloc(1, sizeof(struct access_st));

    access->order = order;

    return access;
}

void access_free(access_t access)
{                                                                                                                                                                      
    if(access->allow != NULL) free(access->allow);
    if(access->deny != NULL) free(access->deny);
    free(access);
}

下面的函数是计算一个IPV4掩码的长度, 什么意思呢?  掩码也是一个IP地址的格式, xxx.xxx.xxx.xxx, 是一个4字节整形, 该函数就是计算一下最右边的比特为1的位置是多少(假设是offset, 是距离IP地址最左位的比特数), 即掩码的位数. 那么对于一个IP地址, 就可以取出IP的左边offset个比特, 判断是否和access_rule_st->ip相同, 相同则命中.

代码中, 需要先将ipv4字符串转换成网络序4字节整形, 然后做ntohl转换成本地序, 之后即可开始循环的右移, 直到整形最低位为1, 那么netsize此时就是掩码的长度了.同时, 也支持直接传递mask的长度, 那么直接j_atoi就可以立即得到掩码长度了.

注意一个细节, 如果不支持inet_pton, 那么则判断是否包含'.', 并尝试使用inet_aton转换掩码mask, 如果成功则说明地址是ipv4, 否则认为mask传入了掩码长度而非掩码地址, 即在此情况下是不支持ipv6的. 但如果支持inet_pton, 也一样是不支持IPV6地址掩码的, 只允许IPV6地址长度.

static int _access_calc_netsize(const char *mask, int defaultsize)
{
    struct in_addr legacy_mask;
    int netsize;

#ifndef HAVE_INET_PTON
    if(strchr(mask, '.') && inet_aton(mask, &legacy_mask))
#else
    if(inet_pton(AF_INET, mask, &legacy_mask.s_addr) > 0)
#endif
    {
        /* netmask has been given in dotted decimal form */
        int temp = ntohl(legacy_mask.s_addr);
        netsize = 32;

        while(netsize && temp%2==0)
        {
            netsize--;
            temp /= 2;
        }
    } else {
        /* numerical netsize */
        netsize = j_atoi(mask, defaultsize);
    }

    return netsize;
}

这个函数更简单, 把ipv4映射的ipv6地址转换成ipv4地址, 因为ipv6的12,15字节是ipv4地址的大端4字节, 所以在本地每次左移8比特生成对应的本地序地址整形, 之后转向网络序存储到地址结构体中.

    memset(dst, 0, sizeof(struct sockaddr_in));
    dst->sin_family = AF_INET;
    dst->sin_addr.s_addr = htonl((((int)src->sin6_addr.s6_addr[12]*256+src->sin6_addr.s6_addr[13])*256+src->sin6_addr.s6_addr[14])*256+(int)src->sin6_addr.s6_addr[15])
;

再看一下怎么添加白名单与黑名单: 两者代码是冗余的, 看一下添加白名单即可. 先j_inet_pton校验地址是否有效, 之后计算掩码长度, 此处直接传入了defaultsize, 即全掩码, 表示IP地址必须完全匹配. 计算完掩码长度后, 给access->allow数组realloc新添加一个rule, 把ip与掩码拷贝进去.

int access_allow(access_t access, char *ip, char *mask)                                                                                                                
{
    struct sockaddr_storage ip_addr;
    int netsize;

    if(j_inet_pton(ip, &ip_addr) <= 0)
        return 1;

    netsize = _access_calc_netsize(mask, ip_addr.ss_family==AF_INET ? 32 : 128);

    access->allow = (access_rule_t) realloc(access->allow, sizeof(struct access_rule_st) * (access->nallow + 1));

    memcpy(&access->allow[access->nallow].ip, &ip_addr, sizeof(ip_addr));
    access->allow[access->nallow].mask = netsize;

    access->nallow++;

    return 0;
}

int access_deny(access_t access, char *ip, char *mask)
{
    struct sockaddr_storage ip_addr;
    int netsize;

    if(j_inet_pton(ip, &ip_addr) <= 0)
        return 1;

    netsize = _access_calc_netsize(mask, ip_addr.ss_family==AF_INET ? 32 : 128);

    access->deny = (access_rule_t) realloc(access->deny, sizeof(struct access_rule_st) * (access->ndeny + 1));

    memcpy(&access->deny[access->ndeny].ip, &ip_addr, sizeof(ip_addr));
    access->deny[access->ndeny].mask = netsize;

    access->ndeny++;

    return 0;
}

下面就是用户接口了, 黑白名单有作用顺序的区别, 先检测黑名单还是先检测白名单, 结果是不同的, 因为黑名单和白名单可能同时命中. 先遍历allow数组, 确定是否allow, 然后deny数组确定是否deny, 最后根据order决定黑白名单的先后顺序.

int access_check(access_t access, char *ip)                                                                                                                            
{
    struct sockaddr_storage addr;
    access_rule_t rule;
    int i, allow = 0, deny = 0;

    if(j_inet_pton(ip, &addr) <= 0)
        return 0;

    /* first, search the allow list */
    for(i = 0; !allow && i < access->nallow; i++)
    {
        rule = &access->allow[i];
        if(_access_check_match(&addr, &rule->ip, rule->mask))
            allow = 1;
    }

    /* now the deny list */
    for(i = 0; !deny && i < access->ndeny; i++)
    {
        rule = &access->deny[i];
        if(_access_check_match(&addr, &rule->ip, rule->mask))
            deny = 1;
    }

    /* allow then deny */
    if(access->order == 0)
    {
        if(allow)
            return 1;

        if(deny)
            return 0;

        /* allow by default */
        return 1;
    }

    /* deny then allow */
    if(deny)
        return 0;

    if(allow)
        return 1;

    /* deny by default */
    return 0;
}

其中用到的_access_check_match判断两个IP在掩码作用下是否相同:IN6_IS_ADDR_V4MAPPED宏出现在util/util_compat.h中, 用于判断IPV6地址是否是IPV4映射的, 具体IPV6地址格式得单独看一下了.

函数很简单, 先判断两个地址是否协议相同, 不同则判断是否是IPV4映射的IPV6,会做一次转换, 之后递归调用函数再次校验.

对于IPV4地址, 会先制作出掩码, 然后转成网络序和网络序的地址进行运算判断是否相等.

对于IPV6地址, 稍微特殊一点, 会先比较前netsize / 8 字节, 如果掩码没有占全最后一个字节, 那么最后一个字节需要按位比较, 剩下的和IPV4的比较一样.

#ifndef IN6_IS_ADDR_V4MAPPED
/** check if an IPv6 is just a mapped IPv4 address */
#define IN6_IS_ADDR_V4MAPPED(a) \
((*(const uint32_t *)(const void *)(&(a)->s6_addr[0]) == 0) && \
(*(const uint32_t *)(const void *)(&(a)->s6_addr[4]) == 0) && \
(*(const uint32_t *)(const void *)(&(a)->s6_addr[8]) == ntohl(0x0000ffff)))
#endif

/** check if two ip addresses are within the same subnet */
static int _access_check_match(struct sockaddr_storage *ip_1, struct sockaddr_storage *ip_2, int netsize)
{
    struct sockaddr_in *sin_1;
    struct sockaddr_in *sin_2;
    struct sockaddr_in6 *sin6_1;
    struct sockaddr_in6 *sin6_2;
    int i;

    sin_1 = (struct sockaddr_in *)ip_1;
    sin_2 = (struct sockaddr_in *)ip_2;
    sin6_1 = (struct sockaddr_in6 *)ip_1;
    sin6_2 = (struct sockaddr_in6 *)ip_2;

    /* addresses of different families */
    if(ip_1->ss_family != ip_2->ss_family)
    {   
        /* maybe on of the addresses is just a IPv6 mapped IPv4 address */
        if (ip_1->ss_family == AF_INET && ip_2->ss_family == AF_INET6 && IN6_IS_ADDR_V4MAPPED(&sin6_2->sin6_addr))
        {   
            struct sockaddr_storage t;
            struct sockaddr_in *temp;

            temp = (struct sockaddr_in *)&t;

            _access_unmap_v4(sin6_2, temp);
            if(netsize>96)
                netsize -= 96; 

            return _access_check_match(ip_1, &t, netsize);
        }   

        if (ip_1->ss_family == AF_INET6 && ip_2->ss_family == AF_INET && IN6_IS_ADDR_V4MAPPED(&sin6_1->sin6_addr))
        {   
            struct sockaddr_storage t;
            struct sockaddr_in *temp;

            temp = (struct sockaddr_in *)&t;
            
            _access_unmap_v4(sin6_1, temp);
            if(netsize>96)
                netsize -= 96; 

            return _access_check_match(&t, ip_2, netsize);
        }                                                          

        return 0;
    }

    /* IPv4? */
    if(ip_1->ss_family == AF_INET)
    {
        int netmask;

        if(netsize > 32)
            netsize = 32;

        netmask = htonl(-1 << (32-netsize));

        return ((sin_1->sin_addr.s_addr&netmask) == (sin_2->sin_addr.s_addr&netmask));
    }

    /* IPv6? */
    if(ip_1->ss_family == AF_INET6)
    {
        unsigned char bytemask;

        if(netsize > 128)
            netsize = 128;

        for(i=0; i<netsize/8; i++)
            if(sin6_1->sin6_addr.s6_addr[i] != sin6_2->sin6_addr.s6_addr[i])
                return 0;

        if(netsize%8 == 0)
            return 1;

        bytemask = 0xff << (8 - netsize%8);

        return ((sin6_1->sin6_addr.s6_addr[i]&bytemask) == (sin6_2->sin6_addr.s6_addr[i]&bytemask));
    }

    /* unknown address family */
    return 0;
}

 

util/log用于打印日志, 支持syslog, stdout, file三种方式, 但代码里有不少关于DEBUG编译的特殊逻辑, 我看着也挺绕, 就不说明了.

log.c依赖util.h中的如下声明:

/* logging */

typedef enum {
    log_STDOUT,
    log_SYSLOG,
    log_FILE
} log_type_t;

typedef struct log_st
{
    log_type_t  type;
    FILE        *file;
} *log_t;

typedef struct log_facility_st
{
    const char  *facility;
    int         number;
} log_facility_t;

JABBERD2_API log_t    log_new(log_type_t type, const char *ident, const char *facility);
JABBERD2_API void     log_write(log_t log, int level, const char *msgfmt, ...);
JABBERD2_API void     log_free(log_t log);

在log.c中可以看到这些日志级别与syslog日志级别相关的定义, log_new创建log, 对于syslog调用_log_facility获取syslog日记级别的数值, openlog打开syslog, 如果是file类型则创建文件, stdout则直接赋值FILE即可.

static int _log_facility(const char *facility) {
    log_facility_t *lp;

    if (facility == NULL) {
        return -1;                                                                                                                                                     
    }
    for (lp = _log_facilities; lp->facility; lp++) {
        if (!strcasecmp(lp->facility, facility)) {
            break;
        }
    }
    return lp->number;
}

log_t log_new(log_type_t type, const char *ident, const char *facility)
{
    log_t log;
    int fnum = 0;

    log = (log_t) calloc(1, sizeof(struct log_st));

    log->type = type;

    if(type == log_SYSLOG) {
        fnum = _log_facility(facility);
        if (fnum < 0)
            fnum = LOG_LOCAL7;
        openlog(ident, LOG_PID, fnum);
        return log;
    }

    else if(type == log_STDOUT) {
        log->file = stdout;
        return log;
    }

    log->file = fopen(ident, "a+");
    if(log->file == NULL)
    {
        fprintf(stderr,
            "ERROR: couldn't open logfile: %s\n"
            "       logging will go to stdout instead\n", strerror(errno));
        log->type = log_STDOUT;
        log->file = stdout;
    }
    return log;
}

下面的函数打印日志, 对于syslog类型优先使用vsyslog, 但如果没有vsyslog接口则使用vsnprintf+syslog完成相同功能.对于非syslog日志, 则自己拼装格式:时间[日志级别], 此处日志级别采用了syslog的级别设定, 从_log_level可以看出. 另外注意, 因为操作的FILE, 所以每次fprintf后立即调用了fflush, 除此之外, log对每行日志做了1024长度的限制, 过长将会被截断.

顺带log_free, 对于syslog调用closelog, 对于FILE则flclose.

void log_write(log_t log, int level, const char *msgfmt, ...)
{
    va_list ap;
    char *pos, message[MAX_LOG_LINE+1];
    int sz, len;
    time_t t;

    if(log && log->type == log_SYSLOG) {
        va_start(ap, msgfmt);
#ifdef HAVE_VSYSLOG
        vsyslog(level, msgfmt, ap);
#else
        len = vsnprintf(message, MAX_LOG_LINE, msgfmt, ap);
        if (len > MAX_LOG_LINE)
            message[MAX_LOG_LINE] = '\0';
        else
            message[len] = '\0';
        syslog(level, "%s", message);
#endif
        va_end(ap);

#ifndef DEBUG
        return;
#endif
    }

    /* timestamp */
    t = time(NULL);
    pos = ctime(&t);
    sz = strlen(pos);
    /* chop off the \n */
    pos[sz-1]=' ';

    /* insert the header */
    len = snprintf(message, MAX_LOG_LINE, "%s[%s] ", pos, _log_level[level]);
    if (len > MAX_LOG_LINE)
        message[MAX_LOG_LINE] = '\0';
    else
        message[len] = '\0';

    /* find the end and attach the rest of the msg */
    for (pos = message; *pos != '\0'; pos++); /*empty statement */
    sz = pos - message;
    va_start(ap, msgfmt);
    vsnprintf(pos, MAX_LOG_LINE - sz, msgfmt, ap);         
    va_end(ap);
#ifndef DEBUG
    if(log && log->type != log_SYSLOG) {
#endif
        if(log && log->file) {
            fprintf(log->file,"%s", message);
            fprintf(log->file, "\n");
            fflush(log->file);
        }
#ifndef DEBUG
    }
#endif

#ifdef DEBUG
    if (!debug_log_target) {
        debug_log_target = stderr;
    }
    /* If we are in debug mode we want everything copied to the stdout */
    if ((log == 0) || (get_debug_flag() && log->type != log_STDOUT)) {
        fprintf(debug_log_target, "%s\n", message);
        fflush(debug_log_target);
    }
#endif /*DEBUG*/
}

void log_free(log_t log) {
    if(log->type == log_SYSLOG)
        closelog();
    else if(log->type == log_FILE)
        fclose(log->file);

    free(log);
}
posted @ 2012-11-19 17:43  xmpp?  阅读(831)  评论(0编辑  收藏  举报