1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  """ 
 19  PySpark supports custom serializers for transferring data; this can improve 
 20  performance. 
 21   
 22  By default, PySpark uses L{PickleSerializer} to serialize objects using Python's 
 23  C{cPickle} serializer, which can serialize nearly any Python object. 
 24  Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be 
 25  faster. 
 26   
 27  The serializer is chosen when creating L{SparkContext}: 
 28   
 29  >>> from pyspark.context import SparkContext 
 30  >>> from pyspark.serializers import MarshalSerializer 
 31  >>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) 
 32  >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) 
 33  [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] 
 34  >>> sc.stop() 
 35   
 36  By default, PySpark serialize objects in batches; the batch size can be 
 37  controlled through SparkContext's C{batchSize} parameter 
 38  (the default size is 1024 objects): 
 39   
 40  >>> sc = SparkContext('local', 'test', batchSize=2) 
 41  >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) 
 42   
 43  Behind the scenes, this creates a JavaRDD with four partitions, each of 
 44  which contains two batches of two objects: 
 45   
 46  >>> rdd.glom().collect() 
 47  [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 
 48  >>> rdd._jrdd.count() 
 49  8L 
 50  >>> sc.stop() 
 51   
 52  A batch size of -1 uses an unlimited batch size, and a size of 1 disables 
 53  batching: 
 54   
 55  >>> sc = SparkContext('local', 'test', batchSize=1) 
 56  >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) 
 57  >>> rdd.glom().collect() 
 58  [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] 
 59  >>> rdd._jrdd.count() 
 60  16L 
 61  """ 
 62   
 63  import cPickle 
 64  from itertools import chain, izip, product 
 65  import marshal 
 66  import struct 
 67  import sys 
 68  import types 
 69  import collections 
 70  import zlib 
 71   
 72  from pyspark import cloudpickle 
 73   
 74   
 75  __all__ = ["PickleSerializer", "MarshalSerializer"] 
 76   
 77   
 79      END_OF_DATA_SECTION = -1 
 80      PYTHON_EXCEPTION_THROWN = -2 
 81      TIMING_DATA = -3 
  82   
 83   
 85   
 87          """ 
 88          Serialize an iterator of objects to the output stream. 
 89          """ 
 90          raise NotImplementedError 
  91   
 93          """ 
 94          Return an iterator of deserialized objects from the input stream. 
 95          """ 
 96          raise NotImplementedError 
  97   
 99          return self.load_stream(stream) 
 100   
101       
102       
103   
104       
105       
106   
108          return isinstance(other, self.__class__) 
 109   
111          return not self.__eq__(other) 
  112   
113   
115   
116      """ 
117      Serializer that writes objects as a stream of (length, data) pairs, 
118      where C{length} is a 32-bit integer and data is C{length} bytes. 
119      """ 
120   
122           
123           
124          self._only_write_strings = sys.version_info[0:2] <= (2, 6) 
 125   
127          for obj in iterator: 
128              self._write_with_length(obj, stream) 
 129   
131          while True: 
132              try: 
133                  yield self._read_with_length(stream) 
134              except EOFError: 
135                  return 
 136   
138          serialized = self.dumps(obj) 
139          write_int(len(serialized), stream) 
140          if self._only_write_strings: 
141              stream.write(str(serialized)) 
142          else: 
143              stream.write(serialized) 
 144   
146          length = read_int(stream) 
147          obj = stream.read(length) 
148          if obj == "": 
149              raise EOFError 
150          return self.loads(obj) 
 151   
153          """ 
154          Serialize an object into a byte array. 
155          When batching is used, this will be called with an array of objects. 
156          """ 
157          raise NotImplementedError 
 158   
160          """ 
161          Deserialize an object from a byte array. 
162          """ 
163          raise NotImplementedError 
  164   
165   
167   
168      """ 
169      Serializes a stream of objects in batches by calling its wrapped 
170      Serializer with streams of objects. 
171      """ 
172   
173      UNLIMITED_BATCH_SIZE = -1 
174   
175 -    def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): 
 176          self.serializer = serializer 
177          self.batchSize = batchSize 
 178   
180          if self.batchSize == self.UNLIMITED_BATCH_SIZE: 
181              yield list(iterator) 
182          else: 
183              items = [] 
184              count = 0 
185              for item in iterator: 
186                  items.append(item) 
187                  count += 1 
188                  if count == self.batchSize: 
189                      yield items 
190                      items = [] 
191                      count = 0 
192              if items: 
193                  yield items 
 194   
196          self.serializer.dump_stream(self._batched(iterator), stream) 
 197   
199          return chain.from_iterable(self._load_stream_without_unbatching(stream)) 
 200   
202          return self.serializer.load_stream(stream) 
 203   
205          return (isinstance(other, BatchedSerializer) and 
206                  other.serializer == self.serializer) 
 207   
209          return "BatchedSerializer<%s>" % str(self.serializer) 
  210   
211   
213   
214      """ 
215      Deserializes the JavaRDD cartesian() of two PythonRDDs. 
216      """ 
217   
219          self.key_ser = key_ser 
220          self.val_ser = val_ser 
 221   
223          key_stream = self.key_ser._load_stream_without_unbatching(stream) 
224          val_stream = self.val_ser._load_stream_without_unbatching(stream) 
225          key_is_batched = isinstance(self.key_ser, BatchedSerializer) 
226          val_is_batched = isinstance(self.val_ser, BatchedSerializer) 
227          for (keys, vals) in izip(key_stream, val_stream): 
228              keys = keys if key_is_batched else [keys] 
229              vals = vals if val_is_batched else [vals] 
230              yield (keys, vals) 
 231   
233          for (keys, vals) in self.prepare_keys_values(stream): 
234              for pair in product(keys, vals): 
235                  yield pair 
 236   
238          return (isinstance(other, CartesianDeserializer) and 
239                  self.key_ser == other.key_ser and self.val_ser == other.val_ser) 
 240   
242          return "CartesianDeserializer<%s, %s>" % \ 
243                 (str(self.key_ser), str(self.val_ser)) 
  244   
245   
247   
248      """ 
249      Deserializes the JavaRDD zip() of two PythonRDDs. 
250      """ 
251   
253          self.key_ser = key_ser 
254          self.val_ser = val_ser 
 255   
257          for (keys, vals) in self.prepare_keys_values(stream): 
258              if len(keys) != len(vals): 
259                  raise ValueError("Can not deserialize RDD with different number of items" 
260                                   " in pair: (%d, %d)" % (len(keys), len(vals))) 
261              for pair in izip(keys, vals): 
262                  yield pair 
 263   
265          return (isinstance(other, PairDeserializer) and 
266                  self.key_ser == other.key_ser and self.val_ser == other.val_ser) 
 267   
269          return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) 
  270   
271   
279   
280   
281   
282   
283  __cls = {} 
284   
285   
287      """ Restore an object of namedtuple""" 
288      k = (name, fields) 
289      cls = __cls.get(k) 
290      if cls is None: 
291          cls = collections.namedtuple(name, fields) 
292          __cls[k] = cls 
293      return cls(*value) 
 294   
295   
297      """ Make class generated by namedtuple picklable """ 
298      name = cls.__name__ 
299      fields = cls._fields 
300   
301      def __reduce__(self): 
302          return (_restore, (name, fields, tuple(self))) 
 303      cls.__reduce__ = __reduce__ 
304      return cls 
305   
306   
308      """ Hack namedtuple() to make it picklable """ 
309       
310      if hasattr(collections.namedtuple, "__hijack"): 
311          return 
312   
313      global _old_namedtuple   
314   
315      def _copy_func(f): 
316          return types.FunctionType(f.func_code, f.func_globals, f.func_name, 
317                                    f.func_defaults, f.func_closure) 
 318   
319      _old_namedtuple = _copy_func(collections.namedtuple) 
320   
321      def namedtuple(*args, **kwargs): 
322          cls = _old_namedtuple(*args, **kwargs) 
323          return _hack_namedtuple(cls) 
324   
325       
326      collections.namedtuple.func_globals["_old_namedtuple"] = _old_namedtuple 
327      collections.namedtuple.func_globals["_hack_namedtuple"] = _hack_namedtuple 
328      collections.namedtuple.func_code = namedtuple.func_code 
329      collections.namedtuple.__hijack = 1 
330   
331       
332       
333       
334      for n, o in sys.modules["__main__"].__dict__.iteritems(): 
335          if (type(o) is type and o.__base__ is tuple 
336                  and hasattr(o, "_fields") 
337                  and "__reduce__" not in o.__dict__): 
338              _hack_namedtuple(o)   
339   
340   
341  _hijack_namedtuple() 
342   
343   
345   
346      """ 
347      Serializes objects using Python's cPickle serializer: 
348   
349          http://docs.python.org/2/library/pickle.html 
350   
351      This serializer supports nearly any Python object, but may 
352      not be as fast as more specialized serializers. 
353      """ 
354   
356          return cPickle.dumps(obj, 2) 
 357   
358      loads = cPickle.loads 
 359   
360   
362   
364          return cloudpickle.dumps(obj, 2) 
  365   
366   
368   
369      """ 
370      Serializes objects using Python's Marshal serializer: 
371   
372          http://docs.python.org/2/library/marshal.html 
373   
374      This serializer is faster than PickleSerializer but supports fewer datatypes. 
375      """ 
376   
377      dumps = marshal.dumps 
378      loads = marshal.loads 
 379   
380   
382   
383      """ 
384      Choose marshal or cPickle as serialization protocol autumatically 
385      """ 
386   
388          FramedSerializer.__init__(self) 
389          self._type = None 
 390   
392          if self._type is not None: 
393              return 'P' + cPickle.dumps(obj, -1) 
394          try: 
395              return 'M' + marshal.dumps(obj) 
396          except Exception: 
397              self._type = 'P' 
398              return 'P' + cPickle.dumps(obj, -1) 
 399   
401          _type = obj[0] 
402          if _type == 'M': 
403              return marshal.loads(obj[1:]) 
404          elif _type == 'P': 
405              return cPickle.loads(obj[1:]) 
406          else: 
407              raise ValueError("invalid sevialization type: %s" % _type) 
  408   
409   
411      """ 
412      compress the serialized data 
413      """ 
414   
416          FramedSerializer.__init__(self) 
417          self.serializer = serializer 
 418   
420          return zlib.compress(self.serializer.dumps(obj), 1) 
 421   
423          return self.serializer.loads(zlib.decompress(obj)) 
  424   
425   
427   
428      """ 
429      Deserializes streams written by String.getBytes. 
430      """ 
431   
432 -    def loads(self, stream): 
 433          length = read_int(stream) 
434          return stream.read(length).decode('utf8') 
 435   
437          while True: 
438              try: 
439                  yield self.loads(stream) 
440              except struct.error: 
441                  return 
442              except EOFError: 
443                  return 
  444   
445   
447      length = stream.read(8) 
448      if length == "": 
449          raise EOFError 
450      return struct.unpack("!q", length)[0] 
 451   
452   
454      stream.write(struct.pack("!q", value)) 
 455   
456   
458      return struct.pack("!q", value) 
 459   
460   
462      length = stream.read(4) 
463      if length == "": 
464          raise EOFError 
465      return struct.unpack("!i", length)[0] 
 466   
467   
469      stream.write(struct.pack("!i", value)) 
 470   
471   
473      write_int(len(obj), stream) 
474      stream.write(obj) 
 475