1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  import os 
 19  import shutil 
 20  import sys 
 21  from threading import Lock 
 22  from tempfile import NamedTemporaryFile 
 23  from collections import namedtuple 
 24   
 25  from pyspark import accumulators 
 26  from pyspark.accumulators import Accumulator 
 27  from pyspark.broadcast import Broadcast 
 28  from pyspark.conf import SparkConf 
 29  from pyspark.files import SparkFiles 
 30  from pyspark.java_gateway import launch_gateway 
 31  from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ 
 32          PairDeserializer 
 33  from pyspark.storagelevel import StorageLevel 
 34  from pyspark import rdd 
 35  from pyspark.rdd import RDD 
 36   
 37  from py4j.java_collections import ListConverter 
 38   
 39   
 40 -class SparkContext(object): 
  41      """ 
 42      Main entry point for Spark functionality. A SparkContext represents the 
 43      connection to a Spark cluster, and can be used to create L{RDD}s and 
 44      broadcast variables on that cluster. 
 45      """ 
 46   
 47      _gateway = None 
 48      _jvm = None 
 49      _writeToFile = None 
 50      _next_accum_id = 0 
 51      _active_spark_context = None 
 52      _lock = Lock() 
 53      _python_includes = None  
 54   
 55   
 56 -    def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, 
 57          environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, 
 58          gateway=None): 
  59          """ 
 60          Create a new SparkContext. At least the master and app name should be set, 
 61          either through the named parameters here or through C{conf}. 
 62   
 63          @param master: Cluster URL to connect to 
 64                 (e.g. mesos://host:port, spark://host:port, local[4]). 
 65          @param appName: A name for your job, to display on the cluster web UI. 
 66          @param sparkHome: Location where Spark is installed on cluster nodes. 
 67          @param pyFiles: Collection of .zip or .py files to send to the cluster 
 68                 and add to PYTHONPATH.  These can be paths on the local file 
 69                 system or HDFS, HTTP, HTTPS, or FTP URLs. 
 70          @param environment: A dictionary of environment variables to set on 
 71                 worker nodes. 
 72          @param batchSize: The number of Python objects represented as a single 
 73                 Java object.  Set 1 to disable batching or -1 to use an 
 74                 unlimited batch size. 
 75          @param serializer: The serializer for RDDs. 
 76          @param conf: A L{SparkConf} object setting Spark properties. 
 77          @param gateway: Use an existing gateway and JVM, otherwise a new JVM 
 78                 will be instatiated. 
 79   
 80   
 81          >>> from pyspark.context import SparkContext 
 82          >>> sc = SparkContext('local', 'test') 
 83   
 84          >>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL 
 85          Traceback (most recent call last): 
 86              ... 
 87          ValueError:... 
 88          """ 
 89          if rdd._extract_concise_traceback() is not None: 
 90              self._callsite = rdd._extract_concise_traceback() 
 91          else: 
 92              tempNamedTuple = namedtuple("Callsite", "function file linenum") 
 93              self._callsite = tempNamedTuple(function=None, file=None, linenum=None) 
 94          SparkContext._ensure_initialized(self, gateway=gateway) 
 95   
 96          self.environment = environment or {} 
 97          self._conf = conf or SparkConf(_jvm=self._jvm) 
 98          self._batchSize = batchSize   
 99          self._unbatched_serializer = serializer 
100          if batchSize == 1: 
101              self.serializer = self._unbatched_serializer 
102          else: 
103              self.serializer = BatchedSerializer(self._unbatched_serializer, 
104                                                  batchSize) 
105   
106           
107          if master: 
108              self._conf.setMaster(master) 
109          if appName: 
110              self._conf.setAppName(appName) 
111          if sparkHome: 
112              self._conf.setSparkHome(sparkHome) 
113          if environment: 
114              for key, value in environment.iteritems(): 
115                  self._conf.setExecutorEnv(key, value) 
116   
117           
118          if not self._conf.contains("spark.master"): 
119              raise Exception("A master URL must be set in your configuration") 
120          if not self._conf.contains("spark.app.name"): 
121              raise Exception("An application name must be set in your configuration") 
122   
123           
124           
125          self.master = self._conf.get("spark.master") 
126          self.appName = self._conf.get("spark.app.name") 
127          self.sparkHome = self._conf.get("spark.home", None) 
128          for (k, v) in self._conf.getAll(): 
129              if k.startswith("spark.executorEnv."): 
130                  varName = k[len("spark.executorEnv."):] 
131                  self.environment[varName] = v 
132   
133           
134          self._jsc = self._initialize_context(self._conf._jconf) 
135   
136           
137           
138          self._accumulatorServer = accumulators._start_update_server() 
139          (host, port) = self._accumulatorServer.server_address 
140          self._javaAccumulator = self._jsc.accumulator( 
141                  self._jvm.java.util.ArrayList(), 
142                  self._jvm.PythonAccumulatorParam(host, port)) 
143   
144          self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python') 
145   
146           
147           
148           
149           
150          self._pickled_broadcast_vars = set() 
151   
152          SparkFiles._sc = self 
153          root_dir = SparkFiles.getRootDirectory() 
154          sys.path.append(root_dir) 
155   
156           
157          self._python_includes = list() 
158          for path in (pyFiles or []): 
159              self.addPyFile(path) 
160   
161           
162           
163          for path in self._conf.get("spark.submit.pyFiles", "").split(","): 
164              if path != "": 
165                  (dirname, filename) = os.path.split(path) 
166                  self._python_includes.append(filename) 
167                  sys.path.append(path) 
168                  if not dirname in sys.path: 
169                      sys.path.append(dirname) 
170   
171           
172          local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) 
173          self._temp_dir = \ 
174              self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() 
 175   
176 -    def _initialize_context(self, jconf): 
 177          """ 
178          Initialize SparkContext in function to allow subclass specific initialization 
179          """ 
180          return self._jvm.JavaSparkContext(jconf) 
 181   
182      @classmethod 
183 -    def _ensure_initialized(cls, instance=None, gateway=None): 
 184          """ 
185          Checks whether a SparkContext is initialized or not. 
186          Throws error if a SparkContext is already running. 
187          """ 
188          with SparkContext._lock: 
189              if not SparkContext._gateway: 
190                  SparkContext._gateway = gateway or launch_gateway() 
191                  SparkContext._jvm = SparkContext._gateway.jvm 
192                  SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile 
193   
194              if instance: 
195                  if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: 
196                      currentMaster = SparkContext._active_spark_context.master 
197                      currentAppName = SparkContext._active_spark_context.appName 
198                      callsite = SparkContext._active_spark_context._callsite 
199   
200                       
201                      raise ValueError("Cannot run multiple SparkContexts at once; existing SparkContext(app=%s, master=%s)" \ 
202                          " created by %s at %s:%s " \ 
203                          % (currentAppName, currentMaster, callsite.function, callsite.file, callsite.linenum)) 
204                  else: 
205                      SparkContext._active_spark_context = instance 
 206   
207      @classmethod 
208 -    def setSystemProperty(cls, key, value): 
 209          """ 
210          Set a Java system property, such as spark.executor.memory. This must 
211          must be invoked before instantiating SparkContext. 
212          """ 
213          SparkContext._ensure_initialized() 
214          SparkContext._jvm.java.lang.System.setProperty(key, value) 
 215   
216      @property 
218          """ 
219          Default level of parallelism to use when not given by user (e.g. for 
220          reduce tasks) 
221          """ 
222          return self._jsc.sc().defaultParallelism() 
 223   
224      @property 
226          """ 
227          Default min number of partitions for Hadoop RDDs when not given by user 
228          """ 
229          return self._jsc.sc().defaultMinPartitions() 
 230   
233   
235          """ 
236          Shut down the SparkContext. 
237          """ 
238          if self._jsc: 
239              self._jsc.stop() 
240              self._jsc = None 
241          if self._accumulatorServer: 
242              self._accumulatorServer.shutdown() 
243              self._accumulatorServer = None 
244          with SparkContext._lock: 
245              SparkContext._active_spark_context = None 
 246   
247 -    def parallelize(self, c, numSlices=None): 
 248          """ 
249          Distribute a local Python collection to form an RDD. 
250   
251          >>> sc.parallelize(range(5), 5).glom().collect() 
252          [[0], [1], [2], [3], [4]] 
253          """ 
254          numSlices = numSlices or self.defaultParallelism 
255           
256           
257           
258          tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) 
259           
260          if "__len__" not in dir(c): 
261              c = list(c)     
262          batchSize = min(len(c) // numSlices, self._batchSize) 
263          if batchSize > 1: 
264              serializer = BatchedSerializer(self._unbatched_serializer, 
265                                             batchSize) 
266          else: 
267              serializer = self._unbatched_serializer 
268          serializer.dump_stream(c, tempFile) 
269          tempFile.close() 
270          readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile 
271          jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) 
272          return RDD(jrdd, self, serializer) 
 273   
274 -    def textFile(self, name, minPartitions=None): 
 275          """ 
276          Read a text file from HDFS, a local file system (available on all 
277          nodes), or any Hadoop-supported file system URI, and return it as an 
278          RDD of Strings. 
279           
280          >>> path = os.path.join(tempdir, "sample-text.txt") 
281          >>> with open(path, "w") as testFile: 
282          ...    testFile.write("Hello world!") 
283          >>> textFile = sc.textFile(path) 
284          >>> textFile.collect() 
285          [u'Hello world!'] 
286          """ 
287          minPartitions = minPartitions or min(self.defaultParallelism, 2) 
288          return RDD(self._jsc.textFile(name, minPartitions), self, 
289                     UTF8Deserializer()) 
 290   
291 -    def wholeTextFiles(self, path, minPartitions=None): 
 292          """ 
293          Read a directory of text files from HDFS, a local file system 
294          (available on all nodes), or any  Hadoop-supported file system 
295          URI. Each file is read as a single record and returned in a 
296          key-value pair, where the key is the path of each file, the 
297          value is the content of each file. 
298   
299          For example, if you have the following files:: 
300   
301            hdfs://a-hdfs-path/part-00000 
302            hdfs://a-hdfs-path/part-00001 
303            ... 
304            hdfs://a-hdfs-path/part-nnnnn 
305   
306          Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")}, 
307          then C{rdd} contains:: 
308   
309            (a-hdfs-path/part-00000, its content) 
310            (a-hdfs-path/part-00001, its content) 
311            ... 
312            (a-hdfs-path/part-nnnnn, its content) 
313   
314          NOTE: Small files are preferred, as each file will be loaded 
315          fully in memory. 
316   
317          >>> dirPath = os.path.join(tempdir, "files") 
318          >>> os.mkdir(dirPath) 
319          >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: 
320          ...    file1.write("1") 
321          >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: 
322          ...    file2.write("2") 
323          >>> textFiles = sc.wholeTextFiles(dirPath) 
324          >>> sorted(textFiles.collect()) 
325          [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] 
326          """ 
327          minPartitions = minPartitions or self.defaultMinPartitions 
328          return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, 
329                     PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) 
 330   
331 -    def _checkpointFile(self, name, input_deserializer): 
 332          jrdd = self._jsc.checkpointFile(name) 
333          return RDD(jrdd, self, input_deserializer) 
 334   
335 -    def union(self, rdds): 
 336          """ 
337          Build the union of a list of RDDs. 
338   
339          This supports unions() of RDDs with different serialized formats, 
340          although this forces them to be reserialized using the default 
341          serializer: 
342   
343          >>> path = os.path.join(tempdir, "union-text.txt") 
344          >>> with open(path, "w") as testFile: 
345          ...    testFile.write("Hello") 
346          >>> textFile = sc.textFile(path) 
347          >>> textFile.collect() 
348          [u'Hello'] 
349          >>> parallelized = sc.parallelize(["World!"]) 
350          >>> sorted(sc.union([textFile, parallelized]).collect()) 
351          [u'Hello', 'World!'] 
352          """ 
353          first_jrdd_deserializer = rdds[0]._jrdd_deserializer 
354          if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): 
355              rdds = [x._reserialize() for x in rdds] 
356          first = rdds[0]._jrdd 
357          rest = [x._jrdd for x in rdds[1:]] 
358          rest = ListConverter().convert(rest, self._gateway._gateway_client) 
359          return RDD(self._jsc.union(first, rest), self, 
360                     rdds[0]._jrdd_deserializer) 
 361   
362 -    def broadcast(self, value): 
 363          """ 
364          Broadcast a read-only variable to the cluster, returning a 
365          L{Broadcast<pyspark.broadcast.Broadcast>} 
366          object for reading it in distributed functions. The variable will be 
367          sent to each cluster only once. 
368          """ 
369          pickleSer = PickleSerializer() 
370          pickled = pickleSer.dumps(value) 
371          jbroadcast = self._jsc.broadcast(bytearray(pickled)) 
372          return Broadcast(jbroadcast.id(), value, jbroadcast, 
373                           self._pickled_broadcast_vars) 
 374   
375 -    def accumulator(self, value, accum_param=None): 
 376          """ 
377          Create an L{Accumulator} with the given initial value, using a given 
378          L{AccumulatorParam} helper object to define how to add values of the 
379          data type if provided. Default AccumulatorParams are used for integers 
380          and floating-point numbers if you do not provide one. For other types, 
381          a custom AccumulatorParam can be used. 
382          """ 
383          if accum_param is None: 
384              if isinstance(value, int): 
385                  accum_param = accumulators.INT_ACCUMULATOR_PARAM 
386              elif isinstance(value, float): 
387                  accum_param = accumulators.FLOAT_ACCUMULATOR_PARAM 
388              elif isinstance(value, complex): 
389                  accum_param = accumulators.COMPLEX_ACCUMULATOR_PARAM 
390              else: 
391                  raise Exception("No default accumulator param for type %s" % type(value)) 
392          SparkContext._next_accum_id += 1 
393          return Accumulator(SparkContext._next_accum_id - 1, value, accum_param) 
 394   
395 -    def addFile(self, path): 
 396          """ 
397          Add a file to be downloaded with this Spark job on every node. 
398          The C{path} passed can be either a local file, a file in HDFS 
399          (or other Hadoop-supported filesystems), or an HTTP, HTTPS or 
400          FTP URI. 
401   
402          To access the file in Spark jobs, use 
403          L{SparkFiles.get(path)<pyspark.files.SparkFiles.get>} to find its 
404          download location. 
405   
406          >>> from pyspark import SparkFiles 
407          >>> path = os.path.join(tempdir, "test.txt") 
408          >>> with open(path, "w") as testFile: 
409          ...    testFile.write("100") 
410          >>> sc.addFile(path) 
411          >>> def func(iterator): 
412          ...    with open(SparkFiles.get("test.txt")) as testFile: 
413          ...        fileVal = int(testFile.readline()) 
414          ...        return [x * 100 for x in iterator] 
415          >>> sc.parallelize([1, 2, 3, 4]).mapPartitions(func).collect() 
416          [100, 200, 300, 400] 
417          """ 
418          self._jsc.sc().addFile(path) 
 419   
420 -    def clearFiles(self): 
 421          """ 
422          Clear the job's list of files added by L{addFile} or L{addPyFile} so 
423          that they do not get downloaded to any new nodes. 
424          """ 
425           
426          self._jsc.sc().clearFiles() 
 427   
428 -    def addPyFile(self, path): 
 429          """ 
430          Add a .py or .zip dependency for all tasks to be executed on this 
431          SparkContext in the future.  The C{path} passed can be either a local 
432          file, a file in HDFS (or other Hadoop-supported filesystems), or an 
433          HTTP, HTTPS or FTP URI. 
434          """ 
435          self.addFile(path) 
436          (dirname, filename) = os.path.split(path)  
437   
438          if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): 
439              self._python_includes.append(filename) 
440              sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))  
 441   
442 -    def setCheckpointDir(self, dirName): 
 443          """ 
444          Set the directory under which RDDs are going to be checkpointed. The 
445          directory must be a HDFS path if running on a cluster. 
446          """ 
447          self._jsc.sc().setCheckpointDir(dirName) 
 448   
449 -    def _getJavaStorageLevel(self, storageLevel): 
 450          """ 
451          Returns a Java StorageLevel based on a pyspark.StorageLevel. 
452          """ 
453          if not isinstance(storageLevel, StorageLevel): 
454              raise Exception("storageLevel must be of type pyspark.StorageLevel") 
455   
456          newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel 
457          return newStorageLevel(storageLevel.useDisk, 
458                                 storageLevel.useMemory, 
459                                 storageLevel.useOffHeap, 
460                                 storageLevel.deserialized, 
461                                 storageLevel.replication) 
 462   
463 -    def setJobGroup(self, groupId, description, interruptOnCancel=False): 
 464          """ 
465          Assigns a group ID to all the jobs started by this thread until the group ID is set to a 
466          different value or cleared. 
467   
468          Often, a unit of execution in an application consists of multiple Spark actions or jobs. 
469          Application programmers can use this method to group all those jobs together and give a 
470          group description. Once set, the Spark web UI will associate such jobs with this group. 
471   
472          The application can use L{SparkContext.cancelJobGroup} to cancel all 
473          running jobs in this group. 
474   
475          >>> import thread, threading 
476          >>> from time import sleep 
477          >>> result = "Not Set" 
478          >>> lock = threading.Lock() 
479          >>> def map_func(x): 
480          ...     sleep(100) 
481          ...     raise Exception("Task should have been cancelled") 
482          >>> def start_job(x): 
483          ...     global result 
484          ...     try: 
485          ...         sc.setJobGroup("job_to_cancel", "some description") 
486          ...         result = sc.parallelize(range(x)).map(map_func).collect() 
487          ...     except Exception as e: 
488          ...         result = "Cancelled" 
489          ...     lock.release() 
490          >>> def stop_job(): 
491          ...     sleep(5) 
492          ...     sc.cancelJobGroup("job_to_cancel") 
493          >>> supress = lock.acquire() 
494          >>> supress = thread.start_new_thread(start_job, (10,)) 
495          >>> supress = thread.start_new_thread(stop_job, tuple()) 
496          >>> supress = lock.acquire() 
497          >>> print result 
498          Cancelled 
499   
500          If interruptOnCancel is set to true for the job group, then job cancellation will result 
501          in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure 
502          that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, 
503          where HDFS may respond to Thread.interrupt() by marking nodes as dead. 
504          """ 
505          self._jsc.setJobGroup(groupId, description, interruptOnCancel) 
 506   
507 -    def setLocalProperty(self, key, value): 
 508          """ 
509          Set a local property that affects jobs submitted from this thread, such as the 
510          Spark fair scheduler pool. 
511          """ 
512          self._jsc.setLocalProperty(key, value) 
 513   
514 -    def getLocalProperty(self, key): 
 515          """ 
516          Get a local property set in this thread, or null if it is missing. See 
517          L{setLocalProperty} 
518          """ 
519          return self._jsc.getLocalProperty(key) 
 520   
521 -    def sparkUser(self): 
 522          """ 
523          Get SPARK_USER for user who is running SparkContext. 
524          """ 
525          return self._jsc.sc().sparkUser() 
 526   
527 -    def cancelJobGroup(self, groupId): 
 528          """ 
529          Cancel active jobs for the specified group. See L{SparkContext.setJobGroup} 
530          for more information. 
531          """ 
532          self._jsc.sc().cancelJobGroup(groupId) 
 533   
534 -    def cancelAllJobs(self): 
 535          """ 
536          Cancel all jobs that have been scheduled or are running. 
537          """ 
538          self._jsc.sc().cancelAllJobs() 
  539   
541      import atexit 
542      import doctest 
543      import tempfile 
544      globs = globals().copy() 
545      globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) 
546      globs['tempdir'] = tempfile.mkdtemp() 
547      atexit.register(lambda: shutil.rmtree(globs['tempdir'])) 
548      (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) 
549      globs['sc'].stop() 
550      if failure_count: 
551          exit(-1) 
 552   
553   
554  if __name__ == "__main__": 
555      _test() 
556