Một số bài viết trước trên blog VietnamLab đã giới thiệu về Apache Spark, một framework rất mạnh phục vụ cho việc tính toán phân tán, đồng thời cũng hỗ trợ rất nhiều các thuật toán Machine Learning. PySpark là giao diện Python hỗ trợ việc viết chương trình Spark sử dụng ngôn ngữ Python (dễ học và dễ dùng hơn nhiều so với Scala). Tuy nhiên việc viết Unit Test tưởng chừng như rất đơn giản đối với các chương trình Python thì lại khá phức tạp và rắc rối đối với các chương trình PySpark (do bản chất phân tán của Spark, đồng thời cũng do PySpark có khá nhiều dependencies phức tạp). Bài viết này sẽ trình bày cách để thiết lập và viết unit test cho các chương trình PySpark.
Unit Test với Python
Việc viết Unit Test trong Python khá đơn giản, cách thông dụng nhất là sử dụng package unittest
có sẵn trong standard library. Một ví dụ vô cùng đơn giản:
import unittest
class SimpleTest(unittest.TestCase):
def test_simple_function(self):
self.assertEqual(len([1, 2, 3, 4]), 4)
if __name__ == '__main__':
unittest.main()
Và ta sẽ chạy test với câu lệnh như sau:
$ python test_pyspark.py SimpleTest.test_simple_function
.
----------------------------------------------------------------------
Ran 1 test in 0.000s
OK
Tuy nhiên với chương trình PySpark thì mọi thứ không đơn giản như vậy. Do Spark thường được cài ở vị trí riêng biệt so với các thư viện của Python nên cần phải thiết lập PYTHON_PATH
để có thể chương trình có thể sử dụng được các module của PySpark. Ngoài ra, việc khởi tạo SparkContext
(bắt buộc đối với chương trình Spark) cho mỗi test case cũng cần có giải pháp phù hợp. Thông thường, khi chạy chương trình Spark với câu lệnh spark-submit
thì ta sẽ không cần để ý đến những vấn đề ở trên.
Viết Unit Test cho Spark
Ta có thể tự mình viết code để giải quyết các vấn đề ở trên, tuy nhiên Spark đã hỗ trợ sẵn 2 base class phục vụ cho việc viết test case với unittest
(dù không có tài liệu chính thức):
- pyspark.tests.PySparkTestCase
class PySparkTestCase(unittest.TestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
self.sc = SparkContext('local[4]', class_name)
def tearDown(self):
self.sc.stop()
sys.path = self._old_sys_path
- pyspark.tests.ReusedPySparkTestCase
class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.sc = SparkContext('local[4]', cls.__name__)
@classmethod
def tearDownClass(cls):
cls.sc.stop()
Sự khác biệt giữa hai class này là với ReusedPySparkTestCase
2 class method setUpClass
và tearDownClass
của unittest.TestCase
được override và phục vụ cho việc khởi tạo và huỷ bỏ SparkContext
, còn với PySparkTestCase
thì là hai instance method setUp
và tearDown
. Do vậy khi sử dụng ReusedPySparkTestCase
, tất cả các test case trong class sẽ chỉ dùng chung một SparkContext
, còn với PySparkTestCase
thì mỗi test case sẽ có một SparkContext
riêng nên việc chạy test sẽ tốn thời gian hơn. Trong hầu hết các trường hợp, ta sẽ sử dụng ReusedPySparkTestCase
.
Một ví dụ đơn giản sử dụng ReusedPySparkTestCase
:
import unittest
from pyspark.tests import ReusedPySparkTestCase
class SimpleSparkTest(ReusedPySparkTestCase):
def test_simple_function(self):
rdd = self.sc.parallelize([1, 2, 3, 4])
self.assertEqual(rdd.collect(), [1, 2, 3, 4])
if __name__ == '__main__':
unittest.main()
Tuy nhiên, khi chạy test case ở trên sẽ xảy ra lỗi thông báo là module pyspark
không tồn tại. Ta cần thêm đường dẫn các thư viện PySpark vào PYTHON_PATH
của chương trình test bằng cách thêm đoạn code như sau trước khi import bất kì module của Spark (đường dẫn và version sẽ tuỳ thuộc thiết lập cài đặt Spark)
import os
import sys
import unittest
spark_home = os.environ['SPARK_HOME']
sys.path.insert(1, os.path.join(spark_home, 'python'));
sys.path.insert(1, os.path.join(spark_home, 'python', 'lib/pyspark.zip'))
sys.path.insert(1, os.path.join(spark_home, 'python', 'lib/py4j-0.10.4-src.zip'))
from pyspark.tests import ReusedPySparkTestCase
class SimpleSparkTest(ReusedPySparkTestCase):
...
SPARK_HOME là biến môi trường lưu đường dẫn đến thư mục cài đặt Spark. Ta có thể thiết lập biến này trong file .bashrc
hoặc thủ công thiết lập trước mỗi lần chạy test. Giờ test case đã có thể chạy ngon lành:
$ export SPARK_HOME=/usr/local/Cellar/apache-spark/2.2.0/libexec
$ python test_pyspark.py SimpleSparkTest.test_simple_function
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
17/07/24 02:28:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
.
----------------------------------------------------------------------
Ran 1 test in 5.258s
OK
Kết luận
Dù hơi mất thời gian để viết, nhưng unit test sẽ khiến chương trình của bạn trở nên an toàn và dễ maintain hơn rất nhiều. Không những thế, việc thiết kế kiến trúc code sao cho nó có thể test được cũng giúp nâng cao chất lượng code. Hi vọng là bài viết này phần nào sẽ giúp bạn đọc có thể dễ dàng viết test, nâng cao chất lượng chương trình PySpark của mình.