Drizzled Public API Documentation

result.py
1 #!/usr/bin/env python
2 #
3 # Drizzle Client & Protocol Library
4 #
5 # Copyright (C) 2008 Eric Day (eday@oddments.org)
6 # All rights reserved.
7 #
8 # Redistribution and use in source and binary forms, with or without
9 # modification, are permitted provided that the following conditions are
10 # met:
11 #
12 # * Redistributions of source code must retain the above copyright
13 # notice, this list of conditions and the following disclaimer.
14 #
15 # * Redistributions in binary form must reproduce the above
16 # copyright notice, this list of conditions and the following disclaimer
17 # in the documentation and/or other materials provided with the
18 # distribution.
19 #
20 # * The names of its contributors may not be used to endorse or
21 # promote products derived from this software without specific prior
22 # written permission.
23 #
24 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
25 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
26 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
27 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
28 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
29 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
30 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
31 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
32 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
33 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
35 #
36 
37 '''
38 MySQL Protocol Result Objects
39 '''
40 
41 import struct
42 import unittest
43 
44 class BadFieldCount(Exception):
45  pass
46 
47 class OkResult(object):
48  '''This class represents an OK result packet sent from the server.'''
49 
50  def __init__(self, packed=None, affected_rows=0, insert_id=0, status=0,
51  warning_count=0, message='', version_40=False):
52  if packed is None:
53  self.affected_rows = affected_rows
54  self.insert_id = insert_id
55  self.status = status
56  self.message = message
57  self.version_40 = version_40
58  if version_40 is False:
59  self.warning_count = warning_count
60  else:
61  self.version_40 = version_40
62  if ord(packed[0]) != 0:
63  raise BadFieldCount('Expected 0, received ' + str(ord(packed[0])))
64  self.affected_rows = ord(packed[1])
65  self.insert_id = ord(packed[2])
66  if version_40 is True:
67  if len(packed) == 3:
68  self.status = 0
69  self.message = ''
70  else:
71  data = struct.unpack('<H', packed[3:5])
72  self.status = data[0]
73  self.message = packed[5:]
74  else:
75  data = struct.unpack('<HH', packed[3:7])
76  self.status = data[0]
77  self.warning_count = data[1]
78  self.message = packed[7:]
79 
80  def __str__(self):
81  if self.version_40 is True:
82  return '''OkResult
83  affected_rows = %s
84  insert_id = %s
85  status = %s
86  message = %s
87  version_40 = %s
88 ''' % (self.affected_rows, self.insert_id, self.status, self.message,
89  self.version_40)
90  else:
91  return '''OkResult
92  affected_rows = %s
93  insert_id = %s
94  status = %s
95  warning_count = %s
96  message = %s
97  version_40 = %s
98 ''' % (self.affected_rows, self.insert_id, self.status, self.warning_count,
99  self.message, self.version_40)
100 
101 class TestOkResult(unittest.TestCase):
102 
103  def testDefaultInit(self):
104  result = OkResult()
105  self.assertEqual(result.affected_rows, 0)
106  self.assertEqual(result.insert_id, 0)
107  self.assertEqual(result.status, 0)
108  self.assertEqual(result.warning_count, 0)
109  self.assertEqual(result.message, '')
110  self.assertEqual(result.version_40, False)
111  result.__str__()
112 
113  def testDefaultInit40(self):
114  result = OkResult(version_40=True)
115  self.assertEqual(result.affected_rows, 0)
116  self.assertEqual(result.insert_id, 0)
117  self.assertEqual(result.status, 0)
118  self.assertEqual(result.message, '')
119  self.assertEqual(result.version_40, True)
120  result.__str__()
121 
122  def testKeywordInit(self):
123  result = OkResult(affected_rows=3, insert_id=5, status=2,
124  warning_count=7, message='test', version_40=False)
125  self.assertEqual(result.affected_rows, 3)
126  self.assertEqual(result.insert_id, 5)
127  self.assertEqual(result.status, 2)
128  self.assertEqual(result.warning_count, 7)
129  self.assertEqual(result.message, 'test')
130  self.assertEqual(result.version_40, False)
131 
132  def testUnpackInit(self):
133  data = struct.pack('BBB', 0, 3, 5)
134  data += struct.pack('<HH', 2, 7)
135  data += 'test'
136 
137  result = OkResult(data)
138  self.assertEqual(result.affected_rows, 3)
139  self.assertEqual(result.insert_id, 5)
140  self.assertEqual(result.status, 2)
141  self.assertEqual(result.warning_count, 7)
142  self.assertEqual(result.message, 'test')
143  self.assertEqual(result.version_40, False)
144  result.__str__()
145 
146  def testUnpackInit40(self):
147  data = struct.pack('BBB', 0, 3, 5)
148  data += struct.pack('<H', 2)
149  data += 'test'
150 
151  result = OkResult(data, version_40=True)
152  self.assertEqual(result.affected_rows, 3)
153  self.assertEqual(result.insert_id, 5)
154  self.assertEqual(result.status, 2)
155  self.assertEqual(result.message, 'test')
156  self.assertEqual(result.version_40, True)
157  result.__str__()
158 
159 class ErrorResult(object):
160  '''This class represents an error result packet sent from the server.'''
161 
162  def __init__(self, packed=None, error_code=0, sqlstate_marker='#',
163  sqlstate='XXXXX', message='', version_40=False):
164  if packed is None:
165  self.error_code = error_code
166  self.message = message
167  self.version_40 = version_40
168  if version_40 is False:
169  self.sqlstate_marker = sqlstate_marker
170  self.sqlstate = sqlstate
171  else:
172  self.version_40 = version_40
173  if ord(packed[0]) != 255:
174  raise BadFieldCount('Expected 255, received ' + str(ord(packed[0])))
175  data = struct.unpack('<H', packed[1:3])
176  self.error_code = data[0]
177  if version_40 is True:
178  self.message = packed[3:]
179  else:
180  self.sqlstate_marker = packed[3]
181  self.sqlstate = packed[4:9]
182  self.message = packed[9:]
183 
184  def __str__(self):
185  if self.version_40 is True:
186  return '''ErrorResult
187  error_code = %s
188  message = %s
189  version_40 = %s
190 ''' % (self.error_code, self.message, self.version_40)
191  else:
192  return '''ErrorResult
193  error_code = %s
194  sqlstate_marker = %s
195  sqlstate = %s
196  message = %s
197  version_40 = %s
198 ''' % (self.error_code, self.sqlstate_marker, self.sqlstate, self.message,
199  self.version_40)
200 
201 class TestErrorResult(unittest.TestCase):
202 
203  def testDefaultInit(self):
204  result = ErrorResult()
205  self.assertEqual(result.error_code, 0)
206  self.assertEqual(result.sqlstate_marker, '#')
207  self.assertEqual(result.sqlstate, 'XXXXX')
208  self.assertEqual(result.message, '')
209  self.assertEqual(result.version_40, False)
210  result.__str__()
211 
212  def testDefaultInit40(self):
213  result = ErrorResult(version_40=True)
214  self.assertEqual(result.error_code, 0)
215  self.assertEqual(result.message, '')
216  self.assertEqual(result.version_40, True)
217  result.__str__()
218 
219  def testKeywordInit(self):
220  result = ErrorResult(error_code=3, sqlstate_marker='@', sqlstate='ABCDE',
221  message='test', version_40=False)
222  self.assertEqual(result.error_code, 3)
223  self.assertEqual(result.sqlstate_marker, '@')
224  self.assertEqual(result.sqlstate, 'ABCDE')
225  self.assertEqual(result.message, 'test')
226  self.assertEqual(result.version_40, False)
227  result.__str__()
228 
229  def testUnpackInit(self):
230  data = chr(255)
231  data += struct.pack('<H', 1234)
232  data += '#ABCDE'
233  data += 'test'
234 
235  result = ErrorResult(data)
236  self.assertEqual(result.error_code, 1234)
237  self.assertEqual(result.sqlstate_marker, '#')
238  self.assertEqual(result.sqlstate, 'ABCDE')
239  self.assertEqual(result.message, 'test')
240  self.assertEqual(result.version_40, False)
241  result.__str__()
242 
243  def testUnpackInit40(self):
244  data = chr(255)
245  data += struct.pack('<H', 1234)
246  data += 'test'
247 
248  result = ErrorResult(data, version_40=True)
249  self.assertEqual(result.error_code, 1234)
250  self.assertEqual(result.message, 'test')
251  self.assertEqual(result.version_40, True)
252  result.__str__()
253 
254 class EofResult(object):
255  '''This class represents an EOF result packet sent from the server.'''
256 
257  def __init__(self, packed=None, warning_count=0, status=0, version_40=False):
258  if packed is None:
259  self.version_40 = version_40
260  if self.version_40 is False:
261  self.warning_count = warning_count
262  self.status = status
263  else:
264  self.version_40 = version_40
265  if ord(packed[0]) != 254:
266  raise BadFieldCount('Expected 254, received ' + str(ord(packed[0])))
267  if version_40 is False:
268  data = struct.unpack('<HH', packed[1:])
269  self.warning_count = data[0]
270  self.status = data[1]
271 
272  def __str__(self):
273  if self.version_40 is True:
274  return '''EofResult
275  version_40 = %s
276 ''' % self.version_40
277  else:
278  return '''EofResult
279  warning_count = %s
280  status = %s
281  version_40 = %s
282 ''' % (self.warning_count, self.status, self.version_40)
283 
284 class TestEofResult(unittest.TestCase):
285 
286  def testDefaultInit(self):
287  result = EofResult()
288  self.assertEqual(result.warning_count, 0)
289  self.assertEqual(result.status, 0)
290  self.assertEqual(result.version_40, False)
291  result.__str__()
292 
293  def testDefaultInit40(self):
294  result = EofResult(version_40=True)
295  self.assertEqual(result.version_40, True)
296  result.__str__()
297 
298  def testKeywordInit(self):
299  result = EofResult(warning_count=3, status=5, version_40=False)
300  self.assertEqual(result.warning_count, 3)
301  self.assertEqual(result.status, 5)
302  self.assertEqual(result.version_40, False)
303  result.__str__()
304 
305  def testUnpackInit(self):
306  data = chr(254)
307  data += struct.pack('<HH', 3, 5)
308 
309  result = EofResult(data)
310  self.assertEqual(result.warning_count, 3)
311  self.assertEqual(result.status, 5)
312  self.assertEqual(result.version_40, False)
313  result.__str__()
314 
315  def testUnpackInit40(self):
316  result = EofResult(chr(254), version_40=True)
317  self.assertEqual(result.version_40, True)
318  result.__str__()
319 
320 class CountResult(object):
321  '''This class represents an count result packet sent from the server.'''
322 
323  def __init__(self, packed=None, count=0):
324  if packed is None:
325  self.count = count
326  else:
327  self.count = ord(packed[0])
328  if self.count == 0 or self.count > 253:
329  raise BadFieldCount('Expected 1-253, received ' + str(ord(packed[0])))
330 
331  def __str__(self):
332  return '''CountResult
333  count = %s
334 ''' % self.count
335 
336 class TestCountResult(unittest.TestCase):
337 
338  def testDefaultInit(self):
339  result = CountResult()
340  self.assertEqual(result.count, 0)
341  result.__str__()
342 
343  def testKeywordInit(self):
344  result = CountResult(count=3)
345  self.assertEqual(result.count, 3)
346  result.__str__()
347 
348  def testUnpackInit(self):
349  result = CountResult("\x03")
350  self.assertEqual(result.count, 3)
351  result.__str__()
352 
353 def create_result(packed, version_40=False):
354  '''This function creates the appropriate result object instance depending on
355  first byte.'''
356  count = ord(packed[0])
357  if count == 0:
358  return OkResult(packed, version_40=version_40)
359  if count == 254:
360  return EofResult(packed, version_40=version_40)
361  if count == 255:
362  return ErrorResult(packed, version_40=version_40)
363  return CountResult(packed)
364 
365 if __name__ == '__main__':
366  unittest.main()