/*
 * Copyright (C) 2010  Internet Systems Consortium, Inc. ("ISC")
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND ISC DISCLAIMS ALL WARRANTIES WITH
 * REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY
 * AND FITNESS.  IN NO EVENT SHALL ISC BE LIABLE FOR ANY SPECIAL, DIRECT,
 * INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM
 * LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE
 * OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
 * PERFORMANCE OF THIS SOFTWARE.
 */

/* $Id: pcpproxy.c 1120 2011-01-30 10:16:08Z fdupont $ */

/*
 * Proxy (aka stateful relay) for PCP
 *
 * Francis_Dupont@isc.org, December 2010
 *
 * from trpmp.c/pmpproxy.c
 *
 * usage:
 *  -b <IPv6>: local IPv6 address on the server side: used as the source
 *	address for PCP requests. required
 *
 *  -s <IPv6>: IPv6 address of the server: used as the destination address
 *	for PCP requests. required
 *
 *  -l <IPv4>: local IPv4 address on the client side: clients send
 *	PCP requests to this address. at least one is required
 */

#include <sys/select.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <err.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

/* Time-to-live of redirect entries */

#define TTL	60

/* Tailq doubled linked list macros (from BSD sys/queue.h) */

#define ISC_TAILQ_HEAD(name, type)					\
	struct name {							\
		struct type *tqh_first;					\
		struct type **tqh_last;					\
	}

#define ISC_TAILQ_ENTRY(type)						\
	struct {							\
		struct type *tqe_next;					\
		struct type **tqe_prev;					\
	}

#define ISC_TAILQ_INIT(head)						\
	do {								\
		(head)->tqh_first = NULL;				\
		(head)->tqh_last = &((head)->tqh_first);		\
	} while (0)

#define ISC_TAILQ_FIRST(head)						\
	((head)->tqh_first)

#define ISC_TAILQ_LAST(head, headname)					\
	(*(((struct headname *)((head)->tqh_last))->tqh_last))

#define ISC_TAILQ_NEXT(elm, field)					\
	((elm)->field.tqe_next)

#define ISC_TAILQ_FOREACH(var, head, field)				\
	for ((var) = ISC_TAILQ_FIRST((head));				\
	     (var) != NULL;						\
	     (var) = ISC_TAILQ_NEXT((var), field))

/* Insert is at tail */

#define ISC_TAILQ_INSERT(head, elm, field)				\
	do {								\
		ISC_TAILQ_NEXT((elm), field) = NULL;			\
		(elm)->field.tqe_prev = (head)->tqh_last;		\
		*(head)->tqh_last = (elm);				\
		(head)->tqh_last = &ISC_TAILQ_NEXT((elm), field);	\
	} while (0)

#define ISC_TAILQ_REMOVE(head, elm, field)				\
	do {								\
		if ((ISC_TAILQ_NEXT((elm), field)) != NULL)		\
			ISC_TAILQ_NEXT((elm), field)->field.tqe_prev =	\
				(elm)->field.tqe_prev;			\
		else							\
			(head)->tqh_last = (elm)->field.tqe_prev;	\
		*(elm)->field.tqe_prev = ISC_TAILQ_NEXT((elm), field);	\
	} while (0)

struct lan {				/* per LAN structure */
	ISC_TAILQ_ENTRY(lan) chain;	/* tailq chaining */
	uint32_t addr;			/* IPv4 address on the LAN */
	int fd;				/* LAN socket */
};
ISC_TAILQ_HEAD(, lan) lans;		/* LANs tailq list */

struct rdr {				/* redirect entries */
	ISC_TAILQ_ENTRY(rdr) chain;	/* tailq chaining */
	const struct lan *lan;		/* LAN (requests from, responses to) */
	int fd;				/* server socket */
	uint32_t expire;		/* date of expire */
	uint32_t claddr;		/* client IPv4 address */
	uint16_t clport;		/* request source port */
	uint8_t reqlen;			/* cached request length */
	uint8_t resplen;		/* cached response size */
	uint8_t *request;		/* cached request */
	uint8_t *response;		/* cached response */
};
ISC_TAILQ_HEAD(, rdr) rdrs;		/* redirects tailq list */

char *server, *local;

/* Open a connected socket to the server */

int
setserver(const char *server, const char *local)
{
	struct sockaddr_in6 srv;
	int fd;

	fd = socket(PF_INET6, SOCK_DGRAM, IPPROTO_UDP);
	if (fd < 0)
		err(1, "socket6");
	memset(&srv, 0, sizeof(srv));
	srv.sin6_family = AF_INET6;
	if (inet_pton(AF_INET6, local, &srv.sin6_addr) <= 0)
		errx(1, "bad local \"%s\"", local);
	if (bind(fd, (struct sockaddr *) &srv, sizeof(srv)) < 0)
		err(1, "bind(%s)", local);
	memset(&srv, 0, sizeof(srv));
	srv.sin6_family = AF_INET6;
	if (inet_pton(AF_INET6, server, &srv.sin6_addr) <= 0)
		errx(1, "bad server \"%s\"", server);
	srv.sin6_port = htons(44323);
	if (connect(fd, (struct sockaddr *) &srv, sizeof(srv)) < 0)
		err(1, "connect");
	return fd;
}

/* Create (or get cached) redirect */

struct rdr *
create_rdr(const struct lan *lan, const struct sockaddr_in *from,
	   uint8_t *request, uint8_t reqlen)
{
	struct rdr *r;

	ISC_TAILQ_FOREACH(r, &rdrs, chain) {
		if (memcmp(&r->claddr, &from->sin_addr, 4) != 0)
			continue;
		if (memcmp(&r->clport, &from->sin_port, 2) != 0)
			continue;
		if (r->reqlen != reqlen)
			continue;
		if (memcmp(r->request, request, reqlen) == 0)
			break;
	}
	if (r != NULL)
		return r;

	r = (struct rdr *) malloc(sizeof(*r));
	if (r == NULL)
		err(1, "malloc");
	memset(r, 0, sizeof(*r));
	r->lan = lan;
	r->fd = setserver(server, local);
	r->expire = (uint32_t) time(NULL) + TTL;
	memcpy(&r->claddr, &from->sin_addr, 4);
	memcpy(&r->clport, &from->sin_port, 2);
	r->reqlen = (uint8_t) reqlen;
	r->request = (uint8_t *) malloc(r->reqlen);
	if (r->request == NULL)
		err(1, "malloc");
	memcpy(r->request, request, r->reqlen);
	ISC_TAILQ_INSERT(&rdrs, r, chain);
	return r;
}

/* Expire old redirects. Assume the list is ordered by expire date */

void
expire(void)
{
	struct rdr *r;
	uint32_t now;

	now = (uint32_t) time(NULL);
	while ((r = ISC_TAILQ_FIRST(&rdrs)) != NULL) {
		if (r->expire > now)
			return;
		ISC_TAILQ_REMOVE(&rdrs, r, chain);
		if (r->request != NULL)
			free(r->request);
		if (r->response != NULL)
			free(r->response);
		(void) close(r->fd);
		free(r);
	}
}

/* Create a LAN structure */

void
setlan(const char *lan)
{
	struct lan *l;
	struct sockaddr_in clt;

	l = (struct lan *) malloc(sizeof(*l));
	if (l == NULL)
		err(1, "malloc lan");
	memset(l, 0, sizeof(*l));
	if (inet_pton(AF_INET, lan, &l->addr) <= 0)
		errx(1, "bad lan \"%s\"", lan);
	memset(&clt, 0, sizeof(clt));
	clt.sin_family = AF_INET;
	clt.sin_addr.s_addr = l->addr;
	clt.sin_port = htons(44323);
	l->fd = socket(PF_INET, SOCK_DGRAM, IPPROTO_UDP);
	if (l->fd < 0)
		err(1, "socket(%s)", lan);
	if (bind(l->fd, (struct sockaddr *) &clt, sizeof(clt)) < 0)
		err(1, "bind(%s)", lan);
	ISC_TAILQ_INSERT(&lans, l, chain);
}

/* Get a request from a client on a LAN */

void
fromclient(const struct lan *l)
{
	struct rdr *r;
	uint8_t buf[1040];
	struct sockaddr_in from;
	socklen_t fromlen;
	int cc;

	memset(buf, 0, sizeof(buf));
	memset(&from, 0, sizeof(from));
	fromlen = sizeof(from);
	cc = recvfrom(l->fd, buf, sizeof(buf), 0,
		      (struct sockaddr *) &from, &fromlen);
	if (cc < 0) {
		warn("recvfrom");
		return;
	} else if (cc < 4) {
		warnx("underrun %d", cc);
		return;
	} else if (cc > 1024) {
		warnx("overrun %d", cc);
		/* just let 4 extra octets go through */
		if (cc > 1028)
			cc = 1028;
	}
	r = create_rdr(l, &from, buf, (uint8_t) cc);
	/* cached response? */
	if (r->response != NULL) {
		struct sockaddr_in to;

		cc = r->resplen;
		memcpy(buf, r->response, cc);
		memset(&to, 0, sizeof(to));
		to.sin_family = AF_INET;
		memcpy(&to.sin_port, &r->clport, 2);
		memcpy(&to.sin_addr, &r->claddr, 4);
		if (sendto(r->lan->fd, buf, cc, 0,
			   (struct sockaddr *) &to, sizeof(to)) < 0)
			warn("sendto(cached)");
		return;
	}
	if (send(r->fd, buf, cc, 0) < 0)
		warn("send6");
}

/* Get a response from the server */

void
fromserver(struct rdr *r)
{
	uint8_t buf[1040];
	struct sockaddr_in to;
	int cc;

	memset(buf, 0, sizeof(buf));
	cc = recv(r->fd, buf, sizeof(buf), 0);
	if (cc < 0) {
		warn("recv6");
		return;
	} else if (cc < 8) {
		warnx("underrun6 %d", cc);
		return;
	} else if (cc > 1024) {
		warnx("overrun6 %d", cc);
		/* should be dropped anyway */
		return;
	}
	if ((r->response != NULL) &&
	    ((r->resplen != cc) || (memcmp(r->response, buf, cc) != 0))) {
			/* different response? */
			warnx("response mismatch");
			free(r->response);
			r->response = NULL;
	}
	if (r->response == NULL) {
		r->resplen = cc;
		r->response = (uint8_t *) malloc(r->resplen);
		if (r->response == NULL)
			err(1, "malloc");
		memcpy(r->response, buf, r->resplen);
	}

	memset(&to, 0, sizeof(to));
	to.sin_family = AF_INET;
	memcpy(&to.sin_port, &r->clport, 2);
	memcpy(&to.sin_addr, &r->claddr, 4);
	if (sendto(r->lan->fd, buf, cc, 0,
		   (struct sockaddr *) &to, sizeof(to)) < 0)
		warn("sendto(forward)");
}

/* Main */

int
main(int argc, char *argv[])
{
	struct rdr *r;
	struct lan *l;
	fd_set set;
	struct timeval tv;
	int opt, maxfd;
	extern char *optarg;
	extern int optind;

	ISC_TAILQ_INIT(&lans);
	ISC_TAILQ_INIT(&rdrs);

	while ((opt = getopt(argc, argv, "b:s:l:")) != -1)
		switch (opt) {
		case 'b':
			local = optarg;
			break;
		case 's':
			server = optarg;
			break;
		case 'l':
			setlan(optarg);
			break;
		default:
			errx(1, "usage: -b <local> -s <server> [-l <lan>]+");
		}
	if (optind != argc)
		errx(1, "extra arguments");
	if (server == NULL)
		errx(1, "server is mandatory");
	if (local == NULL)
		errx(1, "local is mandatory");
	if (ISC_TAILQ_FIRST(&lans) == NULL)
		errx(1, "at least one lan is mandatory");

	for (;;) {
		expire();
		tv.tv_sec = 1;
		tv.tv_usec = 0;
		maxfd = 0;
		FD_ZERO(&set);
		ISC_TAILQ_FOREACH(r, &rdrs, chain) {
			FD_SET(r->fd, &set);
			if (r->fd > maxfd)
				maxfd = r->fd;
		}
		ISC_TAILQ_FOREACH(l, &lans, chain) {
			FD_SET(l->fd, &set);
			if (l->fd > maxfd)
				maxfd = l->fd;
		}
		if (select(maxfd + 1, &set, NULL, NULL, &tv) < 0)
			err(1, "select");
		ISC_TAILQ_FOREACH(r, &rdrs, chain)
			if (FD_ISSET(r->fd, &set))
			    fromserver(r);
		ISC_TAILQ_FOREACH(l, &lans, chain)
			if (FD_ISSET(l->fd, &set))
				fromclient(l);
	}
	errx(1, "unreachable");
}
