/*
 * <:copyright-BRCM:2017:DUAL/GPL:standard 
 * 
 *    Copyright (c) 2017 Broadcom 
 *    All Rights Reserved
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License, version 2, as published by
 * the Free Software Foundation (the "GPL").
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * 
 * A copy of the GPL is available at http://www.broadcom.com/licenses/GPLv2.php, or by
 * writing to the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
 * Boston, MA 02111-1307, USA.
 * 
 * :>
 */

#include <linux/kernel.h>
#include <linux/types.h>
#include <linux/percpu.h>
#include <linux/skbuff.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <linux/if_ether.h>
#include <linux/etherdevice.h>
#include <linux/netfilter_ipv6.h>
#include <linux/netlink.h>
#include <linux/proc_fs.h>
#include <linux/udp.h>
#include <net/dst.h>
#include <net/neighbour.h>
#include <net/netfilter/nf_conntrack_acct.h>
#include <net/netfilter/nf_conntrack_ecache.h>

#include <linux/dpi.h>
#include <linux/blog.h>
#include <linux/bcm_netlink.h>
#include <bcmdpi.h>

#include <tdts.h>

#include "dpi_local.h"

/* --- types and constants --- */
#define DPI_DEFAULT_MAX_PKT	15
#define DPI_DEVICE_MAX_PKT	3

struct dpi_event {
	unsigned long		event;
	void			*data;
	atomic_t		*ref;
	struct list_head	node;
};

/* ----- local functions ----- */
static unsigned int dpi_nf_hook(void *priv, struct sk_buff *skb,
				const struct nf_hook_state *state);
static unsigned int dpi_nf_block(void *priv, struct sk_buff *skb,
				 const struct nf_hook_state *state);
static int dpi_queue_notify(unsigned long event, void *data);
static void dpi_process_notifications(struct work_struct *work);

/* ----- global variables ----- */
struct dpi_stats dpi_stats;
struct proc_dir_entry *dpi_dir;
int dpi_enabled = 1;

/* ----- local variables ----- */
#define DECLARE_HOOK(_fun, _pf, _hooknum, _priority) \
	{ \
		.hook		= _fun, \
		.pf		= _pf, \
		.hooknum	= _hooknum, \
		.priority	= _priority, \
	}
static struct nf_hook_ops hooks[] __read_mostly = {
	/* Forward */
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV4, NF_INET_FORWARD, NF_IP_PRI_FILTER),
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_BRIDGE, NF_INET_FORWARD, NF_IP_PRI_FILTER),
#if defined(CONFIG_IPV6)
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV6, NF_INET_FORWARD, NF_IP6_PRI_FILTER),
#endif
	/* Local in */
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV4, NF_INET_LOCAL_IN, NF_IP_PRI_FILTER),
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_BRIDGE, NF_INET_LOCAL_IN, NF_IP_PRI_FILTER),
#if defined(CONFIG_IPV6)
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV6, NF_INET_LOCAL_IN, NF_IP6_PRI_FILTER),
#endif
	/* Local out */
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV4, NF_INET_LOCAL_OUT, NF_IP_PRI_FILTER),
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_BRIDGE, NF_INET_LOCAL_OUT, NF_IP_PRI_FILTER),
#if defined(CONFIG_IPV6)
	DECLARE_HOOK(dpi_nf_hook, NFPROTO_IPV6, NF_INET_LOCAL_OUT, NF_IP6_PRI_FILTER),
#endif
	/* Pre-routing (blocking) */
	DECLARE_HOOK(dpi_nf_block, NFPROTO_IPV4, NF_INET_PRE_ROUTING, NF_IP_PRI_MANGLE),
	DECLARE_HOOK(dpi_nf_block, NFPROTO_BRIDGE, NF_INET_PRE_ROUTING, NF_IP_PRI_MANGLE),
#if defined(CONFIG_IPV6)
	DECLARE_HOOK(dpi_nf_block, NFPROTO_IPV6, NF_INET_PRE_ROUTING, NF_IP6_PRI_MANGLE),
#endif
};
static struct dpi_core_hooks dpi_core_hooks = {
	.delete	= dpi_ct_destroy,
};

static LIST_HEAD(nl_bcast_list);
static int dpi_max_pkt = DPI_DEFAULT_MAX_PKT;
static DEFINE_PER_CPU(tdts_pkt_parameter_t, pkt_params);
static DECLARE_WORK(notify_work, dpi_process_notifications);
static LIST_HEAD(notify_list);
static DEFINE_SPINLOCK(notify_lock);

struct nl_bcast_entry {
	struct sock		*socket;
	struct net		*net;
	struct list_head	node;
};

/* ----- local functions ----- */
static u64 pkt_count(struct nf_conn *ct)
{
	struct nf_conn_acct *acct;
	struct nf_conn_counter *ctr;

	if (!ct)
		return 0;

	acct = nf_conn_acct_find(ct);
	if (!acct)
		return 0;

	ctr = acct->counter;
	return atomic64_read(&ctr[0].packets) + atomic64_read(&ctr[1].packets);
}

static inline int ipv6_is_multicast(struct in6_addr addr)
{
	return ((ntohl(addr.s6_addr32[0]) & 0xFF000000) == 0xFF000000);
}

static inline int ignore_device(struct dpi_classify_parms *p)
{
	struct sk_buff *skb = p->skb;
	int ret = p->is_multicast; /* ignore device if it's mcast, unless one
	                            * of the following conditions is met: */

	/* allow DHCP broadcasts */
	if (is_dhcp(skb) || is_dhcp6(skb))
		ret = 0;

	return ret;
}

static int calculate_lookup_flags(struct dpi_classify_parms *p,
				  int lookup_flags)
{
	struct ethhdr *h = eth_hdr(p->skb);

	/* don't classify packets which have no source MAC */
	if (!skb_mac_header_was_set(p->skb) || is_zero_ether_addr(h->h_source))
		return 0;

	/* don't classify WAN-side devices */
	if (is_netdev_wan(p->skb->dev))
		lookup_flags &= ~SW_DEVID;

	if (!p->ct)
		goto out;

	/* filter out classification based on current classification status */
	if (test_bit(DPI_APPID_STOP_CLASSIFY_BIT, &p->flags) &&
	    !blog_request(FLOWTRACK_ALG_HELPER, p->ct, 0, 0))
		lookup_flags &= ~SW_APP;
	if (test_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &p->flags))
		lookup_flags &= ~SW_DEVID;
#ifdef DPI_URL_RECORD
	if (test_bit(DPI_URL_STOP_CLASSIFY_BIT, &p->flags))
#endif
		lookup_flags &= ~SW_URL_QUERY;

out:
	return lookup_flags;
}

static void update_classification_stats(struct dpi_classify_parms *p,
					int classified)
{
	/* if we have newly changed a classification status, save results */
	unsigned long diff = p->flags ^ p->old_flags;

	if (classified) {
		if (test_bit(DPI_APPID_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.apps.classified++;
		if (test_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.devs.classified++;
#ifdef DPI_URL_RECORD
		if (test_bit(DPI_URL_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.urls.classified++;
#endif
	} else {
		if (!test_bit(DPI_APPID_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.apps.unclassified++;
		if (!test_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.devs.unclassified++;
#ifdef DPI_URL_RECORD
		if (!test_bit(DPI_URL_STOP_CLASSIFY_BIT, &diff))
			dpi_stats.urls.unclassified++;
#endif
	}
}

static void device_save_classification(struct dpi_classify_parms *p)
{
	tdts_pkt_parameter_t *pkt_param = p->pkt_param;
	struct dpi_dev *dev	= p->dev;
	u16 prio		= TDTS_PKT_PARAMETER_RES_DEVID_PRIO(pkt_param);
	int notify		= 0;

	if (dev->prio < prio)
		return;
	if (!TDTS_PKT_PARAMETER_RES_DEVID_DEV_ID(pkt_param))
		return;

	/* save device classification information */
#define update(field, val)			\
	({					\
		typeof(field) _orig = (field);	\
		(field) = (val);		\
		(field != _orig);		\
	})

	notify |= update(dev->dev_id,	TDTS_PKT_PARAMETER_RES_DEVID_DEV_ID(pkt_param));
	notify |= update(dev->category,	TDTS_PKT_PARAMETER_RES_DEVID_DEV_CAT_ID(pkt_param));
	notify |= update(dev->family,	TDTS_PKT_PARAMETER_RES_DEVID_DEV_FAMILY_ID(pkt_param));
	notify |= update(dev->os,	TDTS_PKT_PARAMETER_RES_DEVID_OS_NAME_ID(pkt_param));
	notify |= update(dev->os_class,	TDTS_PKT_PARAMETER_RES_DEVID_OS_CLASS_ID(pkt_param));
	notify |= update(dev->vendor,	TDTS_PKT_PARAMETER_RES_DEVID_VENDOR_ID(pkt_param));
	notify |= update(dev->prio,	prio);

#undef update

	pr_debug("devid:%08x category:%d family:%d os:%d os_class:%d vendor:%d prio:%d hostname:%s\n",
		 dev->dev_id, dev->category, dev->family, dev->os,
		 dev->os_class, dev->vendor, dev->prio, dev->hostname);

	if (notify)
		dpi_queue_notify(DPI_NOTIFY_DEVICE, dev);
}

static void dpi_classify_device(struct dpi_classify_parms *p)
{
	const struct dst_entry *dst	= skb_dst(p->skb);
	struct neighbour *n		= NULL;

	if (!p->lookup_flags || ignore_device(p))
		return;

	/* find or allocate device info */
	if (!p->dev) {
		struct dpi_ip *ip = dpi_ip_find_by_skb(p->skb);

		/*
		 * For saved IPs, use the saved MAC.
		 * For LAN-initiated flows, use the source MAC.
		 * For WAN-initiated flows, we have to try a neighbour lookup
		 * from the destination IP to find a appropriate MAC.
		 */
		if (ip) {
			p->dev = dpi_dev_find_or_alloc(ip->mac);
		} else if (!is_netdev_wan(p->skb->dev)) {
			u8 *mac = eth_hdr(p->skb)->h_source;
			p->dev = dpi_dev_find_or_alloc(mac);
		} else {
			rcu_read_lock_bh();
			if (dst)
				n = dst_neigh_lookup_skb(dst, p->skb);
			if (n && (n->nud_state & NUD_VALID))
				p->dev = dpi_dev_find_or_alloc(n->ha);
			if (n)
				neigh_release(n);
			rcu_read_unlock_bh();
		}

		if (!p->dev)
			return;

		atomic_inc(&p->dev->refcount);
	}

	if (!(p->lookup_flags & SW_DEVID))
		return;

	if (p->pkt_param) {
		dpi_stats.devs.lookups++;

		if (tdts_check_pkt_parameter_res(p->pkt_param,
						 TDTS_RES_TYPE_DEVID))
			dpi_stats.devs.hits++;
		else
			dpi_stats.devs.misses++;

		device_save_classification(p);
	}

	if (p->ct && pkt_count(p->ct) >= DPI_DEVICE_MAX_PKT)
		set_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &p->flags);
}

static void dpi_classify_app(struct dpi_classify_parms *p)
{
	u32 app_id;

	if (!(p->lookup_flags & SW_APP) || (!p->pkt_param))
		return;

	dpi_stats.apps.lookups++;

	/* update flags */
	if (TDTS_PKT_PARAMETER_RES_APPID_CHECK_NOMORE(p->pkt_param) ||
	    TDTS_PKT_PARAMETER_RES_APPID_CHECK_NOINT(p->pkt_param)) {
		set_bit(DPI_APPID_STOP_CLASSIFY_BIT, &p->flags);
		clear_bit(DPI_APPID_ONGOING_BIT, &p->flags);
	}

	app_id = dpi_classify_app_test(p->skb, &p->flags, p->lookup_flags);
	if (app_id)
		goto classify_save;

	if (!tdts_check_pkt_parameter_res(p->pkt_param, TDTS_RES_TYPE_APPID)) {
		dpi_stats.apps.misses++;
		return;
	}
	app_id = dpi_make_app_id(TDTS_PKT_PARAMETER_RES_APPID_CAT_ID(p->pkt_param),
				 TDTS_PKT_PARAMETER_RES_APPID_APP_ID(p->pkt_param),
				 0);
	if (!app_id) {
		dpi_stats.apps.misses++;
		return;
	}

classify_save:
	/* if this is a reclassification, set the current app and appinst to
	 * NULL to search again for new entries */
	if (p->app && p->app->app_id != app_id) {
		pr_debug("reclassify from %08x to %08x\n", p->app->app_id,
			 app_id);
		atomic_dec(&p->app->refcount);
		p->app = NULL;
	}

	/* find or allocate application info */
	if (!p->app) {
		p->app = dpi_app_find_or_alloc(app_id);
		if (!p->app)
			return;
		atomic_inc(&p->app->refcount);
	}

	set_bit(DPI_APPID_IDENTIFIED_BIT, &p->flags);
	if (TDTS_PKT_PARAMETER_RES_APPID_CHECK_FINAL(p->pkt_param)) {
		set_bit(DPI_APPID_FINAL_BIT, &p->flags);
		set_bit(DPI_APPID_STOP_CLASSIFY_BIT, &p->flags);
		clear_bit(DPI_APPID_ONGOING_BIT, &p->flags);
	} else {
		set_bit(DPI_APPID_ONGOING_BIT, &p->flags);
	}
	dpi_stats.apps.hits++;

	pr_debug("appid:%08x [app:%u(%s) cat:%u(%s) beh:%u(%s)] flags:%lx [final:%d nomore:%d noint:%d]\n",
		 p->app->app_id,
		 TDTS_PKT_PARAMETER_RES_APPID_APP_ID(p->pkt_param),
		 TDTS_PKT_PARAMETER_RES_APPID_APP_NAME(p->pkt_param),
		 TDTS_PKT_PARAMETER_RES_APPID_CAT_ID(p->pkt_param),
		 TDTS_PKT_PARAMETER_RES_APPID_CAT_NAME(p->pkt_param),
		 TDTS_PKT_PARAMETER_RES_APPID_BEH_ID(p->pkt_param),
		 TDTS_PKT_PARAMETER_RES_APPID_BEH_NAME(p->pkt_param),
		 p->flags,
		 TDTS_PKT_PARAMETER_RES_APPID_CHECK_FINAL(p->pkt_param) ? 1 : 0,
		 TDTS_PKT_PARAMETER_RES_APPID_CHECK_NOMORE(p->pkt_param) ? 1 : 0,
		 TDTS_PKT_PARAMETER_RES_APPID_CHECK_NOINT(p->pkt_param) ? 1 : 0);
}

static void dpi_classify_appinst(struct dpi_classify_parms *p)
{
	/* reset appinst if app is reclassified */
	if (p->appinst && p->app != p->appinst->app) {
		atomic_dec(&p->appinst->refcount);
		p->appinst = NULL;
	}

	/* if there is no application instance but device info is available,
	 * find or allocate an appinst. */
	if (!p->appinst && p->dev) {
		p->appinst = dpi_appinst_find_or_alloc(p->app, p->dev);
		if (p->appinst)
			atomic_inc(&p->appinst->refcount);
		p->nfct_update = 1;
	}
}

static void dpi_classify_url(struct dpi_classify_parms *p)
{
#ifdef DPI_URL_RECORD
	char *domain;
	int len;

	if (!(p->lookup_flags & SW_URL_QUERY) || (!p->pkt_param))
		return;

	dpi_stats.urls.lookups++;

	domain	= TDTS_PKT_PARAMETER_RES_URL_DOMAIN(p->pkt_param);
	len	= TDTS_PKT_PARAMETER_RES_URL_DOMAIN_LEN(p->pkt_param);
	if (!domain || !len) {
		dpi_stats.urls.misses++;
		return;
	}

	/* find or allocate url info */
	if (!p->url) {
		p->url = dpi_url_find_or_alloc(domain, len);
		if (!p->url)
			return;
		atomic_inc(&p->url->refcount);
	}

	dpi_stats.urls.hits++;

	set_bit(DPI_URL_STOP_CLASSIFY_BIT, &p->flags);
#endif /* DPI_URL_RECORD */
}

static void dpi_classify_oob(struct dpi_classify_parms *p)
{
	if (is_dhcp(p->skb) || is_dhcp6(p->skb))
		dpi_parse_dhcp(p);
}

static void save_classifications(struct dpi_classify_parms *p)
{
	if (!p->ct) {
		if (p->appinst)
			atomic_dec(&p->appinst->refcount);
		if (p->dev)
			atomic_dec(&p->dev->refcount);
		if (p->app)
			atomic_dec(&p->app->refcount);
		if (p->url)
			atomic_dec(&p->url->refcount);
		return;
	}

	/* save new classification data */
	p->ct->bcm_ext.dpi.appinst	= p->appinst;
	p->ct->bcm_ext.dpi.dev		= p->dev;
	p->ct->bcm_ext.dpi.app		= p->app;
	p->ct->bcm_ext.dpi.url		= p->url;
	p->ct->bcm_ext.dpi.flags	= p->flags;
}

static void send_nfct_update(struct dpi_classify_parms *p)
{
#if IS_ENABLED(CONFIG_NF_CONNTRACK_EVENTS)
	struct nf_conn *ctm;

	if (!p->nfct_update || !p->ct)
		return;

	pr_debug("ct %px reporting to userspace\n", p->ct);
	dpi_nf_ct_event_report(p->ct, NETLINK_CB(p->skb).portid);

	ctm = p->ct->master;
	if (!ctm)
		return;
	pr_debug("ct %px reporting to userspace\n", ctm);
	dpi_nf_ct_event_report(ctm, NETLINK_CB(p->skb).portid);
#endif
}

extern int
tdts_shell_tcp_conn_remove(u8 ip_ver, u8 *sip, u8 *dip, u16 sport, u16 dport);

static void dpi_stop_classification(struct dpi_classify_parms *p)
{
	struct sk_buff *skb = p->skb;
	int ret = 0;

	/* ignore all further classifications for this flow */
	set_bit(DPI_CLASSIFICATION_STOP_BIT, &p->flags);

	/* remove connection from Trend's internal database */
	if (cf_l3v4(skb, IPPROTO_TCP, 0, 0))
		ret = tdts_shell_tcp_conn_remove(ip_hdr(skb)->version,
						 (u8*) &ip_hdr(skb)->saddr,
						 (u8*) &ip_hdr(skb)->daddr,
						 ntohs(tcp_hdr(skb)->source),
						 ntohs(tcp_hdr(skb)->dest));
	else if (cf_l3v6(skb, IPPROTO_TCP, 0, 0))
		ret = tdts_shell_tcp_conn_remove(ipv6_hdr(skb)->version,
						 (u8*) &ipv6_hdr(skb)->saddr,
						 (u8*) &ipv6_hdr(skb)->daddr,
						 ntohs(tcp_hdr(skb)->source),
						 ntohs(tcp_hdr(skb)->dest));
}

static void dpi_classify_prepare(struct dpi_classify_parms *p,
				 struct sk_buff *skb, int lookup_flags)
{
	enum ip_conntrack_info ctinfo;
	struct nf_conn *ctm;

	memset(p, 0, sizeof(*p));
	p->skb	= skb;
	p->ct	= nf_ct_get(skb, &ctinfo);

	/* populate classification data based on the conntrack entry */
	if (p->ct) {
		p->appinst	= p->ct->bcm_ext.dpi.appinst;
		p->dev		= p->ct->bcm_ext.dpi.dev;
		p->app		= p->ct->bcm_ext.dpi.app;
		p->url		= p->ct->bcm_ext.dpi.url;
		p->flags	= p->ct->bcm_ext.dpi.flags;

		p->ct_stats_available = nf_ct_net(p->ct)->ct.sysctl_acct;
	}

	/* if the parent conntrack has classification data and we don't, copy
	 * the classification data to the child */
	ctm = p->ct ? p->ct->master : NULL;
	if (ctm) {
		p->appinst	= p->appinst ? : ctm->bcm_ext.dpi.appinst;
		p->dev		= p->dev ? : ctm->bcm_ext.dpi.dev;
		p->app		= p->app ? : ctm->bcm_ext.dpi.app;
		p->url		= p->url ? : ctm->bcm_ext.dpi.url;
		p->flags	= p->flags ? : ctm->bcm_ext.dpi.flags;
		p->nfct_update	= 1;
	}

	/* check for mcast flows */
	if (skb->protocol == htons(ETH_P_IP)) {
		p->is_multicast = ipv4_is_multicast(ip_hdr(skb)->saddr);
		p->is_multicast |= ipv4_is_multicast(ip_hdr(skb)->daddr);
	} else if (skb->protocol == htons(ETH_P_IPV6)) {
		p->is_multicast = ipv6_is_multicast(ipv6_hdr(skb)->saddr);
		p->is_multicast |= ipv6_is_multicast(ipv6_hdr(skb)->daddr);
	}

	p->old_flags	= p->flags;
	p->lookup_flags	= calculate_lookup_flags(p, lookup_flags);
}

static void dpi_invoke_engine(struct dpi_classify_parms *p)
{
	p->pkt_param = &get_cpu_var(pkt_params);

	memset(p->pkt_param, 0, sizeof(*p->pkt_param));
	tdts_init_pkt_parameter(p->pkt_param, p->lookup_flags, 0);

	if (tdts_shell_dpi_l3_skb(p->skb, p->pkt_param)) {
		if (p->pkt_param->results.pkt_decoder_verdict == -1)
			dpi_stats.engine_errors++;
		put_cpu_var(pkt_params);
		p->pkt_param = NULL;
	}
}

static void dpi_classify(struct dpi_classify_parms *p)
{
	if (test_bit(DPI_CLASSIFICATION_STOP_BIT, &p->flags))
		return;

	/* if we have classified too many packets, ignore further */
	if (pkt_count(p->ct) > dpi_max_pkt) {
		pr_debug("ct %px exceeded classification window %d, flags %lx\n",
			 p->ct, dpi_max_pkt, p->flags);

		/* save stats and stop classification */
		update_classification_stats(p, 0);
		dpi_stop_classification(p);
		p->ct->bcm_ext.dpi.flags = p->flags;
		p->nfct_update = 1;
		return;
	}

	if (p->lookup_flags) {
		/* classify packet in DPI Engine */
		dpi_invoke_engine(p);

		dpi_stats.total_lookups++;
		pr_debug("ct:%px flags:%lx lookup:%x pktcnt:%lld\n",
			 p->ct, p->flags, p->lookup_flags, pkt_count(p->ct));
	}

	/* use classification results to update local data */
	dpi_classify_oob(p);
	dpi_classify_device(p);
	dpi_classify_app(p);
	dpi_classify_url(p);
	dpi_classify_appinst(p);

	if (p->pkt_param) {
		put_cpu_var(pkt_params);
		p->pkt_param = NULL;
	}

	save_classifications(p);
	update_classification_stats(p, 1);

	/* if we have not yet identified everything and if it's NOT
	 * a mcast flow, skip accelerating */
	if (p->ct_stats_available && p->lookup_flags && !p->is_multicast) {
		if (!test_bit(DPI_APPID_STOP_CLASSIFY_BIT, &p->flags) ||
		    !test_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &p->flags))
			blog_skip(p->skb, blog_skip_reason_dpi);
	}

	/* if we have finished identifying everything, explicitly stop further
	 * classification calls and remove from Trend internal list (if TCP) */
	if (test_bit(DPI_APPID_STOP_CLASSIFY_BIT, &p->flags) &&
	    test_bit(DPI_DEVID_STOP_CLASSIFY_BIT, &p->flags) &&
	    test_bit(DPI_URL_STOP_CLASSIFY_BIT, &p->flags)) {
		dpi_stop_classification(p);
		if (p->ct)
			p->ct->bcm_ext.dpi.flags = p->flags;
		p->nfct_update = 1;
		return;
	}
}

static void dpi_classify_post(struct dpi_classify_parms *p)
{
	/* update userspace */
	send_nfct_update(p);
}

static int hook_flags(unsigned int hook, struct sk_buff *skb)
{
	int flags = 0;

	switch (hook) {
	case NF_INET_FORWARD:
		flags |= SW_APP | SW_URL_QUERY | SW_DEVID | SW_SSL_DECRYPTED;
		break;

	case NF_INET_LOCAL_IN:
		/* ignore packets from the WAN */
		if (is_netdev_wan(skb->dev))
			break;

		if (is_dhcp(skb) || is_dhcp6(skb))
			flags |= SW_DEVID;

		if (is_dns(skb))
			flags |= SW_APP | SW_URL_QUERY | SW_DEVID;

		if (is_http(skb) || is_https(skb))
			flags |= SW_APP | SW_DEVID;
		break;

	case NF_INET_LOCAL_OUT:
		if (is_dns_reply(skb))
			flags |= SW_APP;
		if (is_dhcp(skb) || is_dhcp6(skb))
			flags |= SW_DEVID;
		break;
	}

	return flags;
}

static unsigned int dpi_nf_hook(void *priv, struct sk_buff *skb,
				const struct nf_hook_state *state)
{
	struct dpi_classify_parms p;
	int flags;

	if (!dpi_enabled)
		goto out;

	if (unlikely(!skb || !skb->dev)) {
		pr_debug("no skb or data, skipping packet\n");
		goto out;
	}

	flags = hook_flags(state->hook, skb);
	dpi_classify_prepare(&p, skb, flags);
	dpi_classify(&p);
	dpi_classify_post(&p);

out:
	return NF_ACCEPT;
}

static unsigned int dpi_nf_block(void *priv, struct sk_buff *skb,
				 const struct nf_hook_state *state)
{
	if (unlikely(!skb || !skb->dev || !skb_nfct(skb)))
		goto out;

	if (test_bit(DPI_CT_BLOCK_BIT, &dpi_info(skb)->flags)) {
		dpi_stats.blocked_pkts++;
		return NF_DROP;
	}

out:
	return NF_ACCEPT;
}


/* ----- netlink handling functions ----- */
#define dpi_nl_data_ptr(_skb)  \
	((void*)(NLMSG_DATA((struct nlmsghdr *)(_skb)->data) + \
		sizeof(struct dpi_nlmsg_hdr)))

static void
dpi_nl_msg_reply(struct sk_buff *orig_skb, int type, int len, void *data)
{
	enum ip_conntrack_info ctinfo;
	struct nf_conn *ct = nf_ct_get(orig_skb, &ctinfo);
	struct net *net = ct ? nf_ct_net(ct) : &init_net;
	struct dpi_nlmsg_hdr *hdr;
	struct nl_bcast_entry *nb;
	struct nlmsghdr *orig_nlh = (struct nlmsghdr *)orig_skb->data;
	struct nlmsghdr *nlh;
	struct sk_buff *skb;
	int pid = orig_nlh->nlmsg_pid;
	int size;

	size	= NLMSG_SPACE(sizeof(*hdr) + len);
	skb	= alloc_skb(size, GFP_ATOMIC);
	if (!skb) {
		pr_err("couldn't allocate memory for skb\n");
		return;
	}

	nlh = nlmsg_put(skb, 0, 0, NLMSG_DONE, size, 0);
	hdr = NLMSG_DATA(nlh);

	hdr->type	= type;
	hdr->length	= len;
	hdr++;

	memcpy(hdr, data, len);
	NETLINK_CB(skb).dst_group = 0;

	list_for_each_entry(nb, &nl_bcast_list, node) {
		if (nb->net != net)
			continue;
		if (netlink_unicast(nb->socket, skb, pid, MSG_DONTWAIT) < 0)
			pr_err("error while sending DPI status\n");
	}
}

static void dpi_nl_process_maxpkt(struct sk_buff *skb)
{
	int *cfg = dpi_nl_data_ptr(skb);

	if (!cfg)
		return;
	dpi_max_pkt = *cfg;
}

static void dpi_nl_process_status(struct sk_buff *skb)
{
	dpi_nl_msg_reply(skb, DPI_NL_STATUS,
			 sizeof(dpi_enabled), &dpi_enabled);
}

static void dpi_nl_handler(struct sk_buff *skb)
{
	struct nlmsghdr *nlh = (struct nlmsghdr *)skb->data;
	struct dpi_nlmsg_hdr *hdr;
	unsigned long flags = 0;

	if (skb->len < NLMSG_SPACE(0))
		return;

	hdr = NLMSG_DATA(nlh);
	if (nlh->nlmsg_len < sizeof(*nlh) || skb->len < nlh->nlmsg_len)
		return;

	switch (hdr->type) {
	case DPI_NL_ENABLE:
		dpi_enabled = 1;
		break;

	case DPI_NL_DISABLE:
		dpi_enabled = 0;
		break;

	case DPI_NL_STATUS:
		dpi_nl_process_status(skb);
		break;

	case DPI_NL_MAXPKT:
		dpi_nl_process_maxpkt(skb);
		break;

	case DPI_NL_RESET_STATS:
		flags = *(unsigned long *)dpi_nl_data_ptr(skb);
		dpi_reset_stats(flags);
		break;

	default:
		pr_debug("unknown msg type '%d'\n", hdr->type);
		break;
	}
}

static int dpi_queue_notify(unsigned long event, void *data)
{
	struct dpi_event *e;

	e = kzalloc(sizeof(*e), GFP_ATOMIC);
	if (!e) {
		pr_err("couldn't allocate dpi event\n");
		goto out;
	}
	e->event	= event;
	e->data		= data;
	if (event == DPI_NOTIFY_DEVICE)
		e->ref	= &(((struct dpi_dev *)data)->refcount);
	INIT_LIST_HEAD(&e->node);

	/* increment refcount, add the new event to the event list, and
	 * schedule processing */
	if (e->ref)
		atomic_inc(e->ref);
	spin_lock_bh(&notify_lock);
	list_add_tail_rcu(&e->node, &notify_list);
	spin_unlock_bh(&notify_lock);

	schedule_work(&notify_work);

out:
	return NOTIFY_OK;
}

static void dpi_process_notifications(struct work_struct *w)
{
	struct dpi_event *e;

	rcu_read_lock();
	while (1) {
		e = list_first_or_null_rcu(&notify_list, struct dpi_event,
					   node);
		if (!e)
			break;
		dpi_notify(e->event, e->data);

		/* remove from the list and cleanup */
		spin_lock_bh(&notify_lock);
		list_del(&e->node);
		spin_unlock_bh(&notify_lock);

		if (e->ref)
			atomic_dec(e->ref);
		kfree(e);
	}
	rcu_read_unlock();
}

/* ----- driver funs ----- */
static int dpi_stat_seq_show(struct seq_file *s, void *v)
{
	seq_printf(s, "%-9s %-10s   %-10s   %-10s   %-10s   %-10s\n",
		   "type", "lookups", "hits", "misses", "classified", "unclassified");
	seq_puts(s, "---------------------------------------------------------------------------\n");
	seq_printf(s, " app      %-10u   %-10u   %-10u   %-10u   %-10u\n",
		   dpi_stats.apps.lookups,
		   dpi_stats.apps.hits,
		   dpi_stats.apps.misses,
		   dpi_stats.apps.classified,
		   dpi_stats.apps.unclassified);
	seq_printf(s, " device   %-10u   %-10u   %-10u   %-10u   %-10u\n",
		   dpi_stats.devs.lookups,
		   dpi_stats.devs.hits,
		   dpi_stats.apps.misses,
		   dpi_stats.devs.classified,
		   dpi_stats.devs.unclassified);
	seq_printf(s, " url      %-10u   %-10u   %-10u   %-10u   %-10u\n",
		   dpi_stats.urls.lookups,
		   dpi_stats.urls.hits,
		   dpi_stats.apps.misses,
		   dpi_stats.urls.classified,
		   dpi_stats.urls.unclassified);
	seq_printf(s, " total    %-10u\n", dpi_stats.total_lookups);
	seq_puts(s, "---------------------------------------------------------------------------\n");
	seq_printf(s, "    apps identified: %u\n", dpi_stats.app_count);
	seq_printf(s, " devices identified: %u\n", dpi_stats.dev_count);
	seq_printf(s, "appinsts identified: %u\n", dpi_stats.appinst_count);
	seq_printf(s, "    urls identified: %u\n", dpi_stats.url_count);
	seq_printf(s, "    packets blocked: %llu\n", dpi_stats.blocked_pkts);
	seq_printf(s, "      engine errors: %u\n", dpi_stats.engine_errors);
	return 0;
}
static int dpi_stat_open(struct inode *inode, struct file *file)
{
	return single_open(file, dpi_stat_seq_show, NULL);
};
static const struct file_operations dpi_stat_fops = {
	.open    = dpi_stat_open,
	.read    = seq_read,
	.llseek  = seq_lseek,
	.release = single_release
};

static int dpi_info_seq_show(struct seq_file *s, void *v)
{
	unsigned int val, timeout;

	seq_printf(s, " DPI engine version : %d.%d.%d%s\n",
		      TMCFG_E_MAJ_VER,
		      TMCFG_E_MID_VER,
		      TMCFG_E_MIN_VER,
		      TMCFG_E_LOCAL_VER);
	tdts_shell_system_setting_tcp_conn_max_get(&val);
	tdts_shell_system_setting_tcp_conn_timeout_get(&timeout);
	seq_printf(s, " Max TCP connections: %-8u  Timeout: %u\n", val,
		      timeout);
	tdts_shell_system_setting_udp_conn_max_get(&val);
	seq_printf(s, " Max UDP connections: %u\n", val);
	return 0;
}
static int dpi_info_open(struct inode *inode, struct file *file)
{
	return single_open(file, dpi_info_seq_show, NULL);
};
static const struct file_operations dpi_info_fops = {
	.open    = dpi_info_open,
	.read    = seq_read,
	.llseek  = seq_lseek,
	.release = single_release
};

static int dpi_pernet_init(struct net *net)
{
	struct nl_bcast_entry *entry;
	struct netlink_kernel_cfg cfg = {
		.groups	= 0,
		.input	= dpi_nl_handler,
	};
	int ret = -ENOMEM;

	entry = kzalloc(sizeof(*entry), GFP_ATOMIC);
	if (!entry) {
		pr_err("cannot allocate socket entry\n");
		goto err;
	}

	ret = nf_register_net_hooks(net, hooks, ARRAY_SIZE(hooks));
	if (ret < 0) {
		pr_err("cannot register netfilter hooks\n");
		goto err_free_entry;
	}

	entry->socket = netlink_kernel_create(net, NETLINK_DPI, &cfg);
	if (!entry->socket) {
		pr_err("failed to create kernel netlink socket\n");
		goto err_unreg_nf_hooks;
	}
	entry->net = net;
	list_add(&entry->node, &nl_bcast_list);

	/* nf_conntrack accounting must be on for dpi to work */
	net->ct.sysctl_acct = 1;

	return 0;

err_unreg_nf_hooks:
	nf_unregister_net_hooks(&init_net, hooks, ARRAY_SIZE(hooks));
err_free_entry:
	kfree(entry);
err:
	return ret;
}

static void dpi_pernet_exit(struct net *net)
{
	struct nl_bcast_entry *entry, *tmp;

	list_for_each_entry_safe(entry, tmp, &nl_bcast_list, node) {
		if (entry->net != net)
			continue;

		netlink_kernel_release(entry->socket);
		list_del(&entry->node);
		kfree(entry);
	}

	nf_unregister_net_hooks(net, hooks, ARRAY_SIZE(hooks));
}

static struct pernet_operations dpi_net_ops = {
	.init	= dpi_pernet_init,
	.exit	= dpi_pernet_exit,
};

static int __init dpicore_init(void)
{
	struct proc_dir_entry *pde;
	int ret = -EINVAL;

	memset(&dpi_stats, 0, sizeof(dpi_stats));

	/* create proc entries */
	dpi_dir = proc_mkdir("dpi", NULL);
	if (!dpi_dir) {
		pr_err("couldn't create dpi proc directory\n");
		goto err;
	}
	pde = proc_create("stats", 0440, dpi_dir, &dpi_stat_fops);
	if (!pde) {
		pr_err("couldn't create proc entry 'stats'\n");
		goto err_free_dpi_dir;
	}
	pde = proc_create("info", 0440, dpi_dir, &dpi_info_fops);
	if (!pde) {
		pr_err("couldn't create proc entry 'info'\n");
		goto err_free_stats;
	}

	ret = dpi_init_tables();
	if (ret < 0)
		goto err_free_info;

	ret = dpi_test_init();
	if (ret < 0)
		goto err_free_tables;

	ret = dpi_core_hooks_register(&dpi_core_hooks);
	if (ret < 0)
		goto err_free_test;

	ret = register_pernet_subsys(&dpi_net_ops);
	if (ret < 0)
		goto err_unreg_dpi_hooks;

	return 0;

err_unreg_dpi_hooks:
	dpi_core_hooks_unregister();
err_free_test:
	dpi_test_exit();
err_free_tables:
	dpi_deinit_tables();
err_free_info:
	remove_proc_entry("info", dpi_dir);
err_free_stats:
	remove_proc_entry("stats", dpi_dir);
err_free_dpi_dir:
	proc_remove(dpi_dir);
err:
	return ret;
}

static void __exit dpicore_exit(void)
{
	struct dpi_event *e, *tmp;

	unregister_pernet_subsys(&dpi_net_ops);
	dpi_core_hooks_unregister();
	nf_unregister_net_hooks(&init_net, hooks, ARRAY_SIZE(hooks));

	spin_lock_bh(&notify_lock);
	list_for_each_entry_safe(e, tmp, &notify_list, node) {
		list_del(&e->node);
		if (e->ref)
			atomic_dec(e->ref);
		kfree(e);
	}
	spin_unlock_bh(&notify_lock);

	dpi_test_exit();
	dpi_deinit_tables();

	/* remove proc entries */
	remove_proc_entry("info", dpi_dir);
	remove_proc_entry("stats", dpi_dir);
	proc_remove(dpi_dir);
}

module_init(dpicore_init);
module_exit(dpicore_exit);
MODULE_LICENSE("GPL");
