]> code.delx.au - bg-scripts/blobdiff - lib/twisted_wget.py
Initial import
[bg-scripts] / lib / twisted_wget.py
diff --git a/lib/twisted_wget.py b/lib/twisted_wget.py
new file mode 100755 (executable)
index 0000000..aca0d55
--- /dev/null
@@ -0,0 +1,145 @@
+#!/usr/bin/env python2.4
+
+import GregDebug, base64, os, sys, urlparse
+
+from twisted.internet import reactor, protocol
+from twisted.web.client import HTTPClientFactory
+from twisted.web.http import HTTPClient
+from twisted.web.client import _parse as parseURL
+
+__all__ = ('downloadURL', )
+
+def parseURL(url, defaultPort = None):
+       """Based on twisted.web.client._parse"""
+       parsed = urlparse.urlparse(url)
+       scheme = parsed[0]
+       path = urlparse.urlunparse(('','')+parsed[2:])
+       if defaultPort is None:
+               if scheme == 'https':
+                       defaultPort = 443
+               else:
+                       defaultPort = 80
+       host, port = parsed[1], defaultPort
+
+       if '@' in host:
+               authUser, host = host.split('@', 1)
+               auth = (authUser, )
+               if ':' in authUser:
+                       auth = tuple(authUser.split(':', 1))
+       else:
+               auth = None
+
+       if ':' in host:
+               host, port = host.rsplit(':', 1)
+               port = int(port)
+
+       return scheme, auth, host, port, path
+
+class HTTPProxyFactory(protocol.ClientFactory):
+       def __init__(self, realFactory, proxyServer, proxyMethod = 'GET', proxyPassword = None):
+               self.realFactory = realFactory
+               self.proxyHost, self.proxyPort = proxyServer
+               self.proxyMethod = proxyMethod
+               self.proxyPassword = proxyPassword
+
+       def buildProtocol(self, addr):
+               protocol = HTTPProxyProtocol(self, self.realFactory.buildProtocol(addr) )
+               return protocol
+
+       def __getattr__(self, key):
+               return getattr(self.realFactory, key)
+
+class HTTPProxyProtocol(protocol.Protocol):
+       def __init__(self, factory, proxied):
+               self.factory = factory
+               self.proxied = proxied
+               self.proxyPassword = factory.proxyPassword
+               if self.proxyPassword is not None:
+                       self.proxyPassword = base64.standard_b64encode('%s:%s' % self.proxyPassword)
+               if factory.proxyMethod == 'GET':
+                       self.__connectionMade = self.__connectionMade_GET
+               else:
+                       raise NotImplementedError
+
+       def __send(self, value):
+               self.transport.write(value)
+
+       def __getTransportWrites(self, function, *args, **kwargs):
+               temp = self.transport.write 
+               request = []
+               self.transport.write = lambda data: request.append(data)
+               function(*args, **kwargs)
+               self.proxied.connectionMade()
+               self.transport.write = temp
+               return request
+
+       def __connectionMade_GET(self):
+               self.factory.realFactory.path = self.factory.realFactory.url
+               self.proxied.makeConnection(self.transport)
+
+               self.__send('GET %s HTTP/1.0\r\n' % self.factory.realFactory.url)
+               if self.proxyPassword is not None:
+                       self.__send('Proxy-Authorization: Basic %s\r\n' % self.proxyPassword)
+
+               # Remove the real http client's get request
+               for line in self.__getTransportWrites(self.proxied.connectionMade)[1:]:
+                       self.__send(line)
+
+       def connectionMade(self):
+               self.proxied.transport = self.transport
+               self.__connectionMade()
+
+       def dataReceived(self, data):
+               self.proxied.dataReceived(data)
+
+       def connectionLost(self, reason):
+               self.proxied.connectionLost(reason)
+
+proxies = {}
+def downloadURL(url, method = 'GET', successBack = None, errorBack = None):
+       factory = HTTPClientFactory(url, method = method)
+       scheme, auth, host, port, path = parseURL(url)
+       if successBack is not None:
+               factory.deferred.addCallback(successBack)
+       if errorBack is not None:
+               factory.deferred.addErrback(errorBack)
+       if scheme in proxies:
+               (host, port), password, factory_type = proxies[scheme]
+               # Change the factory to the proxies one
+               factory = factory_type(realFactory = factory, proxyServer = (host, port), proxyMethod = method, proxyPassword = password) 
+       
+       reactor.connectTCP(host, port, factory)
+       return factory
+
+# Note: Does not currently honor the no-proxy variable
+def parseProxies():
+       for k,v in ( (k,v) for k,v in os.environ.items() if v and k.endswith('_proxy')):
+               proxy_type = k[:-len('_proxy')]
+               if proxy_type == 'http':
+                       _, auth, host, port, _ = parseURL(v)
+                       proxies[proxy_type] = (host, port), auth, HTTPProxyFactory
+
+def main(urls):
+       def summerise(string, summerisedLen = 100):
+               if len(string) <= summerisedLen:
+                       return string
+               else:
+                       summerisedLen -= 5
+                       start = summerisedLen // 2
+                       return '%s ... %s' % (string[:start], string[-(summerisedLen - start):])
+
+       def s(data):
+               print 'Success: "%r"' % summerise(data)
+###            print 'factory: (\n\t%s\n)' % '\n\t'.join('%s:%s' % (attr, getattr(factory, attr)) for attr in dir(factory))
+
+       def e(data):
+               print data
+
+       for url in urls:
+               factory = downloadURL(url, successBack = s, errorBack = e)
+       reactor.run()
+
+# Parse the environment variables for proxy servers
+parseProxies()
+if __name__ == "__main__":
+       main(sys.argv[1:])