python – 如何在PySpark中处理数据之前在所有Spark工作程序上运行函数?

我正在使用YARN在集群中运行Spark Streaming任务.集群中的每个节点都运行多个spark worker.在流式传输开始之前,我想在群集中所有节点上的所有工作程序上执行“设置”功能.

流式传输任务将传入的邮件分类为垃圾邮件或非垃圾邮件,但在此之前,它需要将最新的预先训练的模型从HDFS下载到本地磁盘,如此伪代码示例:

def fetch_models():
    if hadoop.version > local.version:
        hadoop.download()

我在SO上看过以下示例:

sc.parallelize().map(fetch_models)

但是在Spark 1.6中,parallelize()需要使用一些数据,就像我现在正在做的那种糟糕的解决方法:

sc.parallelize(range(1, 1000)).map(fetch_models)

为了确保该函数在所有工作程序上运行,我将范围设置为1000.我还不确切地知道在运行时集群中有多少个工作程序.

我已经阅读了编程文档并无情地搜索了,但我似乎无法找到任何方法实际上只向所有工作人员分发任何没有任何数据的东西.

完成此初始化阶段后,流式传输任务与往常一样,对来自Kafka的传入数据进行操作.

我使用模型的方法是运行类似这样的函数:

spark_partitions = config.get(ConfigKeys.SPARK_PARTITIONS)
stream.union(*create_kafka_streams())\
    .repartition(spark_partitions)\
    .foreachRDD(lambda rdd: rdd.foreachPartition(lambda partition: spam.on_partition(config, partition)))

从理论上讲,我可以在on_partition函数中检查模型是否是最新的,尽管在每个批次上执行此操作会非常浪费.我想在Spark开始从Kafka检索批次之前这样做,因为从HDFS下载可能需要几分钟……

更新:

要明确:这不是关于如何分发文件或如何加载它们的问题,而是关于如何在不对任何数据进行操作的情况下对所有工作程序运行任意方法.

澄清当前实际加载模型的含义:

def on_partition(config, partition):
    if not MyClassifier.is_loaded():
        MyClassifier.load_models(config)

    handle_partition(config, partition)

虽然MyClassifier是这样的:

class MyClassifier:
    clf = None

    @staticmethod
    def is_loaded():
        return MyClassifier.clf is not None

    @staticmethod
    def load_models(config):
        MyClassifier.clf = load_from_file(config)

静态方法,因为PySpark似乎无法使用非静态方法序列化类(类的状态与另一个worker的关系无关).在这里,我们只需要调用load_models()一次,并在将来的所有批次中调用MyClassifier.clf.对于每个批次来说,这是不应该做的事情,这是一次性的事情.与使用fetch_models()从HDFS下载文件相同.

解决方法:

如果您只想在工作机器之间分发文件,最简单的方法是使用SparkFiles机制:

some_path = ...  # local file, a file in DFS, an HTTP, HTTPS or FTP URI.
sc.addFile(some_path)

并使用SparkFiles.get和标准IO工具在worker上检索它:

from pyspark import SparkFiles

with open(SparkFiles.get(some_path)) as fw:
    ... # Do something

如果要确保实际加载模型,最简单的方法是加载模块导入.假设config可用于检索模型路径:

> model.py:

from pyspark import SparkFiles

config = ...
class MyClassifier:
    clf = None

    @staticmethod
    def is_loaded():
        return MyClassifier.clf is not None

    @staticmethod
    def load_models(config):
        path = SparkFiles.get(config.get("model_file"))
        MyClassifier.clf = load_from_file(path)

# Executed once per interpreter 
MyClassifier.load_models(config)  

> main.py:

from pyspark import SparkContext

config = ...

sc = SparkContext("local", "foo")

# Executed before StreamingContext starts
sc.addFile(config.get("model_file"))
sc.addPyFile("model.py")

import model

ssc = ...
stream = ...
stream.map(model.MyClassifier.do_something).pprint()

ssc.start()
ssc.awaitTermination()
上一篇:Linux下Python3.5使用pyqt5.11报错 ImportError: /usr/local/lib/python3.5/dist-packages/PyQt5/QtCore.so: undefined symbol: PySlice_AdjustIndices 解决方法


下一篇:python – spark-submit和pyspark有什么区别?