本文共 2728 字,大约阅读时间需要 9 分钟。
UDF 输入一行,输出一行
UDAF 输入多行,输出一行
UDTF 输入一样,输出多行
//导包import org.apache.spark.sql.SparkSession//编写代码// 1.实例SparkSession val spark = SparkSession.builder().master("local[*]").appName("udf").getOrCreate()// 2.根据sparkSession获取SparkContext val sc = spark.sparkContext//3.读取数据并输出 val datas = spark.read.textFile("./data/udf/udf.txt")//4.数据展示 datas.show()//5.编写UDF将小写变成大写 spark.udf.register("smallToBig", (str: String) => str.toUpperCase())//6.将RDD转换为DataFrame val dataFrame = datas.toDF()//7.注册表 dataFrame.createOrReplaceTempView("word")//8.使用自定义函数查询 并输出 spark.sql("select value, smallToBig(value) from word").show()
继承UserDefinedAggregateFunction方法重写说明
InputSchema: 输入数据的类型
bufferSchema: 产生中间结果的数据类型
dataType:最终返回的结果类型
dataeministic: 确保一致性,一般用true
initialize: 指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate: 计算最终的结果
//导包import org.apache.spark.sql.{Row, SparkSession}import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructField, StructType}//编写自定义UDAFclass MyUDAF extends UserDefinedAggregateFunction {// 输入的数据类型的schema override def inputSchema: StructType = { StructType(StructField("input", LongType) :: Nil) }//缓冲去数据类型schema 就是转换字后的数据schema override def bufferSchema: StructType = { StructType(StructField("sum", LongType) :: StructField("total", LongType) :: Nil) }// 返回值数据类型 override def dataType: DataType = { DoubleType }// 确定是否相同的输入会有相同的输出 override def deterministic: Boolean = true// 初始化内部数据结构 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = 0L buffer(1) = 0L }// 更新数据内部结构,区内计算 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {// 所有的金额 buffer(0) = buffer.getLong(0) + input.getLong(0)// 一共多少条数据 buffer(1) = buffer.getLong(1) + 1 }// 来字不同分区数据进行合并,全局合并 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1) }// 计算输出数据值 override def evaluate(buffer: Row): Any = { buffer.getLong(0).toDouble / buffer.getLong(1) } }//编写测试代码//1.实例SparkSession val spark = SparkSession.builder().master("local[*]").appName("sql").getOrCreate()//2.根据SparkSession获取SparkContext 获取上下文对象 val sc = spark.sparkContext//3.使用SparkContext 读取数据 val dataRDD = spark.read.json("./data/udf/udaf.json")//4.注册表 dataRDD.createOrReplaceTempView("word")//5.注册 UDAF 函数 spark.udf.register("myavg", new MyUDAF)//6.使用自定义UDAF函数 spark.sql("select myavg(salary) from word").show()
转载地址:http://cokzi.baihongyu.com/