/* Copyright (C) 2000-2002 Joakim Axelsson <gozem@linux.nu>
 *                         Patrick Schaaf <bof@bof.de>
 *                         Martin Josefsson <gandalf@wlug.westbo.se>
 * Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

/* Kernel module to match an IP set. */

#include <linux/module.h>
#include <linux/ip.h>
#include <linux/skbuff.h>
#include <linux/version.h>

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,16)
#include <linux/netfilter_ipv4/ip_tables.h>
#define xt_register_match	ipt_register_match
#define xt_unregister_match	ipt_unregister_match
#define xt_match		ipt_match
#else
#include <linux/netfilter/x_tables.h>
#endif
#include <linux/netfilter_ipv4/ip_set.h>
#include <linux/netfilter_ipv4/ipt_set.h>

static inline int
match_set(const struct ipt_set_info *info,
	  const struct sk_buff *skb,
	  int inv)
{	
	if (ip_set_testip_kernel(info->index, skb, info->flags))
		inv = !inv;
	return inv;
}

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,0)
static int
and_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      const void *hdr,
      u_int16_t datalen,
      int *hotdrop) 
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,16)
static int
and_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      int *hotdrop) 
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
static int
and_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      unsigned int protoff,
      int *hotdrop)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,23)
static int
and_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const struct xt_match *match,
      const void *matchinfo,
      int offset,
      unsigned int protoff,
      int *hotdrop)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
static bool
and_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const struct xt_match *match,
      const void *matchinfo,
      int offset, 
      unsigned int protoff, 
      bool *hotdrop)
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,28) */
static bool
and_match(const struct sk_buff *skb,
      const struct xt_match_param *par)
#endif
{
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
	const struct ipt_mset_info_match *info = matchinfo;
#else
	const struct ipt_mset_info_match *info = par->matchinfo;
#endif
	int i;
	for(i = 0; i < MAXSET; i++) {
		if (info->match_set[i].index == IP_SET_INVALID_ID)
			break;
		else if (match_set(&info->match_set[i], skb, info->match_set[i].flags[0] & IPSET_MATCH_INV) == 0)
			return 0;
	}
	return 1;
}

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,0)
static int
or_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      const void *hdr,
      u_int16_t datalen,
      int *hotdrop) 
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,16)
static int
or_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      int *hotdrop) 
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
static int
or_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const void *matchinfo,
      int offset,
      unsigned int protoff,
      int *hotdrop)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,23)
static int
or_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const struct xt_match *match,
      const void *matchinfo,
      int offset,
      unsigned int protoff,
      int *hotdrop)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
static bool
or_match(const struct sk_buff *skb,
      const struct net_device *in,
      const struct net_device *out,
      const struct xt_match *match,
      const void *matchinfo,
      int offset, 
      unsigned int protoff, 
      bool *hotdrop)
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,28) */
static bool
or_match(const struct sk_buff *skb,
      const struct xt_match_param *par)
#endif
{
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
	const struct ipt_mset_info_match *info = matchinfo;
#else
	const struct ipt_mset_info_match *info = par->matchinfo;
#endif
	int i;
	for(i = 0; i < MAXSET; i++) {
		if (info->match_set[i].index == IP_SET_INVALID_ID)
			break;
		else if (match_set(&info->match_set[i], skb, info->match_set[i].flags[0] & IPSET_MATCH_INV) != 0)
			return 1;
	}
	return 0;
}

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,35)
#define CHECK_OK	1
#define CHECK_FAIL	0
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,35) */
#define CHECK_OK	0
#define CHECK_FAIL	-EINVAL
#endif

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,16)
static int
checkentry(const char *tablename,
	   const struct ipt_ip *ip,
	   void *matchinfo,
	   unsigned int matchsize,
	   unsigned int hook_mask)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
static int
checkentry(const char *tablename,
	   const void *inf,
	   void *matchinfo,
	   unsigned int matchsize,
	   unsigned int hook_mask)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,19)
static int
checkentry(const char *tablename,
	   const void *inf,
	   const struct xt_match *match,
	   void *matchinfo,
	   unsigned int matchsize,
	   unsigned int hook_mask)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,23)
static int
checkentry(const char *tablename,
	   const void *inf,
	   const struct xt_match *match,
	   void *matchinfo,
	   unsigned int hook_mask)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
static bool
checkentry(const char *tablename,
	   const void *inf,
	   const struct xt_match *match,
	   void *matchinfo,
	   unsigned int hook_mask)
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,28) */
static bool
checkentry(const struct xt_mtchk_param *par)
#endif
{
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
	struct ipt_mset_info_match *info = matchinfo;
#else
	struct ipt_mset_info_match *info = par->matchinfo;
#endif
	int i, j;
	ip_set_id_t index;
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
	if (matchsize != IPT_ALIGN(sizeof(struct ipt_mset_info_match))) {
		ip_set_printk("invalid matchsize %d", matchsize);
		return 0;
	}
#endif

	if (info->match_set[0].index == IP_SET_INVALID_ID)
		return CHECK_FAIL;

	for(i = 0; i < MAXSET; i++) {
		if (info->match_set[i].index == IP_SET_INVALID_ID) {
			break;
		}
		index = ip_set_get_byindex(info->match_set[i].index);
		if (index == IP_SET_INVALID_ID) {
			ip_set_printk("Cannot find set indentified by id %u to match",
			      info->match_set[i].index);
			goto err;	/* error */
		}
		if (info->match_set[i].flags[IP_SET_MAX_BINDINGS] != 0) {
			ip_set_printk("That's nasty!");
			goto err;	/* error */
		}
	}
	return CHECK_OK;
err:
	for(j = 0; j < i; j++) {
		ip_set_put_byindex(info->match_set[j].index);
	}	
	return CHECK_FAIL;
}

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
static void destroy(void *matchinfo,
		    unsigned int matchsize)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,19)
static void destroy(const struct xt_match *match,
		    void *matchinfo,
		    unsigned int matchsize)
#elif LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
static void destroy(const struct xt_match *match,
		    void *matchinfo)
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,28) */
static void destroy(const struct xt_mtdtor_param *par)
#endif
{
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,28)
	struct ipt_mset_info_match *info = matchinfo;
#else
	struct ipt_mset_info_match *info = par->matchinfo;
#endif
	int i;
#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
	if (matchsize != IPT_ALIGN(sizeof(struct ipt_mset_info_match))) {
		ip_set_printk("invalid matchsize %d", matchsize);
		return;
	}
#endif
	for(i = 0; i < MAXSET; i++) {
		if (info->match_set[i].index < IP_SET_INVALID_ID)
			ip_set_put_byindex(info->match_set[i].index);
		else
			break;
	}
}

#if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,17)
static struct xt_match mset_match[] = {
	{
		.name		= "mset",
		.match		= &and_match,
		.checkentry	= &checkentry,
		.destroy	= &destroy,
		.me		= THIS_MODULE,
	},
	{
		.name		= "mset2",
		.match		= &or_match,
		.checkentry	= &checkentry,
		.destroy	= &destroy,
		.me		= THIS_MODULE,
	},	
};
#else /* LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,17) */
static struct xt_match mset_match[] = {
	{
		.name		= "mset",
		.family		= AF_INET,
		.match		= &and_match,
		.matchsize	= sizeof(struct ipt_mset_info_match),
		.checkentry	= &checkentry,
		.destroy	= &destroy,
		.me		= THIS_MODULE,
	},
	{
		.name		= "mset2",
		.family		= AF_INET,
		.match		= &or_match,
		.matchsize	= sizeof(struct ipt_mset_info_match),
		.checkentry	= &checkentry,
		.destroy	= &destroy,
		.me		= THIS_MODULE,
	},	
};
#endif

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
MODULE_DESCRIPTION("iptables IP set match module");

static int __init ipt_mset_init(void)
{
	return xt_register_matches(mset_match, ARRAY_SIZE(mset_match));
}

static void __exit ipt_mset_fini(void)
{
	xt_unregister_matches(mset_match, ARRAY_SIZE(mset_match));
}

module_init(ipt_mset_init);
module_exit(ipt_mset_fini);
