1   
  2   
  3   
  4   
  5   
  6   
  7   
  8   
  9   
 10   
 11   
 12   
 13   
 14   
 15   
 16   
 17   
 18  from pyspark.rdd import RDD 
 19   
 20  from py4j.protocol import Py4JError 
 21   
 22  __all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] 
 23   
 24   
 25 -class SQLContext: 
  26      """Main entry point for SparkSQL functionality. 
 27   
 28      A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as 
 29      tables, execute SQL over tables, cache tables, and read parquet files. 
 30      """ 
 31   
 32 -    def __init__(self, sparkContext, sqlContext = None): 
  33          """Create a new SQLContext. 
 34   
 35          @param sparkContext: The SparkContext to wrap. 
 36   
 37          >>> srdd = sqlCtx.inferSchema(rdd) 
 38          >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 39          Traceback (most recent call last): 
 40              ... 
 41          ValueError:... 
 42   
 43          >>> bad_rdd = sc.parallelize([1,2,3]) 
 44          >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL 
 45          Traceback (most recent call last): 
 46              ... 
 47          ValueError:... 
 48   
 49          >>> allTypes = sc.parallelize([{"int" : 1, "string" : "string", "double" : 1.0, "long": 1L, 
 50          ... "boolean" : True}]) 
 51          >>> srdd = sqlCtx.inferSchema(allTypes).map(lambda x: (x.int, x.string, x.double, x.long, 
 52          ... x.boolean)) 
 53          >>> srdd.collect()[0] 
 54          (1, u'string', 1.0, 1, True) 
 55          """ 
 56          self._sc = sparkContext 
 57          self._jsc = self._sc._jsc 
 58          self._jvm = self._sc._jvm 
 59          self._pythonToJavaMap = self._jvm.PythonRDD.pythonToJavaMap 
 60   
 61          if sqlContext: 
 62              self._scala_SQLContext = sqlContext 
  63   
 64      @property 
 65 -    def _ssql_ctx(self): 
  66          """Accessor for the JVM SparkSQL context. 
 67   
 68          Subclasses can override this property to provide their own 
 69          JVM Contexts. 
 70          """ 
 71          if not hasattr(self, '_scala_SQLContext'): 
 72              self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) 
 73          return self._scala_SQLContext 
  74   
 75 -    def inferSchema(self, rdd): 
  76          """Infer and apply a schema to an RDD of L{dict}s. 
 77   
 78          We peek at the first row of the RDD to determine the fields names 
 79          and types, and then use that to extract all the dictionaries. 
 80   
 81          >>> srdd = sqlCtx.inferSchema(rdd) 
 82          >>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, 
 83          ...                    {"field1" : 3, "field2": "row3"}] 
 84          True 
 85          """ 
 86          if (rdd.__class__ is SchemaRDD): 
 87              raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__) 
 88          elif not isinstance(rdd.first(), dict): 
 89              raise ValueError("Only RDDs with dictionaries can be converted to %s: %s" % 
 90                               (SchemaRDD.__name__, rdd.first())) 
 91   
 92          jrdd = self._pythonToJavaMap(rdd._jrdd) 
 93          srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) 
 94          return SchemaRDD(srdd, self) 
  95   
 96 -    def registerRDDAsTable(self, rdd, tableName): 
  97          """Registers the given RDD as a temporary table in the catalog. 
 98   
 99          Temporary tables exist only during the lifetime of this instance of 
100          SQLContext. 
101   
102          >>> srdd = sqlCtx.inferSchema(rdd) 
103          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
104          """ 
105          if (rdd.__class__ is SchemaRDD): 
106              jschema_rdd = rdd._jschema_rdd 
107              self._ssql_ctx.registerRDDAsTable(jschema_rdd, tableName) 
108          else: 
109              raise ValueError("Can only register SchemaRDD as table") 
 110   
111 -    def parquetFile(self, path): 
 112          """Loads a Parquet file, returning the result as a L{SchemaRDD}. 
113   
114          >>> import tempfile, shutil 
115          >>> parquetFile = tempfile.mkdtemp() 
116          >>> shutil.rmtree(parquetFile) 
117          >>> srdd = sqlCtx.inferSchema(rdd) 
118          >>> srdd.saveAsParquetFile(parquetFile) 
119          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
120          >>> srdd.collect() == srdd2.collect() 
121          True 
122          """ 
123          jschema_rdd = self._ssql_ctx.parquetFile(path) 
124          return SchemaRDD(jschema_rdd, self) 
 125   
126 -    def sql(self, sqlQuery): 
 127          """Return a L{SchemaRDD} representing the result of the given query. 
128   
129          >>> srdd = sqlCtx.inferSchema(rdd) 
130          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
131          >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") 
132          >>> srdd2.collect() == [{"f1" : 1, "f2" : "row1"}, {"f1" : 2, "f2": "row2"}, 
133          ...                     {"f1" : 3, "f2": "row3"}] 
134          True 
135          """ 
136          return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) 
 137   
138 -    def table(self, tableName): 
 139          """Returns the specified table as a L{SchemaRDD}. 
140   
141          >>> srdd = sqlCtx.inferSchema(rdd) 
142          >>> sqlCtx.registerRDDAsTable(srdd, "table1") 
143          >>> srdd2 = sqlCtx.table("table1") 
144          >>> srdd.collect() == srdd2.collect() 
145          True 
146          """ 
147          return SchemaRDD(self._ssql_ctx.table(tableName), self) 
 148   
149 -    def cacheTable(self, tableName): 
 150          """Caches the specified table in-memory.""" 
151          self._ssql_ctx.cacheTable(tableName) 
 152   
153 -    def uncacheTable(self, tableName): 
 154          """Removes the specified table from the in-memory cache.""" 
155          self._ssql_ctx.uncacheTable(tableName) 
  156   
157   
158 -class HiveContext(SQLContext): 
 159      """A variant of Spark SQL that integrates with data stored in Hive. 
160   
161      Configuration for Hive is read from hive-site.xml on the classpath. 
162      It supports running both SQL and HiveQL commands. 
163      """ 
164   
165      @property 
166 -    def _ssql_ctx(self): 
 167          try: 
168              if not hasattr(self, '_scala_HiveContext'): 
169                  self._scala_HiveContext = self._get_hive_ctx() 
170              return self._scala_HiveContext 
171          except Py4JError as e: 
172              raise Exception("You must build Spark with Hive. Export 'SPARK_HIVE=true' and run " \ 
173                              "sbt/sbt assembly" , e) 
 174   
175 -    def _get_hive_ctx(self): 
 176          return self._jvm.HiveContext(self._jsc.sc()) 
 177   
178 -    def hiveql(self, hqlQuery): 
 179          """ 
180          Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 
181          """ 
182          return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery), self) 
 183   
184 -    def hql(self, hqlQuery): 
 185          """ 
186          Runs a query expressed in HiveQL, returning the result as a L{SchemaRDD}. 
187          """ 
188          return self.hiveql(hqlQuery) 
  189   
190   
191 -class LocalHiveContext(HiveContext): 
 192      """Starts up an instance of hive where metadata is stored locally. 
193   
194      An in-process metadata data is created with data stored in ./metadata. 
195      Warehouse data is stored in in ./warehouse. 
196   
197      >>> import os 
198      >>> hiveCtx = LocalHiveContext(sc) 
199      >>> try: 
200      ...     supress = hiveCtx.hql("DROP TABLE src") 
201      ... except Exception: 
202      ...     pass 
203      >>> kv1 = os.path.join(os.environ["SPARK_HOME"], 'examples/src/main/resources/kv1.txt') 
204      >>> supress = hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") 
205      >>> supress = hiveCtx.hql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" % kv1) 
206      >>> results = hiveCtx.hql("FROM src SELECT value").map(lambda r: int(r.value.split('_')[1])) 
207      >>> num = results.count() 
208      >>> reduce_sum = results.reduce(lambda x, y: x + y) 
209      >>> num 
210      500 
211      >>> reduce_sum 
212      130091 
213      """ 
214   
215 -    def _get_hive_ctx(self): 
 216          return self._jvm.LocalHiveContext(self._jsc.sc()) 
  217   
218   
219 -class TestHiveContext(HiveContext): 
 220   
221 -    def _get_hive_ctx(self): 
 222          return self._jvm.TestHiveContext(self._jsc.sc()) 
  223   
224   
225   
226   
227 -class Row(dict): 
 228      """A row in L{SchemaRDD}. 
229   
230      An extended L{dict} that takes a L{dict} in its constructor, and 
231      exposes those items as fields. 
232   
233      >>> r = Row({"hello" : "world", "foo" : "bar"}) 
234      >>> r.hello 
235      'world' 
236      >>> r.foo 
237      'bar' 
238      """ 
239   
241          d.update(self.__dict__) 
242          self.__dict__ = d 
243          dict.__init__(self, d) 
  244   
247      """An RDD of L{Row} objects that has an associated schema. 
248   
249      The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can 
250      utilize the relational query api exposed by SparkSQL. 
251   
252      For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the 
253      L{SchemaRDD} is not operated on directly, as it's underlying 
254      implementation is a RDD composed of Java objects. Instead it is 
255      converted to a PythonRDD in the JVM, on which Python operations can 
256      be done. 
257      """ 
258   
259 -    def __init__(self, jschema_rdd, sql_ctx): 
 260          self.sql_ctx = sql_ctx 
261          self._sc = sql_ctx._sc 
262          self._jschema_rdd = jschema_rdd 
263   
264          self.is_cached = False 
265          self.is_checkpointed = False 
266          self.ctx = self.sql_ctx._sc 
267          self._jrdd_deserializer = self.ctx.serializer 
 268   
269      @property 
271          """Lazy evaluation of PythonRDD object. 
272   
273          Only done when a user calls methods defined by the 
274          L{pyspark.rdd.RDD} super class (map, filter, etc.). 
275          """ 
276          if not hasattr(self, '_lazy_jrdd'): 
277              self._lazy_jrdd = self._toPython()._jrdd 
278          return self._lazy_jrdd 
 279   
280      @property 
282          return self._jrdd.id() 
 283   
285          """Save the contents as a Parquet file, preserving the schema. 
286   
287          Files that are written out using this method can be read back in as 
288          a SchemaRDD using the L{SQLContext.parquetFile} method. 
289   
290          >>> import tempfile, shutil 
291          >>> parquetFile = tempfile.mkdtemp() 
292          >>> shutil.rmtree(parquetFile) 
293          >>> srdd = sqlCtx.inferSchema(rdd) 
294          >>> srdd.saveAsParquetFile(parquetFile) 
295          >>> srdd2 = sqlCtx.parquetFile(parquetFile) 
296          >>> srdd2.collect() == srdd.collect() 
297          True 
298          """ 
299          self._jschema_rdd.saveAsParquetFile(path) 
 300   
302          """Registers this RDD as a temporary table using the given name. 
303   
304          The lifetime of this temporary table is tied to the L{SQLContext} 
305          that was used to create this SchemaRDD. 
306   
307          >>> srdd = sqlCtx.inferSchema(rdd) 
308          >>> srdd.registerAsTable("test") 
309          >>> srdd2 = sqlCtx.sql("select * from test") 
310          >>> srdd.collect() == srdd2.collect() 
311          True 
312          """ 
313          self._jschema_rdd.registerAsTable(name) 
 314   
315 -    def insertInto(self, tableName, overwrite = False): 
 316          """Inserts the contents of this SchemaRDD into the specified table. 
317   
318          Optionally overwriting any existing data. 
319          """ 
320          self._jschema_rdd.insertInto(tableName, overwrite) 
 321   
323          """Creates a new table with the contents of this SchemaRDD.""" 
324          self._jschema_rdd.saveAsTable(tableName) 
 325   
327          """Return the number of elements in this RDD. 
328   
329          Unlike the base RDD implementation of count, this implementation 
330          leverages the query optimizer to compute the count on the SchemaRDD, 
331          which supports features such as filter pushdown. 
332   
333          >>> srdd = sqlCtx.inferSchema(rdd) 
334          >>> srdd.count() 
335          3L 
336          >>> srdd.count() == srdd.map(lambda x: x).count() 
337          True 
338          """ 
339          return self._jschema_rdd.count() 
 340   
342           
343           
344          from pyspark.sql import Row 
345          jrdd = self._jschema_rdd.javaToPython() 
346           
347           
348           
349          return RDD(jrdd, self._sc, self._sc.serializer).map(lambda d: Row(d)) 
 350   
351       
352       
354          self.is_cached = True 
355          self._jschema_rdd.cache() 
356          return self 
 357   
359          self.is_cached = True 
360          javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) 
361          self._jschema_rdd.persist(javaStorageLevel) 
362          return self 
 363   
365          self.is_cached = False 
366          self._jschema_rdd.unpersist() 
367          return self 
 368   
370          self.is_checkpointed = True 
371          self._jschema_rdd.checkpoint() 
 372   
375   
377          checkpointFile = self._jschema_rdd.getCheckpointFile() 
378          if checkpointFile.isDefined(): 
379              return checkpointFile.get() 
380          else: 
381              return None 
 382   
383 -    def coalesce(self, numPartitions, shuffle=False): 
 386   
390   
392          if (other.__class__ is SchemaRDD): 
393              rdd = self._jschema_rdd.intersection(other._jschema_rdd) 
394              return SchemaRDD(rdd, self.sql_ctx) 
395          else: 
396              raise ValueError("Can only intersect with another SchemaRDD") 
 397   
401   
402 -    def subtract(self, other, numPartitions=None): 
 403          if (other.__class__ is SchemaRDD): 
404              if numPartitions is None: 
405                  rdd = self._jschema_rdd.subtract(other._jschema_rdd) 
406              else: 
407                  rdd = self._jschema_rdd.subtract(other._jschema_rdd, numPartitions) 
408              return SchemaRDD(rdd, self.sql_ctx) 
409          else: 
410              raise ValueError("Can only subtract another SchemaRDD") 
  411   
413      import doctest 
414      from pyspark.context import SparkContext 
415      globs = globals().copy() 
416       
417       
418      sc = SparkContext('local[4]', 'PythonTest', batchSize=2) 
419      globs['sc'] = sc 
420      globs['sqlCtx'] = SQLContext(sc) 
421      globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"}, 
422          {"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}]) 
423      (failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS) 
424      globs['sc'].stop() 
425      if failure_count: 
426          exit(-1) 
 427   
428   
429  if __name__ == "__main__": 
430      _test() 
431