/*
 * Copyright (c) 1993 The Regents of the University of California.
 *
 * See the file "license.terms" for information on usage and redistribution
 * of this file, and for a DISCLAIMER OF ALL WARRANTIES.
 */

/* checks network connections */

/*
 * Solaris 2.x has this really wacky protocol for doing sockets stuff.
 * The OS doesn't implement the sockets calls directly; it implements
 * some other interface.  libc emulates sockets calls, with the use
 * of a sockets STREAMS module pushed on top of /dev/tcp or /dev/udp.
 * 
 * Unfortunately, this means we don't get to monitor and trap a connect()
 * system call with all the relevant arguments.  Instead, we have to monitor
 * the interaction between the libc and the sockets STREAMS module and
 * other OS gunk.  This is a serious pain in the ass, especially since
 * Solaris has an undocumented icky protocol for interfacing those parts.
 * 
 * Gack.
 * 
 * A short synopsis of the relevant bits of a typical protocol run:
 *  socket(family = 2 = AF_INET, type = SOCK_STREAM, 0) -->
 * 	fd = open("/dev/tcp", O_RDWR)
 * 	ioctl(fd, I_STR, (char *) {AF_INET, SOCK_STREAM, ...})
 *  connect(dest_name = {AF_INET, port = 80, sin_addr}) -->
 * 	putmsg(fd, (char *) {len = 36, maxlen, buf}, 0, 0)
 * 	example buf value and translation of it follows:

<00000000	00000010	00000014	FFFFFFFF	00000000
		00020050	8020231F	00000000	00000000>
(T_CONN_REQ	dst_len=16	dst_off=20	opt_len=-1 ***	opt_off=0
     family=2=AF_INET port=80	ip_dst=128.32.35.31
***: OPT_length == -1 == 0xFFFFFFFF is a secret sign of a nonblocking connect)
 
 * Note that the relevant parameters which our access control is based on
 * is spread across multiple system calls.  Blech.  This means that we must
 * maintain per-file-descriptor state across system calls.
 * 
 * I hate security-critical state, but it can't be avoided.  Sigh.
 */

#include "module.h"
#include "version.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <unistd.h>
#include <sys/ioctl.h>
#include <sys/stropts.h>
#include <sys/tiuser.h>
#include <sys/sockmod.h>
#include <sys/filio.h>

extern char **environ, **traced_environ;

extern char    *fetchstr(const int, const char *);


typedef enum {NONE=0, TCP=1, UDP=2, BOTH=TCP|UDP} proto_type;

typedef struct
{
    proto_type ptype;
    unsigned long addr, addr_mask;
    unsigned short port, port_mask;
    action what;
} state_t;

typedef struct s_fd_node_t
{
    int fd;
    int type; /* SOCK_STREAM or SOCK_DGRAM */
    struct s_fd_node_t *next;
} fd_node_t;

fd_node_t *open_fds;

action	putmsg_hook(const struct prstatus *p, int procfd, void *st)
{
    state_t *state = st;
    fd_node_t *fp;
    unsigned short fam;
    int where, len, fd = p->pr_sysarg[0];
    unsigned char buf[36];
    unsigned long dest_addr;
    unsigned short dest_port;

    /* Look at p->pr_sysarg[] and return one of NO_COMMENT, ALLOW,
	DENY, SUPER_ALLOW, SUPER_DENY */
    
    if (!state || !p->pr_sysarg[1]) return(NO_COMMENT);

    /* if fd is not in the list, we won't comment on it */
    for (fp = open_fds; fp; fp=fp->next)
	if (fp->fd == fd)
        	break;
    if (!fp) return(NO_COMMENT);
    switch(fp->type) {
	case SOCK_STREAM: if (!(state->ptype & TCP)) return(NO_COMMENT); break;
	case SOCK_DGRAM:  if (!(state->ptype & UDP)) return(NO_COMMENT); break;
    }

    if (lseek(procfd, p->pr_sysarg[1]+4, SEEK_SET) == -1) return(NO_COMMENT);
    if (read(procfd, &len, 4) < 4) return (NO_COMMENT);
    /* hey, buflen = 36 is the common case */
    if (len != 36) return (NO_COMMENT);
    if (read(procfd, &where, 4) < 4) return (NO_COMMENT);
    if (lseek(procfd, where, SEEK_SET) == -1) return(NO_COMMENT);

    if (read(procfd, buf, len) < len) return (NO_COMMENT);

    /* ok.  now buf contains the contents of the buffer */
    /* T_CONN_REQ == 0 according to <sys/tihdr.h> */
    if (memcmp(buf, "\x00\x00\x00\x00", 4) != 0) return(DENY);

    /* hey, dst_addr_len = 16, dst_addr_off = 20 is the common case */
    if (memcmp(buf+4, "\x00\x00\x00\x10\x00\x00\x00\x14", 8) != 0) return(DENY);

    /* there should be no options.
       but to support netscape, opt_len = -1 (non-blocking connect) allowed. */
    if (memcmp(buf+12, "\x00\x00\x00\x00", 4) != 0 &&
	    memcmp(buf+12, "\xFF\xFF\xFF\xFF", 4) != 0) return(DENY);
    if (memcmp(buf+16, "\x00\x00\x00\x00", 4) != 0) return(DENY);

    memmove(&fam, buf+20, 2);
    memmove(&dest_port, buf+22, 2);
    memmove(&dest_addr, buf+24, 4);

    /* fprintf(stderr, "MDD: connect:putmsg: addr=%ul port = %u\n", dest_addr, (unsigned)dest_port);*/

    if ((fam == AF_INET) &&
	((dest_addr & state->addr_mask) == state->addr) &&
	((dest_port & state->port_mask) == state->port))
    {
	return(state->what);
    }

    return(NO_COMMENT);
}

static action ioctl_push(const struct prstatus *p, int procfd)
{
    char *arg;
    action rv = DENY;

    if (!(p->pr_sysarg[2])) return(NO_COMMENT);
    arg = (char *) fetchstr(procfd, (char *) p->pr_sysarg[2]);
    if (!arg) return(NO_COMMENT);
    if (strcmp(arg, "timod") == 0 || strcmp(arg, "sockmod") == 0)
	rv = ALLOW;
    free(arg);
    return(rv);
}

static action ioctl_soparms(int procfd, fd_node_t *fp)
{
    unsigned int len, timeout, ptr;
    struct si_sockparams parms;

    if (read(procfd, (char *) &timeout, 4) < 4) return(NO_COMMENT);
    if (read(procfd, (char *) &len, 4) < 4) return(NO_COMMENT);
    /* hey, len = 12 is the common case */
    if (len != 12) return(NO_COMMENT);
    if (read(procfd, (char *) &ptr, 4) < 4) return(NO_COMMENT);

    if (lseek(procfd, ptr, SEEK_SET) == -1) return(NO_COMMENT);

    if (read(procfd, (char *) &parms, sizeof(parms)) < sizeof(parms))
	return(NO_COMMENT);

    if (parms.sp_family != AF_INET) return(NO_COMMENT);
    if (parms.sp_protocol != 0) return(NO_COMMENT);
    /*
     * sanity check:
     * check in case they open("/dev/tcp") but do ioctl(... type=SOCK_DGRAM ...)
     */
    if (parms.sp_type != fp->type) return(DENY);

    return(ALLOW);
}


/* MDD */
/*
 * Transport-level interface to sockets.
 */
#include <sys/tihdr.h>
#include <sys/timod.h>
#include <net/if.h>
#include <sys/sockio.h>




static action ioctl_tioptmgmt(int procfd, fd_node_t *f)
{
    unsigned int len, timeout, ptr;
    struct T_optmgmt_req opt_req;

    /* stolen from ioctl_soparms above */
    if (read(procfd, (char *) &timeout, 4) < 4) return(NO_COMMENT);
    if (read(procfd, (char *) &len, 4) < 4) return(NO_COMMENT);
    /* hey, len = 32 is the common case */
    if (len != 32) return(NO_COMMENT);
    if (read(procfd, (char *) &ptr, 4) < 4) return(NO_COMMENT);
    if (lseek(procfd, ptr, SEEK_SET) == -1) return(NO_COMMENT);

    /*
     * The request comes in a T_optmgmt_req struct.
     * (struct defined in <sys/tihdr.h>)
     */
    if(read(procfd,  (char *)&opt_req, sizeof(opt_req)) < sizeof(opt_req)){
      return (NO_COMMENT);
    }

    /*
     * Currently allow only T_OPTMGMT_REQ type. List
     * of request types is in <sys/tihdr.h>
     */
    if(opt_req.PRIM_type == T_OPTMGMT_REQ){
      /* allow all setsockopt calls */
      return ALLOW;
    }
    return NO_COMMENT;
}



static action ioctl_tibind(int procfd, fd_node_t *f,   state_t *state)
{
    unsigned int len, timeout, bind_req_ptr;
    struct T_bind_req bind_req;

    /* stolen from ioctl_soparms above */
    if (read(procfd, (char *) &timeout, 4) < 4) return(NO_COMMENT);
    if (read(procfd, (char *) &len, 4) < 4) return(NO_COMMENT);
    /* hey, len = 32 is the common case for T_optmgmt_req */
    if (len != 32) return(NO_COMMENT);
    if (read(procfd, (char *) &bind_req_ptr, 4) < 4) return(NO_COMMENT);
    if (lseek(procfd, bind_req_ptr, SEEK_SET) == -1) return(NO_COMMENT);

    /*
     * The request comes in a T_bind_req struct.
     * (struct defined in <sys/tihdr.h>)
     */
    if(read(procfd,  (char *)&bind_req, sizeof(bind_req)) < sizeof(bind_req)){
      return (NO_COMMENT);
    }

    /*
     * Currently allow only T_OPTMGMT_REQ type. List
     * of request types is in <sys/tihdr.h>
     */
    if(bind_req.PRIM_type == T_BIND_REQ){
      int namelen;
      int name_ptr;
      struct sockaddr name;
      struct sockaddr_in *nameInet;
      unsigned long *addrAsULongp;

      namelen = bind_req.ADDR_length;
      if(namelen != sizeof(struct sockaddr)){
        return NO_COMMENT;
      }

      name_ptr = bind_req.ADDR_offset + bind_req_ptr;
      if(lseek(procfd, name_ptr, SEEK_SET) == -1) return(NO_COMMENT);

      if(read(procfd, (char *)&name, sizeof(struct sockaddr)) < sizeof(struct sockaddr)){
	return NO_COMMENT;
      }
      if(name.sa_family != AF_INET){
        return NO_COMMENT;
      }
      /* we're now entitled to make the cast */
      nameInet = (struct sockaddr_in *)&name; 

      addrAsULongp = (unsigned long *)(&nameInet->sin_addr);
      if (((*addrAsULongp & state->addr_mask) == state->addr) &&
	  ((unsigned short)nameInet->sin_port & state->port_mask) == state->port)
	{
	  return(state->what);
	}
      return NO_COMMENT;
    }

    return NO_COMMENT;
}


/* end MDD */





action	ioctl_hook(const struct prstatus *p, int procfd, void *st)
{
    state_t *state = (state_t *) st;
    fd_node_t *fp;
    int fd = p->pr_sysarg[0];
    unsigned int whichi = p->pr_sysarg[1], cmd=0;

    /* if fd is not in the list, we won't comment on it */
    for (fp = open_fds; fp; fp=fp->next)
	if (fp->fd == fd)
        	break;
    if (!fp) return(NO_COMMENT);

/* Hrm. Hrm. Hrm.
    switch(fp->type) {
	case SOCK_STREAM: if (!(state->ptype & TCP)) return(NO_COMMENT);
			  break;
	case SOCK_DGRAM:  if (!(state->ptype & UDP)) return(NO_COMMENT);
			  break;
    }
*/



    switch(whichi) {
	case I_PUSH: return(ioctl_push(p, procfd));
	case I_STR: break;
	/* MDD: allow accept() to work */
        case I_FDINSERT: return (ALLOW);
	  /* MDD: allow some other thing that apache needs */
        case SIOCGIFCONF: 
        case SIOCGIFFLAGS: 
	  return(ALLOW);
	default: 
	  return(NO_COMMENT);
    }

    if (!(p->pr_sysarg[2])) return(NO_COMMENT);
    if (lseek(procfd, p->pr_sysarg[2], SEEK_SET) == -1) return(NO_COMMENT);

    if (read(procfd, (char *) &cmd, 4) < 4) return(NO_COMMENT);

    switch(whichi) {

	case I_STR:

	  switch(cmd) {
			case SI_SOCKPARAMS: return(ioctl_soparms(procfd, fp));
			case SI_GETUDATA: return(ALLOW);
			  /* MDD: allow some setsockopt calls*/
	                case TI_OPTMGMT: return (ioctl_tioptmgmt(procfd, fp));
			  /* MDD: allow bind */
	                case TI_BIND:  return (ioctl_tibind(procfd, fp, (state_t *)st));
			  /* MDD: allow listen */
	                case SI_LISTEN: return (ALLOW);
			default: 
			  return(NO_COMMENT);
		    }
	default: return(NO_COMMENT);
    }
}

action	open_exit_hook(const struct prstatus *p, int procfd, void *st)
{
    char *arg;
    int type;
    int rv = p->pr_reg[R_O0];
    fd_node_t *fp;

    if (rv == -1 || rv < 0) return(NO_COMMENT);

    if (!(p->pr_sysarg[0])) return(NO_COMMENT);
    arg = (char *) fetchstr(procfd, (char *) p->pr_sysarg[0]);
    if (!arg) return(NO_COMMENT);

    if (strcmp(arg, "/dev/tcp") == 0 || strcmp(arg, "/dev/ticotsord") == 0)
	type = SOCK_STREAM;
    else if (strcmp(arg, "/dev/udp") == 0 || strcmp(arg, "/dev/ticlts") == 0)
	type = SOCK_DGRAM;
    else {
	free(arg);
	return(NO_COMMENT);
    }
    free(arg);


    /*
     * if fd is already in the list, it must've been closed a while back
     * and the fd number reused, so our list of open_fds is out of date.
     * just update it, if this is the case.
     */
    for (fp = open_fds; fp; fp=fp->next)
	if (fp->fd == rv)
        	break;
    if (fp) { fp->type = type; return(NO_COMMENT); }

    fp = (fd_node_t *) malloc(sizeof(fd_node_t));
    if (!fp) { fprintf(stderr, "out of memory\n"); exit(1); }

    fp->type = type; fp->fd = rv;
    fp->next = open_fds; open_fds = fp;

    return(NO_COMMENT);
}




static int parse_display(unsigned long *addr, unsigned short *port)
{
    char *dpy;
    char *colon;
    int dspnum;
    char **save_environ;

    save_environ = environ;
    environ = traced_environ;
    dpy = getenv("DISPLAY");
    environ = save_environ;
    if (!dpy)
    {
	return 0;
    }

    dpy = strdup(dpy);

    /* Find the colon */
    colon = strchr(dpy, ':');
    if (!colon) { free(dpy); return 1; }
    *colon = '\0';
    ++colon;

    /* Get the display number */
    if (sscanf(colon, "%d", &dspnum) < 1)
    {
	free(dpy); return 1;
    }
    *port = htons(6000+dspnum);

    /* Get the host IP */
    if (*dpy == '\0')
    {
	*addr = inet_addr("127.0.0.1");
    }
    else if (*dpy >= '0' && *dpy <= '9')
    {
	*addr = inet_addr(dpy);
    }
    else
    {
	struct hostent *hst = gethostbyname(dpy);
	if (!hst || !hst->h_addr_list || !*hst->h_addr_list)
	    { free(dpy); return 1; }
	*addr = *(unsigned long *)(*hst->h_addr_list);
    }

    free(dpy);
    return 0;
}

static int parse_addr(char *str, unsigned long *addr, char **next)
{
    int len;
    char c;

    len = 0;
    sscanf(str, "%*3d.%*3d.%*3d.%*3d%n", &len);
    if (len < 7) return 1;
    c = str[len];
    str[len] = 0;
    *addr = inet_addr(str);
    str[len] = c;
    if (next) *next = str+len;
    return 0;
}

static int parse_port(char *str, unsigned short *port, char **next)
{
    int len;
    unsigned short p;

    len = 0;
    sscanf(str, "%hd%n", &p, &len);
    if (len == 0) return 1;
    *port = htons(p);
    if (next) *next = str+len;
    return 0;
}

void *	init(const char *conf_line)
{
    char *addr, *protocol, *addr_mask, *port, *port_mask, *what;
    state_t *state = malloc(sizeof(state_t));
    if (state == NULL) return (INIT_FAIL);

    /* Parse conf_line, fill in state */
    what = strdup(conf_line);
    if (!what) return (INIT_FAIL);

    if (!strncmp(what, "allow ", 6)) {
	state->what = ALLOW;
	protocol = what + 6;
    } else if (!strncmp(what, "deny ", 5)) {
	state->what = DENY;
	protocol = what + 5;
    } else {
	free(what);
	fprintf(stderr, "couldn't find allow/deny keyword\n");
	return (INIT_FAIL);
    }

    /* What's the protocol? */
    if (strncasecmp(protocol, "tcp ", 4) == 0) {
	state->ptype = TCP;
	addr = protocol + 4;
    } else if (strncasecmp(protocol, "udp ", 4) == 0) {
	state->ptype = UDP;
	addr = protocol + 4;
    } else if (strncmp(protocol, "* ", 2) == 0) {
	state->ptype = TCP|UDP;
	addr = protocol + 2;
    } else if (strcmp(protocol, "display") == 0) {
	if (parse_display(&state->addr, &state->port))
	{
	    free(what);
	    return (INIT_FAIL);
	}
	state->ptype = TCP;
	state->addr_mask = -1;
	state->port_mask = -1;
	free(what);
	return state;
    } else {
	fprintf(stderr, "couldn't find tcp/udp/*/display keyword\n");
	free(what);
	return(INIT_FAIL);
    }

    /* Grab an IP address */
    if (parse_addr(addr, &state->addr, &addr_mask))
    {
	free(what);
	return(INIT_FAIL);
    }

    /* Grab the IP mask */
    if (*addr_mask == '/')
    {
	++addr_mask;
	if (parse_addr(addr_mask, &state->addr_mask, &port))
	{
	    free(what);
	    return(INIT_FAIL);
	}
    }
    else
    {
	state->addr_mask = -1;
	port = addr_mask;
    }

    /* Grab the port */
    if (*port == ':')
    {
	++port;
	if (parse_port(port, &state->port, &port_mask))
	{
	    free(what);
	    return(INIT_FAIL);
	}
    }
    else
    {
	state->port = 0;
	state->port_mask = 0;
	free(what);
	return(state);
    }

    /* Grab the port mask */
    if (*port_mask == '/')
    {
	++port_mask;
	if (parse_port(port_mask, &state->port_mask, NULL))
	{
	    free(what);
	    return(INIT_FAIL);
	}
    }
    else
    {
	state->port_mask = -1;
	free(what);
	return state;
    }

    free(what);
    return state;
}

const syscall_entry	entries[] = {
    {SYS_putmsg, FUNC, putmsg_hook},
    {SYS_putpmsg, FUNC, putmsg_hook},
    {SYS_open, EXIT_FUNC, open_exit_hook},
    {SYS_ioctl, FUNC, ioctl_hook}
    };
const int		nentries = sizeof(entries) / sizeof(syscall_entry);
