1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  from pyspark.rdd import RDD, PipelinedRDD 
 19  from pyspark.serializers import BatchedSerializer, PickleSerializer 
 20   
 21  from py4j.protocol import Py4JError 
 22   
 23  __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] 
 24   
 25   
 26 -class SQLContext: 
  27      """Main entry point for SparkSQL functionality. 
 28   
 29      A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 
 30      tables, execute SQL over tables, cache tables, and read parquet files. 
 31      """ 
 32   
 33 -    def __init__(self, sparkContext, sqlContext = None): 
  34          """Create a new SQLContext. 
 35   
 36          @param sparkContext: The SparkContext to wrap. 
 37   
 38          >>> srdd = sqlCtx.inferSchema(rdd) 
 39          >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 40          Traceback (most recent call last): 
 41              ... 
 42          ValueError:... 
 43   
 44          >>> bad_rdd = sc.parallelize([1,2,3]) 
 45          >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 46          Traceback (most recent call last): 
 47              ... 
 48          ValueError:... 
 49   
 50          >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, 
 51          ... "boolean" : True}]) 
 52          >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, 
 53          ... x.boolean)) 
 54          >>> srdd.collect()[0] 
 55          (1, u'string', 1.0, 1, True) 
 56          """ 
 57          self._sc = sparkContext 
 58          self._jsc = self._sc._jsc 
 59          self._jvm = self._sc._jvm 
 60          self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap 
 61   
 62          if sqlContext: 
 63              self._scala_SQLContext = sqlContext 
  64   
 65      @property 
 66 -    def _ssql_ctx(self): 
  67          """Accessor for the JVM SparkSQL context. 
 68   
 69          Subclasses can override this property to provide their own 
 70          JVM Contexts. 
 71          """ 
 72          if not hasattr(self, '_scala_SQLContext'): 
 73              self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 
 74          return self._scala_SQLContext 
  75   
 76 -    def inferSchema(self, rdd): 
  77          """Infer and apply a schema to an RDD of L{dict}s. 
 78   
 79          We peek at the first row of the RDD to determine the fields names 
 80          and types, and then use that to extract all the dictionaries. Nested 
 81          collections are supported, which include array, dict, list, set, and 
 82          tuple. 
 83   
 84          >>> srdd = sqlCtx.inferSchema(rdd) 
 85          >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, 
 86          ...                    {"field1" : 3, "field2": "row3"}] 
 87          True 
 88   
 89          >>> from array import array 
 90          >>> srdd = sqlCtx.inferSchema(nestedRdd1) 
 91          >>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 
 92          ...                    {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}] 
 93          True 
 94   
 95          >>> srdd = sqlCtx.inferSchema(nestedRdd2) 
 96          >>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 
 97          ...                    {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}] 
 98          True 
 99          """ 
100          if (rdd.__class__ is SchemaRDD): 
101              raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) 
102          elif not isinstance(rdd.first(), dict): 
103              raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % 
104                               (SchemaRDD.__name__, rdd.first())) 
105   
106          jrdd = self._pythonToJavaMap(rdd._jrdd) 
107          srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) 
108          return SchemaRDD(srdd, self) 
 109   
110 -    def registerRDDAsTable(self, rdd, tableName): 
 111          """Registers the given RDD as a temporary table in the catalog. 
112   
113          Temporary tables exist only during the lifetime of this instance of 
114          SQLContext. 
115   
116          >>> srdd = sqlCtx.inferSchema(rdd) 
117          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
118          """ 
119          if (rdd.__class__ is SchemaRDD): 
120              jschema_rdd = rdd._jschema_rdd 
121              self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) 
122          else: 
123              raise ValueError("Can only register SchemaRDD as table") 
 124   
125 -    def parquetFile(self, path): 
 126          """Loads a Parquet file, returning the result as a L{SchemaRDD}. 
127   
128          >>> import tempfile, shutil 
129          >>> parquetFile = tempfile.mkdtemp() 
130          >>> shutil.rmtree(parquetFile) 
131          >>> srdd = sqlCtx.inferSchema(rdd) 
132          >>> srdd.saveAsParquetFile(parquetFile) 
133          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
134          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
135          True 
136          """ 
137          jschema_rdd = self._ssql_ctx.parquetFile(path) 
138          return SchemaRDD(jschema_rdd, self) 
 139   
140   
141 -    def jsonFile(self, path): 
 142          """Loads a text file storing one JSON object per line, 
143             returning the result as a L{SchemaRDD}. 
144             It goes through the entire dataset once to determine the schema. 
145   
146          >>> import tempfile, shutil 
147          >>> jsonFile = tempfile.mkdtemp() 
148          >>> shutil.rmtree(jsonFile) 
149          >>> ofn = open(jsonFile, 'w') 
150          >>> for json in jsonStrings: 
151          ...   print>>ofn, json 
152          >>> ofn.close() 
153          >>> srdd = sqlCtx.jsonFile(jsonFile) 
154          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
155          >>> srdd2 = sqlCtx.sql( 
156          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") 
157          >>> srdd2.collect() == [ 
158          ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, 
159          ... {"f1":2, "f2":None, "f3":{"field4":22,  "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, 
160          ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] 
161          True 
162          """ 
163          jschema_rdd = self._ssql_ctx.jsonFile(path) 
164          return SchemaRDD(jschema_rdd, self) 
 165   
166 -    def jsonRDD(self, rdd): 
 167          """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. 
168             It goes through the entire dataset once to determine the schema. 
169   
170          >>> srdd = sqlCtx.jsonRDD(json) 
171          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
172          >>> srdd2 = sqlCtx.sql( 
173          ...   "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") 
174          >>> srdd2.collect() == [ 
175          ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, 
176          ... {"f1":2, "f2":None, "f3":{"field4":22,  "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, 
177          ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] 
178          True 
179          """ 
180          def func(split, iterator): 
181              for x in iterator: 
182                  if not isinstance(x, basestring): 
183                      x = unicode(x) 
184                  yield x.encode("utf-8") 
 185          keyed = PipelinedRDD(rdd, func) 
186          keyed._bypass_serializer = True 
187          jrdd = keyed._jrdd.map(self._jvm.BytesToString()) 
188          jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) 
189          return SchemaRDD(jschema_rdd, self) 
 190   
191 -    def sql(self, sqlQuery): 
 192          """Return a L{SchemaRDD} representing the result of the given query. 
193   
194          >>> srdd = sqlCtx.inferSchema(rdd) 
195          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
196          >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 
197          >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, 
198          ...                     {"f1" : 3, "f2": "row3"}] 
199          True 
200          """ 
201          return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) 
 202   
203 -    def table(self, tableName): 
 204          """Returns the specified table as a L{SchemaRDD}. 
205   
206          >>> srdd = sqlCtx.inferSchema(rdd) 
207          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
208          >>> srdd2 = sqlCtx.table("table1") 
209          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
210          True 
211          """ 
212          return SchemaRDD(self._ssql_ctx.table(tableName), self) 
 213   
214 -    def cacheTable(self, tableName): 
 215          """Caches the specified table in-memory.""" 
216          self._ssql_ctx.cacheTable(tableName) 
 217   
218 -    def uncacheTable(self, tableName): 
 219          """Removes the specified table from the in-memory cache.""" 
220          self._ssql_ctx.uncacheTable(tableName) 
 221   
222   
223 -class HiveContext(SQLContext): 
 224      """A variant of Spark SQL that integrates with data stored in Hive. 
225   
226      Configuration for Hive is read from hive-site.xml on the classpath. 
227      It supports running both SQL and HiveQL commands. 
228      """ 
229   
230      @property 
231 -    def _ssql_ctx(self): 
 232          try: 
233              if not hasattr(self, '_scala_HiveContext'): 
234                  self._scala_HiveContext = self._get_hive_ctx() 
235              return self._scala_HiveContext 
236          except Py4JError as e: 
237              raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ 
238                              "sbt/sbt assembly" , e) 
 239   
240 -    def _get_hive_ctx(self): 
 241          return self._jvm.HiveContext(self._jsc.sc()) 
 242   
243 -    def hiveql(self, hqlQuery): 
 244          """ 
245          Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 
246          """ 
247          return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) 
 248   
249 -    def hql(self, hqlQuery): 
 250          """ 
251          Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 
252          """ 
253          return self.hiveql(hqlQuery) 
  254   
255   
256 -class LocalHiveContext(HiveContext): 
 257      """Starts up an instance of hive where metadata is stored locally. 
258   
259      An in-process metadata data is created with data stored in ./metadata. 
260      Warehouse data is stored in in ./warehouse. 
261   
262      >>> import os 
263      >>> hiveCtx = LocalHiveContext(sc) 
264      >>> try: 
265      ...     supress = hiveCtx.hql("DROP TABLE src") 
266      ... except Exception: 
267      ...     pass 
268      >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') 
269      >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 
270      >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) 
271      >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) 
272      >>> num = results.count() 
273      >>> reduce_sum = results.reduce(lambda x, y: x + y) 
274      >>> num 
275      500 
276      >>> reduce_sum 
277      130091 
278      """ 
279   
280 -    def _get_hive_ctx(self): 
 281          return self._jvm.LocalHiveContext(self._jsc.sc()) 
  282   
283   
284 -class TestHiveContext(HiveContext): 
 285   
286 -    def _get_hive_ctx(self): 
 287          return self._jvm.TestHiveContext(self._jsc.sc()) 
  288   
289   
290   
291   
292 -class Row(dict): 
 293      """A row in L{SchemaRDD}. 
294   
295      An extended L{dict} that takes a L{dict} in its constructor, and 
296      exposes those items as fields. 
297   
298      >>> r = Row({"hello" : "world", "foo" : "bar"}) 
299      >>> r.hello 
300      'world' 
301      >>> r.foo 
302      'bar' 
303      """ 
304   
306          d.update(self.__dict__) 
307          self.__dict__ = d 
308          dict.__init__(self, d) 
  309   
312      """An RDD of L{Row} objects that has an associated schema. 
313   
314      The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 
315      utilize the relational query api exposed by SparkSQL. 
316   
317      For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 
318      L{SchemaRDD} is not operated on directly, as it's underlying 
319      implementation is an RDD composed of Java objects. Instead it is 
320      converted to a PythonRDD in the JVM, on which Python operations can 
321      be done. 
322      """ 
323   
324 -    def __init__(self, jschema_rdd, sql_ctx): 
 325          self.sql_ctx = sql_ctx 
326          self._sc = sql_ctx._sc 
327          self._jschema_rdd = jschema_rdd 
328   
329          self.is_cached = False 
330          self.is_checkpointed = False 
331          self.ctx = self.sql_ctx._sc 
332          self._jrdd_deserializer = self.ctx.serializer 
 333   
334      @property 
336          """Lazy evaluation of PythonRDD object. 
337   
338          Only done when a user calls methods defined by the 
339          L{pyspark.rdd.RDD} super class (map, filter, etc.). 
340          """ 
341          if not hasattr(self, '_lazy_jrdd'): 
342              self._lazy_jrdd = self._toPython()._jrdd 
343          return self._lazy_jrdd 
 344   
345      @property 
347          return self._jrdd.id() 
 348   
350          """Save the contents as a Parquet file, preserving the schema. 
351   
352          Files that are written out using this method can be read back in as 
353          a SchemaRDD using the L{SQLContext.parquetFile} method. 
354   
355          >>> import tempfile, shutil 
356          >>> parquetFile = tempfile.mkdtemp() 
357          >>> shutil.rmtree(parquetFile) 
358          >>> srdd = sqlCtx.inferSchema(rdd) 
359          >>> srdd.saveAsParquetFile(parquetFile) 
360          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
361          >>> sorted(srdd2.collect()) == sorted(srdd.collect()) 
362          True 
363          """ 
364          self._jschema_rdd.saveAsParquetFile(path) 
 365   
367          """Registers this RDD as a temporary table using the given name. 
368   
369          The lifetime of this temporary table is tied to the L{SQLContext} 
370          that was used to create this SchemaRDD. 
371   
372          >>> srdd = sqlCtx.inferSchema(rdd) 
373          >>> srdd.registerAsTable("test") 
374          >>> srdd2 = sqlCtx.sql("select * from test") 
375          >>> sorted(srdd.collect()) == sorted(srdd2.collect()) 
376          True 
377          """ 
378          self._jschema_rdd.registerAsTable(name) 
 379   
380 -    def insertInto(self, tableName, overwrite = False): 
 381          """Inserts the contents of this SchemaRDD into the specified table. 
382   
383          Optionally overwriting any existing data. 
384          """ 
385          self._jschema_rdd.insertInto(tableName, overwrite) 
 386   
388          """Creates a new table with the contents of this SchemaRDD.""" 
389          self._jschema_rdd.saveAsTable(tableName) 
 390   
392          """Returns the output schema in the tree format.""" 
393          return self._jschema_rdd.schemaString() 
 394   
396          """Prints out the schema in the tree format.""" 
397          print self.schemaString() 
 398   
400          """Return the number of elements in this RDD. 
401   
402          Unlike the base RDD implementation of count, this implementation 
403          leverages the query optimizer to compute the count on the SchemaRDD, 
404          which supports features such as filter pushdown. 
405   
406          >>> srdd = sqlCtx.inferSchema(rdd) 
407          >>> srdd.count() 
408          3L 
409          >>> srdd.count() == srdd.map(lambda x: x).count() 
410          True 
411          """ 
412          return self._jschema_rdd.count() 
 413   
424   
425       
426       
428          self.is_cached = True 
429          self._jschema_rdd.cache() 
430          return self 
 431   
433          self.is_cached = True 
434          javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 
435          self._jschema_rdd.persist(javaStorageLevel) 
436          return self 
 437   
439          self.is_cached = False 
440          self._jschema_rdd.unpersist() 
441          return self 
 442   
444          self.is_checkpointed = True 
445          self._jschema_rdd.checkpoint() 
 446   
449   
451          checkpointFile = self._jschema_rdd.getCheckpointFile() 
452          if checkpointFile.isDefined(): 
453              return checkpointFile.get() 
454          else: 
455              return None 
 456   
457 -    def coalesce(self, numPartitions, shuffle=False): 
 460   
464   
466          if (other.__class__ is SchemaRDD): 
467              rdd = self._jschema_rdd.intersection(other._jschema_rdd) 
468              return SchemaRDD(rdd, self.sql_ctx) 
469          else: 
470              raise ValueError("Can only intersect with another SchemaRDD") 
 471   
475   
476 -    def subtract(self, other, numPartitions=None): 
 477          if (other.__class__ is SchemaRDD): 
478              if numPartitions is None: 
479                  rdd = self._jschema_rdd.subtract(other._jschema_rdd) 
480              else: 
481                  rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) 
482              return SchemaRDD(rdd, self.sql_ctx) 
483          else: 
484              raise ValueError("Can only subtract another SchemaRDD") 
  485   
487      import doctest 
488      from array import array 
489      from pyspark.context import SparkContext 
490      globs = globals().copy() 
491       
492       
493      sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 
494      globs['sc'] = sc 
495      globs['sqlCtx'] = SQLContext(sc) 
496      globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, 
497          {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) 
498      jsonStrings = ['{"field1": 1, "field2": "row1", "field3":{"field4":11}}', 
499         '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]}, "field6":[{"field7": "row2"}]}', 
500         '{"field1" : null, "field2": "row3", "field3":{"field4":33, "field5": []}}'] 
501      globs['jsonStrings'] = jsonStrings 
502      globs['json'] = sc.parallelize(jsonStrings) 
503      globs['nestedRdd1'] = sc.parallelize([ 
504          {"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}}, 
505          {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]) 
506      globs['nestedRdd2'] = sc.parallelize([ 
507          {"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)}, 
508          {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]) 
509      (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) 
510      globs['sc'].stop() 
511      if failure_count: 
512          exit(-1) 
 513   
514   
515  if __name__ == "__main__": 
516      _test() 
517