博客内容Blog Content

FlinkML实时流数据机器学习 FlinkML Real-time Stream Data Machine Learning

BlogType : Big Data releaseTime : 2024-09-20 16:00:00

通过FlinkML可实现”实时大数据+机器学习“的结合,本节给出一个实时流数据机器学习的例子 FlinkML enables the combination of "real-time big data + machine learning." This section provides an example of real-time stream data machine learning.

概述 Introduction

FlinkML 是 Apache Flink 的机器学习库,其设计目的是在 Flink 上进行分布式的实时机器学习,主要处理大规模流数据。FlinkML 可以充分利用 Flink 的分布式计算能力,处理大规模数据流,并在数据流中实时训练和应用机器学习模型。而 Python 机器学习库则更适合静态数据的批量处理和模型开发。如果你的应用场景需要实时处理大规模数据流,FlinkML 是一个不错的选择。而如果你专注于模型开发、实验和静态数据分析,Python 机器学习库可能更适合你的需求。

FlinkML is Apache Flink's machine learning library, designed for distributed real-time machine learning on Flink, primarily for handling large-scale streaming data. FlinkML can fully leverage Flink’s distributed computing capabilities to process large data streams, and to train and apply machine learning models in real time within the data stream. While Python machine learning libraries are more suited for batch processing of static data and model development. If your use case involves real-time processing of large-scale data streams, FlinkML is a good choice. However, if you are focused on model development, experimentation, and static data analysis, Python machine learning libraries may be more appropriate for your needs.



Java示例代码 Java Code Example

以下为一段K-means实时分类的示例代码,在线 K-Means 扩展了 K-Means 的功能,支持根据无限的训练数据流持续训练 K-Means 模型,使用“迷你批量” K-Means 规则进行更新,包含遗忘机制(即衰减)并推广。

The following is an example code for real-time K-means classification. Online K-Means extends the function of K-Means, supporting to train a K-Means model continuously according to an unbounded stream of train data, and it makes updates with the “mini-batch” K-Means rule, generalized to incorporate forgetfulness (i.e. decay).

package flink.clustering;

import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.types.Row;
import org.apache.flink.util.CloseableIterator;

import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Queue;

public class OnlineKmeansExample {
    public static void main(String[] args) throws Exception {
        // init env
        Configuration config = new Configuration();
        config.setBoolean("rest.enable-web-submission", true);
        config.setInteger("rest.port", 8081);
        StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironmentWithWebUI(config).setParallelism(1);
        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);

        String featuresCol = "features";
        String predictionCol = "prediction";
        DataStream<DenseVector> inputStream = env
                .addSource(new SourceFunction<DenseVector>() {
                    private Queue<DenseVector> data = new ArrayDeque<>(Arrays.asList(
                            Vectors.dense(0.0),
                            Vectors.dense(0.1),
                            Vectors.dense(0.3),
                            Vectors.dense(9.1),
                            Vectors.dense(9.2),
                            Vectors.dense(9.6),
                            Vectors.dense(100.1),
                            Vectors.dense(100.2),
                            Vectors.dense(0.5),
                            Vectors.dense(9.5),
                            Vectors.dense(30)
                    ));

                    @Override
                    public void run(SourceContext<DenseVector> ctx) throws Exception {
                        while (!data.isEmpty()) {
                            ctx.collect(data.poll());
                            Thread.sleep(1000);
                        }
                    }

                    @Override
                    public void cancel() {
                    }
                });

        // Convert data from DataStream to Table, as Flink ML uses Table API.
        Table input = tEnv.fromDataStream(inputStream).as(featuresCol);
        OnlineKMeans onlineKMeans =
                new OnlineKMeans()
                        .setFeaturesCol(featuresCol)
                        .setPredictionCol(predictionCol)
                        .setGlobalBatchSize(3)
                        .setInitialModelData(
                                KMeansModelData.generateRandomModelData(tEnv, 2, 1, 0.0, 0)
                        )
                        .setDecayFactor(1.0);

        // Trains the online K-means Model.
        OnlineKMeansModel onlineModel = onlineKMeans.fit(input);

        // Use the K-means Model for predictions.
        Table output = onlineModel.transform(input)[0];

        // Extracts and displays prediction results in a separate stream.
        DataStream<String> predictionStream = tEnv.toDataStream(output)
                .map(row -> {
                    DenseVector vector = (DenseVector) row.getField(featuresCol);
                    int clusterId = (Integer) row.getField(predictionCol);
                    return "Vector: " + vector + "\tCluster ID: " + clusterId;
                });
        predictionStream.printToErr();

        // Show centroids
        DataStream<String> centroidStream = tEnv.toDataStream(onlineModel.getModelData()[0])
                .map(row -> {
                    DenseVector[] centroids = (DenseVector[]) row.getField("centroids");
                    return "Centroids: " + Arrays.toString(centroids);
                });
        centroidStream.printToErr();

        env.execute("FlinkML Example");
    }
}


这里,我们指定了计算批次=3、聚类参数K=2以及衰减因子=1.0,对应结果和质心更新情况如下

Here, we specified the batch size = 3, the number of clusters K = 2, and the decay factor = 1.0. The corresponding results and centroid updates are as follows:

image.png


从结果可以看出,在线 K-Means 的质心会根据数据流的输入不断调整,体现了流处理的特点。模型不会等所有数据都到达后再做全局优化,而是根据当前的输入数据做出局部的调整

From the results, we can see that the centroids in online K-Means continuously adjust based on the incoming data stream, reflecting the characteristics of stream processing. The model does not wait for all data to arrive before performing a global optimization, but instead makes local adjustments based on the current input data,


尽管这可能不是全局最优的聚类结果,但它符合流处理的逻辑,即模型会逐渐学习和调整,而不是基于全局静态数据做出决策。

Although this may not result in a globally optimal clustering, it aligns with the logic of stream processing, where the model gradually learns and adjusts instead of making decisions based on globally static data.


接下来我们把衰减因子改成0.0看看结果如何,这样旧数据权重更低(相当于会更快遗忘新数据),数据会更迅速向质心靠拢:

Next, we will set the decay factor to 0.0 to see the results. This way, the weight of old data is lower (essentially, the model will forget old data more quickly), and the data will converge towards the centroids more rapidly:

image.png


我们看到,最后一个数据的聚类结果和之前不同,因为两者的衰减因子不同导致了质心变化的快慢不同,最终影响了聚类结果

We can see that the clustering result of the last data point is different from before. This is because the decay factor is different, which leads to varying speeds of centroid changes, ultimately affecting the clustering result.


此外,计算的批次也是个重要因素,因为它决定了到达什么批次时更新质心

In addition, the batch size is also an important factor, as it determines at which batch the centroids are updated.