spark sql自定义函数

1.继承UserDefinedAggregateFunction类,多输入一输出。

package sparkRdd_practice
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction

/**
  * @Description * @Author 黄仁议<613024710@qq.com>
  * @Version V1.0
  * @Since 1.0
  * @Date 2019/6/4 0004 17:39
  * @Description * @ClassName UDAF1
  */
//自定义弱类型,使用dataframe
//多个输入,对应一个输出,实现计算年龄均值
class UDAF1 extends UserDefinedAggregateFunction{
  //定义输入参数的数据的schema
  override def inputSchema: StructType = {
    StructType(List(StructField("age",IntegerType,true)))
  }
  //分区中聚合的时候产生的中间结果的schema
  //(age1+age1,1+1)
  override def bufferSchema: StructType = {
    StructType(StructField("sum",IntegerType)::StructField("count",IntegerType)::Nil)
    //    StructType(List(StructField("sum",IntegerType),StructField("count",IntegerType)))
    //    new StructType().add("sum",IntegerType).add("count",IntegerType)
  }

  //定义最后聚合返回结果的数据类型
  override def dataType: DataType = DoubleType

  //多个确定的输入,输出结果是否是确定的
  override def deterministic: Boolean = true
  //初始化中间结果对象
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //初始化年龄聚合指
    buffer(0) = 0
    //初始化人的个数
    buffer(1) = 0
  }
  //处理分区中的每一条数据,聚合到中间结果
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    if(!input.isNullAt(0)) {
      buffer(0) = buffer.getInt(0) + input.getInt(0)
      buffer(1) = buffer.getInt(1) + 1
    }
  }
  //每一个分区聚合后的结果再汇总
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1(0) = buffer1.getInt(0)+buffer2.getInt(0)
    buffer1(1) = buffer1.getInt(1)+buffer2.getInt(1)
  }
  //获取最后聚合结果,即返回值
  override def evaluate(buffer: Row): Double = {
    buffer.getInt(0)/buffer.getInt(1).toDouble
  }

}

object UDAFDemo {

  def main(args: Array[String]): Unit = {
    val sparkSession = SparkSession.builder().appName("UDFDemo").master("local[2]").getOrCreate()
    val frame: DataFrame = sparkSession.read.json("jsonres")
    //注册自定义聚合函数
    sparkSession.udf.register("myAvg",new UDAF1)
    frame.createOrReplaceTempView("t_people")
    sparkSession.sql("select myAvg(age) as ageavg FROM t_people").show()
    sparkSession.stop()
  }

}

2.继承Aggregator,输出与输入是一对一

package sparkRdd_practice
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator

/**
  * @Description * @Author 黄仁议<613024710@qq.com>
  * @Version V1.0
  * @Since 1.0
  * @Date 2019/6/4 0004 17:42
  * @Description * @ClassName UDAF2
  */
//输入数据类型,中间结果类型,返回结果类型
case class Average(var sum:Int,var count:Int)
class UDAF2 extends Aggregator[People,Average,Double]{
  //初始化中间结果
  override def zero: Average = Average(0,0)
  //分区内聚合,把每一条数据聚合到中间结果对象
  override def reduce(b: Average, a: People): Average = {
    b.sum = b.sum + a.age
    b.count = b.count + 1
    b
  }

  //分区结果汇总
  override def merge(b1: Average, b2: Average): Average = {
    b1.sum = b1.sum + b2.sum
    b1.count = b1.count+b2.count
    b1
  }
  //返回最后聚合的结果
  override def finish(reduction: Average): Double = reduction.sum/reduction.count.toDouble
  //定义中间结果和返回结果的编码方式
  override def bufferEncoder: Encoder[Average] = Encoders.product

  override def outputEncoder: Encoder[Double] = Encoders.scalaDouble

}

object UDAFDemo2 {

}

上一篇:第六次试验


下一篇:计算分段函数