In Apache Spark, Accumulators are shared variables that can only be "added" to through an associative and commutative operation. They are primarily used to implement counters or sums across a distributed cluster.
While normal variables in Spark are copied to each worker node (and changes made on those workers are not sent back to the driver), Accumulators provide a way for tasks running on different executors to send updates back to a single central variable on the Driver.
Key Characteristics
-
Write-Only for Executors: Worker nodes (executors) can add to an accumulator, but they cannot read its value. This prevents complex synchronization issues.
-
Read-Only for the Driver: Only the Driver program can read the final accumulated value using the
.value()method. -
Commutative & Associative: The operation must work regardless of order (e.g., $1 + 2 = 2 + 1$) because Spark tasks run in parallel and can complete at different times.
What are they used for?
Accumulators are most commonly used for side-effect tasks that don't change the main data flow:
-
Data Quality Metrics: Counting how many "bad" or "corrupted" records were found during an ETL process without needing a separate
count()action. -
Debugging: Tracking how many times a specific code path was executed across a massive dataset.
-
Performance Monitoring: Measuring internal metrics, such as the number of empty lines processed in a file or the number of times a specific API was called.
Important Limitations & Best Practices
Because Spark is a distributed system with lazy evaluation, Accumulators can behave unexpectedly if not used carefully:
-
Use in Actions vs. Transformations: * In Actions (e.g.,
foreach): Spark guarantees that each task’s update to the accumulator is applied only once, even if a task is retried due to failure.-
In Transformations (e.g.,
map,filter): There is no guarantee. If Spark re-executes a stage because of a node failure or "speculative execution," the accumulator may be updated twice, leading to inflated results.
-
-
Performance: While efficient, having thousands of tasks sending updates to a single Driver can create a bottleneck if the accumulator itself is a very complex object.
| Feature | Broadcast Variables | Accumulators |
| Direction | Driver $\rightarrow$ Executors | Executors $\rightarrow$ Driver |
| Purpose | Efficiently sharing large read-only data | Aggregating results/counters |
| Access | Read-only for workers | Write-only for workers |
In PySpark, using an accumulator involves three main steps: initializing it on the driver, updating it within a transformation or action on the executors, and reading the final value back on the driver.
Here is a practical example where we use an accumulator to count "malformed" or "invalid" records in a dataset without interrupting the main processing flow.
from pyspark.sql import SparkSession
# Initialize Spark Session
spark = SparkSession.builder.appName("AccumulatorExample").get_OrCreate()
sc = spark.sparkContext
# 1. Initialize the Accumulator (starting at 0)
invalid_record_count = sc.accumulator(0)
# Sample data: some numbers, some "None", and some empty strings
data = ["10", "20", "", "30", None, "40", "error"]
def process_data(x):
global invalid_record_count
try:
# Try to convert to integer
return int(x)
except (ValueError, TypeError):
# 2. Update the Accumulator if the data is "bad"
invalid_record_count += 1
return None
# Create RDD and apply the processing function
rdd = sc.parallelize(data)
processed_rdd = rdd.map(process_data)
# 3. Trigger an Action (Accumulators are only updated when an action runs)
# We'll collect the valid numbers (filtering out the Nones)
valid_numbers = processed_rdd.filter(lambda x: x is not None).collect()
print(f"Valid Numbers: {valid_numbers}")
# 4. Read the Accumulator value on the Driver
print(f"Total Invalid Records Found: {invalid_record_count.value}")
Key Takeaways from this Example
-
Global Reference: In the worker function
process_data, we refer to the accumulator globally. Spark handles the complexity of shipping that reference to the workers. -
Lazy Evaluation: If you commented out the
.collect()or.count()action, theinvalid_record_count.valuewould still be 0. This is because Spark wouldn't have actually processed the data yet. -
The "Double Counting" Trap: If you were to run another action on
processed_rdd(likeprocessed_rdd.count()), Spark might re-run the map function, causing the accumulator to increment a second time. To avoid this, you should cache the RDD if you plan to perform multiple actions:processed_rdd.cache()
Would you like to see how to create a Custom Accumulator (for example, one that accumulates a list of error messages instead of just a counter)?
Yes, I would.
To create a Custom Accumulator, you need to inherit from the AccumulatorParam class (in older versions/RDD API) or the AccumulatorV2 logic. In PySpark, the most flexible way is to define a class that tells Spark how to initialize the "zero" value and how to "add" two values together.
This is particularly useful if you want to collect a list of specific error messages or a dictionary of category counts rather than just a simple sum.
Example: Collecting Error Messages in a List
In this example, we’ll create an accumulator that gathers strings. Since Spark is distributed, it needs to know how to merge lists from different executors.
from pyspark.accumulators import AccumulatorParam
# 1. Define the Custom Accumulator Logic
class ListAccumulatorParam(AccumulatorParam):
# Define the 'zero' value (an empty list)
def zero(self, value):
return []
# How to add a single value to the accumulator (on the executor)
# OR how to merge two lists (when executors send results back to driver)
def addInPlace(self, value1, value2):
if isinstance(value2, list):
value1.extend(value2)
else:
value1.append(value2)
return value1
# 2. Initialize the Spark Context and the Custom Accumulator
sc = spark.sparkContext
error_log = sc.accumulator([], ListAccumulatorParam())
data = [
{"id": 1, "status": "OK"},
{"id": 2, "status": "ERROR", "msg": "Timeout"},
{"id": 3, "status": "OK"},
{"id": 4, "status": "ERROR", "msg": "Connection Refused"}
]
def check_status(row):
global error_log
if row["status"] == "ERROR":
# Add the specific error message to the list
error_log += row["msg"]
return row
# 3. Process the Data
rdd = sc.parallelize(data)
rdd.map(check_status).collect() # Action triggers the mapping
# 4. Access the final list on the Driver
print("Captured Errors:")
for msg in error_log.value:
print(f"- {msg}")
How it works under the hood
-
Local Addition: Each executor maintains its own local list. When
error_log += row["msg"]is called, it appends to that local list. -
Merging (addInPlace): When a task finishes, it sends its local list to the Driver. The Driver uses the
addInPlacemethod to merge these lists together into one final master list. -
Associativity: This works because adding
List AtoList Bresults in the same total items as addingList BtoList A(order might vary, but the content remains).
When should you avoid this?
Custom accumulators are powerful, but they are kept in the Driver's memory.
-
Don't use them to collect millions of strings; you will cause an
OutOfMemoryErroron the Driver. -
Do use them for small-to-medium diagnostic data, like a "Top 10" list of most frequent error types or simple summary statistics.
Would you like to explore how to do this using the Spark SQL/DataFrame API instead of RDDs, or are you comfortable with this approach?
Yes, I would.
Using Accumulators with the DataFrame API is a bit different because DataFrames are optimized by Spark's Catalyst engine. While the RDD API allows you to easily inject a Python function with global variables, DataFrames prefer a more structured approach.
To use an accumulator in a DataFrame context, you typically wrap the logic inside a UDF (User Defined Function) or use the .foreach() / .foreachPartition() methods.
Example: Counting Issues in a DataFrame
This is the most common way to track row-level metadata without breaking your DataFrame pipeline.
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
spark = SparkSession.builder.appName("DataFrameAccumulator").getOrCreate()
# 1. Initialize the counter
null_count = spark.sparkContext.accumulator(0)
# Sample DataFrame with some null values
data = [(1, "Apple"), (2, None), (3, "Banana"), (4, None)]
df = spark.createDataFrame(data, ["id", "fruit"])
# 2. Define a function that updates the accumulator
def count_nulls(value):
global null_count
if value is None:
null_count += 1
return value
# Wrap it in a UDF
count_udf = udf(count_nulls, IntegerType())
# 3. Apply the transformation
# NOTE: The accumulator won't update yet (Lazy Evaluation)
df_transformed = df.withColumn("fruit_checked", count_udf(df["fruit"]))
# 4. Trigger an Action
# We must perform an action to actually "see" the data and trigger the UDF
df_transformed.collect()
print(f"Number of NULL fruits found: {null_count.value}")
Why .foreach() is often better
If you don't actually need to change the data but just want to scan it for metrics, foreach is much more efficient. It avoids the overhead of creating a new DataFrame and keeps the logic "side-effect only."
def check_row(row):
global null_count
if row['fruit'] is None:
null_count += 1
# This triggers the processing immediately without needing a .collect()
df.foreach(check_row)
print(f"Final Count via foreach: {null_count.value}")
⚠️ The "Lazy Evaluation" Warning (Again!)
This is the most common mistake with DataFrames and Accumulators.
If you define a transformation (like withColumn) but never call an action (like count(), show(), or save()), your accumulator will stay at 0. Conversely, if you call df.show() and then df.write(), Spark might run the same code twice, and your accumulator value will be doubled.
Summary Table: RDD vs. DataFrame Accumulators
| Feature | RDD Approach | DataFrame/SQL Approach |
| Ease of Use | Very natural (standard Python) | Requires UDFs or foreach |
| Performance | Slower (Python serialization) | Faster, but UDFs can slow it down |
| Reliability | Good for quick debugging | Risky due to Catalyst optimizations |

浙公网安备 33010602011771号