下面代码按照之前参加Kaggle的python代码改写,只完成了模型的训练过程,还需要对test集的数据进行转换和对test集进行预测。
scala 2.11.12 spark 2.2.2package ML.Titanicimport org.apache.spark.SparkContextimport org.apache.spark.sql._import org.apache.spark.sql.functions._import org.apache.spark.ml.feature.Bucketizerimport org.apache.spark.ml.feature.QuantileDiscretizerimport org.apache.spark.ml.feature.StringIndexerimport org.apache.spark.ml.feature.OneHotEncoderimport org.apache.spark.ml.Pipelineimport org.apache.spark.ml.feature.VectorAssemblerimport org.apache.spark.ml.classification.GBTClassifierimport org.apache.spark.ml.tuning.ParamGridBuilderimport org.apache.spark.ml.evaluation.MulticlassClassificationEvaluatorimport org.apache.spark.ml.tuning.{TrainValidationSplit, TrainValidationSplitModel}import org.apache.spark.ml.PipelineModelimport org.apache.spark.sql.types._/** * GBTClassifier for predicting survival in the Titanic ship */object TitanicChallenge { def main(args: Array[String]) { val spark = SparkSession.builder. master("local[*]") .appName("example") .config("spark.sql.shuffle.partitions", 20) .config("spark.default.parallelism", 20) .config("spark.driver.memory", "4G") .config("spark.memory.fraction", 0.75) .getOrCreate() val sc = spark.sparkContext spark.sparkContext.setLogLevel("ERROR") val schemaArray = StructType(Array( StructField("PassengerId", IntegerType, true), StructField("Survived", IntegerType, true), StructField("Pclass", IntegerType, true), StructField("Name", StringType, true), StructField("Sex", StringType, true), StructField("Age", FloatType, true), StructField("SibSp", IntegerType, true), StructField("Parch", IntegerType, true), StructField("Ticket", StringType, true), StructField("Fare", FloatType, true), StructField("Cabin", StringType, true), StructField("Embarked", StringType, true) )) val path = "Titanic/" val df = spark.read .option("header", "true") .schema(schemaArray) .csv(path + "train.csv") .drop("PassengerId")// df.cache() val utils = new TitanicChallenge(spark) val df2 = utils.transCabin(df) val df3 = utils.transTicket(sc, df2) val df4 = utils.transEmbarked(df3) val df5 = utils.extractTitle(sc, df4) val df6 = utils.transAge(sc, df5) val df7 = utils.categorizeAge(df6) val df8 = utils.createFellow(df7) val df9 = utils.categorizeFellow(df8) val df10 = utils.extractFName(df9) val df11 = utils.transFare(df10) val prePipelineDF = df11.select("Survived", "Pclass", "Sex", "Age_categorized", "fellow_type", "Fare_categorized", "Embarked", "Cabin", "Ticket", "Title", "family_type")// prePipelineDF.show(1)// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+// |Survived|Pclass| Sex|Age_categorized|fellow_type|Fare_categorized|Embarked|Cabin|Ticket|Title|family_type|// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+// | 0| 3|male| 3.0| Small| 0.0| S| U| 0| Mr| 0|// +--------+------+----+---------------+-----------+----------------+--------+-----+------+-----+-----------+ val (df_indexed, colsTrain) = utils.index_onehot(prePipelineDF) df_indexed.cache() //训练模型 val validatorModel = utils.trainData(df_indexed, colsTrain) //打印最优模型的参数 val bestModel = validatorModel.bestModel println(bestModel.asInstanceOf[PipelineModel].stages.last.extractParamMap) //打印各模型的成绩和参数 val paramsAndMetrics = validatorModel.validationMetrics .zip(validatorModel.getEstimatorParamMaps) .sortBy(-_._1) paramsAndMetrics.foreach { case (metric, params) => println(metric) println(params) println() } validatorModel.write.overwrite().save(path + "Titanic_gbtc") spark.stop() }}class TitanicChallenge(private val spark: SparkSession) extends Serializable { import spark.implicits._ //Cabin,用“U”填充null,并提取Cabin的首字母 def transCabin(df: Dataset[Row]): Dataset[Row] = { df.na.fill("U", Seq("Cabin")) .withColumn("Cabin", substring($"Cabin", 0, 1)) } // def transTicket(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = { 提取船票的号码,如“A/5 21171”中的21171 val medDF1 = df.withColumn("Ticket", split($"Ticket", " ")) .withColumn("Ticket", $"Ticket"(size($"Ticket").minus(1))) .filter($"Ticket" =!= "LINE")//去掉某种特殊的船票 //对船票号进行分类,小于四位号码的为“1”,四位号码的以第一个数字开头,后面接上“0”,大于4位号码的,取前三个数字开头。如21171变为211 val ticketTransUdf = udf((ticket: String) => { if (ticket.length < 4) { "1" } else if (ticket.length == 4){ ticket(0)+"0" } else { ticket.slice(0, 3) } }) val medDF2 = medDF1.withColumn("Ticket", ticketTransUdf($"Ticket")) //将数量小于等于5的类别统一归为“0”。先统计小于5的名单,然后用udf进行转换。 val filterList = medDF2.groupBy($"Ticket").count() .filter($"count" <= 5) .map(row => row.getString(0)) .collect.toList val filterList_bc = sc.broadcast(filterList) val ticketTransAdjustUdf = udf((subticket: String) => { if (filterList_bc.value.contains(subticket)) "0" else subticket }) medDF2.withColumn("Ticket", ticketTransAdjustUdf($"Ticket")) } //用“S”填充null def transEmbarked(df: Dataset[Row]): Dataset[Row] = { df.na.fill("S", Seq("Embarked")) } def extractTitle(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = { val regex = ".*, (.*?)\\..*" //对头衔进行归类 val titlesMap = Map( "Capt"-> "Officer", "Col"-> "Officer", "Major"-> "Officer", "Jonkheer"-> "Royalty", "Don"-> "Royalty", "Sir" -> "Royalty", "Dr"-> "Officer", "Rev"-> "Officer", "the Countess"->"Royalty", "Mme"-> "Mrs", "Mlle"-> "Miss", "Ms"-> "Mrs", "Mr" -> "Mr", "Mrs" -> "Mrs", "Miss" -> "Miss", "Master" -> "Master", "Lady" -> "Royalty" ) val titlesMap_bc = sc.broadcast(titlesMap) df.withColumn("Title", regexp_extract(($"Name"), regex, 1)) .na.replace("Title", titlesMap_bc.value) } //根据null age的records对应的Pclass和Name_final分组后的平均来填充缺失age。 // 首先,生成分组key,并获取分组后的平均年龄map。然后广播map,当Age为null时,用udf返回需要填充的值。 def transAge(sc: SparkContext, df: Dataset[Row]): Dataset[Row] = { val medDF = df.withColumn("Pclass_Title_key", concat($"Title", $"Pclass")) val meanAgeMap = medDF.groupBy("Pclass_Title_key") .mean("Age") .map(row => (row.getString(0), row.getDouble(1))) .collect().toMap val meanAgeMap_bc = sc.broadcast(meanAgeMap) val fillAgeUdf = udf((comb_key: String) => meanAgeMap_bc.value.getOrElse(comb_key, 0.0)) medDF.withColumn("Age", when($"Age".isNull, fillAgeUdf($"Pclass_Title_key")).otherwise($"Age")) } //对Age进行分类 def categorizeAge(df: Dataset[Row]): Dataset[Row] = { val ageBucketBorders = 0.0 +: (10.0 to 60.0 by 5.0).toArray :+ 150.0 val ageBucketer = new Bucketizer().setSplits(ageBucketBorders).setInputCol("Age").setOutputCol("Age_categorized") ageBucketer.transform(df).drop("Pclass_Title_key") } //将SibSp和Parch相加,得出同行人数 def createFellow(df: Dataset[Row]): Dataset[Row] = { df.withColumn("fellow", $"SibSp" + $"Parch") } //fellow_type, 对fellow进行分类。此处其实可以留到pipeline部分一次性完成。 def categorizeFellow(df: Dataset[Row]): Dataset[Row] = { df.withColumn("fellow_type", when($"fellow" === 0, "Alone") .when($"fellow" <= 3, "Small") .otherwise("Large")) } def extractFName(df: Dataset[Row]): Dataset[Row] = { //检查df是否有Survived和fellow列 if (!df.columns.contains("Survived") || !df.columns.contains("fellow")){ throw new IllegalArgumentException( """ |Check if the argument is a training set or if this training set contains column named \"fellow\" """.stripMargin) } //FName,提取家庭名称。例如:"Johnston, Miss. Catherine Helen ""Carrie""" 提取出Johnston // 由于spark的读取csv时,如果有引号,读取就会出现多余的引号,所以除了split逗号,还要再split一次引号。 val medDF = df .withColumn("FArray", split($"Name", ",")) .withColumn("FName", expr("FArray[0]")) .withColumn("FArray", split($"FName", "\"")) .withColumn("FName", $"FArray"(size($"FArray").minus(1))) //family_type,分为三类,第一类是60岁以下女性遇难的家庭,第二类是18岁以上男性存活的家庭,第三类其他。 val femaleDiedFamily_filter = $"Sex" === "female" and $"Age" < 60 and $"Survived" === 0 and $"fellow" > 0 val maleSurvivedFamily_filter = $"Sex" === "male" and $"Age" >= 18 and $"Survived" === 1 and $"fellow" > 1 val resDF = medDF.withColumn("family_type", when(femaleDiedFamily_filter, 1) .when(maleSurvivedFamily_filter, 2).otherwise(0)) //familyTable,家庭分类名单,用于后续test集的转化。此处用${FName}_${family_type}的形式保存。 resDF.filter($"family_type".isin(1,2)) .select(concat($"FName", lit("_"), $"family_type")) .dropDuplicates() .write.format("text").mode("overwrite").save("familyTable") //如果需要直接收集成Map的话,可用下面代码。 // 此代码先利用mapPartitions对各分块的数据进行聚合,降低直接调用count而使driver挂掉的风险。 //另外新建一个默认Set是为了防止某个partition并没有数据的情况(出现概率可能比较少), // 从而使得Set的类型变为Set[_>:Tuple]而不能直接flatten // val familyMap = df10 // .filter($"family_type" === 1 || $"family_type" === 2) // .select("FName", "family_type") // .rdd // .mapPartitions{iter => { // if (!iter.isEmpty) { // Iterator(iter.map(row => (row.getString(0), row.getInt(1))).toSet)} // else Iterator(Set(("defualt", 9)))} // } // .collect() // .flatten // .toMap resDF } //Fare。首先去掉缺失的(test集合中有一个,如果量多的话,也可以像Age那样通过头衔,年龄等因数来推断) //然后对Fare进行分类 def transFare(df: Dataset[Row]): Dataset[Row] = { val medDF = df.na.drop("any", Seq("Fare")) val fareBucketer = new QuantileDiscretizer() .setInputCol("Fare") .setOutputCol("Fare_categorized") .setNumBuckets(4) fareBucketer.fit(medDF).transform(medDF) } def index_onehot(df: Dataset[Row]): Tuple2[Dataset[Row], Array[String]] = { val stringCols = Array("Sex","fellow_type", "Embarked", "Cabin", "Ticket", "Title") val subOneHotCols = stringCols.map(cname => s"${cname}_index") val index_transformers: Array[org.apache.spark.ml.PipelineStage] = stringCols.map( cname => new StringIndexer() .setInputCol(cname) .setOutputCol(s"${cname}_index") .setHandleInvalid("skip") ) val oneHotCols = subOneHotCols ++ Array("Pclass", "Age_categorized", "Fare_categorized", "family_type") val vectorCols = oneHotCols.map(cname => s"${cname}_encoded") val encode_transformers: Array[org.apache.spark.ml.PipelineStage] = oneHotCols.map( cname => new OneHotEncoder() .setInputCol(cname) .setOutputCol(s"${cname}_encoded") ) val pipelineStage = index_transformers ++ encode_transformers val index_onehot_pipeline = new Pipeline().setStages(pipelineStage) val index_onehot_pipelineModel = index_onehot_pipeline.fit(df) val resDF = index_onehot_pipelineModel.transform(df).drop(stringCols:_*).drop(subOneHotCols:_*) println(resDF.columns.size) (resDF, vectorCols) } def trainData(df: Dataset[Row], vectorCols: Array[String]): TrainValidationSplitModel = { //separate and model pipeline,包含划分label和features,机器学习模型的pipeline val vectorAssembler = new VectorAssembler() .setInputCols(vectorCols) .setOutputCol("features") val gbtc = new GBTClassifier() .setLabelCol("Survived") .setFeaturesCol("features") .setPredictionCol("prediction") val pipeline = new Pipeline().setStages(Array(vectorAssembler, gbtc)) val paramGrid = new ParamGridBuilder() .addGrid(gbtc.stepSize, Seq(0.1)) .addGrid(gbtc.maxDepth, Seq(5)) .addGrid(gbtc.maxIter, Seq(20)) .build() val multiclassEval = new MulticlassClassificationEvaluator() .setLabelCol("Survived") .setPredictionCol("prediction") .setMetricName("accuracy") val tvs = new TrainValidationSplit() .setTrainRatio(0.75) .setEstimatorParamMaps(paramGrid) .setEstimator(pipeline) .setEvaluator(multiclassEval) tvs.fit(df) }}