1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  """ 
 19  >>> from pyspark.context import SparkContext 
 20  >>> sc = SparkContext('local', 'test') 
 21  >>> a = sc.accumulator(1) 
 22  >>> a.value 
 23  1 
 24  >>> a.value = 2 
 25  >>> a.value 
 26  2 
 27  >>> a += 5 
 28  >>> a.value 
 29  7 
 30   
 31  >>> sc.accumulator(1.0).value 
 32  1.0 
 33   
 34  >>> sc.accumulator(1j).value 
 35  1j 
 36   
 37  >>> rdd = sc.parallelize([1,2,3]) 
 38  >>> def f(x): 
 39  ...     global a 
 40  ...     a += x 
 41  >>> rdd.foreach(f) 
 42  >>> a.value 
 43  13 
 44   
 45  >>> b = sc.accumulator(0) 
 46  >>> def g(x): 
 47  ...     b.add(x) 
 48  >>> rdd.foreach(g) 
 49  >>> b.value 
 50  6 
 51   
 52  >>> from pyspark.accumulators import AccumulatorParam 
 53  >>> class VectorAccumulatorParam(AccumulatorParam): 
 54  ...     def zero(self, value): 
 55  ...         return [0.0] * len(value) 
 56  ...     def addInPlace(self, val1, val2): 
 57  ...         for i in xrange(len(val1)): 
 58  ...              val1[i] += val2[i] 
 59  ...         return val1 
 60  >>> va = sc.accumulator([1.0, 2.0, 3.0], VectorAccumulatorParam()) 
 61  >>> va.value 
 62  [1.0, 2.0, 3.0] 
 63  >>> def g(x): 
 64  ...     global va 
 65  ...     va += [x] * 3 
 66  >>> rdd.foreach(g) 
 67  >>> va.value 
 68  [7.0, 8.0, 9.0] 
 69   
 70  >>> rdd.map(lambda x: a.value).collect() # doctest: +IGNORE_EXCEPTION_DETAIL 
 71  Traceback (most recent call last): 
 72      ... 
 73  Py4JJavaError:... 
 74   
 75  >>> def h(x): 
 76  ...     global a 
 77  ...     a.value = 7 
 78  >>> rdd.foreach(h) # doctest: +IGNORE_EXCEPTION_DETAIL 
 79  Traceback (most recent call last): 
 80      ... 
 81  Py4JJavaError:... 
 82   
 83  >>> sc.accumulator([1.0, 2.0, 3.0]) # doctest: +IGNORE_EXCEPTION_DETAIL 
 84  Traceback (most recent call last): 
 85      ... 
 86  Exception:... 
 87  """ 
 88   
 89  import select 
 90  import struct 
 91  import SocketServer 
 92  import threading 
 93  from pyspark.cloudpickle import CloudPickler 
 94  from pyspark.serializers import read_int, PickleSerializer 
 95   
 96   
 97  pickleSer = PickleSerializer() 
 98   
 99   
100   
101  _accumulatorRegistry = {} 
110   
113   
114      """ 
115      A shared variable that can be accumulated, i.e., has a commutative and associative "add" 
116      operation. Worker tasks on a Spark cluster can add values to an Accumulator with the C{+=} 
117      operator, but only the driver program is allowed to access its value, using C{value}. 
118      Updates from the workers get propagated automatically to the driver program. 
119   
120      While C{SparkContext} supports accumulators for primitive data types like C{int} and 
121      C{float}, users can also define accumulators for custom types by providing a custom 
122      L{AccumulatorParam} object. Refer to the doctest of this module for an example. 
123      """ 
124   
125 -    def __init__(self, aid, value, accum_param): 
 133   
135          """Custom serialization; saves the zero value from our AccumulatorParam""" 
136          param = self.accum_param 
137          return (_deserialize_accumulator, (self.aid, param.zero(self._value), param)) 
 138   
139      @property 
141          """Get the accumulator's value; only usable in driver program""" 
142          if self._deserialized: 
143              raise Exception("Accumulator.value cannot be accessed inside tasks") 
144          return self._value 
 145   
146      @value.setter 
148          """Sets the accumulator's value; only usable in driver program""" 
149          if self._deserialized: 
150              raise Exception("Accumulator.value cannot be accessed inside tasks") 
151          self._value = value 
 152   
153 -    def add(self, term): 
 154          """Adds a term to this accumulator's value""" 
155          self._value = self.accum_param.addInPlace(self._value, term) 
 156   
158          """The += operator; adds a term to this accumulator's value""" 
159          self.add(term) 
160          return self 
 161   
163          return str(self._value) 
 164   
166          return "Accumulator<id=%i, value=%s>" % (self.aid, self._value) 
  167   
170   
171      """ 
172      Helper object that defines how to accumulate values of a given type. 
173      """ 
174   
175 -    def zero(self, value): 
 176          """ 
177          Provide a "zero value" for the type, compatible in dimensions with the 
178          provided C{value} (e.g., a zero vector) 
179          """ 
180          raise NotImplementedError 
 181   
183          """ 
184          Add two values of the accumulator's data type, returning a new value; 
185          for efficiency, can also update C{value1} in place and return it. 
186          """ 
187          raise NotImplementedError 
  188   
191   
192      """ 
193      An AccumulatorParam that uses the + operators to add values. Designed for simple types 
194      such as integers, floats, and lists. Requires the zero value for the underlying type 
195      as a parameter. 
196      """ 
197   
199          self.zero_value = zero_value 
 200   
201 -    def zero(self, value): 
 202          return self.zero_value 
 203   
205          value1 += value2 
206          return value1 
  207   
208   
209   
210  INT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0) 
211  FLOAT_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0) 
212  COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) 
216   
217      """ 
218      This handler will keep polling updates from the same socket until the 
219      server is shutdown. 
220      """ 
221   
 234   
237   
238      """ 
239      A simple TCP server that intercepts shutdown() in order to interrupt 
240      our continuous polling on the handler. 
241      """ 
242      server_shutdown = False 
243   
 247   
250      """Start a TCP server to receive accumulator updates in a daemon thread, and returns it""" 
251      server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler) 
252      thread = threading.Thread(target=server.serve_forever) 
253      thread.daemon = True 
254      thread.start() 
255      return server 
 256