Can Adding Partitions Improve The Performance of Your Spark Job On Skewed Data Sets?

After reading a number of on-line articles on how to handle ‘data skew’ in one’s Spark cluster, I ran some experiments on my own ‘single JVM’ cluster to try out one of the techniques mentioned. This post presents the results, but before we get to those, I could not restrain myself from some nitpicking (below) about the definition of ‘skew’. You can quite easily skip the next section if you just want to get to the Spark techniques.

A Statistical Aside

Statistics defines a symmetric distribution as one in which the mean, median, and mode are all equal, and a skewed distribution as one where these properties do not hold. Many online resources use a conflicting definition of data skew, for example this one, which talks about skew in terms of “some data slices [having] more rows of a table than others”. We can’t use the traditional statistics definition of skew if our concern is unequal distribution of data across the partitions of our Spark tasks.

Consider a degenerate case where you have allocated 100 partitions to process a batch of data, and all the keys in that batch are from the same customer. Then, if we are using a hash or range partitioner, all records would be processed in one partition, while the other 99 would be idle. But clearly in this case the mean (average), the mode (most common value), and the median (the value ‘in the middle’ of the distribution) would all be the same. So, our data is not ‘skewed’ in the traditional sense, but definitely unequally distributed amongst our partitions. Perhaps a better term to use instead of ‘skewed’ would be ‘non-uniform’, but everyone uses ‘skewed’. So, fine. I will stop losing sleep over this and go with the data-processing literature usage of the term.

Techniques for Handling Data Skew

More Partitions

Increasing the number of partitions data may result in data associated with a given key being hashed into more partitions. However, this will likely not help when one or relatively few keys are dominant in the data. The following sections will discuss this technique in more detail.

Bump up spark.sql.autoBroadcastJoinThreshold

Increasing the value of this setting will increase the likelihood that the Spark query engine chooses the BroadcastHashJoin strategy for joins in preference to the more data intensive SortMergeJoin. This involves transmitting the smaller to-be-joined table to each executor’s memory, then streaming the larger table and joining row-by-row. As the size of the smaller table increases, memory pressure will also increase, and the viability of this technique will decrease.

Iterative (Chunked) Broadcast Join

When your smaller table becomes prohibitively large it might be worth considering the approach of iteratively taking slices of your smaller (but not that small) table, broadcasting those, joining with the larger table, then unioning the result. Here is a talk that explains the details nicely.

Adding salt

Add ‘salt’ to the keys of your data set by mapping each key to a pair whose first element is the original key, and whose second element is a random integer in some range. For very frequently occurring keys the range would be larger than for keys which occur with average or lower frequency.

Say you had a table with data like the one below:

        customerId  itemOrdered Quantity 
            USGOV   a-1         10 // frequently occurring
            USGOV   a-2         44 // frequently occurring
            USGOV   a-5         553// frequently occurring
            small1  a-1         2
            small1  a-1         4
            small3  a-1         2
            USGOV   a-5         553// frequently occurring
            small2  a-5         1

And you needed to join to a table of discounts to figure final price, like this:

        customerId  discountPercent
            USGOV   .010
            small1  .001
            small2  .001
            small3  .002

You would add an additional salt column to both tables, then join on the customerId and the salt, with the modified input to the join appearing as shown below. Note that ‘USGOV’ records used to wind up in one partition, but now, with the salt added to the key they will likely end up in one of three partitions (‘salt range’ == 3.) The records associated with less frequently occurring keys will only get one salt value (‘salt range’ == 1), as we don’t need to ensure that they end up in different partitions.

            customerId  salt  itemOrdered Quantity 
                USGOV   1     a-1         10   
                USGOV   2     a-2         44   
                USGOV   3     a-5         553  
                small1  1     a-1         2    
                small1  1     a-1         4    
                small3  1     a-1         2     
                USGOV   3     a-5         553
                small2  1     a-5         1

To ensure the join works, the salt column needs to be added to the smaller table, and for each random salt value associated with higher frequency keys we need to add new records (note there are now three USGOV records.) This will add to the size of the smaller table, but often this will be out-weighed by the efficiency gained from not having a few partitions loaded up with a majority of the data to be processed.

        customerId  salt discountPercent 
            USGOV   1    .010
            USGOV   2    .010
            USGOV   3    .010
            small1  1    .001
            small2  1    .001 
            small3  1    .002 

Adding More Partitions: Unhelpful When One Key Dominates

Before we look at code, lets consider a minimal contrived example where we have a data set of twelve records that needs to be distributed across our cluster, which we will accomplish by ‘mod’ing the key by the number of partitions. First consider a non-skewed data set where no key dominates, and 3 partitions. We see partition 0 gets filled with 5 items. While partition 2 get filled with three. This skewing is a result of the fact that we have very few partitions.

3 partitions: 0, 1, 2
distribute to partition via: key % 3

Uniform Data Set 

    key     partition
*   0       0 
    1       1
    2       2
*   3       0
    4       1
    5       2
*   6       0
    7       1
    8       2
*   9       0
    10      1
*   12      0 

Now, lets look at two skewed data sets, one in which one key (0) dominates, and another where the skewedness is the fault of two keys (0 and 12.) We will again partition by mod’ing by the number of available partitions. In both cases, partition 0 gets flooded with 8 of 12 records. Other partitions get only 2 records.

Skewed Data Set  -- One Key (0) Dominates


    key     partition
*   0       0
*   0       0
*   0       0
    1       1
    2       2
*   3       0
    4       1
    5       2
*   6       0
*   0       0
*   0       0
*   0       0




Skewed Data Set  -- No Single Key Dominates (0 & 12 occur most often)

    key     partition
*   0       0
*   0       0
*   0       0
    1       1
    2       2
*   3       0
    4       1
    5       2
*   6       0
*   12      0
*   12      0
*   12      0

Now let’s see what happens when we increase the number of partitions to 11, and distribute records across partitions by mod’ing by the same number. In the case where one key (0) dominates, we find that partition 0 still gets 7 out of 12 records. But when the ‘skew’ is spread across not one, but two keys (0 and 12), we find that only 3 out of 12 records end up in partition zero. This shows that the more ‘dominance’ is concentrated around a small set of keys (or one key, as often happens with nulls), the less we will benefit by simply adding partitions.

Skewed Data Set  -- One Key (0) Dominates



    key     partition
*   0       0
*   0       0
*   0       0
    1       1
    2       2
    3       3
    4       4
    5       5
*   6       6
*   0       0
*   0       0
*   0       0




Skewed Data Set  -- No Single Key Dominates (0/12 are most likely)

    key     partition
*   0       0
*   0       0
*   0       0
    1       1
    2       2
    3       3
    4       4
    5       5
    6       6
    12      1    
    12      1
    12      1

Adding More Partitions: A Simple Test Program

The main processing component of the test program I developed to explore what happens with skewed data does a mapPartitionsWithIndex over each partition of a Pair RDD with each pair consisting of the key, followed by the value. The keys are generated by the KeyGenerator object which will always generate 100 keys, either as a uniform distribution from 1 to 100, or as two flavors of skewed distribution, both of which have 55 random keys. The ‘oneKeyDominant’ distribution augments the 55 random keys with 55 0’s, while the ‘not oneKeyDominant’ distribution uses 3 high frequency keys: 0, 2, and 4, occurring 18, 18, and 19 times, respectively.

At the beginning of the mapPartitionsWithIndex we start a timer so we can see how much time it takes to completely process each partition. As we iterate over each key we call ‘process’ which emulates some complex processing by sleeping for 50 milliseconds.

    def process(key: (Int, Int), index: Int): Unit = {
      println(s"processing $key in partition '$index'")
      Thread.sleep(50)     // Simulate processing w/ delay
    }

    keysRdd.mapPartitionsWithIndex{
      (index, keyzIter) =>
        val start = System.nanoTime()
        keyzIter.foreach {
          key =>
            process(key, index)
        }
        val end = System.nanoTime()
        println(
          s"processing of keys in partition '$index' took " +
            s" ${(end-start) / (1000 * 1000)} milliseconds")
        keyzIter
    }
      .count
  }

The full code is presented in full below. As long as you use Spark 2.2+ you should be able to run this code by copy/pasting into any existing Spark project you might have. To reproduce the results we report in the next section, you need to manually set these variables:   numPartitions , useSkewed,  oneKeyDominant before you launch the application.

import org.apache.spark.{HashPartitioner, SparkConf}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._

import scala.collection.immutable
import scala.util.Random


object KeyGenerator {
  val random = new Random()

  def getKeys(useSkewed: Boolean, 
              oneKeyDominant: Boolean)  : immutable.Seq[Int] = {

    def genKeys(howMany: Int,
                lowerInclusive: Int,
                upperExclusive: Int)  = {
      (1 to howMany).map{ i =>
        lowerInclusive + 
            random.nextInt(upperExclusive - lowerInclusive)
      }
    }

    val keys  =
      if (useSkewed) {
        val skewedKeys =
          if (oneKeyDominant)
            Seq.fill(55)(0)
        else
            Seq.fill(18)(0) ++ Seq.fill(18)(2) ++ Seq.fill(19)(4)

        genKeys(45, 1, 45) ++ skewedKeys
      }
      else {
        genKeys(100, 1, 100)
      }

    System.out.println("keys:" + keys);
    System.out.println("keys size:" + keys.size);

    keys
  }
}

object SaltedToPerfection extends App { 

  import KeyGenerator._
  def runApp(numPartitions: Int, 
             useSkewed: Boolean, 
             oneKeyDominant: Boolean) = {

    val keys: immutable.Seq[Int] = getKeys(useSkewed, oneKeyDominant)
    val keysRdd: RDD[(Int, Int)] =
      sparkSession.sparkContext.
        parallelize(keys).map(key => (key,key)). // to pair RDD
        partitionBy(new HashPartitioner(numPartitions))


    System.out.println("keyz.partitioner:" + keysRdd.partitioner)
    System.out.println("keyz.size:" + keysRdd.partitions.length)

    def process(key: (Int, Int), index: Int): Unit = {
      println(s"processing $key in partition '$index'")
      Thread.sleep(50)     // Simulate processing w/ delay
    }

    keysRdd.mapPartitionsWithIndex{
      (index, keyzIter) =>
        val start = System.nanoTime()
        keyzIter.foreach {
          key =>
            process(key, index)
        }
        val end = System.nanoTime()
        println(
          s"processing of keys in partition '$index' took " +
            s" ${(end-start) / (1000 * 1000)} milliseconds")
        keyzIter
    }
      .count
  }


  lazy val sparkConf = new SparkConf()
    .setAppName("Learn Spark")
    .setMaster("local[4]")

  lazy val sparkSession = SparkSession
    .builder()
    .config(sparkConf)
    .getOrCreate()


  val numPartitions = 50
  val useSkewed = true
  val oneKeyDominant = true

  runApp(numPartitions, useSkewed, oneKeyDominant)
  Thread.sleep(1000 * 600)    // 10 minutes sleep to explore with UI
}

Adding More Partitions: Test Results

The results obtained from running our test program accorded with the informal analysis we performed above on various cardinality=12 data sets, namely, that increasing the number of partitions is more helpful when more than one key dominates the distribution. When one key dominates, increasing partitions improved performance by 16% (see difference between runs 3 and 5), whereas when multiple keys dominate the distribution we saw an improvement of 29% (see difference between runs 2 and 4.)

Run     Partitions      Skew                        Job Duration

1       4               none                        2.057556 s
2       4               multiple dominant keys      3.125907 s
3       4               one dominant key            4.045455 s
4       50              multiple dominant keys      2.217383 s
5       50              one dominant key            3.378734 s



Performance improvements obtained by increasing partitions (4->50)

    one dominant key    
        Elapsed time difference between run 3 and 5
        (4.045455 - 3.378734) / 4.045455  = 16%

    multiple dominant keys
        Elapsed time difference between run 2 and 4
        (3.125907 - 2.217383) / 3.125907  = 29%

Reducing Integration Hassles With JSON Schema Contracts

I recently worked on a project where the ‘contract’ between service consumers and providers consisted primarily of annotated mock-ups of the JSON responses one would obtain from each of a given service’s end-points. A much better way of expressing the contract for a service is to use a standard schema format. If your stuck with XML, use XML schema. If you are using JSON then there are tools and libraries (presented below ) which will help you use JSON schema to express a service’s contract. This article will assume that you have gone through the available JSON schema documentation and have a basic ideas of how to use it. It assumes that you are developing on a JVM-based platform, and most of the recipes will be helpful for Java developers (although our example of dynamic schema validation is presented using a bit of Scala.)

Why Use JSON Schema As Your Contract ?

Suppose you are supporting a JSON-based service, with your contract expressed in some type of “by-example” format rather than the JSON schema standard. Now one of the components consuming your service throws an exception while parsing a response. The developer of said client service comes to you and says “your service has a problem”. Well, both of you then have to pore over the examples that define your service’s responses and figure out if the response sent in this instance honors or violates the implicit contract. This is a very manual process with room for mistakes, and at the worst, can lead to finger pointing and debates about whether the response is correct. Not fun.

However, if the server and client teams on your project come to agreement on a schema for each JSON response, then the task of figuring out if a given response is correct boils down to simply running a validation tool where the inputs are the response document in question, and the schema to which it must conform. If the validator reports no errors then you are off the hook, with no debate.

Json Schema Tools

This section describes how to install and use various tools for auto-generation of JSON schema from sample documents, generation of sample instance documents from schema, and schema validation. As long as your environment is configured with Java 1.8, Python 2.7+, and the pip installer, then the provided set-up instructions should work on either Linux or Mac (at least they worked for me!)

Auto-generating JSON Schema From Instance Documents

genson is a utility for auto-generating JSON schema from instance documents. It can be installed via the command

    sudo pip install genson==0.1.0   # install it

Next try generating a schema for a simple document.

    echo '{ "foo": 100 }'  > /tmp/foo.json
    cat /tmp/foo.json | genson | tee /tmp/foo.schema 

foo.schema should contain the following content:

    {
      "$schema": "http://json-schema.org/schema#",
      "required": [
        "foo"
      ],
      "type": "object",
      "properties": {
        "foo": {
          "type": "integer"
        }
      }
    }

Sometimes you will be generating multiple schemas from a related set of JSON documents (e.g., you might be starting from a set of sample responses from a legacy service with no defined schema, which you plan to retrofit .) In this case you will definitely want to familiarize yourself with the $ref keyword which lets you refactor commonly occurring fragments of schema code into one place (even a different file.)

generation of sample instance documents from schema

Once you have a schema you can feed it into a tool, such as this one from Liquid Technologies, to facilitate generation of mock data that you can use for testing.

Command LINE TOOLS FOR Schema validation

The best command line tool I have found for JSON schema validation is json-schema-validator. Its current documentation indicates support for JSON Schema draft v4 which is a bit behind the latest draft (7, at the time of this writing.) So, if you need the latest spec-supported features in your schemas, you should take extra care to ensure this tool is right for your needs.

Assuming you have gone through the previous step of installing and testing genson, you can download and verify the validator via the commands below (if you are on a Mac without wget, then please try curl):

wget 'https://bintray.com/fge/maven/download_file?file_path=com%2Fgithub%2Ffge%2Fjson-schema-validator%2F2.2.6%2Fjson-schema-validator-2.2.6-lib.jar' -O /tmp/validator.jar

# now validate your sample document against the schema you created above

cd /tmp ;  java -jar validator.jar /tmp/foo.schema /tmp/foo.json

You should see:

validation: SUCCESS

Now let’s see how the tool reports validation failures. Deliberately mess up your instance document (so it no longer conforms to the schema) via the command:

cat /tmp/foo.json |  sed -e's/foo/zoo/' > /tmp/bad.json

cd /tmp ; java -jar validator.jar /tmp/foo.schema /tmp/bad.json

You should see error output which includes the line:

"message" : "object has missing required properties ([\"foo\"])",

On THE FLY SCHEMA VALIDATION At RUN-TIME

When previously discussed, the json-schema-validator was shown in command line mode. As a bonus you can also embed this this project’s associated Java library into any of your services that require run-time validation of arbitrary instance documents against a schema. The code snippet below (available as a project here)  is written in Scala, but you could easily use this in Java projects as well.


import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.JsonNode
import com.github.fge.jackson.JsonLoader
import com.github.fge.jsonschema.main.{JsonSchema, JsonSchemaFactory}
import com.fasterxml.jackson.databind._
import com.github.fge.jsonschema.core.report.ProcessingReport

object SchemaValidator {
lazy val mapper: ObjectMapper = new ObjectMapper
lazy val jsonSchemaFactory: JsonSchemaFactory = JsonSchemaFactory.byDefault
lazy val schemaNode: JsonNode = JsonLoader.fromResource("/schema.json")
lazy val schema: JsonSchema = jsonSchemaFactory.getJsonSchema(schemaNode)

def validateWithReport(json: String): Boolean = {
val bytes: Array[Byte] = json.getBytes("utf-8")
val parser: JsonParser = mapper.getFactory.createParser(bytes)
val node: JsonNode = mapper. readTree( parser)
val validationResult: ProcessingReport = schema.validate(node)
if (validationResult.isSuccess) {
true
} else {
val errMsg =
  s"Validation error. Instance=$json, msg=$validationResult"
System.out.println("errMsg:" + errMsg)
false
}
}
}

object FakeGoodWebService {
def getJsonResponse = """{ "foo": 100 }"""
}

object FakeBadWebService {
def getJsonResponse = """{ "zoo": 100 }"""
}
import com.fasterxml.jackson.core.JsonParser
import com.fasterxml.jackson.databind.JsonNode
import com.github.fge.jackson.JsonLoader
import com.github.fge.jsonschema.main.{JsonSchema, JsonSchemaFactory}
import com.fasterxml.jackson.databind._
import com.github.fge.jsonschema.core.report.ProcessingReport

object SchemaValidator {
lazy val mapper: ObjectMapper = new ObjectMapper
lazy val jsonSchemaFactory: JsonSchemaFactory = JsonSchemaFactory.byDefault
lazy val schemaNode: JsonNode = JsonLoader.fromResource("/schema.json")
lazy val schema: JsonSchema = jsonSchemaFactory.getJsonSchema(schemaNode)

def validateWithReport(json: String): Boolean = {
val bytes: Array[Byte] = json.getBytes("utf-8")
val parser: JsonParser = mapper.getFactory.createParser(bytes)
val node: JsonNode = mapper. readTree( parser)
val validationResult: ProcessingReport = schema.validate(node)
if (validationResult.isSuccess) {
true
} else {
val errMsg = s"Validation error. Instance=$json, msg=$validationResult"
System.out.println("errMsg:" + errMsg)
false
}
}
}

object FakeGoodWebService {
def getJsonResponse = """{ "foo": 100 }"""
}

object FakeBadWebService {
def getJsonResponse = """{ "zoo": 100 }"""
}


object JsonSchemaValidationDemo extends App {
import SchemaValidator._

val goodResult =
validateWithReport(
FakeGoodWebService.getJsonResponse)
System.out.println("result:" + goodResult);

val badResult =
validateWithReport(
FakeBadWebService.getJsonResponse)
System.out.println("result:" + badResult);
}





object JsonSchemaValidationDemo extends App {
import SchemaValidator._

val goodResult = validateWithReport(FakeGoodWebService.getJsonResponse)
System.out.println("result:" + goodResult);

val badResult = validateWithReport(FakeBadWebService.getJsonResponse)
System.out.println("result:" + badResult);
}



We have stashed the ‘foo’ schema from our previous discussion into src/main/resources and the object constructor for SchemaValidator loads that schema into the ‘schema’ variable. We then call validateWithReport from JsonSchemaValidationDemo first with a valid response from a mock of a nicely behaving web service, then we feed validateWithReport a JSON response from a misbehaving web service. The resultant output is shown below.

result:true
errMsg:Validation error. Instance={ "zoo": 100 }, 
    msg=com.github.fge.jsonschema.core.report.ListProcessingReport: failure
--- BEGIN MESSAGES ---
error: object has missing required properties (["foo"])
    level: "error"
    schema: {"loadingURI":"#","pointer":""}
    instance: {"pointer":""}
    domain: "validation"
    keyword: "required"
    required: ["foo"]
    missing: ["foo"]
---  END MESSAGES  ---

result:false

Conclusion

Miscommunication and incorrect assumptions are most likely at what formally trained project managers call “interface points at subsystem boundaries” (you can read up more here.) But now you have some tools for minimizing the thrash and churn that can occur around these interface points.

License

This work is licensed under the Creative Commons Attribution 4.0 International License. Use as you wish, but if you can, please give attribution to the Data Lackey Labs Blog.