在PySpark中根据条件连接表格:如果点在多边形内部

8

我有两个 PySpark 数据框: 一个是点的数据框 df_pnt,另一个是多边形的数据框 df_poly。由于我不太熟悉 PySpark,所以我在正确地将这些数据框联接时遇到了困难,需要判断一个点是否在多边形内部。

我从这个页面上的资料中构建了以下代码:

from shapely import wkt  
import numpy as np
from shapely.geometry import Polygon, Point
import matplotlib.pyplot as plt
import pandas as pd
import geopandas as gpd
from pyspark.sql.types import StringType

# Create simple data
polygon1 = Polygon([[0, 0], [.5, 0], [0.3, 0.2], [0, 0.2]])
polygon2 = Polygon([[0.6, 0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]])
polygon3 = Polygon([[0.6, 0.5], [.5, 0.5], [0.3, 0.7], [0.4, 0.8]])
polygon4 = Polygon([[0, .5], [.2, 0.4], [0.5, 0.3], [0.5, 0.1]])

df = pd.DataFrame(data={'id':[0, 1, 2, 3],
                 'geometry':[polygon1, polygon2,   polygon3, polygon4]})
df_poly = gpd.GeoDataFrame(
    df, geometry=df['geometry']); del df


df = pd.DataFrame(data={'id':range(0,15),
                'geometry':[Point(pnt) for pnt in np.random.rand(15,2)]})
df_pnt = gpd.GeoDataFrame(
    df, geometry=df['geometry']); del df

# convert shape to str in pandas df
df_poly['wkt'] = pd.Series(
    map(lambda geom: str(geom.to_wkt()), df_poly['geometry']),
    index=df_poly.index, dtype='str')

df_pnt['wkt'] = pd.Series(
        map(lambda geom: str(geom.to_wkt()), df_pnt['geometry']),
        index=df_pnt.index, dtype='str')

# Now we create geometry column as string column in pyspark df
tmp = df_poly.drop("geometry", axis=1)
df_poly = spark.createDataFrame(tmp).cache(); del tmp

tmp = df_pnt.drop("geometry", axis=1)
df_pnt = spark.createDataFrame(tmp).cache(); del tmp

如果我们想要绘制第一个多边形,我们应该运行这段代码。
wkt.loads(df_poly.take(1)[0].wkt)

如果我们想检查一个Polygon对象是否包含一个Point对象,我们需要以下代码:

Polygon.contains(Point)

问题是如何在连接过程中处理这个自定义条件?df_poly比点数据框要小得多,因此我希望也能利用广播技术。
更新: 如果我需要在geopandas中实现这个功能,代码将如下所示:
df_pnt
    id  geometry
0   0   POINT (0.08834 0.23203)
1   1   POINT (0.67457 0.19285)
2   2   POINT (0.71186 0.25128)
3   3   POINT (0.55621 0.35016)
4   4   POINT (0.79637 0.24668)
5   5   POINT (0.40932 0.37155)
6   6   POINT (0.36124 0.68229)
7   7   POINT (0.13476 0.58242)
8   8   POINT (0.41659 0.46298)
9   9   POINT (0.74878 0.78191)
10  10  POINT (0.82088 0.58064)
11  11  POINT (0.28797 0.24399)
12  12  POINT (0.40502 0.99233)
13  13  POINT (0.68928 0.73251)
14  14  POINT (0.37765 0.71518)

df_poly

        id  geometry
0   0   POLYGON ((0.00000 0.00000, 0.50000 0.00000, 0....
1   1   POLYGON ((0.60000 0.00000, 0.60000 0.30000, 0....
2   2   POLYGON ((0.60000 0.50000, 0.50000 0.50000, 0....
3   3   POLYGON ((0.00000 0.50000, 0.20000 0.40000, 0....

gpd.sjoin(df_pnt, df_poly, how="left", op='intersects')

    id_left     geometry    index_right     id_right
0   0   POINT (0.08834 0.23203)     NaN     NaN
1   1   POINT (0.67457 0.19285)     1.0     1.0
2   2   POINT (0.71186 0.25128)     NaN     NaN
3   3   POINT (0.55621 0.35016)     NaN     NaN
4   4   POINT (0.79637 0.24668)     NaN     NaN
5   5   POINT (0.40932 0.37155)     NaN     NaN
6   6   POINT (0.36124 0.68229)     2.0     2.0
7   7   POINT (0.13476 0.58242)     NaN     NaN
8   8   POINT (0.41659 0.46298)     NaN     NaN
9   9   POINT (0.74878 0.78191)     NaN     NaN
10  10  POINT (0.82088 0.58064)     NaN     NaN
11  11  POINT (0.28797 0.24399)     NaN     NaN
12  12  POINT (0.40502 0.99233)     NaN     NaN
13  13  POINT (0.68928 0.73251)     NaN     NaN
14  14  POINT (0.37765 0.71518)     2.0     2.0

data


当您创建Spark数据框df_pntdf_poly时,您可以打印模式吗?(df.printSchema())并显示一些值吗?(df.show(truncate=False))。并非每个pyspark用户都熟悉pandas,因此很难按原样回答您的问题。 - blackbishop
两个 Spark 数据框 printSchema 的结果: root |-- id: long (nullable = true) |-- wkt: string (nullable = true) - James Flash
df_pntdf.show 如下所示: +---+---------------------------------------------+ |id |wkt | +---+---------------------------------------------+ |0 |POINT (0.2921357376954469 0.6871580673326519)| |1 |POINT (0.6286913183363046 0.1356827455860742)| |2 |POINT (0.8953860983142878 0.5851118896234707)| |3 |POINT (0.3906532809342733 0.7742480793942560)| |4 |POINT (0.2680620635805934 0.1676353319933286)| +---+---------------------------------------------+ - James Flash
df_polydf.show如下所示: |id |wkt | |0 |POLYGON ((0.0000000000000000 0.0000000000000000, 0.5000000000000000 0.0000000000000000, 0.3000000000000000 0.2000000000000000, 0.0000000000000000 0.2000000000000000, 0.0000000000000000 0.0000000000000000))| - James Flash
输出结果更大了,但我已经删除了一些行和字符以适应注释部分。 - James Flash
显示剩余3条评论
1个回答

0

你可以将 Polygon.contains() 包装在 UDF 中,并在其上连接表格。用户定义函数作为连接条件仅允许在内连接中使用。

创建示例数据框:

from pyspark.sql import types as T
from pyspark.sql import functions as F
from shapely.geometry import Polygon, Point
from typing import Tuple, List


# create points dataframe
point = T.StructType(
    [
        T.StructField("x", T.FloatType()),
        T.StructField("y", T.FloatType()),
    ]
)

# point table
df_pnt = spark_session.createDataFrame(
    data=[
        (0, (0.08834, 0.23203)),
        (1, (0.67457, 0.19285)),
        (2, (0.71186, 0.25128)),
    ],
    schema=T.StructType(
        [
            T.StructField("id", T.IntegerType()),
            T.StructField("point", point),
        ]
    ),
)
df_pnt.printSchema()
df_pnt.show()

# create polygon dataframe 
polygon = T.ArrayType(point)
df_plg = spark_session.createDataFrame(
    data=[
        (0, [[0.0, 0.0], [0.5, 0.0], [0.3, 0.2], [0.0, 0.2]]),
        (1, [[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]),
        (2, [[0.6, 0.5], [0.5, 0.5], [0.3, 0.7], [0.4, 0.8]]),
        (3, [[0.0, 0.5], [0.2, 0.4], [0.5, 0.3], [0.5, 0.1]]),
    ],
    schema=T.StructType(
        [
            T.StructField("id", T.IntegerType()),
            T.StructField("polygon", polygon),
        ]
    ),
)
df_plg.printSchema()
df_plg.show(truncate=False)

输出:

root
 |-- id: integer (nullable = true)
 |-- point: struct (nullable = true)
 |    |-- x: float (nullable = true)
 |    |-- y: float (nullable = true)

+---+------------------+
| id|             point|
+---+------------------+
|  0|[0.08834, 0.23203]|
|  1|[0.67457, 0.19285]|
|  2|[0.71186, 0.25128]|
+---+------------------+

root
 |-- id: integer (nullable = true)
 |-- polygon: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- x: float (nullable = true)
 |    |    |-- y: float (nullable = true)

+---+------------------------------------------------+
|id |polygon                                         |
+---+------------------------------------------------+
|0  |[[0.0, 0.0], [0.5, 0.0], [0.3, 0.2], [0.0, 0.2]]|
|1  |[[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]|
|2  |[[0.6, 0.5], [0.5, 0.5], [0.3, 0.7], [0.4, 0.8]]|
|3  |[[0.0, 0.5], [0.2, 0.4], [0.5, 0.3], [0.5, 0.1]]|
+---+------------------------------------------------+

使用自定义连接条件连接表:

# join condition UDF
@F.udf(returnType=T.BooleanType())
def is_point_in_polygon(
    point: Tuple[float, float], polygon: List[Tuple[float, float]]
) -> bool:
    return Polygon(shell=polygon).contains(Point(*point))


# use UDF as join condition
df_pnt.join(df_plg, on=is_point_in_polygon("point", "polygon"), how="inner").show(truncate=False)

# show UDF value for each (point, polygon) pair
df_pnt.crossJoin(df_plg).withColumn(
    "is_in", is_point_in_polygon("point", "polygon")
).show(truncate=False)

输出:

+---+------------------+---+------------------------------------------------+
|id |point             |id |polygon                                         |
+---+------------------+---+------------------------------------------------+
|1  |[0.67457, 0.19285]|1  |[[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]|
+---+------------------+---+------------------------------------------------+


+---+------------------+---+------------------------------------------------+-----+
|id |point             |id |polygon                                         |is_in|
+---+------------------+---+------------------------------------------------+-----+
|0  |[0.08834, 0.23203]|0  |[[0.0, 0.0], [0.5, 0.0], [0.3, 0.2], [0.0, 0.2]]|false|
|0  |[0.08834, 0.23203]|1  |[[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]|false|
|0  |[0.08834, 0.23203]|2  |[[0.6, 0.5], [0.5, 0.5], [0.3, 0.7], [0.4, 0.8]]|false|
|0  |[0.08834, 0.23203]|3  |[[0.0, 0.5], [0.2, 0.4], [0.5, 0.3], [0.5, 0.1]]|false|
|1  |[0.67457, 0.19285]|0  |[[0.0, 0.0], [0.5, 0.0], [0.3, 0.2], [0.0, 0.2]]|false|
|1  |[0.67457, 0.19285]|1  |[[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]|true |
|2  |[0.71186, 0.25128]|0  |[[0.0, 0.0], [0.5, 0.0], [0.3, 0.2], [0.0, 0.2]]|false|
|2  |[0.71186, 0.25128]|1  |[[0.6, 0.0], [0.6, 0.3], [0.6, 0.4], [0.7, 0.2]]|false|
|1  |[0.67457, 0.19285]|2  |[[0.6, 0.5], [0.5, 0.5], [0.3, 0.7], [0.4, 0.8]]|false|
|1  |[0.67457, 0.19285]|3  |[[0.0, 0.5], [0.2, 0.4], [0.5, 0.3], [0.5, 0.1]]|false|
|2  |[0.71186, 0.25128]|2  |[[0.6, 0.5], [0.5, 0.5], [0.3, 0.7], [0.4, 0.8]]|false|
|2  |[0.71186, 0.25128]|3  |[[0.0, 0.5], [0.2, 0.4], [0.5, 0.3], [0.5, 0.1]]|false|
+---+------------------+---+------------------------------------------------+-----+

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接