#!/usr/bin/env python # Copyright 2006 Dan Callaghan # See . # This program is free software; you can redistribute it and/or modify it under # the terms of the GNU General Public License as published by the Free Software # Foundation; either version 2 of the License, or (at your option) any later # version. # # This program is distributed in the hope that it will be useful, but WITHOUT # ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS # FOR A PARTICULAR PURPOSE. See the GNU General Public License for more # details. # # The GNU General Public License is available as # . import sys import array import random from datetime import datetime, timedelta import socket import select import cPickle class BitVector: """ A BitVector efficiently represents a fixed-size vector of binary digits.""" @staticmethod def _pos_to_index_bit(pos): # pos must be an integer! return (pos / 8, 2 ** (pos % 8)) def __init__(self, size): # size must be an integer! self.vector = array.array('B', [0] * ((size - 1) / 8 + 1)) def __len__(self): return len(self.vector) * 8 def __setitem__(self, i, x): if x: self.set_bit(i) else: self.clear_bit(i) def __getitem__(self, i): return self.get_bit(i) def get_bit(self, pos): index, bit = BitVector._pos_to_index_bit(pos) return self.vector[index] & bit and 1 or 0 def set_bit(self, pos): index, bit = BitVector._pos_to_index_bit(pos) self.vector[index] |= bit def clear_bit(self, pos): index, bit = BitVector._pos_to_index_bit(pos) self.vector[index] &= ~bit def toggle_bit(self, pos): index, bit = BitVector._pos_to_index_bit(pos) self.vector[index] ^= bit def clear_all(self): self.vector = array.array('B', [0] * len(self.vector)) class SaltedHashFunc: """ A SaltedHashFunc is a callable object that wraps Python's hash(), adding a salt (randomly generated at construction time) and a modulo.""" def __init__(self, m): """ Initialises the SaltedHashFunc (including random salt generation). m is the modulo operand, i.e. range of values the hash function should return.""" self.salt = "%.10f" % random.random() self.m = m def __call__(self, s): return hash(self.salt + str(s)) % self.m class BloomFilter: """ A Bloom filter is a hash-based storage structure (supports insertion of hashable objects). Membership tests are supported, but have a probability of returning a falsely positive result (false negatives are impossible). Object retrieval and removal from the structure are NOT possible. See for details.""" def __init__(self, num_hashes, vector_size, max_insertions=None): """ Initialises the Bloom filter. num_hashes is the number of hash functions to be applied (k), vector_size is the size of the underlying bit vector (m). If max_insertions is given and non-zero, the Bloom filter will be automatically reset to empty after max_insertions calls to the insert() method.""" self.v = BitVector(vector_size) self.insertion_count = 0 self.max_insertions = max_insertions or 0 self.hash_funcs = [] for i in range(num_hashes): self.hash_funcs.append(SaltedHashFunc(vector_size)) def insert(self, s): """ Inserts a new value into the Bloom filter. The value is converted to a string for insertion.""" self.insertion_count += 1 if self.max_insertions and self.insertion_count > self.max_insertions: self.v.clear_all() self.insertion_count = 0 for hf in self.hash_funcs: self.v.set_bit(hf(s)) def test(self, s): """ Tests for the presence of a value in the Bloom filter. Note that one of the properties of Bloom filters is that this can sometimes return a falsely positive result. False positive rates depend on initialisation variables and hash distribution.""" for hf in self.hash_funcs: if not self.v.get_bit(hf(s)): return False return True class BloomgreyPolicyServer: def __init__(self, exempt_list=None, grey_delay=5, grey_lifetime=120, max_white_entries=100000): """ Initialises the policy server. exempt_list is a list of strings against which the client_address attribute (IP address of the client connected to Postfix) is compared. If found in this list, the request is always allowed. Avoid making this list too large as every request performs a search against it! grey_delay is the number of minutes after first seeing a triple that we will allow it. grey_lifetime is the number of minutes after first seeing a triple (and never again seeing it) that we will forget about it. max_white_entries is the maximum number of entries the underlying Bloom filter will store before it resets itself. For optimum performance this should be set to the number of unique triples your server sees in 36 days. Note that the Bloom filter's memory requirements are directly proportional to this value (1 max_white_entries = 1 byte).""" self.greylist = {} self.grey_delay = grey_delay self.grey_lifetime = grey_lifetime self.last_grey_clean = datetime.now() self.whitelist = BloomFilter(4, max_white_entries * 8, max_white_entries) self.exempt_list = exempt_list or [] def process_triple(self, triple): """ Processes a triple. Any string will be accepted, although convention dictates that it should be "client_address/sender/recipient". Returns True for "allow", False for "reject" (defer). When first seen, a triple is placed on a greylist with the timestamp recorded, and it is rejected. If the triple is then seen more than grey_delay minutes after it was first seen, it is added to the internal whitelist (a Bloom filter) and allowed. Any subsequent checks of that triple will match the whitelist and be allowed. Triples on the greylist expire when their timestamp is more than grey_lifetime minutes old. Note that user-defined whitelist checking occures in process_request().""" # First we check the whitelist to see if the triple is already allowed. if self.whitelist.test(triple): return True # Otherwise we look for the triple on the greylist. if triple in self.greylist: if self.greylist[triple] > (datetime.now() - timedelta(minutes=self.grey_delay)): return False else: del self.greylist[triple] self.whitelist.insert(triple) return True # Otherwise we add the triple to the grey list. self.greylist[triple] = datetime.now() # Clean the greylist if necessary (XXX too wasteful??) if self.last_grey_clean < datetime.now() - timedelta(minutes=self.grey_lifetime/2): self.clean_greylist() self.last_grey_clean = datetime.now() return False def clean_greylist(self): """ Removes expired entries from the greylist. This is expensive, avoid excessive calls.""" for triple_hash in self.greylist.keys(): if self.greylist[triple_hash] < datetime.now() - timedelta(minutes=self.grey_lifetime): del self.greylist[triple_hash] def process_request(self, attrs): """ Processes a dict of request attributes, as provided by Postfix. Apart from user-defined whitelist checking and sanity checking, all the checks are done in process_triple().""" if attrs["request"] != "smtpd_access_policy": sys.stderr.write("Unknown request type %s\n" % attrs["request"]) return True try: if attrs["client_address"] in self.exempt_list: return True except KeyError: pass try: triple = "%s/%s/%s" % (attrs["client_address"], attrs["sender"], attrs["recipient"]) return self.process_triple(triple) except KeyError: sys.stderr.write("Request lacked necessary attributes\n") return True def serve_forever(self, address): """ Binds a socket to the given address and begins serving Postfix policy requests on the socket. If address is a string, we use a UNIX domain socket. If it is a tuple ("ip_address", port) we use a TCP socket. For Postfix policy protocol details see .""" if isinstance(address, str): listensock = socket.socket(socket.AF_UNIX) else: listensock = socket.socket(socket.AF_INET) listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listensock.bind(address) listensock.listen(5) activesocks = [listensock] while True: for sock in select.select(activesocks, [], [])[0]: if sock == listensock: newsock, addr = listensock.accept() activesocks.append(newsock) else: file = sock.makefile("rw", 0) # TODO make this less wasteful? attrs = {} while True: line = file.readline() if not line: # connection was closed activesocks.remove(sock) break if not line.rstrip("\n"): # end of request attributes, if self.process_request(attrs): file.write("action=DUNNO\n\n") else: file.write("action=DEFER_IF_PERMIT Service temporarily unavailable\n\n") break else: # build the attrs dict equals = line.find("=") if equals: attrs[line[0:equals]] = line[equals+1:].rstrip("\n") else: sys.stderr.write("Invalid input line (no equals): %s\n" % line) # Commented out until Py2.5 which will allow array pickling #def save_state(self, file): # cPickle.dump(self, file, cPickle.HIGHEST_PROTOCOL) if __name__ == "__main__": #try: # ps = cPickle.load(open("ps.p", "r")) #except IOError: # assume file not found exempt_list=["204.13.250.91", "204.13.250.92", "204.13.249.91", "204.13.249.92",\ "140.105.134.102"] # lists.gentoo.org ps = BloomgreyPolicyServer(exempt_list=exempt_list, grey_delay=1) try: ps.serve_forever(("127.0.0.1", 10942)) except KeyboardInterrupt: #ps.save_state(open("ps.p", "w")) sys.exit(0)