diff --git a/extensions/xt_geoip.c b/extensions/xt_geoip.c index 0b15b83..9c080a7 100644 --- a/extensions/xt_geoip.c +++ b/extensions/xt_geoip.c @@ -113,6 +113,25 @@ static struct geoip_info *find_node(u_int16_t cc) return NULL; } +static bool geoip_bsearch(const struct geoip_subnet *range, + uint32_t addr, int lo, int hi) +{ + int mid; + + if (hi < lo) + return false; + mid = (lo + hi) / 2; + if (range[mid].begin <= addr && addr <= range[mid].end) + return true; + if (range[mid].begin > addr) + return geoip_bsearch(range, addr, lo, mid - 1); + else if (range[mid].end < addr) + return geoip_bsearch(range, addr, mid + 1, hi); + + WARN_ON(true); + return false; +} + static bool xt_geoip_mt(const struct sk_buff *skb, const struct net_device *in, const struct net_device *out, const struct xt_match *match, const void *matchinfo, int offset, unsigned int protoff, bool *hotdrop) @@ -120,7 +139,7 @@ static bool xt_geoip_mt(const struct sk_buff *skb, const struct net_device *in, const struct xt_geoip_match_info *info = matchinfo; const struct geoip_info *node; /* This keeps the code sexy */ const struct iphdr *iph = ip_hdr(skb); - u_int32_t ip, i, j; + uint32_t ip, i; if (info->flags & XT_GEOIP_SRC) ip = ntohl(iph->saddr); @@ -136,12 +155,10 @@ static bool xt_geoip_mt(const struct sk_buff *skb, const struct net_device *in, continue; } - for (j = 0; j < node->count; j++) - if (ip >= node->subnets[j].begin && - ip <= node->subnets[j].end) { - spin_unlock_bh(&geoip_lock); - return (info->flags & XT_GEOIP_INV) ? 0 : 1; - } + if (geoip_bsearch(node->subnets, ip, 0, node->count)) { + spin_unlock_bh(&geoip_lock); + return (info->flags & XT_GEOIP_INV) ? 0 : 1; + } } spin_unlock_bh(&geoip_lock);