Creating Nested Columns in PySpark Dataframe

2 minute read

Published:

A nested column is basically just a column with one or more sub-columns. Take a look at the following example.

=================
| nested_col	|
=================
| col_a	| col_b	|
=================
| dummy	| dummy	|
| dummy	| dummy	|
=================

The above dataframe shows that it has one nested column which consists of two sub-columns, namely col_a and col_b.

To make it brief, let’s take a look at how we can create a nested column in PySpark’s dataframe.

from pyspark.sql.types import StringType, StructField, StructType

schema_p = StructType(
	[
		StructField('x', StringType(), True),
		StructField('y', StringType(), True)
	]
)

schema_df = StructType(
	[
		StructField('dummy_col', StringType(), True),
		StructField('p', schema_p, True)
	]
)
 
df = spark.createDataFrame([
			('dummy_a', ['x_a', 'y_a']), 
			('dummy_b', ['x_b', 'y_b']),
			('dummy_c', ['x_c', 'y_c'])
], schema_df)

The above code simply does the following ways:

  • Create the inner schema (schema_p) for column p. This inner schema consists of two columns, namely x and y
  • Create the schema for the whole dataframe (schema_df). As you can see, we specify the type of column p with schema_p
  • Create the dataframe rows based on schema_df

The above code will result in the following dataframe and schema.

root
 |-- dummy_col: string (nullable = true)
 |-- p: struct (nullable = true)
 |    	|-- x: string (nullable = true)
 |    	|-- y: string (nullable = true)


==========================
| dummy_col |     p      |
==========================
| dummy_a   | [x_a, y_a] |
| dummy_b   | [x_b, y_b] |
| dummy_c   | [x_c, y_c] |
==========================

To access the sub-columns, we can use dot (.) character. Here are some of operations we can apply to the above dataframe.

# select the second elements of column 'p'
df.select('p.y')

# filter out all the rows whose the first element is 'x_a'
df.filter(F.col('p.x') == 'x_a')

# group based on the first element of column 'p'
df.groupby('p.x').count()

However, this dot (.) mechanism can’t be applied on all the dataframe operations. One of the examples is drop operation.

df.drop('p.x').printSchema()

Resulting schema:
root
 |-- dummy_col: string (nullable = true)
 |-- p: struct (nullable = true)
 |    	|-- x: string (nullable = true)
 |    	|-- y: string (nullable = true)

Or even when we’d like to replace the sub-column name with a new one.

df.withColumnRenamed('p.x', 'new_px').printSchema()

Resulting schema:
root
 |-- dummy_col: string (nullable = true)
 |-- p: struct (nullable = true)
 |    	|-- x: string (nullable = true)
 |    	|-- y: string (nullable = true)

Hope it helps. Thank you for reading.