From 2a24beede560e15aad22ae9d76cff89a8d30eb57 Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Wed, 18 Mar 2015 15:46:27 +0100 Subject: [PATCH 1/7] Use the same library when Reading and Writing CSV --- .../databricks/spark/csv/JavaCsvParser.java | 8 ++- .../com/databricks/spark/csv/CsvParser.scala | 8 ++- .../databricks/spark/csv/CsvRelation.scala | 3 ++ .../databricks/spark/csv/DefaultSource.scala | 9 +++- .../com/databricks/spark/csv/package.scala | 49 +++++++++++-------- 5 files changed, 54 insertions(+), 23 deletions(-) diff --git a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java b/src/main/java/com/databricks/spark/csv/JavaCsvParser.java index e8e9d18..3ff2615 100755 --- a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java +++ b/src/main/java/com/databricks/spark/csv/JavaCsvParser.java @@ -27,6 +27,7 @@ public class JavaCsvParser { private Boolean useHeader = true; private Character delimiter = ','; private Character quote = '"'; + private Character escape = '\\'; private StructType schema = null; public JavaCsvParser withUseHeader(Boolean flag) { @@ -44,6 +45,11 @@ public JavaCsvParser withQuoteChar(Character quote) { return this; } + public JavaCsvParser withEscapeChar(Character escape) { + this.escape = escape; + return this; + } + public JavaCsvParser withSchema(StructType schema) { this.schema = schema; return this; @@ -52,7 +58,7 @@ public JavaCsvParser withSchema(StructType schema) { /** Returns a Schema RDD for the given CSV path. */ public DataFrame csvFile(SQLContext sqlContext, String path) { CsvRelation relation = new - CsvRelation(path, useHeader, delimiter, quote, schema, sqlContext); + CsvRelation(path, useHeader, delimiter, quote, escape, schema, sqlContext); return sqlContext.baseRelationToDataFrame(relation); } } diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index e75f3b3..beeeda6 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -26,6 +26,7 @@ class CsvParser { private var useHeader: Boolean = true private var delimiter: Character = ',' private var quote: Character = '"' + private var escape: Character = '\\' private var schema: StructType = null def withUseHeader(flag: Boolean): CsvParser = { @@ -43,6 +44,11 @@ class CsvParser { this } + def withEscapeChar(escape: Character): CsvParser = { + this.escape = escape + this + } + def withSchema(schema: StructType): CsvParser = { this.schema = schema this @@ -50,7 +56,7 @@ class CsvParser { /** Returns a Schema RDD for the given CSV path. */ def csvFile(sqlContext: SQLContext, path: String): DataFrame = { - val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, schema)(sqlContext) + val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, escape, schema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 7b324d4..5c14499 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -34,6 +34,7 @@ case class CsvRelation protected[spark] ( useHeader: Boolean, delimiter: Char, quote: Char, + escape: Char, userSchema: StructType = null)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { @@ -53,6 +54,7 @@ case class CsvRelation protected[spark] ( val csvFormat = CSVFormat.DEFAULT .withDelimiter(delimiter) .withQuote(quote) + .withEscape(escape) .withSkipHeaderRecord(false) .withHeader(fieldNames: _*) @@ -78,6 +80,7 @@ case class CsvRelation protected[spark] ( val csvFormat = CSVFormat.DEFAULT .withDelimiter(delimiter) .withQuote(quote) + .withEscape(escape) .withSkipHeaderRecord(false) val firstRow = CSVParser.parse(firstLine, csvFormat).getRecords.head.toList val header = if (useHeader) { diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index a0c1ea1..d9fac9a 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -62,6 +62,13 @@ class DefaultSource throw new Exception("Quotation cannot be more than one character.") } + val escape = parameters.getOrElse("escape", "\\") + val escapeChar = if (escape.length == 1) { + escape.charAt(0) + } else { + throw new Exception("Escape cannot be more than one character.") + } + val useHeader = parameters.getOrElse("header", "true") val headerFlag = if (useHeader == "true") { true @@ -71,7 +78,7 @@ class DefaultSource throw new Exception("Header flag can be true or false") } - CsvRelation(path, headerFlag, delimiterChar, quoteChar, schema)(sqlContext) + CsvRelation(path, headerFlag, delimiterChar, quoteChar, escapeChar, schema)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 3692ec3..b0bb1f0 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -17,6 +17,13 @@ package com.databricks.spark import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.commons.csv.CSVFormat; +import org.apache.commons.csv.CSVPrinter; + +import java.io.StringWriter; + +import scala.collection.convert.WrapAsJava + package object csv { /** @@ -28,7 +35,8 @@ package object csv { location = filePath, useHeader = true, delimiter = ',', - quote = '"')(sqlContext) + quote = '"', + escape = '\\')(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } @@ -37,7 +45,8 @@ package object csv { location = filePath, useHeader = true, delimiter = '\t', - quote = '"')(sqlContext) + quote = '"', + escape = '\\')(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } } @@ -45,28 +54,28 @@ package object csv { implicit class CsvSchemaRDD(dataFrame: DataFrame) { def saveAsCsvFile(path: String, parameters: Map[String, String] = Map()): Unit = { // TODO(hossein): For nested types, we may want to perform special work - val delimiter = parameters.getOrElse("delimiter", ",") + val delimiter = parameters.getOrElse("delimiter", ",").charAt(0) + val quote = parameters.getOrElse("quote", "\"").charAt(0) + val escape = parameters.getOrElse("escape", "\\").charAt(0) val generateHeader = parameters.getOrElse("header", "false").toBoolean - val header = if (generateHeader) { - dataFrame.columns.map(c => s""""$c"""").mkString(delimiter) - } else { - "" // There is no need to generate header in this case - } - val strRDD = dataFrame.rdd.mapPartitions { iter => - new Iterator[String] { - var firstRow: Boolean = generateHeader + val header = dataFrame.columns.map(c => c).toArray - override def hasNext = iter.hasNext + var firstRow: Boolean = true + val csvFileFormat = CSVFormat.DEFAULT + .withDelimiter(delimiter) + .withQuote(quote) + .withEscape(escape) - override def next: String = { - if (firstRow) { - firstRow = false - header + "\n" + iter.next.mkString(delimiter) - } else { - iter.next.mkString(delimiter) - } - } + val strRDD = dataFrame.rdd.mapPartitions { iter => + var firstRow: Boolean = generateHeader + val newIter = iter.map(_.toSeq.toArray) + val stringWriter = new StringWriter() + val csvPrinter = new CSVPrinter(stringWriter, csvFileFormat) + if (firstRow) { + csvPrinter.printRecord(header) } + csvPrinter.printRecords(WrapAsJava.asJavaIterable(newIter.toIterable)) + Iterator(stringWriter.toString) } strRDD.saveAsTextFile(path) } From 411a0ae6bd9034539e540c356c33be35b9a59c7b Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Wed, 18 Mar 2015 21:36:44 +0100 Subject: [PATCH 2/7] fixed issue header in not well printed --- src/main/scala/com/databricks/spark/csv/package.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index b0bb1f0..e292e54 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -58,9 +58,9 @@ package object csv { val quote = parameters.getOrElse("quote", "\"").charAt(0) val escape = parameters.getOrElse("escape", "\\").charAt(0) val generateHeader = parameters.getOrElse("header", "false").toBoolean - val header = dataFrame.columns.map(c => c).toArray + val header = dataFrame.columns - var firstRow: Boolean = true + var firstRow: Boolean = generateHeader val csvFileFormat = CSVFormat.DEFAULT .withDelimiter(delimiter) .withQuote(quote) @@ -72,7 +72,8 @@ package object csv { val stringWriter = new StringWriter() val csvPrinter = new CSVPrinter(stringWriter, csvFileFormat) if (firstRow) { - csvPrinter.printRecord(header) + firstRow = false + csvPrinter.printRecord(header:_*) } csvPrinter.printRecords(WrapAsJava.asJavaIterable(newIter.toIterable)) Iterator(stringWriter.toString) From 54782ac62703cadffda91a28ec15bfb2a92c40e7 Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Thu, 19 Mar 2015 14:23:55 +0100 Subject: [PATCH 3/7] changed the type of quote and escape from char to Character so nullable --- src/main/scala/com/databricks/spark/csv/CsvParser.scala | 2 +- src/main/scala/com/databricks/spark/csv/CsvRelation.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index beeeda6..31f121f 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -26,7 +26,7 @@ class CsvParser { private var useHeader: Boolean = true private var delimiter: Character = ',' private var quote: Character = '"' - private var escape: Character = '\\' + private var escape: Character = null private var schema: StructType = null def withUseHeader(flag: Boolean): CsvParser = { diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 5c14499..3cfb922 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -33,8 +33,8 @@ case class CsvRelation protected[spark] ( location: String, useHeader: Boolean, delimiter: Char, - quote: Char, - escape: Char, + quote: Character, + escape: Character, userSchema: StructType = null)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { From 508dbbebd7712b5d22f8e57b69e367ac66fd03a6 Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Thu, 19 Mar 2015 14:24:35 +0100 Subject: [PATCH 4/7] added tests for escape --- src/test/resources/family-cars.csv | 4 +++ .../com/databricks/spark/csv/CsvSuite.scala | 33 +++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 src/test/resources/family-cars.csv diff --git a/src/test/resources/family-cars.csv b/src/test/resources/family-cars.csv new file mode 100644 index 0000000..1819bea --- /dev/null +++ b/src/test/resources/family-cars.csv @@ -0,0 +1,4 @@ +year,make,model,comment +2012,VW,Touran,"The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"" +2013,Seat,Alhambra,"It is a great \"family\" car, for big families" +2014,Peugeot,5008,"It is a fine \"family\" car" diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index ace3a44..4c2756c 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -27,8 +27,10 @@ import TestSQLContext._ class CsvSuite extends FunSuite { val carsFile = "src/test/resources/cars.csv" val carsAltFile = "src/test/resources/cars-alternative.csv" + val familyCarsFile = "src/test/resources/family-cars.csv" val emptyFile = "src/test/resources/empty.csv" val tempEmptyDir = "target/test/empty/" + val tempFamilyCarsDir = "target/test/family-cars" test("DSL test") { val results = TestSQLContext @@ -61,6 +63,37 @@ class CsvSuite extends FunSuite { assert(results.size === 2) } + test("DSL test read write with escape") { + //Parse a csv file with \ as escape character + val results = new CsvParser() + .withEscapeChar('\\') + .csvFile(TestSQLContext, familyCarsFile) + //Check that the file was as expected parse + val firstComment1 = results + .select("comment") + .collect() + .head + .getString(0) + assert(firstComment1 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"") + + TestUtils.deleteRecursively(new File(tempFamilyCarsDir)) + //Save the dataFrame without providing an escape character (default is ") + results.saveAsCsvFile(tempFamilyCarsDir, Map("header" -> "true")) + //Check that the generated file is well formed + val rawData = TestSQLContext.sparkContext.textFile(tempFamilyCarsDir).toArray + assert(rawData.contains("2012,VW,Touran,\"The ideal car for \"\"families\"\" and all their \"\"bags\"\", \"\"boxes\"\" and \"\"barbecues\"\"\"")) + + //Check that the generated file is well parsed + val results2 = new CsvParser() + .csvFile(TestSQLContext, tempFamilyCarsDir) + val firstComment2 = results2 + .select("comment") + .collect() + .head + .getString(0) + assert(firstComment2 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"") + } + test("DDL test with alternative delimiter and quote") { sql( s""" From 2909010e956b41e6a4343cfc0c0f5e0fcb2b6688 Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Fri, 3 Apr 2015 14:14:41 +0200 Subject: [PATCH 5/7] merged with original --- .../databricks/spark/csv/JavaCsvParser.java | 64 ------------------- 1 file changed, 64 deletions(-) delete mode 100755 src/main/java/com/databricks/spark/csv/JavaCsvParser.java diff --git a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java b/src/main/java/com/databricks/spark/csv/JavaCsvParser.java deleted file mode 100755 index 3ff2615..0000000 --- a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.databricks.spark.csv; - -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.StructType; - -/** - * A collection of static functions for working with CSV files in Spark SQL - */ -public class JavaCsvParser { - - private Boolean useHeader = true; - private Character delimiter = ','; - private Character quote = '"'; - private Character escape = '\\'; - private StructType schema = null; - - public JavaCsvParser withUseHeader(Boolean flag) { - this.useHeader = flag; - return this; - } - - public JavaCsvParser withDelimiter(Character delimiter) { - this.delimiter = delimiter; - return this; - } - - public JavaCsvParser withQuoteChar(Character quote) { - this.quote = quote; - return this; - } - - public JavaCsvParser withEscapeChar(Character escape) { - this.escape = escape; - return this; - } - - public JavaCsvParser withSchema(StructType schema) { - this.schema = schema; - return this; - } - - /** Returns a Schema RDD for the given CSV path. */ - public DataFrame csvFile(SQLContext sqlContext, String path) { - CsvRelation relation = new - CsvRelation(path, useHeader, delimiter, quote, escape, schema, sqlContext); - return sqlContext.baseRelationToDataFrame(relation); - } -} From b26655731136adca26dfdb589b669c8719d21fd9 Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Fri, 3 Apr 2015 14:17:25 +0200 Subject: [PATCH 6/7] merged with original --- .gitignore | 1 + README.md | 98 +++++++--- build.sbt | 11 +- .../com/databricks/spark/csv/CsvParser.scala | 30 ++- .../databricks/spark/csv/CsvRelation.scala | 37 +++- .../databricks/spark/csv/DefaultSource.scala | 14 +- .../com/databricks/spark/csv/package.scala | 121 ++++++++---- .../spark/csv/util/ParseModes.scala | 40 ++++ .../databricks/spark/csv/JavaCsvSuite.java | 63 ++++++ src/test/resources/cars-alternative.csv | 1 + src/test/resources/cars.csv | 3 +- src/test/resources/escape.csv | 2 + src/test/resources/family-cars.csv | 4 - .../com/databricks/spark/csv/CsvSuite.scala | 179 ++++++++++++++---- .../com/databricks/spark/csv/TestUtils.scala | 15 ++ 15 files changed, 491 insertions(+), 128 deletions(-) create mode 100644 src/main/scala/com/databricks/spark/csv/util/ParseModes.scala create mode 100644 src/test/java/com/databricks/spark/csv/JavaCsvSuite.java create mode 100644 src/test/resources/escape.csv delete mode 100644 src/test/resources/family-cars.csv diff --git a/.gitignore b/.gitignore index 706ddde..a1c2435 100755 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ sbt/sbt-launch*.jar target/ .idea/ .idea_modules/ +.DS_Store diff --git a/README.md b/README.md index 189036f..6b052bb 100755 --- a/README.md +++ b/README.md @@ -13,36 +13,37 @@ You can link against this library in your program at the following coordiates: ``` groupId: com.databricks -artifactId: spark-csv_2.10 -version: 1.0.0 +artifactId: spark-csv_2.11 +version: 1.0.3 ``` -The spark-csv assembly jar file can also be added to a Spark using the `--jars` command line option. For example, to include it when starting the spark shell: + +## Using with Apache Spark +This package can be added to Spark using the `--jars` command line option. For example, to include it when starting the spark shell: ``` -$ bin/spark-shell --jars spark-csv-assembly-1.0.0.jar +$ bin/spark-shell --packages com.databricks:spark-csv_2.10:1.0.3 ``` ## Features +This package allows reading CSV files in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/1.3.0/sql-programming-guide.html). +When reading files the API accepts several options: +* path: location of files. Similar to Spark can accept standard Hadoop globbing expressions. +* header: when set to true the first line of files will be used to name columns and will not be included in data. All types will be assumed string. Default value is false. +* delimiter: by default lines are delimited using ',', but delimiter can be set to any character +* quote: by default the quote character is '"', but can be set to any character. Delimiters inside quotes are ignored +* mode: determines the parsing mode. By default it is PERMISSIVE. Possible values are: + * PERMISSIVE: tries to parse all lines: nulls are inserted for missing tokens and extra tokens are ignored. + * DROPMALFORMED: drops lines which have fewer or more tokens than expected + * FAILFAST: aborts with a RuntimeException if encounters any malformed line + +The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. + These examples use a CSV file available for download [here](https://github.com/databricks/spark-csv/raw/master/src/test/resources/cars.csv): ``` $ wget https://github.com/databricks/spark-csv/raw/master/src/test/resources/cars.csv ``` -### Scala API - -You can use the library by loading the implicits from `com.databricks.spark.csv._`. - -``` -import org.apache.spark.sql.SQLContext - -val sqlContext = new SQLContext(sc) - -import com.databricks.spark.csv._ - -val cars = sqlContext.csvFile("cars.csv") -``` - ### SQL API CSV data can be queried in pure SQL by registering the data as a (temporary) table. @@ -59,18 +60,65 @@ USING com.databricks.spark.csv OPTIONS (path "cars.csv", header "true") ``` +### Scala API +The recommended way to load CSV data is using the load/save functions in SQLContext. + +```scala +import org.apache.spark.sql.SQLContext + +val sqlContext = new SQLContext(sc) +val df = sqlContext.load("com.databricks.spark.csv", Map("path" -> "cars.csv", "header" -> "true")) +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv") +``` + +You can also use the implicits from `com.databricks.spark.csv._`. + +```scala +import org.apache.spark.sql.SQLContext +import com.databricks.spark.csv._ + +val sqlContext = new SQLContext(sc) + +val cars = sqlContext.csvFile("cars.csv") +cars.select("year", "model").saveAsCsvFile("newcars.tsv") +``` + ### Java API -CSV files can be read using functions in JavaCsvParser. +Similar to Scala, we recommend load/save functions in SQLContext. ```java -import com.databricks.spark.csv.JavaCsvParser; +import org.apache.spark.sql.SQLContext -DataFrame cars = (new JavaCsvParser()).withUseHeader(true).csvFile(sqlContext, "cars.csv"); +SQLContext sqlContext = new SQLContext(sc); + +HashMap options = new HashMap(); +options.put("header", "true"); +options.put("path", "cars.csv"); + +DataFrame df = sqlContext.load("com.databricks.spark.csv", options); +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv"); +``` +See documentations of load and save for more details. + +In Java (as well as Scala) CSV files can be read using functions in CsvParser. + +```java +import com.databricks.spark.csv.CsvParser; +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); + +DataFrame cars = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, "cars.csv"); ``` -### Saving as CSV -You can save your DataFrame using `saveAsCsvFile` function. The function allows you to specify the delimiter and whether we should generate a header row for the table (each header has name `C$i` where `$i` is column index). For example: -```myDataFrame.saveAsCsvFile("/mydir", Map("delimiter" -> "|", "header" -> "true"))``` +### Python API +In Python you can read and save CSV files using load/save functions. + +```python +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + +df = sqlContext.load(source="com.databricks.spark.csv", header="true", path = "cars.csv") +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv") +``` ## Building From Source -This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line-Reference.html), which is automatically downloaded by the included shell script. To build a JAR file simply run `sbt/sbt assembly` from the project root. +This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line-Reference.html), which is automatically downloaded by the included shell script. To build a JAR file simply run `sbt/sbt package` from the project root. The build configuration includes support for both Scala 2.10 and 2.11. diff --git a/build.sbt b/build.sbt index 6cb34a0..1f511bb 100755 --- a/build.sbt +++ b/build.sbt @@ -1,10 +1,12 @@ name := "spark-csv" -version := "1.0.0" +version := "1.0.3" organization := "com.databricks" -scalaVersion := "2.10.4" +scalaVersion := "2.11.6" + +crossScalaVersions := Seq("2.10.4", "2.11.6") libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.3.0" % "provided" @@ -55,7 +57,6 @@ sparkVersion := "1.3.0" sparkComponents += "sql" -// Enable Junit testing. -// libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test" - libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test" + +libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test" diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index 31f121f..0699484 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -18,16 +18,19 @@ package com.databricks.spark.csv import org.apache.spark.sql.{SQLContext, DataFrame} import org.apache.spark.sql.types.StructType +import com.databricks.spark.csv.util.ParseModes + /** * A collection of static functions for working with CSV files in Spark SQL */ class CsvParser { - private var useHeader: Boolean = true + private var useHeader: Boolean = false private var delimiter: Character = ',' private var quote: Character = '"' - private var escape: Character = null + private var escape: Character = '\\' private var schema: StructType = null + private var parseMode: String = ParseModes.DEFAULT def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -44,19 +47,32 @@ class CsvParser { this } - def withEscapeChar(escape: Character): CsvParser = { - this.escape = escape + def withSchema(schema: StructType): CsvParser = { + this.schema = schema this } - def withSchema(schema: StructType): CsvParser = { - this.schema = schema + def withParseMode(mode: String): CsvParser = { + this.parseMode = mode + this + } + + def withEscape(escapeChar: Character): CsvParser = { + this.escape = escapeChar this } /** Returns a Schema RDD for the given CSV path. */ + @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { - val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, escape, schema)(sqlContext) + val relation: CsvRelation = CsvRelation( + path, + useHeader, + delimiter, + quote, + escape, + parseMode, + schema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 3cfb922..2f3a43f 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -28,18 +28,28 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan import org.apache.spark.sql.types.{StructType, StructField, StringType} import org.slf4j.LoggerFactory +import com.databricks.spark.csv.util.ParseModes case class CsvRelation protected[spark] ( location: String, useHeader: Boolean, delimiter: Char, - quote: Character, - escape: Character, + quote: Char, + escape: Char, + parseMode: String, userSchema: StructType = null)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + private val failFast = ParseModes.isFailFastMode(parseMode) + private val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + private val permissive = ParseModes.isPermissiveMode(parseMode) + val schema = inferSchema() // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. @@ -118,6 +128,7 @@ case class CsvRelation protected[spark] ( projection: MutableProjection, row: GenericMutableRow): Iterator[Row] = { iter.flatMap { line => + var index: Int = 0 try { val records = CSVParser.parse(line, csvFormat).getRecords if (records.isEmpty) { @@ -125,15 +136,25 @@ case class CsvRelation protected[spark] ( None } else { val tokens = records.head - var index = 0 - while (index < schemaFields.length) { - row(index) = tokens.get(index) - index = index + 1 + index = 0 + if (dropMalformed && schemaFields.length != tokens.size) { + logger.warn(s"Dropping malformed line: $line") + None + } else if (failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: $line") + } else { + while (index < schemaFields.length) { + row(index) = tokens.get(index) + index = index + 1 + } + Some(projection(row)) } - Some(projection(row)) } } catch { - case NonFatal(e) => + case aiob: ArrayIndexOutOfBoundsException if permissive => + (index until schemaFields.length).foreach(ind => row(ind) = null) + Some(projection(row)) + case NonFatal(e) if !failFast => logger.error(s"Exception while parsing line: $line. ", e) None } diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index d9fac9a..0b18d5a 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -66,10 +66,12 @@ class DefaultSource val escapeChar = if (escape.length == 1) { escape.charAt(0) } else { - throw new Exception("Escape cannot be more than one character.") + throw new Exception("Escape character cannot be more than one character.") } - val useHeader = parameters.getOrElse("header", "true") + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + + val useHeader = parameters.getOrElse("header", "false") val headerFlag = if (useHeader == "true") { true } else if (useHeader == "false") { @@ -78,7 +80,13 @@ class DefaultSource throw new Exception("Header flag can be true or false") } - CsvRelation(path, headerFlag, delimiterChar, quoteChar, escapeChar, schema)(sqlContext) + CsvRelation(path, + headerFlag, + delimiterChar, + quoteChar, + escapeChar, + parseMode, + schema)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index e292e54..d5cec13 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -15,14 +15,10 @@ */ package com.databricks.spark -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.commons.csv.CSVFormat +import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.commons.csv.CSVFormat; -import org.apache.commons.csv.CSVPrinter; - -import java.io.StringWriter; - -import scala.collection.convert.WrapAsJava +import org.apache.spark.sql.{SQLContext, DataFrame, Row} package object csv { @@ -30,55 +26,106 @@ package object csv { * Adds a method, `csvFile`, to SQLContext that allows reading CSV data. */ implicit class CsvContext(sqlContext: SQLContext) { - def csvFile(filePath: String) = { + def csvFile(filePath: String, + useHeader: Boolean = true, + delimiter: Char = ',', + quote: Char = '"', + escape: Char = '\\', + mode: String = "PERMISSIVE") = { val csvRelation = CsvRelation( location = filePath, - useHeader = true, - delimiter = ',', - quote = '"', - escape = '\\')(sqlContext) + useHeader = useHeader, + delimiter = delimiter, + quote = quote, + escape = escape, + parseMode = mode)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } - def tsvFile(filePath: String) = { + def tsvFile(filePath: String, useHeader: Boolean = true) = { val csvRelation = CsvRelation( location = filePath, - useHeader = true, + useHeader = useHeader, delimiter = '\t', quote = '"', - escape = '\\')(sqlContext) + escape = '\\', + parseMode = "PERMISSIVE")(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } } - + implicit class CsvSchemaRDD(dataFrame: DataFrame) { - def saveAsCsvFile(path: String, parameters: Map[String, String] = Map()): Unit = { + + /** + * Saves DataFrame as csv files. By default uses ',' as delimiter, and includes header line. + */ + def saveAsCsvFile(path: String, parameters: Map[String, String] = Map(), + compressionCodec: Class[_ <: CompressionCodec] = null): Unit = { // TODO(hossein): For nested types, we may want to perform special work - val delimiter = parameters.getOrElse("delimiter", ",").charAt(0) - val quote = parameters.getOrElse("quote", "\"").charAt(0) - val escape = parameters.getOrElse("escape", "\\").charAt(0) - val generateHeader = parameters.getOrElse("header", "false").toBoolean - val header = dataFrame.columns + val delimiter = parameters.getOrElse("delimiter", ",") + val delimiterChar = if (delimiter.length == 1) { + delimiter.charAt(0) + } else { + throw new Exception("Delimiter cannot be more than one character.") + } + + val escape = parameters.getOrElse("escape", "\\") + val escapeChar = if (escape.length == 1) { + escape.charAt(0) + } else { + throw new Exception("Escape character cannot be more than one character.") + } + + val quoteChar = parameters.get("quote") match { + case Some(s) => { + if (s.length == 1) { + Some(s.charAt(0)) + } else { + throw new Exception("Quotation cannot be more than one character.") + } + } + case None => None + } - var firstRow: Boolean = generateHeader - val csvFileFormat = CSVFormat.DEFAULT - .withDelimiter(delimiter) - .withQuote(quote) - .withEscape(escape) + val csvFormatBase = CSVFormat.DEFAULT + .withDelimiter(delimiterChar) + .withEscape(escapeChar) + .withSkipHeaderRecord(false) + .withNullString("null") + val csvFormat = quoteChar match { + case Some(c) => csvFormatBase.withQuote(c) + case _ => csvFormatBase + } + + val generateHeader = parameters.getOrElse("header", "false").toBoolean + val header = if (generateHeader) { + csvFormat.format(dataFrame.columns.map(_.asInstanceOf[AnyRef]):_*) + } else { + "" // There is no need to generate header in this case + } val strRDD = dataFrame.rdd.mapPartitions { iter => - var firstRow: Boolean = generateHeader - val newIter = iter.map(_.toSeq.toArray) - val stringWriter = new StringWriter() - val csvPrinter = new CSVPrinter(stringWriter, csvFileFormat) - if (firstRow) { - firstRow = false - csvPrinter.printRecord(header:_*) + + new Iterator[String] { + var firstRow: Boolean = generateHeader + + override def hasNext = iter.hasNext + + override def next: String = { + val row = csvFormat.format(iter.next.toSeq.map(_.asInstanceOf[AnyRef]):_*) + if (firstRow) { + firstRow = false + header + "\n" + row + } else { + row + } + } } - csvPrinter.printRecords(WrapAsJava.asJavaIterable(newIter.toIterable)) - Iterator(stringWriter.toString) } - strRDD.saveAsTextFile(path) + compressionCodec match { + case null => strRDD.saveAsTextFile(path) + case codec => strRDD.saveAsTextFile(path, codec) + } } } } diff --git a/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala b/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala new file mode 100644 index 0000000..babad29 --- /dev/null +++ b/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.csv.util + +private[csv] object ParseModes { + + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String) = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String) = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String) = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java b/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java new file mode 100644 index 0000000..1e26dbf --- /dev/null +++ b/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java @@ -0,0 +1,63 @@ +package com.databricks.spark.csv; + +import java.io.File; +import java.util.HashMap; +import java.util.Random; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext$; + +public class JavaCsvSuite { + private transient SQLContext sqlContext; + private int numCars = 3; + + String carsFile = "src/test/resources/cars.csv"; + + private String tempDir = "target/test/csvData/"; + + @Before + public void setUp() { + // Trigger static initializer of TestData + sqlContext = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + sqlContext = null; + } + + @Test + public void testCsvParser() { + DataFrame df = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, carsFile); + int result = df.select("model").collect().length; + Assert.assertEquals(result, numCars); + } + + @Test + public void testLoad() { + HashMap options = new HashMap(); + options.put("header", "true"); + options.put("path", carsFile); + + DataFrame df = sqlContext.load("com.databricks.spark.csv", options); + int result = df.select("year").collect().length; + Assert.assertEquals(result, numCars); + } + + @Test + public void testSave() { + DataFrame df = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, carsFile); + TestUtils.deleteRecursively(new File(tempDir)); + df.select("year", "model").save(tempDir, "com.databricks.spark.csv"); + + DataFrame newDf = (new CsvParser()).csvFile(sqlContext, tempDir); + int result = newDf.select("C1").collect().length; + Assert.assertEquals(result, numCars); + + } +} diff --git a/src/test/resources/cars-alternative.csv b/src/test/resources/cars-alternative.csv index b7f83c8..2c1285a 100644 --- a/src/test/resources/cars-alternative.csv +++ b/src/test/resources/cars-alternative.csv @@ -2,3 +2,4 @@ year|make|model|comment '2012'|'Tesla'|'S'| 'No comment' 1997|Ford|E350|'Go get one now they are going fast' +2015|Chevy|Volt diff --git a/src/test/resources/cars.csv b/src/test/resources/cars.csv index 86512c1..24d5e11 100644 --- a/src/test/resources/cars.csv +++ b/src/test/resources/cars.csv @@ -1,4 +1,5 @@ year,make,model,comment,blank -"2012","Tesla","S", "No comment", +"2012","Tesla","S","No comment", 1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt \ No newline at end of file diff --git a/src/test/resources/escape.csv b/src/test/resources/escape.csv new file mode 100644 index 0000000..d9ff81a --- /dev/null +++ b/src/test/resources/escape.csv @@ -0,0 +1,2 @@ +"column" +|"thing \ No newline at end of file diff --git a/src/test/resources/family-cars.csv b/src/test/resources/family-cars.csv deleted file mode 100644 index 1819bea..0000000 --- a/src/test/resources/family-cars.csv +++ /dev/null @@ -1,4 +0,0 @@ -year,make,model,comment -2012,VW,Touran,"The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"" -2013,Seat,Alhambra,"It is a great \"family\" car, for big families" -2014,Peugeot,5008,"It is a fine \"family\" car" diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index 4c2756c..88db433 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -17,7 +17,9 @@ package com.databricks.spark.csv import java.io.File +import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.sql.test._ +import org.apache.spark.SparkException import org.apache.spark.sql.types._ import org.scalatest.FunSuite @@ -27,10 +29,11 @@ import TestSQLContext._ class CsvSuite extends FunSuite { val carsFile = "src/test/resources/cars.csv" val carsAltFile = "src/test/resources/cars-alternative.csv" - val familyCarsFile = "src/test/resources/family-cars.csv" val emptyFile = "src/test/resources/empty.csv" + val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty/" - val tempFamilyCarsDir = "target/test/family-cars" + + val numCars = 3 test("DSL test") { val results = TestSQLContext @@ -38,7 +41,7 @@ class CsvSuite extends FunSuite { .select("year") .collect() - assert(results.size === 2) + assert(results.size === numCars) } test("DDL test") { @@ -49,49 +52,72 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsFile", header "true") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT year FROM carsTable").collect().size === 2) + assert(sql("SELECT year FROM carsTable").collect().size === numCars) + } + + test("DSL test for DROPMALFORMED parsing mode") { + val results = new CsvParser() + .withParseMode("DROPMALFORMED") + .withUseHeader(true) + .csvFile(TestSQLContext, carsFile) + .select("year") + .collect() + + assert(results.size === numCars - 1) } + test("DSL test for FAILFAST parsing mode") { + val parser = new CsvParser() + .withParseMode("FAILFAST") + .withUseHeader(true) + + val exception = intercept[SparkException]{ + parser.csvFile(TestSQLContext, carsFile) + .select("year") + .collect() + } + + assert(exception.getMessage.contains("Malformed line in FAILFAST mode")) + } + + test("DSL test with alternative delimiter and quote") { val results = new CsvParser() .withDelimiter('|') .withQuoteChar('\'') + .withUseHeader(true) .csvFile(TestSQLContext, carsAltFile) .select("year") .collect() - assert(results.size === 2) + assert(results.size === numCars) } - test("DSL test read write with escape") { - //Parse a csv file with \ as escape character - val results = new CsvParser() - .withEscapeChar('\\') - .csvFile(TestSQLContext, familyCarsFile) - //Check that the file was as expected parse - val firstComment1 = results - .select("comment") + test("DSL test with alternative delimiter and quote using sparkContext.csvFile") { + val results = + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = '|', quote = '\'') + .select("year") .collect() - .head - .getString(0) - assert(firstComment1 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"") - - TestUtils.deleteRecursively(new File(tempFamilyCarsDir)) - //Save the dataFrame without providing an escape character (default is ") - results.saveAsCsvFile(tempFamilyCarsDir, Map("header" -> "true")) - //Check that the generated file is well formed - val rawData = TestSQLContext.sparkContext.textFile(tempFamilyCarsDir).toArray - assert(rawData.contains("2012,VW,Touran,\"The ideal car for \"\"families\"\" and all their \"\"bags\"\", \"\"boxes\"\" and \"\"barbecues\"\"\"")) - - //Check that the generated file is well parsed - val results2 = new CsvParser() - .csvFile(TestSQLContext, tempFamilyCarsDir) - val firstComment2 = results2 - .select("comment") + + assert(results.size === numCars) + } + + test("Expect parsing error with wrong delimiter settting using sparkContext.csvFile") { + intercept[ org.apache.spark.sql.AnalysisException] { + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = ',', quote = '\'') + .select("year") + .collect() + } + } + + test("Expect wrong parsing results with wrong quote setting using sparkContext.csvFile") { + val results = + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = '|', quote = '"') + .select("year") .collect() - .head - .getString(0) - assert(firstComment2 === "The ideal car for \"families\" and all their \"bags\", \"boxes\" and \"barbecues\"") + + assert(results.slice(0, numCars).toSeq.map(_(0).asInstanceOf[String]) == + Seq("'2012'", "1997", "2015")) } test("DDL test with alternative delimiter and quote") { @@ -102,7 +128,7 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsAltFile", header "true", quote "'", delimiter "|") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT year FROM carsTable").collect().size === 2) + assert(sql("SELECT year FROM carsTable").collect().size === numCars) } @@ -134,11 +160,12 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsFile", header "true") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT makeName FROM carsTable").collect().size === 2) - assert(sql("SELECT avg(yearMade) FROM carsTable group by grp").collect().head(0) === 2004.5) + assert(sql("SELECT makeName FROM carsTable").collect().size === numCars) + assert(sql("SELECT avg(yearMade) FROM carsTable where grp = '' group by grp") + .collect().head(0) === 2004.5) } - test("column names test") { + test("DSL column names test") { val cars = new CsvParser() .withUseHeader(false) .csvFile(TestSQLContext, carsFile) @@ -163,7 +190,7 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$tempEmptyDir", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT * FROM carsTableIO").collect().size === 3) + assert(sql("SELECT * FROM carsTableIO").collect().size === numCars + 1) assert(sql("SELECT * FROM carsTableEmpty").collect().isEmpty) sql( @@ -171,6 +198,82 @@ class CsvSuite extends FunSuite { |INSERT OVERWRITE TABLE carsTableEmpty |SELECT * FROM carsTableIO """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT * FROM carsTableEmpty").collect().size == 3) + assert(sql("SELECT * FROM carsTableEmpty").collect().size == numCars + 1) + } + + test("DSL save") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet== cars.collect.map(_.toString).toSet) + } + + test("DSL save with a compression codec") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true"), classOf[GzipCodec]) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with quoting") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with alternate quoting") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "!")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/", quote = '!') + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with quoting, escaped quote") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "escape-copy.csv" + + val escape = TestSQLContext.csvFile(escapeFile, escape='|', quote='"') + escape.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + + val escapeCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(escapeCopy.count == escape.count) + assert(escapeCopy.collect.map(_.toString).toSet == escape.collect.map(_.toString).toSet) + assert(escapeCopy.head().getString(0) == "\"thing") } } diff --git a/src/test/scala/com/databricks/spark/csv/TestUtils.scala b/src/test/scala/com/databricks/spark/csv/TestUtils.scala index ac78215..0c32f12 100644 --- a/src/test/scala/com/databricks/spark/csv/TestUtils.scala +++ b/src/test/scala/com/databricks/spark/csv/TestUtils.scala @@ -1,3 +1,18 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.databricks.spark.csv import java.io.{File, IOException} From 6987735c0818733bcfebaac24cb4e2e3b9de69af Mon Sep 17 00:00:00 2001 From: Abdelaziz Bendadani Date: Thu, 16 Apr 2015 17:48:45 +0200 Subject: [PATCH 7/7] added comment --- src/main/scala/com/databricks/spark/csv/package.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index d5cec13..58f503e 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -99,6 +99,7 @@ package object csv { } val generateHeader = parameters.getOrElse("header", "false").toBoolean + //Use format instead of mkString val header = if (generateHeader) { csvFormat.format(dataFrame.columns.map(_.asInstanceOf[AnyRef]):_*) } else {