自定义UDAF,需要extends org.apache.spark.sql.expressions.UserDefinedAggregateFunction,并实现接口中的8个方法。
udaf写起来比较麻烦,我下面列一个之前写的取众数聚合函数,在我们通常在聚合统计的时候可能会受某条脏数据的影响。
举个栗子:
对于一个app日志聚合的时候,有id与ip,原则上一个id有一个ip,但是在多条数据里有一条ip是错误的或者为空的,这时候group能会聚合成两条数据了就,如果使用max,min对ip也进行聚合,那也不太合理,这时候可以进行投票,去类似多数对结果,从而聚合后只有一个设备。
废话少说,上代码:
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ /** * Description: 自定义聚合函数:众数(取列内频率最高的一条) */ class UDAFGetMode extends UserDefinedAggregateFunction { override def inputSchema: StructType = { StructType(StructField("inputStr", StringType, true) :: Nil) } override def bufferSchema: StructType = { StructType(StructField("bufferMap", MapType(keyType = StringType, valueType = IntegerType), true) :: Nil) } override def dataType: DataType = StringType override def deterministic: Boolean = false //初始化map override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = scala.collection.immutable.Map[String, Int]() } //如果包含这个key则value+1,否则写入key,value=1 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { val key = input.getAs[String](0) val immap = buffer.getAs[Map[String, Int]](0) val bufferMap = scala.collection.mutable.Map[String, Int](immap.toSeq: _*) val ret = if (bufferMap.contains(key)) { // val new_value = bufferMap.get(key).get + 1 val new_value = bufferMap(key) + 1 bufferMap.put(key, new_value) bufferMap } else { bufferMap.put(key, 1) bufferMap } buffer.update(0, ret) } override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //合并两个map 相同的key的value累加 val tempMap = (buffer1.getAs[Map[String, Int]](0) /: buffer2.getAs[Map[String, Int]](0)) { case (map, (k, v)) => map + (k -> (v + map.getOrElse(k, 0))) } buffer1.update(0, tempMap) } override def evaluate(buffer: Row): Any = { //返回值最大的key var max_value = 0 var max_key = "" buffer.getAs[Map[String, Int]](0).foreach({ x => val key = x._1 val value = x._2 if (value > max_value) { max_value = value max_key = key } }) max_key } }
测试类:
object UDAFTest { def main(args: Array[String]): Unit = { val spark = SparkSession.builder().master("local").appName(this.getClass.getSimpleName).getOrCreate() spark.udf.register("get_mode", new UDAFGetMode) import spark.implicits._ val df = Seq( (1, "10.10.1.1", "start"), (1, "10.10.1.1", "search"), (2, "123.123.123.1", "search"), (1, "10.10.1.0", "stop"), (2, "123.123.123.1", "start") ).toDF("id", "ip", "action") df.createTempView("tb") spark.sql(s"select id,get_mode(ip) as u_ip,count(*) as cnt from tb group by id").show() } }