/* authserv.c -- sample authentication server main program
 *
 * BSD+ License  <http://access1.sun.com/codesamples/BSD.html>
 *
 * Copyright (c) 2002-2004 Sun Microsystems, Inc. All Rights Reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * o  Redistribution of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * o  Redistribution in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in
 *    the documentation and/or other materials provided with the
 *    distribution.
 *
 * Neither the name of Sun Microsystems, Inc. or the names of
 * contributors may be used to endorse or promote products derived
 * from this software without specific prior written permission.
 *
 * This software is provided "AS IS," without a warranty of any
 * kind. ALL EXPRESS OR IMPLIED CONDITIONS, REPRESENTATIONS AND
 * WARRANTIES, INCLUDING ANY IMPLIED WARRANTY OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE OR NON-INFRINGEMENT, ARE HEREBY
 * EXCLUDED. SUN MICROSYSTEMS, INC. ("SUN") AND ITS LICENSORS SHALL NOT
 * BE LIABLE FOR ANY DAMAGES SUFFERED BY LICENSEE AS A RESULT OF USING,
 * MODIFYING OR DISTRIBUTING THIS SOFTWARE OR ITS DERIVATIVES. IN NO
 * EVENT WILL SUN OR ITS LICENSORS BE LIABLE FOR ANY LOST REVENUE,
 * PROFIT OR DATA, OR FOR DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL,
 * INCIDENTAL OR PUNITIVE DAMAGES, HOWEVER CAUSED AND REGARDLESS OF THE
 * THEORY OF LIABILITY, ARISING OUT OF THE USE OF OR INABILITY TO USE
 * THIS SOFTWARE, EVEN IF SUN HAS BEEN ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGES.
 *
 * You acknowledge that this software is not designed, licensed or
 * intended for use in the design, construction, operation or
 * maintenance of any nuclear facility.
 *
 *  Contributor(s): Chris Newman <chris.newman@sun.com>
 *
 ************************************************************************/

#if defined(__sun) && !defined(_REENTRANT)
#define _REENTRANT
#endif

#include <assert.h>
#include <pthread.h>		/* use pthreads */
#include <stdlib.h>		/* malloc/realloc */
#include <unistd.h>             /* sleep */
#include <stdio.h>              /* fprintf */
#include <stdarg.h>             /* vfprintf */
#include <string.h>             /* strerror */
#include <errno.h>              /* errno */
#include <sys/types.h>          /* sockaddr */
#include <sys/socket.h>         /* socket */
#include <sys/uio.h>		/* writev and struct iovec */
#include <netinet/in.h>         /* sockaddr_in */
#include <arpa/inet.h>          /* inet_addr */

#include "authserv.h"

/* this should stay as 127.0.0.1 until protocol is revised to add
 * security (likely using BEEP RFC 3080)
 */
#define BINDIP                      "127.0.0.1"

/* these are good candidates for configuration options
 */
#define BINDPORT                    56
#define LISTEN_QUEUE                1024
#define MAX_CONCURRENT_CONNECTIONS  256
#define MAX_PAYLOAD                 8192
#define MAX_ATTRS                   64
#define MAX_VALS                    128

/* debugging macro */
#if DEBUG
#define DEBUGLOG(arglist) log_write arglist
#else
#define DEBUGLOG(arglist)
#endif

/* error handling macro */
#define FD_TEMPORARY_ERR()	(errno == EINTR || errno == ENOMEM)

/* constant for a carrage-return/line-feed sequence */
static const char crlf[] = "\r\n";

/* server connection context
 */
struct conn {
    int fd;			/* file descriptor */
    unsigned datasize;		/* size of semi-persistant data buffer */
    unsigned replysize;		/* size of transient reply buffer */
    unsigned numattr;		/* number of attributes in reply */
    unsigned numval;		/* number of values in reply */
    unsigned replyok;		/* flag for successful reply sent */
    void *data;			/* semi-persistant data buffer */
    char *reply;		/* transient reply buffer */
    struct memobj *authcontinue; /* context for multi-round trip auth */
    struct conn *next;		/* next connection context */
    struct threadpool *pool;	/* thread pool context */
    char info[64];		/* connection logging information */
    char buf[2048];		/* IO buffer for connection */
    char sentinel;		/* sentinel for IO buffer */
};

/* thread pool */
struct threadpool {
    pthread_mutex_t lock, logm;
    pthread_cond_t cond;
    int threads_total;
    int threads_waiting;
    int exit_flag;
    struct conn *head, *tail, *unused;
    void *context;
};

/*
 * Table for decoding base64
 */
#define XX 127
static const char base64d[256] = {
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,62, XX,XX,XX,63,
    52,53,54,55, 56,57,58,59, 60,61,XX,XX, XX,XX,XX,XX,
    XX, 0, 1, 2,  3, 4, 5, 6,  7, 8, 9,10, 11,12,13,14,
    15,16,17,18, 19,20,21,22, 23,24,25,XX, XX,XX,XX,XX,
    XX,26,27,28, 29,30,31,32, 33,34,35,36, 37,38,39,40,
    41,42,43,44, 45,46,47,48, 49,50,51,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
    XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX, XX,XX,XX,XX,
};
#define CHAR64(c)  (base64d[(unsigned char)(c)])

static const char base64e[] =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
#define BASE64(c)  (base64e[(unsigned char)(c)])
#define B64_LEN(x) (((((x) + 2) / 3) * 4) + 1)

/* strict base64 decode
 */
static int b64_decode(const char *in, int inlen,
		      char *out, int outmax, unsigned *outlen)
{
    unsigned char in1, in2, in3, in4;
    unsigned char *iptr, *optr, *end;

    /* base64 decode the input */
    optr = (unsigned char *) out;
    iptr = (unsigned char *) in;
    if (inlen < 0) inlen = strlen(in);
    if (outmax < 0 && in == (const char *) out) {
	outmax = inlen;
    }
    end = iptr + inlen;
    while (iptr < end) {
	if (iptr + 1 >= end || (in1 = CHAR64(iptr[0])) == XX
	    || (in2 = CHAR64(iptr[1])) == XX) {
	    return (-1);
	}
	if (outmax < 2) return (-1);
	*optr++ = (in1 << 2) + (in2 >> 4);
	if (iptr + 2 >= end || (in3 = CHAR64(iptr[2])) == XX) break;
	if (outmax < 3) return (-1);
	*optr++ = (in2 << 4) | (in3 >> 2);
	if (iptr + 3 >= end || (in4 = CHAR64(iptr[3])) == XX) break;
	if (outmax < 4) return (-1);
	*optr++ = (in3 << 6) | in4;
	iptr += 4;
	outmax -= 3;
    }
    if (outmax < 1) return (-1);
    *optr = '\0';
    if (outlen != (unsigned *)0) *outlen = optr - (unsigned char *) out;

    return (0);
}

/* base64 encoder (no line folding)
 */
static int b64_encode(const char *in, int inlen,
		      char *out, unsigned outmax, unsigned *outlen)
{
    unsigned char *iptr, *optr, *end, *oend;
    unsigned olen;

    /* compute space needed */
    if (inlen < 0) inlen = strlen(in);
    olen = ((inlen + 2) / 3) * 4;

    /* report space needed on error */
    if (outlen != (unsigned *)0) *outlen = olen;
    if (out == (char *)0 || outmax == 0) return (-1);

    /* if input and output are the same, shift input so we have room */
    if (in == (const char *) out && inlen > 0) {
	if (outmax <= olen) return (-1);
	memmove(out + olen - inlen, out, inlen);
	in += olen - inlen;
    }

    /* encode full blocks */
    iptr = (unsigned char *) in;
    end = iptr + inlen;
    optr = (unsigned char *) out;
    oend = optr + outmax;
    while (iptr + 2 < end && optr + 4 < oend) {
	optr[0] = BASE64(iptr[0] >> 2);
	optr[1] = BASE64(((iptr[0] & 0x03) << 4) | (iptr[1] >> 4));
	optr[2] = BASE64(((iptr[1] & 0x0f) << 2) | (iptr[2] >> 6));
	optr[3] = BASE64(iptr[2] & 0x3f);
	optr += 4;
	iptr += 3;
    }

    /* do final partial block */
    if (iptr + 1 < end && optr + 4 < oend) {
	optr[0] = BASE64(iptr[0] >> 2);
	optr[1] = BASE64(((iptr[0] & 0x03) << 4) | (iptr[1] >> 4));
	optr[2] = BASE64((iptr[1] & 0x0f) << 2);
	optr[3] = '=';
	optr += 4;
	iptr += 2;
    } else if (iptr < end && optr + 4 < oend) {
	optr[0] = BASE64(iptr[0] >> 2);
	optr[1] = BASE64((iptr[0] & 0x03) << 4);
	optr[2] = '=';
	optr[3] = '=';
	optr += 4;
	iptr += 1;
    }
    *optr = '\0';
    if (outlen != (unsigned *)0) *outlen = optr - (unsigned char *) out;

    return (iptr < end ? -1 : 0);
}

/* log an error
 */
static void log_write(struct threadpool *pool, const char *fmt, ...)
{
    va_list pvar;
    time_t now;
    struct tm tm;

    time(&now);
    localtime_r(&now, &tm);
    va_start(pvar, fmt);
    if (pool != (struct threadpool *)0) pthread_mutex_lock(&pool->logm);
    fprintf(stderr, "%04d-%02d-%02d %02d:%02d:%02d ",
             tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday,
             tm.tm_hour, tm.tm_min, tm.tm_sec);
    vfprintf(stderr, fmt, pvar);
    if (pool != (struct threadpool *)0) pthread_mutex_unlock(&pool->logm);
    va_end(pvar);
}

/* wait for a thread pool event
 */
static void pool_wait(struct threadpool *pool)
{
    int err;
    
    err = pthread_cond_wait(&pool->cond, &pool->lock);
    if (err != 0) {
        log_write(pool, "FATAL: pthread_cond_wait: %s (%d)\n",
		  strerror(err), err);
        exit(1);
    }
}

/* signal a thread pool event
 */
static void pool_signal(struct threadpool *pool)
{
    int err;
    
    err = pthread_cond_signal(&pool->cond);
    if (err != 0) {
        log_write(pool, "FATAL: pthread_cond_signal: %s (%d)\n",
		  strerror(err), err);
        exit(1);
    }
}

/* lock the thread pool context
 */
static void pool_lock(struct threadpool *pool)
{
    int err;

    err = pthread_mutex_lock(&pool->lock);
    if (err != 0) {
        log_write(pool, "FATAL: pthread_mutex_lock: %s (%d)\n",
		  strerror(err), err);
        exit(1);
    }
}

/* unlock the thread pool context
 */
static void pool_unlock(struct threadpool *pool)
{
    int err;

    err = pthread_mutex_unlock(&pool->lock);
    if (err != 0) {
        log_write(pool, "FATAL: pthread_mutex_unlock: %s (%d)\n",
		  strerror(err), err);
        exit(1);
    }
}

/* retry a writev until all data is written or a fatal error occurs
 *  NOTE: this may not work on HPUX due to bugs in HPUX writev
 */
static long retry_writev(int fd, const struct iovec *iov, int iovcnt)
{
    long n;
    long written = 0;
    struct iovec *iovp;
    struct iovec iov_saved = {0, 0};
    
    while (iovcnt > 0) {
	if (!iov[0].iov_len) {
	    iov++;
	    iovcnt--;
	    continue;
	}
	n = writev(fd, (struct iovec *)iov, iovcnt);
	if (n == -1) {
	    if (FD_TEMPORARY_ERR())
		continue;
#ifdef __hpux
	    /* workaround HP bug, see Sun bugtraq 4560296,4576405 */
	    if (errno == EFAULT)
		continue;
#endif
	    if (iov_saved.iov_base) {
		((struct iovec *)iov)[0] = iov_saved;
	    }
	    return -1;
	}
#ifdef __hpux
	/* workaround HP bug, see Sun bugtraq 4550812 for details */
	if (n == 0) {
	    errno = EPIPE;
	    return -1;
	}
#endif	
	written += n;
	while (n) {
	    if (iov[0].iov_len > (unsigned) n) {
		if (!iov_saved.iov_base) {
		    iov_saved = iov[0];
		}
	    	iovp = (struct iovec *)iov;
		iovp->iov_base = (char *)iovp->iov_base + n;
		iovp->iov_len -= n;
		n = 0;
	    } else {
		n -= iov[0].iov_len;
		if (iov_saved.iov_base) {
		    ((struct iovec *)iov)[0] = iov_saved;
		    iov_saved.iov_base = 0;
		}
		iov++;
		iovcnt--;
	    }
	}
    }
    
    return (written);
}

/* callback to signal authentication failed
 */
static void auth_fail(const struct authdata *adat,
		      int errcode,
		      const char *errlang,
		      const char *errtext)
{
    static const char a_errlang[] = "errlang ";
    static const char a_errtext[] = "errtext ";
    static const char a_errcode[] = "errcode ";
    struct iovec reply[14];
    char head[256], code[10];
    unsigned nio = 1;
    unsigned total = 0;
    int len;
    struct conn *conn;

    if (adat == (const struct authdata *)0 || errcode >= SASL_OK ||
	(errtext != (char *)0 && strchr(errtext, '\r') != (char *)0)) {
	return;
    }
    conn = adat->priv;
    if (errlang != (char *)0) {
	reply[nio].iov_base = (void *) a_errlang;
	total += reply[nio].iov_len = sizeof (a_errlang) - 1;
	reply[nio+1].iov_base = (void *) errlang;
	total += reply[nio+1].iov_len = strlen(errlang);
	reply[nio+2].iov_base = (void *) crlf;
	total += reply[nio+2].iov_len = 2;
	nio += 3;
	++conn->numattr, ++conn->numval;
    }
    if (errtext != (char *)0) {
	reply[nio].iov_base = (void *) a_errtext;
	total += reply[nio].iov_len = sizeof (a_errtext) - 1;
	reply[nio+1].iov_base = (void *) errtext;
	total += reply[nio+1].iov_len = strlen(errtext);
	reply[nio+2].iov_base = (void *) crlf;
	total += reply[nio+2].iov_len = 2;
	nio += 3;
	++conn->numattr, ++conn->numval;
    }
    len = snprintf(code, sizeof (code), "%d", errcode);
    assert(len > 0 && (unsigned) len < sizeof (code));
    reply[nio].iov_base = (void *) a_errcode;
    total += reply[nio].iov_len = sizeof (a_errcode) - 1;
    reply[nio+1].iov_base = code;
    total += reply[nio+1].iov_len = len;
    reply[nio+2].iov_base = (void *) crlf;
    total += reply[nio+2].iov_len = 2;
    reply[nio+3].iov_base = (void *) crlf;
    total += reply[nio+3].iov_len = 2;
    nio += 4;
    ++conn->numattr, ++conn->numval;

    /* format the header */
    reply[0].iov_base = head;
    len = snprintf(head, sizeof (head), "%u %u %u\r\n",
		   total, conn->numattr, conn->numval);
    assert(len > 0 && (unsigned) len < sizeof (head));
    reply[0].iov_len = len;

    /* write the reply */
    len = retry_writev(conn->fd, reply, nio);
    if (len < 0) {
	log_write(conn->pool, "%s: error result from writev: %s (%d)\n",
		  conn->info, strerror(errno), errno);
    } else if (len == 0) {
	log_write(conn->pool, "%s: zero result from writev\n", conn->info);
    } else {
	conn->replyok = 1;
    }

    /* clean up storage */
    if (conn->authcontinue != (struct memobj *)0) {
	conn->authcontinue->destruct(conn->authcontinue);
	conn->authcontinue = (struct memobj *)0;
    }
}

/* callback to signal authentication successful
 */
static void auth_success(const struct authdata *adat,
			 const struct replaydata *rdat)
{
    static const char replayauth[] = "replayauth ";
    static const char replayuser[] = "replayuser ";
    static const char replaypass[] = "replaypass ";
    static const char orig_user[] = "username ";
    static const char orig_authz[] = "authname ";
    static const char sasldata[] = "sasldata ";
    static const char errcode0[] = "errcode 0\r\n\r\n";
    static const char errcode1[] = "errcode 1\r\n\r\n";
    const char *errcode = errcode0;
    struct iovec reply[23];
    char head[256];
    unsigned nio = 1;
    unsigned total = 0;
    unsigned sasllen;
    int len;
    struct conn *conn;
    char *sasl64 = (char *)0;

    if (adat == (const struct authdata *)0) return;
    conn = adat->priv;
    if (rdat != (const struct replydata *)0) {
	if (rdat->authname != (char *)0) {
	    reply[nio].iov_base = (void *) replayauth;
	    total += reply[nio].iov_len = sizeof (replayauth) - 1;
	    reply[nio+1].iov_base = rdat->authname;
	    total += reply[nio+1].iov_len = strlen(rdat->authname);
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}
	if (rdat->username != (char *)0) {
	    reply[nio].iov_base = (void *) replayuser;
	    total += reply[nio].iov_len = sizeof (replayuser) - 1;
	    reply[nio+1].iov_base = rdat->username;
	    total += reply[nio+1].iov_len = strlen(rdat->username);
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}
	if (rdat->password != (char *)0) {
	    reply[nio].iov_base = (void *) replaypass;
	    total += reply[nio].iov_len = sizeof (replaypass) - 1;
	    reply[nio+1].iov_base = rdat->password;
	    total += reply[nio+1].iov_len = strlen(rdat->password);
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}

	/* reply data version 2 changes: */
	if (rdat->authcontinue != (struct memobj *)0) {
	    errcode = errcode1;
	    conn->authcontinue = rdat->authcontinue;
	}
	if (rdat->orig_user != (char *)0) {
	    reply[nio].iov_base = (void *) orig_user;
	    total += reply[nio].iov_len = sizeof (orig_user) - 1;
	    reply[nio+1].iov_base = rdat->orig_user;
	    total += reply[nio+1].iov_len = strlen(rdat->orig_user);
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}
	if (rdat->orig_authz != (char *)0) {
	    reply[nio].iov_base = (void *) orig_authz;
	    total += reply[nio].iov_len = sizeof (orig_authz) - 1;
	    reply[nio+1].iov_base = rdat->orig_authz;
	    total += reply[nio+1].iov_len = strlen(rdat->orig_authz);
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}
	if (rdat->sasldata != (char *)0) {
	    len = B64_LEN(rdat->sasllen);
	    sasl64 = (char *) malloc(len);
	    if (sasl64 == (char *)0) {
		log_write(conn->pool, "%s: malloc %d failed\n",
			  conn->info, len);
		return;
	    }
	    assert(b64_encode(rdat->sasldata, rdat->sasllen,
			      sasl64, len, &sasllen) == 0);
	    reply[nio].iov_base = (void *) sasldata;
	    total += reply[nio].iov_len = sizeof (sasldata) - 1;
	    reply[nio+1].iov_base = sasl64;
	    total += reply[nio+1].iov_len = sasllen;
	    reply[nio+2].iov_base = (void *) crlf;
	    total += reply[nio+2].iov_len = 2;
	    nio += 3;
	    ++conn->numattr, ++conn->numval;
	}
    }
    reply[nio].iov_base = (void *) errcode;
    total += reply[nio].iov_len = strlen(errcode);
    ++conn->numattr, ++conn->numval;
    ++nio;

    /* LDAP attributes, if any */
    if (conn->reply != (char *)0) {
	reply[nio].iov_base = conn->reply;
	total += reply[nio].iov_len = conn->replysize;
	++nio;
    }

    /* format the header */
    reply[0].iov_base = head;
    len = snprintf(head, sizeof (head), "%u %u %u\r\n",
		   total, conn->numattr, conn->numval);
    assert(len > 0 && (unsigned) len < sizeof (head));
    reply[0].iov_len = len;

    /* write the reply */
    len = retry_writev(conn->fd, reply, nio);
    if (len < 0) {
	log_write(conn->pool, "%s: error result from writev: %s (%d)\n",
		  conn->info, strerror(errno), errno);
    } else if (len == 0) {
	log_write(conn->pool, "%s: zero result from writev\n", conn->info);
    } else {
	conn->replyok = 1;
    }

    /* clean up allocated memory */
    if (sasl64 != (char *)0) free(sasl64);
    if (conn->authcontinue != (struct memobj *)0
	&& rdat != (const struct replydata *)0
	&& rdat->authcontinue == (struct memobj *)0) {
	conn->authcontinue->destruct(conn->authcontinue);
	conn->authcontinue = (struct memobj *)0;
    }
}

/* callback to set an LDAP-style attribute
 *  this actually copies and formats the data for the reply protocol
 */
static int set_attr(const struct authdata *adat,
		    const char *attrname,
		    const char **values,
		    unsigned numvals)
{
    unsigned j, total, len;
    struct conn *conn;
    char *reply;

    /* validate parameters */
    if (adat == (const struct authdata *)0 ||
	attrname == (const char *)0 ||
	(numvals > 0 && values == (const char **)0)) {
	return (-1);
    }
    conn = (struct conn *) adat->priv;

    /* determine space needed for protocol */
    len = strlen(attrname);
    total = len + 2;
    for (j = 0; j < numvals; ++j) {
	total += strlen(values[j]) + 2;
    }
    if (j == 0) total += 2;

    /* allocate space for protocol */
    reply = realloc(conn->reply, conn->replysize + total);
    if (reply == (char *)0) return (-1);
    conn->reply = reply;
    reply += conn->replysize;

    /* copy attribute/value data into reply buffer */
    memcpy(reply, attrname, len);
    reply[len] = ' ';
    reply += len + 1;
    for (j = 0; j < numvals; ++j) {
	len = strlen(values[j]);
	memcpy(reply, values[j], len);
	reply[len] = '\r';
	reply[len+1] = '\n';
	reply += len + 2;
    }
    if (j == 0) {
	reply[0] = '\r';
	reply[1] = '\n';
	reply += 2;
    }
    *reply = '\0';
    conn->replysize = reply - conn->reply;
    ++conn->numattr;
    conn->numval += numvals;

    return (0);
}

/* callback to get an LDAP-style attribute
 */
static const char **get_attr(const struct authdata *adat,
			     const char *attrname,
			     unsigned *numvals)
{
    struct conn *conn;
    char **dptr, **dbase;

    /* validate parameters */
    if (adat == (const struct authdata *)0 || attrname == (const char *)0) {
	return ((const char **)0);
    }

    /* loop through LDAP attributes looking for a match */
    conn = (struct conn *) adat->priv;
    dptr = conn->data;
    while (*dptr != (char *)0) {
	dbase = dptr++;
	while (*dptr != (char *)0) ++dptr;
	if (strcasecmp(attrname, *dbase++) == 0) {
	    if (numvals != (unsigned *)0) {
		*numvals = dptr - dbase;
	    } else if (dptr - dbase > 1) {
		break;
	    }
	    return ((const char **) dbase);
	}
	++dptr;
    }

    return ((const char **)0);
}

/* parse an integer from a buffer
 *  returns -1 on error, 0 on success
 */
static int parseint(char **bptr, const char *end, int *val)
{
    char *scan = *bptr;
    
    *val = 0;
    while (scan < end && *scan >= '0' && *scan <= '9') {
        *val = *val * 10 + (*scan - '0');
        ++scan;
    }
    if (scan == *bptr) return (-1);
    if (scan < end && *scan++ != ' ') return (-1);
    *bptr = scan;

    return (0);
}

/* parse an IP address string and port pair into a sockaddr_in
 *  NOTE: needs updating for IPv6
 */
static void parse_addr(char *str, struct sockaddr_in *addr)
{
    char *split;

    split = strchr(str, ' ');
    if (split != (char *)0) *split = '\0';
    addr->sin_addr.s_addr = inet_addr(str);
    if (split != (char *)0) {
	addr->sin_port = htons(atoi(split + 1));
	*split = ' ';
    }
}

/* handle payload of a request
 *  returns -1 on serious parse error
 */
static int handle_payload(struct conn *conn, char *dat, int dsize,
			  int numattr, int numval)
{
    char **dptr, **dbase;
    char *attr, *start, *scan, *datend, *dupcheck, *seclevel;
    int valnum, attnum = 1;
    int ldap_attr = 0, permit_multi;
    struct authdata adat;
    struct sockaddr_in laddr, raddr;
    
    /* initialize authentication data */
    memset(&adat, 0, sizeof (adat));
    memset(&laddr, 0, sizeof (laddr));
    memset(&raddr, 0, sizeof (raddr));
    adat.raddr = (struct sockaddr *) &raddr;
    adat.laddr = (struct sockaddr *) &laddr;

    /* set up pointer to handle attribute/value lists */
    dptr = dbase = conn->data;
    datend = dat + dsize;
    scan = dat;

    while (scan < datend) {
        /* validate attribute count */
        if (numattr-- == 0) {
            log_write(conn->pool, "%s: too many attributes found: %d\n",
		      conn->info, attnum);
            return (-1);
        }
        
        /* parse attribute */
        attr = scan;
        while (scan < datend && *scan >= 0x21 && *scan <= 0x7E) {
            ++scan;
        }
        if (scan == datend || *scan != ' ' || attr == scan) {
            log_write(conn->pool, "%s: Parse error at attribute %d\n",
		      conn->info, attnum);
            return (-1);
        }
        *scan = '\0';
        ++scan;
        *dptr++ = attr;

        /* loop over values */
        valnum = 1;
        do {
            /* validate value count */
            if (numval-- == 0) {
                log_write(conn->pool,
		  "%s: too many values found at attribute %d value %d\n",
			  conn->info, attnum, valnum);
                return (-1);
            }

            /* parse value */
            start = scan;
            while (scan < datend && *scan != '\0'
                   && *scan != '\r' && *scan != '\n') {
                ++scan;
            }
            if (scan == datend || *scan != '\r'
                || scan + 1 == datend || scan[1] != '\n') {
                log_write(conn->pool,
			  "%s: Parse error at attribute %d value %d\n",
			  conn->info, attnum, valnum);
                return (-1);
            }
            *scan = '\0';
            scan += 2;
            *dptr++ = start;
            ++valnum;
        } while (scan < datend && *scan == ' ' && ++scan < datend);
        ++attnum;
        *dptr++ = (char *)0;

        /* check for defined attributes */
        if (ldap_attr == 0) {
            permit_multi = 1;
            dupcheck = (char *)0;
            switch (*attr) {
              case 'a':
                if (strcmp(attr, "authname") != 0) break;
                dupcheck = adat.authname;
                adat.authname = dbase[1];
                permit_multi = 0;
                break;

              case 'l':
                if (strcmp(attr, "localaddr") == 0) {
		    dupcheck = adat.localaddr;
		    adat.localaddr = dbase[1];
		    permit_multi = 0;
		    parse_addr(adat.localaddr, &laddr);
                } else if (strcmp(attr, "lang") == 0) {
                    dupcheck = adat.lang;
                    adat.lang = dbase[1];
                    permit_multi = 0;
                }
                break;

              case 'p':
                if (strcmp(attr, "password") != 0) break;
                dupcheck = adat.password;
                adat.password = dbase[1];
                permit_multi = 0;
                break;

              case 'r':
                if (strcmp(attr, "remoteaddr") != 0) break;
		dupcheck = adat.remoteaddr;
		adat.remoteaddr = dbase[1];
		permit_multi = 0;
		parse_addr(adat.remoteaddr, &raddr);
                break;
                
              case 's':
                if (strcmp(attr, "saslmech") == 0) {
                    dupcheck = adat.saslmech;
                    adat.saslmech = dbase[1];
                    permit_multi = 0;
		} else if (strcmp(attr, "sasldata") == 0) {
		    dupcheck = adat.sasldata;
		    adat.sasldata = dbase[1];
		    permit_multi = 0;
                } else if (strcmp(attr, "seclevel") == 0) {
                    dupcheck = seclevel;
                    seclevel = dbase[1];
                    adat.seclevel = atoi(seclevel);
                    permit_multi = 0;
                } else if (strcmp(attr, "service") == 0) {
                    dupcheck = adat.service;
                    adat.service = dbase[1];
                    permit_multi = 0;
                }
                break;

              case 'u':
                if (strcmp(attr, "username") != 0) break;
                dupcheck = adat.username;
                adat.username = dbase[1];
                permit_multi = 0;
                break;
            }

	    if (dupcheck != (char *)0) {
		log_write(conn->pool, "%s: duplicate attribute '%s'\n",
			  conn->info, attr);
		return (-1);
	    }
	    if (permit_multi == 0 && dbase[2] != (char *)0) {
		log_write(conn->pool,
			  "%s: attribute '%s' must not be multi-valued\n",
			  conn->info, attr);
		return (-1);
	    }
        }
        
        /* don't save defined attributes */
        if (ldap_attr == 0) dptr = dbase;

        /* check for blank line between defined and LDAP attributes */
        if (ldap_attr == 0 && scan < datend && *scan == '\r'
            && scan + 1 < datend && scan[1] == '\n') {
            ldap_attr = 1;
            scan += 2;
        }
    }
    *dptr = '\0';

    if (ldap_attr == 0) {
	log_write(conn->pool,
	  "%s: missing blank line between defined and LDAP attributes\n",
		  conn->info);
	return (-1);
    }
    if (numattr > 0) {
	log_write(conn->pool,
		  "%s: too few attributes found %d of %d expected\n",
		  conn->info, attnum + numattr, numattr);
	return (-1);
    }
    if (numval > 0) {
	log_write(conn->pool,
		  "%s: too few values found; expected %d additional\n",
		  conn->info, numval);
	return (-1);
    }

    /* prepare for reply */
    conn->reply = (char *)0;
    conn->replysize = 0;
    conn->replyok = 0;
    conn->numattr = 0;
    conn->numval = 0;
    adat.version = AUTHDATA_VERSION;
    adat.priv = conn;
    adat.raddrsz = sizeof (raddr);
    adat.laddrsz = sizeof (laddr);
    laddr.sin_family = AF_INET;
    raddr.sin_family = AF_INET;
    adat.get_attr = get_attr;
    adat.set_attr = set_attr;
    adat.auth_success = auth_success;
    adat.auth_fail = auth_fail;
    adat.authcontinue = conn->authcontinue;
    if (adat.saslmech == (char *)0) adat.saslmech = "PLAIN";

    /* decode sasl data */
    if (adat.sasldata != (char *)0
	&& b64_decode(adat.sasldata, -1,
		      adat.sasldata, -1, &adat.sasllen) == -1) {
	/* for base64 decode failure, bypass authentication handler */
	log_write(conn->pool, "%s: Invalid base64 data", conn->info);
	auth_fail(&adat, SASL_BADPROT, NULL, "Invalid base64 data");
    } else {
	/* call custom authentication handler */
	authdat_handler(conn->pool->context, &adat);
    }

    /* cleanup storage */
    if (conn->reply != (char *)0) {
	free(conn->reply);
	conn->reply = (char *)0;
    }
    
    return (conn->replyok ? 0 : -1);
}

static const char *def_mechlist[] = { "PLAIN", 0 };
static const char GREETING[] = "version sample-authserver-v1.1\r\nsaslmech";
#define GREETATTR 2
#define GREETVAL  1

/* handle an authentication server connection
 */
static void handle_connection(struct conn *conn)
{
    int len, outlen, used, headlen, bufsize, err;
    int dsize, numattr, numval, j;
    unsigned datasize;
    char *endline, *scan, *bufend;
    struct authhead *ahead = (struct authhead *) conn->pool->context;
    const char **mechlist = def_mechlist;

    /* get list of SASL mechanisms to send */
    dsize = (int) sizeof (GREETING) - 1;
    numattr = GREETATTR;
    numval = GREETVAL;
    if (ahead != (struct authhead *)0
	&& ahead->id == AUTHHEAD_ID
	&& ahead->version == AUTHHEAD_VERSION) {
	if (ahead->mechlist != (const char **)0
	    && *ahead->mechlist != (const char *)0) {
	    mechlist = ahead->mechlist;
	}
    }
    for (j = 0; mechlist[j] != (const char *)0; ++j) {
	dsize += strlen(mechlist[j]) + 3;
    }
    numval += j;

    /* write greeting */
    len = snprintf(conn->buf, sizeof (conn->buf), "authserver %d %d %d\r\n%s",
		   dsize, numattr, numval, GREETING);
    assert(len > 0 && (unsigned) len < sizeof (conn->buf));
    for (j = 0; mechlist[j] != (const char *)0; ++j) {
	outlen = snprintf(conn->buf + len, sizeof (conn->buf) - len,
			  " %s\r\n", mechlist[j]);
	assert(outlen > 0 && (unsigned) outlen < sizeof (conn->buf) - len);
	len += outlen;
    }
    do {
        outlen = write(conn->fd, conn->buf, len);
    } while (outlen == -1 && FD_TEMPORARY_ERR());
    if (outlen == -1) {
        fprintf(stderr, "%s: greeting write error: %s (%d)\n",
                conn->info, strerror(errno), errno);
        return;
    }
    if (outlen < len) {
        fprintf(stderr, "%s: greeting short write: %d of %d\n",
                conn->info, outlen, len);
        return;
    }

    /* reset state */
    conn->authcontinue = (struct memobj *)0;

    /* handle requests */
    used = 0;
    for (;;) {
        len = read(conn->fd, conn->buf + used, sizeof (conn->buf) - used);
        if (len == 0) return;   /* normal completion */
        if (len == -1) {
            if (errno == EINTR) continue;
            log_write(conn->pool, "%s: read error: %s (%d)\n",
		      conn->info, strerror(errno), errno);
            return;
        }
        
        /* check for complete line */
        used += len;
        bufend = conn->buf + used;
        for (endline = conn->buf; endline < bufend; ++endline) {
            /* endline[1] test is safe due to sentinel */
            if (endline[0] == '\r' && endline[1] == '\n')
                break;
        }
        if (endline == bufend) {
            if (used == sizeof (conn->buf)) {
                log_write(conn->pool, "%s: attempt to overwrite buffer\n",
			  conn->info);
                return;
            }
            continue;
        }

        /* parse header */
        scan = conn->buf;
        if (parseint(&scan, endline, &dsize) == -1
            || parseint(&scan, endline, &numattr) == -1
            || parseint(&scan, endline, &numval) == -1) {
            log_write(conn->pool, "%s: invalid header: %.*s\n",
		      conn->info, endline - conn->buf, conn->buf);
            return;
        }

        /* validate header */
        if (dsize == 0
            || dsize > MAX_PAYLOAD
            || dsize < numattr * 4
            || numattr == 0
            || numattr > MAX_ATTRS
            || numval > MAX_VALS
            || numval < numattr) {
            log_write(conn->pool, "%s: invalid header %d %d %d\n",
		      conn->info, dsize, numattr, numval);
            return;
        }

        /* determine and allocate space for the payload */
        datasize = (numattr * 2 + numval + 1) * sizeof (void *);
        if (sizeof (conn->buf) - headlen <= (unsigned) dsize) {
	    datasize += dsize;
	}
        if (conn->datasize < datasize) {
            conn->data = realloc(conn->data, datasize);
            if (conn->data == (void *)0) {
                conn->datasize = 0;
                log_write(conn->pool, "%s: out of memory\n", conn->info);
                return;
            }
            conn->datasize = datasize;
        }

        /* set up the buffer for the payload */
        headlen = endline + 2 - conn->buf;
        if (sizeof (conn->buf) - headlen <= (unsigned) dsize) {
            scan = (char *) (((void **) conn->data) + numattr * 2 + numval);
            bufsize = dsize;
            if (used > headlen) {
                memcpy(scan, conn->buf + headlen, used - headlen);
                used -= headlen;
            } else {
                used = 0;
            }
        } else {
            scan = conn->buf + headlen;
            used -= headlen;
            bufsize = sizeof (conn->buf) - headlen;
        }

        /* fill in the payload buffer if necessary */
        while (used < dsize) {
            len = read(conn->fd, scan + used, bufsize - used);
            if (len == 0) {
                log_write(conn->pool,
			  "%s: connected closed in payload %d %d %d\n",
			  conn->info, dsize, numattr, numval);
                return;
            }
            if (len == -1) {
                if (errno == EINTR) continue;
                log_write(conn->pool,
			  "%s: read error in payload %d %d %d: %s %d\n",
			  conn->info, dsize, numattr, numval,
			  strerror(errno), errno);
                return;
            }
            used += len;
        }

        /* parse the payload */
        err = handle_payload(conn, scan, dsize, numattr, numval);

        /* if there was an error, drop the connection */
        if (err < 0) return;

        /* it's possible we read multiple requests into the default buffer */
        if (used > dsize) {
            used -= dsize;
            memmove(conn->buf, conn->buf + headlen + dsize, used);
        } else {
            used = 0;
        }
    }
}

/* thread pool handler
 *  Since threads are started on an as-needed basis, the first thing
 *  this does is attempt to pull a connection from the ready list.
 *  If no connection is ready or a connection is completed, this
 *  waits for a new connection to be ready.
 */
static void *threadpool_work(void *arg)
{
    struct threadpool *pool = arg;
    struct conn *conn;

    pool_lock(pool);
    ++pool->threads_total;
    DEBUGLOG((pool, "Thread id %u started\n", (unsigned int) pthread_self()));
    while (!pool->exit_flag) {
        /* pull the connection from the list */
        conn = pool->head;
        if (conn != (struct conn *)0) {
            pool->head = conn->next;
            if (conn->next == (struct conn *)0) {
                pool->tail = (struct conn *)0;
            }

            /* handle the connection */
            pool_unlock(pool);
            DEBUGLOG((pool, "Thread id %u got connection %d\n",
                     (unsigned int) pthread_self(), conn->fd));
            handle_connection(conn);
            close(conn->fd);
            conn->fd = -1;

	    /* clean up lingering allocated memory */
	    if (conn->authcontinue != (struct memobj *)0) {
		conn->authcontinue->destruct(conn->authcontinue);
		conn->authcontinue = (struct memobj *)0;
	    }

            /* push connection context on unused list */
            pool_lock(pool);
            conn->next = pool->unused;
            pool->unused = conn;

            /* dispose connection buffer for idle connections */
            conn = conn->next;
            if (conn != (struct conn *)0 && conn->data != (void *)0) {
                /* wipe potentially sensitive information from memory */
                memset(conn->data, 0, conn->datasize);
                free(conn->data);
                conn->data = (void *)0;
                conn->datasize = 0;
            }
        }
        if (pool->exit_flag) break;

        /* wait for a connection */
        if (pool->threads_waiting++ == 0
         && pool->threads_total >= MAX_CONCURRENT_CONNECTIONS) {
            /* signal if we're the first ready have a backlog */
            pool_signal(pool);
        }
        pool_wait(pool);
        --pool->threads_waiting;
    }
    DEBUGLOG((pool, "Thread id %u exit\n", (unsigned int) pthread_self()));
    --pool->threads_total;
    pool_unlock(pool);

    return ((void *)0);
}

/* connection accept loop
 */
static void do_accept(int sock, struct threadpool *pool)
{
    struct sockaddr_in addr;
    int newsock;
    int err;
    int addrlen;
    struct conn *newconn;
    pthread_t threadid;
    pthread_attr_t tattr;

    newconn = (struct conn *)0;
    for (;;) {
        /* accept new connection */
	addrlen = sizeof (addr);
        newsock = accept(sock, (struct sockaddr *) &addr, &addrlen);
        if (newsock < 1) {
            log_write(pool, "accept: %s (%d)\n", strerror(errno), errno);
            sleep(1);
            continue;
        }

        /* allocate space for connection, if needed */
        while (newconn == (struct conn *)0) {
            newconn = (struct conn *) malloc(sizeof (struct conn));
            if (newconn == (struct conn *)0) {
                log_write(pool, "malloc: out of memory; sleeping\n");
                sleep(1);
                continue;
            }
            newconn->data = (void *)0;
            newconn->datasize = 0;
        }
        newconn->fd = newsock;
        newconn->pool = pool;
        newconn->next = (struct conn *)0;
        newconn->sentinel = '\0';

        /* set connection information */
        snprintf(newconn->info, sizeof (newconn->info),
                 "%d %d.%d.%d.%d %d",
                 newsock,
		 ((char *) &addr.sin_addr.s_addr)[0],
		 ((char *) &addr.sin_addr.s_addr)[1],
		 ((char *) &addr.sin_addr.s_addr)[2],
		 ((char *) &addr.sin_addr.s_addr)[3],
                 ntohs(addr.sin_port));

        /* access the thread pool */
        pool_lock(pool);

        /* add to end of ready list */
        if (pool->tail == (struct conn *)0) {
            pool->head = pool->tail = newconn;
        } else {
            pool->tail->next = newconn;
            pool->tail = newconn;
        }
        DEBUGLOG((pool, "%s: accepted\n", newconn->info));

        /* wait for completion if necessary */
        if (pool->threads_waiting == 0
         && pool->threads_total >= MAX_CONCURRENT_CONNECTIONS) {
            DEBUGLOG((pool, "Max connections reached, waiting\n"));
            pool_wait(pool);
        }

        /* create a thread if necessary */
        while (pool->threads_waiting == 0) {
	    /* create detached threads to save resources */
	    err = pthread_attr_init(&tattr);
	    if (err != 0) {
		log_write(pool, "pthread_attr_init: %s (%d)\n",
			  strerror(err), err);
		pool_wait(pool);
		continue;
	    }
	    err = pthread_attr_setdetachstate(&tattr, PTHREAD_CREATE_DETACHED);
	    if (err != 0) {
		log_write(pool, "pthread_attr_setdetachstate: %s (%d)\n",
			  strerror(err), err);
		pool_wait(pool);
		continue;
	    }

	    /* create a thread */
            err = pthread_create(&threadid, &tattr,
				 threadpool_work, (void *) pool);
            if (err == 0) {
                DEBUGLOG((pool, "New thread created: %u\n",
                         (unsigned int) threadid));
                break;
            }
            log_write(pool, "pthread_create: %s (%d)\n",
		      strerror(err), err);
            pool_wait(pool);
        }

        /* reuse an old connection context if possible */
        if (pool->unused != (struct conn *)0) {
            newconn = pool->unused;
            pool->unused = newconn->next;
        }

        pool_unlock(pool);

        /* signal thread that connection is ready */
        pool_signal(pool);
    }
}

int main()
{
    struct threadpool pool;
    int sock;
    int on;
    int err;
    struct sockaddr_in addr;

    /* set up connection pool */
    memset(&pool, 0, sizeof (pool));
    pthread_mutex_init(&pool.lock, (pthread_mutexattr_t *)0);
    pthread_mutex_init(&pool.logm, (pthread_mutexattr_t *)0);
    pthread_cond_init(&pool.cond, (pthread_condattr_t *)0);
    
    /* set up address */
    addr.sin_family = AF_INET;
    addr.sin_addr.s_addr = inet_addr(BINDIP);
    addr.sin_port = htons(BINDPORT);

    /* create socket */
    sock = socket(AF_INET, SOCK_STREAM, 0);
    if (sock < 0) {
        perror("socket");
        exit(1);
    }

    /* allow address reuse */
    on = 1;
    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, (char *) &on, sizeof (on));

    /* bind & listen for connections */
    err = bind(sock, (struct sockaddr *) &addr, sizeof (addr));
    if (err < 0) {
        perror("bind");
        exit(1);
    }
    err = listen(sock, LISTEN_QUEUE);
    if (err < 0) {
        perror("listen");
        exit(1);
    }

    /* initialize subsystem */
    err = authdat_init(&pool.context);
    if (err != 0) {
        fprintf(stderr, "FATAL: authdat_init failed\n");
        exit(1);
    }

    /* accept connections */
    do_accept(sock, &pool);

    return (0);
}
