How AlmaBetter created an
IMPACT!Arunav Goswami
Data Science Consultant at almaBetter
Explore a detailed PySpark cheat sheet covering functions, DataFrame operations, RDD basics and commands. Perfect for data engineers and big data enthusiasts
PySpark is the Python API for Apache Spark, an open-source, distributed computing system. PySpark allows data engineers and data scientists to process large datasets efficiently and integrate with Hadoop and other big data technologies. This cheat sheet is designed to provide an overview of the most frequently used PySpark functionalities, organized for ease of reference.
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("App Name").getOrCreate()
From RDD:
from pyspark.sql import Row
rdd = spark.sparkContext.parallelize([Row(name="Alice", age=25)])
df = rdd.toDF()
From a file:
df = spark.read.csv("path/to/file.csv", header=True, inferSchema=True)
Display:
df.show()
Schema:
df.printSchema()
Select columns:
df.select("column_name").show()
Filter rows:
df.filter(df["column_name"] > value).show()
Add a column:
df.withColumn("new_column", df["existing_column"] + 10).show()
df.createOrReplaceTempView("table_name")
spark.sql("SELECT * FROM table_name").show()
from pyspark.sql.functions import col, avg, max, min, count
# Group by and aggregate
df.groupBy("column1").agg(avg("column2"), max("column2")).show()
# Count distinct values
df.select(countDistinct("column1")).show()
# Aggregate without grouping
df.agg(min("column1"), max("column2")).show()
from pyspark.sql.functions import upper
df.select(upper(df["column_name"])).show()
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, rank
# Define a window
window_spec = Window.partitionBy("column1").orderBy("column2")
# Add row numbers
df.withColumn("row_number", row_number().over(window_spec)).show()
# Add ranks
df.withColumn("rank", rank().over(window_spec)).show()
from pyspark.sql.functions import lit, concat
df.withColumn("concatenated", concat(df["col1"], lit("_"), df["col2"])).show()
from pyspark.sql.functions import current_date, date_add
df.withColumn("today", current_date()).withColumn("tomorrow", date_add(current_date(), 1)).show()
from pyspark.sql.functions import sum, count
df.groupBy("group_col").agg(sum("value_col"), count("*")).show()
df = spark.read.json("path/to/json/file.json")
df = spark.read.parquet("path/to/parquet/file")
df.write.csv("path/to/save.csv", header=True)
df.write.json("path/to/save.json")
df.cache()
df.persist()
df.unpersist()
# Inner join
df1.join(df2, df1["key"] == df2["key"], "inner").show()
# Left outer join
df1.join(df2, df1["key"] == df2["key"], "left").show()
# Full outer join
df1.join(df2, df1["key"] == df2["key"], "outer").show()
# Sort by column
df.sort("column1").show()
# Sort descending
df.orderBy(df["column1"].desc()).show()
# Drop rows with null values
df.na.drop().show()
# Fill null values
df.na.fill({"column1": 0, "column2": "missing"}).show()
# Replace null values in a column
df.fillna(0, subset=["column1"]).show()
rdd = spark.sparkContext.parallelize([1, 2, 3, 4])
rddFromFile = spark.sparkContext.textFile("path/to/file.txt")
map: Applies a function to each element.
rdd.map(lambda x: x * 2).collect()
filter: Filters elements based on a condition.
rdd.filter(lambda x: x % 2 == 0).collect()
flatMap: Maps each element to multiple elements.
rdd.flatMap(lambda x: [x, x * 2]).collect()
collect: Returns all elements.
rdd.collect()
count: Returns the number of elements.
rdd.count()
reduce: Aggregates elements using a function.
rdd.reduce(lambda x, y: x + y)
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
# Define a UDF
def multiply_by_two(x):
return x * 2
multiply_udf = udf(multiply_by_two, IntegerType())
# Apply UDF
df.withColumn("new_column", multiply_udf(df["column"])).show()
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
# Save model
model = LinearRegression().fit(training_data)
model.save("path/to/model")
# Load model
from pyspark.ml.regression import LinearRegressionModel
loaded_model = LinearRegressionModel.load("path/to/model")
df.repartition(10)
jdbc_url = "jdbc:mysql://host:port/db"
df = spark.read.format("jdbc").option("url", jdbc_url).option("dbtable", "table").load()
PySpark is a versatile tool for handling big data. This cheat sheet covers RDDs, DataFrames, SQL queries, and built-in functions essential for data engineering. Using these commands effectively can optimize data processing workflows, making PySpark indispensable for scalable, efficient data solutions.
More Cheat Sheets and Top Picks