changeset 20:4b62687da58a

Added batch requests Refactored code Added some extra tests
author Ben Croston <ben@croston.org>
date Sun, 11 Sep 2011 22:27:42 +0100
parents 636405ae3a66
children 6df12f09f4f4
files AuthRPC/client/__init__.py AuthRPC/server/__init__.py AuthRPC/tests.py README.txt setup.py
diffstat 5 files changed, 223 insertions(+), 88 deletions(-) [+]
line wrap: on
line diff
--- a/AuthRPC/client/__init__.py	Sun Sep 11 12:17:44 2011 +0100
+++ b/AuthRPC/client/__init__.py	Sun Sep 11 22:27:42 2011 +0100
@@ -34,6 +34,12 @@
 else:
     IS_PY3 = False
 
+def _encrypt_password(password):
+    if IS_PY3:
+        return hashlib.md5(password.encode()).hexdigest()
+    else:
+        return hashlib.md5(password).hexdigest()
+
 class _Method(object):
     def __init__(self, call, name, username=None, password=None):
         self.call = call
@@ -53,10 +59,6 @@
             params = copy.copy(args)
         elif len(kwargs) > 0:
             params = copy.copy(kwargs)
-            index = 0
-            for arg in args:
-                params[str(index)] = arg
-                index = index + 1
         else:
             params = None
         request['params'] = params
@@ -64,16 +66,17 @@
         if self._username is not None:
             request['username'] = self._username
         if self._password is not None:
-            if IS_PY3:
-                request['password'] = hashlib.md5(self._password.encode()).hexdigest()
-            else:
-                request['password'] = hashlib.md5(self._password).hexdigest()
+            request['password'] = _encrypt_password(self._password)
 
         resp = self.call(json.dumps(request))
-        if resp is not None and resp['error'] is None and resp['id'] == request['id']:
-            return resp['result']
-        else:
-            raise Exception('This is not supposed to happen -- btc') ########
+        if resp is None:
+            raise Exception('Server response is None')
+        if resp['id'] != request['id']:
+            raise Exception('Inconsistent JSON request id returned')
+        if resp['error'] is not None:
+            raise Exception('JSONRPC Server Exception:\n%s'%resp['error']['error'])
+
+        return resp['result']
 
     def __getattr__(self, name):
         return _Method(self.call, "%s.%s" % (self.name, name), self._username, self._password)
@@ -111,8 +114,8 @@
 
 class BadRequestException(Exception):
     """HTTP 400 - Bad Request"""
-    def __init__(self):
-        Exception.__init__(self,'HTTP 400 - Bad Request')
+    def __init__(self, msg=''):
+        Exception.__init__(self,'HTTP 400 - Bad Request\n%s'%msg)
 
 class UnauthorisedException(Exception):
     """HTTP 401 - Unauthorised"""
@@ -138,6 +141,11 @@
     def __init__(self):
         Exception.__init__(self,'HTTP 502 - Bad Gateway')
 
+class InternalServerException(Exception):
+    """HTTP 500 - Internal Server Error"""
+    def __init__(self):
+        Exception.__init__(self,'HTTP 500 - Internal Server Error')
+
 class ProtocolError(Exception):
     """Raised when the JSONRPC protocol has been broken"""
     pass
@@ -159,7 +167,7 @@
         self._username = username
         self._password = password
 
-    def __request(self, request):
+    def _request(self, request):
         # call a method on the remote server
         try:
             response = self.__transport.request(request)
@@ -171,7 +179,10 @@
             else:
                 return json.loads(response.read())
         elif response.status == 400:
-            raise BadRequestException
+            if IS_PY3:
+                raise BadRequestException(response.read().decode())
+            else:
+                raise BadRequestException(response.read())
         elif response.status == 401:
             raise UnauthorisedException
         elif response.status == 403:
@@ -179,11 +190,7 @@
         elif response.status == 404:
             raise NotFoundException
         elif response.status == 500:
-            if IS_PY3:
-                msg = json.loads(response.read().decode())
-            else:
-                msg = json.loads(response.read())
-            raise Exception('JSONRPCError\n%s'%msg['error']['error'])
+            raise InternalServerException
         elif response.status == 502:
             raise BadGatewayException
         else:
@@ -199,5 +206,86 @@
 
     def __getattr__(self, name):
         # magic method dispatcher
-        return _Method(self.__request, name, self._username, self._password)
+        return _Method(self._request, name, self._username, self._password)
+
+
+################# batch calls vvv
+
+class _BatchMethod(object):
+    def __init__(self, name):
+        self.name = name
+        self.request = {}
+
+    @property
+    def id(self):
+        return self.request['id']
+
+    def __call__(self, *args, **kwargs):
+        self.request = {}
+        self.request['method'] = self.name
+        self.request['id'] = str(uuid4())
+        if len(args) > 0 and len(kwargs) > 0:
+            raise ProtocolError('Cannot use both positional and keyword arguments.')
+        if len(args) > 0:
+            self.request['params'] = copy.copy(args)
+        elif len(kwargs) > 0:
+            self.request['params'] = copy.copy(kwargs)
+        else:
+            self.request['params'] = None
+
+    def __getattr__(self, name):
+         new_name = '%s.%s' % (self.name, name)
+         self.name = new_name
+         return self
+
+    def __repr__(self):
+        return json.dumps(self.request)
+
+    __str__ = __repr__
 
+class BatchCall(object):
+    """
+    A class to place a batch of requests in one call
+    """
+    def __init__(self, serverproxy):
+        self._server = serverproxy
+        self._queue = []
+
+    def __getattr__(self, name):
+        """Add the call to the queue"""
+        method = _BatchMethod(name)
+        self._queue.append(method)
+        return method
+
+    def __call__(self):
+        """Process the queue"""
+        requests = []
+        if len(self._queue) < 1:
+            # no calls have been added to batch
+            return
+
+        req = [self._server._username, _encrypt_password(self._server._password)]
+        req = json.dumps(req)
+        req = req[:-1]    # strip trailing ']'
+        req += ', '
+        req += ', '.join(str(q) for q in self._queue)
+        req += ']'
+
+        response = self._server._request(req)
+
+        result = []
+        for i,r in enumerate(response):
+            if r['id'] != self._queue[i].id:
+                raise Exception('Inconsistent JSON request id returned')
+            result.append(r['result'])
+
+        # clear the queue
+        self._queue = []
+
+        return result
+
+    def __repr__(self):
+        return '<BatchCall for %s>' % self._server
+
+    __str__ = __repr__
+
--- a/AuthRPC/server/__init__.py	Sun Sep 11 12:17:44 2011 +0100
+++ b/AuthRPC/server/__init__.py	Sun Sep 11 22:27:42 2011 +0100
@@ -42,8 +42,6 @@
         req = Request(environ)
         try:
             resp = self._process(req)
-        except ValueError, e:
-            resp = exc.HTTPBadRequest(str(e))
         except exc.HTTPException, e:
             resp = e
         return resp(environ, start_response)
@@ -59,54 +57,45 @@
         try:
             json = loads(req.body)
         except ValueError, e:
-            raise ValueError('Bad JSON: %s' % e)
-
-        try:
-            method = json['method']
-            params = json['params']
-            id = json['id']
-            username = json['username'] if 'username' in json else None
-            password = json['password'] if 'password' in json else None
-        except KeyError, e:
-            raise ValueError("JSON body missing parameter: %s" % e)
+            raise exc.HTTPBadRequest('Bad JSON: %s' % e)
 
-        if params is None:
-            params = []
-        if not isinstance(params, list):
-            raise ValueError("Bad params %r: must be a list" % params)
-            text = traceback.format_exc()
-            exc_value = sys.exc_info()[1]
-            error_value = dict(
-                name='JSONRPCError',
-                code=100,
-                message=str(exc_value),
-                error=text)
-            return Response(
-                status=500,
-                content_type='application/json',
-                body=dumps(dict(result=None,
-                                error=error_value,
-                                id=id)))
+        if isinstance(json, list):
+            # batch request
+            try:
+                username = json[0]
+                password = json[1]
+                cmds = json[2:]
+            except IndexError:
+                raise exc.HTTPBadRequest('JSON body missing parameters')
+            self._check_auth(username, password, req.user_agent)
+            result = []
+            for c in cmds:
+                if 'method' in c and 'params' in c and 'id' in c:
+                    self._process_single(c['method'], c['params'], c['id'])
+                else:
+                    raise exc.HTTPBadRequest('JSON body missing parameter')
+                result.append(self._process_single(c['method'], c['params'], c['id']))
+        else:
+            # single request
+            try:
+                username = json['username'] if 'username' in json else None
+                password = json['password'] if 'password' in json else None
+                id = json['id']
+                method = json['method']
+                params = json['params']
+            except KeyError, e:
+                raise exc.HTTPBadRequest("JSON body missing parameter: %s" % e)
+            self._check_auth(username, password, req.user_agent, id)
+            result = self._process_single(method, params, id)
 
-        obj = self.obj
-        if isinstance(self.obj,tuple) or isinstance(self.obj,list):
-            for x in self.obj:
-                if method.startswith('%s.'%x.__class__.__name__):
-                   obj = x
-                   method = method.replace('%s.'%obj.__class__.__name__,'',1)
-                   break
-        elif method.startswith('%s.'%self.obj.__class__.__name__):
-            method = method.replace('%s.'%self.obj.__class__.__name__,'',1)
-        if method.startswith('_'):
-            raise exc.HTTPForbidden("Bad method name %s: must not start with _" % method).exception
-        try:
-            method = getattr(obj, method)
-        except AttributeError:
-            raise ValueError("No such method %s" % method)
+        return Response(content_type='application/json',
+                        body=dumps(result))
 
+
+    def _check_auth(self, username, password, user_agent, id=None):
         if self.auth is not None:
             try:
-                auth_result = self.auth(username, password, req.user_agent)
+                auth_result = self.auth(username, password, user_agent)
             except:
                 text = traceback.format_exc()
                 exc_value = sys.exc_info()[1]
@@ -124,8 +113,37 @@
             if not auth_result:
                 raise exc.HTTPUnauthorized().exception
 
+    def _process_single(self, method, params, id):
+        retval = {}
+        retval['id'] = id
+        retval['result'] = None
+        retval['error'] = None
+
+        if params is None:
+            params = []
+
+        obj = self.obj
+        if isinstance(self.obj, tuple) or isinstance(self.obj, list):
+            for x in self.obj:
+                if method.startswith('%s.'%x.__class__.__name__):
+                   obj = x
+                   method = method.replace('%s.'%obj.__class__.__name__,'',1)
+                   break
+        elif method.startswith('%s.'%self.obj.__class__.__name__):
+            method = method.replace('%s.'%self.obj.__class__.__name__,'',1)
+        if method.startswith('_'):
+            retval['error'] = 'Bad method name %s: must not start with _' % method
+            return retval
         try:
-            result = method(*params)
+            method = getattr(obj, method)
+        except AttributeError, e:
+            raise exc.HTTPBadRequest(str(e))
+
+        try:
+            if isinstance(params, list):
+                retval['result'] = method(*params)
+            else:
+                retval['result'] = method(**params)
         except:
             text = traceback.format_exc()
             exc_value = sys.exc_info()[1]
@@ -134,16 +152,8 @@
                 code=100,
                 message=str(exc_value),
                 error=text)
-            return Response(
-                status=500,
-                content_type='application/json',
-                body=dumps(dict(result=None,
-                                error=error_value,
-                                id=id)))
+            retval['result'] = None
+            retval['error'] = error_value
 
-        return Response(
-            content_type='application/json',
-            body=dumps(dict(result=result,
-                            error=None,
-                            id=id)))
+        return retval
 
--- a/AuthRPC/tests.py	Sun Sep 11 12:17:44 2011 +0100
+++ b/AuthRPC/tests.py	Sun Sep 11 22:27:42 2011 +0100
@@ -48,6 +48,9 @@
     def returnnothing(self):
         pass
 
+    def add(self, a, b):
+        return a+b
+
 def myauth(username, password, useragent=None):
     return username == 'testuser' and \
            hashlib.md5('s3cr3t').hexdigest() == password and \
@@ -118,13 +121,13 @@
 class ExceptionTest(AuthRPCTests):
     def runTest(self):
         with self.assertRaises(Exception):
-            self.client.raiseexception()
+            self.client.api.raiseexception()
 
 class BadRequestTest(AuthRPCTests):
     def runTest(self):
         from client import BadRequestException
         with self.assertRaises(BadRequestException):
-            self.client.FunctionDoesNotExist()
+            self.client.api.FunctionDoesNotExist()
 
 class EchoTest(AuthRPCTests):
     def runTest(self):
@@ -132,18 +135,44 @@
             POUND = '\u00A3'
         else:
             POUND = unicode('\u00A3')
-        self.assertEqual(self.client.echo(POUND), 'ECHO: ' + POUND)
-        self.assertEqual(self.client.echo('hello mum!'), 'ECHO: hello mum!')
+        self.assertEqual(self.client.api.echo(POUND), 'ECHO: ' + POUND)
+        self.assertEqual(self.client.api.echo('hello mum!'), 'ECHO: hello mum!')
+        self.assertEqual(self.client.api.echo(mystring='wibble'), 'ECHO: wibble')
+
+class AddTest(AuthRPCTests):
+    def runTest(self):
+        self.assertEqual(self.client.api.add(12,34), 46)
+        self.assertEqual(self.client.api.add(1.2, 34), 35.2)
 
 class ReturnNothing(AuthRPCTests):
     def runTest(self):
-        self.assertEqual(self.client.returnnothing(), None)
+        self.assertEqual(self.client.api.returnnothing(), None)
 
 class ProtocolErrorTest(AuthRPCTests):
     def runTest(self):
         from client import ProtocolError
         with self.assertRaises(ProtocolError):
-            self.client.test(1, '2', three=3)
+            self.client.api.test(1, '2', three=3)
+
+class BatchTest(AuthRPCTests):
+    def runTest(self):
+        from client import BatchCall
+        batch = BatchCall(self.client)
+        batch.api.echo('One')
+        batch.api.echo(mystring='Two')
+        batch.echo('Three')
+        batch.api.returnnothing()
+        batch.api.add(9,1)
+        batch.api.FunctionDoesNotExist()
+        self.assertEqual(batch(), ['ECHO: One', 'ECHO: Two', 'ECHO: Three', None, 10])
+
+class BadBatchTest(AuthRPCTests):
+    def runTest(self):
+        from client import BatchCall, BadRequestException
+        batch = BatchCall(self.client)
+        batch.api.FunctionDoesNotExist()
+        with self.assertRaises(BadRequestException):
+            batch()
 ##### client ^^^ #####
 
 finished = False
@@ -167,8 +196,12 @@
     suite.addTest(ExceptionTest())
     suite.addTest(BadRequestTest())
     suite.addTest(EchoTest())
+    suite.addTest(AddTest())
     suite.addTest(ReturnNothing())
     suite.addTest(ProtocolErrorTest())
+    suite.addTest(BatchTest())
+    suite.addTest(BadBatchTest())
+    # btc fixme - test a list/tuple of api classes in another server
     return suite
 
 if __name__ == '__main__':
--- a/README.txt	Sun Sep 11 12:17:44 2011 +0100
+++ b/README.txt	Sun Sep 11 22:27:42 2011 +0100
@@ -29,10 +29,14 @@
 
 ::
 
-    from AuthRPC.client import ServerProxy
+    from AuthRPC.client import ServerProxy, BatchCall
     client = ServerProxy('http://localhost:1234/',
                          username='myuser',
                          password='secret',
                          user_agent='myprogram')
     retval = client.do_something('test')
+    batch = BatchCall(client)
+    batch.do_something('call 1')
+    batch.do_something('call 2')
+    batch()
 
--- a/setup.py	Sun Sep 11 12:17:44 2011 +0100
+++ b/setup.py	Sun Sep 11 22:27:42 2011 +0100
@@ -27,12 +27,12 @@
     extra['use_2to3'] = True
 
 setup(name             = 'AuthRPC',
-      version          = '0.0.1a',
+      version          = '0.0.2a',
       packages         = find_packages(exclude=exclude),
       install_requires = install_requires,
       author           = 'Ben Croston',
       author_email     = 'ben@croston.org',
-      description      = 'A JSONRPC-like client and server with additions to enable authentication',
+      description      = 'A JSONRPC-like client and server with additions to enable authenticated requests',
       long_description = open('README.txt').read(),
       license          = 'MIT',
       keywords         = 'json, rpc, wsgi, auth',