Skip to content
Closed
Prev Previous commit
Next Next commit
refactor gen-protos.sh
  • Loading branch information
LuciferYang committed Nov 11, 2024
commit 05606ef25fabd22f63d99dcdbdfb36ecca242175
29 changes: 15 additions & 14 deletions dev/gen-protos.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ cd "$SPARK_HOME"

OUTPUT_PATH=""
MODULE=""
SOURCE_MODULE=""
TARGET_MODULE=""

function usage() {
echo "Usage:./dev/gen-protos.sh [connect|streaming] [output_path]"
echo "Usage:./dev/gen-protos.sh [connect|streaming] [output_path]"
exit -1
}

Expand All @@ -37,9 +39,13 @@ fi
if [[ $1 == "connect" ]]; then
MODULE="connect"
OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/connect/proto/
SOURCE_MODULE="spark.connect"
TARGET_MODULE="pyspark.sql.connect.proto"
elif [[ $1 == "streaming" ]]; then
MODULE="streaming"
OUTPUT_PATH=${SPARK_HOME}/python/pyspark/sql/streaming/proto/
SOURCE_MODULE="spark.streaming"
TARGET_MODULE="pyspark.sql.streaming.proto"
else
usage
fi
Expand Down Expand Up @@ -82,23 +88,18 @@ rm -Rf gen
# Now, regenerate the new files
buf generate --debug -vvv

# We need to edit the generate python files to account for the actual package location and not
# the one generated by proto.
for f in `find gen/proto/python -name "*.py*"`; do
if [[ $MODULE == "connect" && ($f == *_pb2.py || $f == *_pb2_grpc.py) ]]; then
sed -e 's/from spark.connect import/from pyspark.sql.connect.proto import/g' $f > $f.tmp
# First fix the imports.
if [[ $f == *_pb2.py || $f == *_pb2_grpc.py ]]; then
sed -e "s/from ${SOURCE_MODULE} import/from ${TARGET_MODULE} import/g" $f > $f.tmp
mv $f.tmp $f
sed -e "s/DESCRIPTOR, 'spark.connect/DESCRIPTOR, 'pyspark.sql.connect.proto/g" $f > $f.tmp
mv $f.tmp $f
elif [[ $MODULE == "streaming" && $f == *_pb2.py ]]; then
sed -e 's/from spark.streaming import/from pyspark.sql.streaming.proto import/g' $f > $f.tmp
mv $f.tmp $f
sed -e "s/DESCRIPTOR, 'spark.streaming/DESCRIPTOR, 'pyspark.sql.streaming.proto/g" $f > $f.tmp
# Now fix the module name in the serialized descriptor.
sed -e "s/DESCRIPTOR, '${SOURCE_MODULE}/DESCRIPTOR, '${TARGET_MODULE}/g" $f > $f.tmp
mv $f.tmp $f
elif [[ $f == *.pyi ]]; then
if [[ $MODULE == "connect" ]]; then
sed -e 's/import spark.connect./import pyspark.sql.connect.proto./g' -e 's/spark.connect./pyspark.sql.connect.proto./g' -e '/ *@typing_extensions\.final/d' $f > $f.tmp
else
sed -e 's/import spark.streaming./import pyspark.sql.streaming.proto./g' -e 's/spark.streaming./pyspark.sql.streaming.proto./g' -e '/ *@typing_extensions\.final/d' $f > $f.tmp
fi
sed -e "s/import ${SOURCE_MODULE}./import ${TARGET_MODULE}./g" -e "s/${SOURCE_MODULE}./${TARGET_MODULE}./g" -e '/ *@typing_extensions\.final/d' $f > $f.tmp
mv $f.tmp $f
fi

Expand Down