public Dataset<T> where(Column condition)
public Dataset<T> where(String conditionExpr)
第一种是通过Column对象操作条件查询,第二种是通过直接写条件查询
// 条件查询
// id > 3
Column id = new Column("id").gt(3);
// age = 16
Column name = new Column("age").equalTo("16");
// id > 3 and age =16
Column select = id.and(name);
// 直接书写条件
json.select("id","name","age","phone").where("id > 3 and age = 16").orderBy(new Column("id").desc()).show();
// 通过多个where生成 id > 3 and age =16
json.select("id","name","age","phone").where(id).where(name).orderBy(new Column("id").desc()).show();
// 通过Column操作转换得到 id > 3 and age =16
json.select("id","name","age","phone").where(select).orderBy(new Column("id").desc()).show();
public class MyAverage extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage() {
ArrayList<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn",DataTypes.LongType,true));
inputSchema = DataTypes.createStructType(inputFields);
ArrayList<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum",DataTypes.LongType,true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType,true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
// Data types of input arguments of this aggregate function
// 聚合函数输入参数的数据类型(其实是该函数所作用的Dataset指定列的数据类型)
@Override
public StructType inputSchema() {
return inputSchema;
}
// Data types of values in the aggregation buffer
// 聚合函数的缓冲器结构,返回之前定义了用于记录累加值和累加数的字段结构
@Override
public StructType bufferSchema() {
return bufferSchema;
}
// The data type of the returned value
// 聚合函数返回值的数据类型
@Override
public DataType dataType() {
return DataTypes.DoubleType;
}
// Whether this function always returns the same output on the identical input
// 此函数是否始终在相同输入上返回相同输出
@Override
public boolean deterministic() {
return true;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
// 初始化给定的buffer聚合缓冲器
// buffer 聚合缓冲器其本身是一个Row对象,因此可以调用其标准方法访问buffer内的元素,例如在索引处检索一个值
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0,0L);
buffer.update(1,0L);
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)){
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1)+1;
buffer.update(0,updatedSum);
buffer.update(1,updatedCount);
}
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0,mergedSum);
buffer1.update(1,mergedCount);
}
@Override
public Object evaluate(Row buffer) {
return ((double)buffer.getLong(0))/buffer.getLong(1);
}
}
运行代码:
public static void main(String[] args) {
SparkSession sparkSession = SparkSession.builder().master("local").appName("XXXXXXXXXX").config("spark.testing.memory", 471859200).getOrCreate();
// 读取文件
Dataset<Row> df = sparkSession.read().json("D:\\sparksqlfile\\jsondata\\student5.json");
// 注册自定义函数
sparkSession.udf().register("myAverage",new MyAverage());
// 显示原始数据
df.createOrReplaceTempView("student");
df.show();
// 使用自定义UDF求平均值
Dataset<Row> result = sparkSession.sql("SELECT myAverage(age) as average_salary FROM student");
result.show();
}
// 定义Employee样例类型规范聚合函数输入数据的数据类型
public class Employee implements Serializable {
private String name;
private long age;
private String sex;
private String institute;
private String phone;
public Employee() {
}
public Employee(String name, long age, String sex, String institute, String phone) {
this.name = name;
this.age = age;
this.sex = sex;
this.institute = institute;
this.phone = phone;
}
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public long getAge() {
return age;
}
public void setAge(long age) {
this.age = age;
}
public String getSex() {
return sex;
}
public void setSex(String sex) {
this.sex = sex;
}
public String getInstitute() {
return institute;
}
public void setInstitute(String institute) {
this.institute = institute;
}
public String getPhone() {
return phone;
}
public void setPhone(String phone) {
this.phone = phone;
}
@Override
public String toString() {
return "Employee{" +
"name='" + name + '\'' +
", age=" + age +
", sex='" + sex + '\'' +
", institute='" + institute + '\'' +
", phone='" + phone + '\'' +
'}';
}
}
定义聚合函数缓冲器:
// 定义Average样例类规范buffer聚合缓冲器的数据类型
public class Average implements Serializable {
private long sum;
private long count;
public Average() {
}
public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
public long getSum() {
return sum;
}
public void setSum(long sum) {
this.sum = sum;
}
public long getCount() {
return count;
}
public void setCount(long count) {
this.count = count;
}
@Override
public String toString() {
return "Average{" +
"sum=" + sum +
", count=" + count +
'}';
}
}
UDF代码:
// 用户自定义的强类型聚合函数必须继承Aggregator抽象类,注意需要传入聚合函数输入数据,buffer缓冲器以及返回的结果的泛型参数
public class MyAverage2 extends Aggregator<Employee,Average,Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
// 定义聚合的零值,应该满足任何b + zero = b
@Override
public Average zero() {
return new Average(0L, 0L);
}
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
// 定义作为Average对象的buffer聚合缓冲器如何处理每一条输入数据(Employee对象)的聚合逻辑,
// 与上例的求取平均值的无类型聚合函数的update方法一样,每一次调用reduce都会更新buffer聚合函数的缓冲器
// 并将更新后的buffer作为返回值
@Override
public Average reduce(Average buffer, Employee employee) {
long newSum = buffer.getSum() + employee.getAge();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Merge two intermediate values
// 与上例的求取平均值的无类型聚合函数的merge方法实现的逻辑相同
@Override
public Average merge(Average b1, Average b2) {
long mergeSum = b1.getSum() + b2.getSum();
long mergeCount = b1.getCount() + b2.getCount();
b1.setSum(mergeSum);
b1.setCount(mergeCount);
return b1;
}
// Transform the output of the reduction
// 定义输出结果的逻辑,reduction表示buffer聚合缓冲器经过多次reduce,merge之后的最终聚合结果
// 仍为Average对象记录着所有数据的累加,累加次数
@Override
public Double finish(Average reduction) {
System.out.println("////////////////"+((double) reduction.getSum()) / reduction.getCount());
return ((double)reduction.getSum())/reduction.getCount();
}
// Transform the output of the reduction
// 指定中间值的编码器类型
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Specifies the Encoder for the final output value type
// 指定最终输出的编码器类型
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}