Spark machine learning - Let us learn machine learning using Yelp dataset


Demo is given in our state of the art big data cluster

This weekend I am planning to explore machine learning on Yelp dataset. You are welcome to be part of this. I will update this topic as we make progress.

We will try to use Spark machine learning using Scala as programming language for this effort.


If you have lab access data sets are available under /data/yelp-dataset
You can also download the data set from kaggle to your own environment -

  • Sign up to using browser
  • Download using browser
  • Upload data to the environment on which you are going to use the data

You should see following csv files once you download and unarchive data sets.


Problem Statement

We need to use yelp_review.csv for this exercise.

Official Spark Documentation
We will be using official spark machine learning documentation.

Typical ML Cycle

  • Understand data
  • Create Data Frame
  • Build training model using sample data
  • Apply the model on the actual data
  • Validate for accuracy


Few findings about data in yelp_review.csv

  • There are 9 fields for this data set
    • review_id
    • user_id
    • business_id
    • stars
    • date
    • text
    • useful
    • funny
    • cool
  • text is the field which contains review text and stars contain rating
  • Each record is delimited by carriage return (\r)
  • Each field is delimited by , and each field is enclosed in double quotes
  • Review is written in free flowing text and many reviews have new line characters
  • Some reviews even have carriage returns in them (so there are some data quality issues)
  • For now we will ignore all those records which does not have 9 fields in them with , as delimiter considering enclosing character "

Here is how data validations are done.

  • awk script to count number of fields in each record
  • record separator is \r
  • field separator is ā€œ,ā€
  • Copy script to some awk file (yelp_review.awk)
  print NF; 
  • Run script using awk - awk -f yelp_review.awk yelp_review.csv|sort|uniq -c
  • Above script will tell how many fields each record have
     12 0
   2440 1
     20 10
      4 11
      1 12
    534 4
    534 6
5261110 9


Spark code to read yelp_review.csv into RDD

  • We cannot use sc.textFile as line delimiter is carriage return (\r). sc.textFile assumes line or record delimiter to be new line character (\n)
  • We can use Hadoop input format TextFileInputFormat with custom delimiter

We need to create Configuration object to set property for custom delimiter

val conf = new Configuration
conf.set("textinputformat.record.delimiter", "\r")
  • Then we need to use newAPIHadoopFile along with path of our data, input file format, key type and value type from HDFS APIs along with configuration object

Using newAPIHadoopFile to read data

val yelpReviews = sc.newAPIHadoopFile(yelpReviewsPath,
  • This will create RDD of type key value pairs where key is of type LongWritable and value is of type Text

Here is the complete code snippet

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat

val conf = new Configuration
conf.set("textinputformat.record.delimiter", "\r")

val yelpReviewsPath = "/user/training/yelp-dataset/yelp_review.csv"
val yelpReviews = sc.newAPIHadoopFile(yelpReviewsPath,


Let us preview first few records

  map(t => t._2.toString).
  • Here is the code snippet to get count of records with different field length. Output should be same as our finding earlier. This will validate that we are parsing the data in right manner.
  • We have used ā€œ,ā€ (all three characters together) as field delimiter
  • Alternative way is to use regular expression - s.split(",(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", -1).size. But regular expression is too slow.

Validation Code => (s.toString.split("\",\"").size, 1)).
  reduceByKey(_ + _).

Now It is time to create data frame out of this RDD with all valid records with exactly 9 fields in it.


Creating data frame for yelp data

While creating data frame

  • we need to filter out header
  • we need to filter out all the records which does not have exactly 9 fields
  • also there are 8 records where text ends with ā€œ,ā€ due to which we are getting incorrect values for useful. That need to be handled.

Filtering out invalid records and header

val yelpReviewsFiltered = yelpReviews.
  filter(r => {
    r._2.toString.split("\",\"").size == 9 && 
    r._1.toString.toLong != 0

Creating data frame

val yelpReviewsDF = yelpReviewsFiltered.
  map(r => {
    val yr = r._2.toString.split("\",\"")
    (yr(0).substring(1), yr(1), yr(2), 
     yr(3), yr(4), yr(5),
     yr(6).replace(",\"", "").toInt, yr(7).toInt, 
     yr(8).substring(0, yr(8).length - 1).toInt)
  }).toDF("review_id", "user_id", "business_id",
          "stars", "date", "text",
          "useful", "funny", "cool")

Now we can apply data frame operations on yelpReviewsDF to understand data further.




Now it is time to understand TF-IDF


First cut - still validating. I will provide explanation later.

val yelpReviewsSample = yelpReviewsDF.sample(false, 0.005)

val yelpReviewsTraining = yelpReviewsSample.
  withColumn("label", expr("case when stars > 3 then 1 else 0 end")).
  select("review_id", "text", "label")

import{Pipeline, PipelineModel}
import{HashingTF, Tokenizer}
import org.apache.spark.sql.Row

val tokenizer = new Tokenizer().
val hashingTF = new HashingTF().
val lr = new LogisticRegression().
val pipeline = new Pipeline().
  setStages(Array(tokenizer, hashingTF, lr))

val model =

val validate = model.transform(yelpReviewsDF).
  select("review_id", "text", "stars", "probability", "prediction").
  withColumn("actual", expr("case when stars > 3 then 1.0 else 0.0 end"))

validate.where("prediction != actual").count

With 0.005 sample - Right predictions are 4296256 out of 5261109


Thanks for sharing this Durga , just curious to know if you have started to figure out the statistics part behind these algorithms.

Is there any material that you would recommend for the same?

Thanks in Advance


Durga, why use scala? Why not use python as it has better support for machine learning? Iā€™m also curious as someone who aspires to be a data engineer, which language should I focus on? I already know Python but many of your courses focus on Scala, please help thanks!


Should it stem from the fact that python is interpreted and it is bound to be slow. For applications to be processing Data faster, Scala looks better alternative than Python. And it is not like you cannot implement Machine learning algorithm using Scala.