Spark fornece um StackOverflowError ao treinar usando o ALS

Ao tentar treinar um modelo de aprendizado de máquina usando o ALS no MLLib do Spark, continuei recebendo um StackoverflowError. Aqui está uma pequena amostra do rastreamento de pilha:

Traceback (most recent call last): File "/Users/user/Spark/imf.py", line 31, in  model = ALS.train(rdd, rank, numIterations) File "/usr/local/Cellar/apache-spark/1.3.1_1/libexec/python/pyspark/mllib/recommendation.py", line 140, in train lambda_, blocks, nonnegative, seed) File "/usr/local/Cellar/apache-spark/1.3.1_1/libexec/python/pyspark/mllib/common.py", line 120, in callMLlibFunc return callJavaFunc(sc, api, *args) File "/usr/local/Cellar/apache-spark/1.3.1_1/libexec/python/pyspark/mllib/common.py", line 113, in callJavaFunc return _java2py(sc, func(*args)) File "/usr/local/Cellar/apache-spark/1.3.1_1/libexec/python/lib/py4j-0.8.2.1-src.zip/py4j/java_gateway.py", line 538, in __call__ File "/usr/local/Cellar/apache-spark/1.3.1_1/libexec/python/lib/py4j-0.8.2.1-src.zip/py4j/protocol.py", line 300, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o35.trainALSModel. : org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 40.0 failed 1 times, most recent failure: Lost task 0.0 in stage 40.0 (TID 35, localhost): java.lang.StackOverflowError at java.io.ObjectInputStream$PeekInputStream.peek(ObjectInputStream.java:2296) at java.io.ObjectInputStream$BlockDataInputStream.peek(ObjectInputStream.java:2589) 

Esse erro também apareceria ao tentar executar .mean () para calcular o erro médio quadrático. Ele apareceu na versão 1.3.1_1 e na versão 1.4.1 do Spark. Eu estava usando o PySpark e aumentar a memory disponível não ajudou.

A solução foi adicionar pontos de verificação, o que impede que a recursion usada pelo codebase crie um estouro. Primeiro, crie um novo diretório para armazenar os pontos de verificação. Em seguida, faça com que seu SparkContext use esse diretório para o ponto de verificação. Aqui está o exemplo em Python:

 sc.setCheckpointDir('checkpoint/') 

Talvez você também precise adicionar pontos de verificação ao ALS, mas não consegui determinar se isso faz diferença. Para adicionar um ponto de verificação lá (provavelmente não necessário), basta fazer:

 ALS.checkpointInterval = 2