Skip to content

Commit 9eb857b

Browse files
author
Davies Liu
committed
withColumn should replace the old column
1 parent 8bae901 commit 9eb857b

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,15 +1202,18 @@ def freqItems(self, cols, support=None):
12021202
@ignore_unicode_prefix
12031203
@since(1.3)
12041204
def withColumn(self, colName, col):
1205-
"""Returns a new :class:`DataFrame` by adding a column.
1205+
"""
1206+
Returns a new :class:`DataFrame` by adding a column or replacing the
1207+
existing column that has the same name.
12061208
12071209
:param colName: string, name of the new column.
12081210
:param col: a :class:`Column` expression for the new column.
12091211
12101212
>>> df.withColumn('age2', df.age + 2).collect()
12111213
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
12121214
"""
1213-
return self.select('*', col.alias(colName))
1215+
assert isinstance(col, Column), "col should be Column"
1216+
return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)
12141217

12151218
@ignore_unicode_prefix
12161219
@since(1.3)
@@ -1223,10 +1226,8 @@ def withColumnRenamed(self, existing, new):
12231226
>>> df.withColumnRenamed('age', 'age2').collect()
12241227
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
12251228
"""
1226-
cols = [Column(_to_java_column(c)).alias(new)
1227-
if c == existing else c
1228-
for c in self.columns]
1229-
return self.select(*cols)
1229+
assert existing in self.columns, "%s is not an existing column" % existing
1230+
return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)
12301231

12311232
@since(1.4)
12321233
@ignore_unicode_prefix

python/pyspark/sql/tests.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,10 @@ def test_capture_illegalargument_exception(self):
10351035
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
10361036
lambda: df.select(sha2(df.a, 1024)).collect())
10371037

1038+
def test_with_column_with_existing_name(self):
1039+
keys = self.df.withColumn("key", self.df.key).select("key").collect()
1040+
self.assertEqual([r.key for r in keys], list(range(100)))
1041+
10381042

10391043
class HiveContextSQLTests(ReusedPySparkTestCase):
10401044

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,8 @@ class DataFrame private[sql](
11331133
/////////////////////////////////////////////////////////////////////////////
11341134

11351135
/**
1136-
* Returns a new [[DataFrame]] by adding a column.
1136+
* Returns a new [[DataFrame]] by adding a column or replacing the existing column that has
1137+
* the same name.
11371138
* @group dfops
11381139
* @since 1.3.0
11391140
*/

0 commit comments

Comments
 (0)