Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,18 @@ def freqItems(self, cols, support=None):
@ignore_unicode_prefix
@since(1.3)
def withColumn(self, colName, col):
"""Returns a new :class:`DataFrame` by adding a column.
"""
Returns a new :class:`DataFrame` by adding a column or replacing the
existing column that has the same name.

:param colName: string, name of the new column.
:param col: a :class:`Column` expression for the new column.

>>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))
assert isinstance(col, Column), "col should be Column"
return DataFrame(self._jdf.withColumn(colName, col._jc), self.sql_ctx)

@ignore_unicode_prefix
@since(1.3)
Expand All @@ -1223,10 +1226,8 @@ def withColumnRenamed(self, existing, new):
>>> df.withColumnRenamed('age', 'age2').collect()
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
"""
cols = [Column(_to_java_column(c)).alias(new)
if c == existing else c
for c in self.columns]
return self.select(*cols)
assert existing in self.columns, "%s is not an existing column" % existing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not consistent with scala version, see the doc, we should not report error here.

Actually I'm wondering why we need to do checking at python side(not only this one)? Can we just call the scala API and catch the java exception?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, will remove it. I'm really surprised by the behavior in Scala.

The reason we want to have some check on Python side is that the Java exception is not easy to understand for Python programmer (nested inside a Py4j exception). The Python exception or messages does improve the experience for Python programmer, especially beginners.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgive my lack of python skills... could have have a generic wrapper that we use whenever we call back into scala that catches and unwraps certain expressions nicely?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already did this for some of them, for example, AnalysisException and IllegalArgumentException, these could help.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, ideally we will pretty much only throw those two unless there is a bug. If there are cases where that is not true we should consider fixing on the scala side.

return DataFrame(self._jdf.withColumnRenamed(existing, new), self.sql_ctx)

@since(1.4)
@ignore_unicode_prefix
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,10 @@ def test_capture_illegalargument_exception(self):
self.assertRaisesRegexp(IllegalArgumentException, "1024 is not in the permitted values",
lambda: df.select(sha2(df.a, 1024)).collect())

def test_with_column_with_existing_name(self):
keys = self.df.withColumn("key", self.df.key).select("key").collect()
self.assertEqual([r.key for r in keys], list(range(100)))


class HiveContextSQLTests(ReusedPySparkTestCase):

Expand Down
3 changes: 2 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,8 @@ class DataFrame private[sql](
/////////////////////////////////////////////////////////////////////////////

/**
* Returns a new [[DataFrame]] by adding a column.
* Returns a new [[DataFrame]] by adding a column or replacing the existing column that has
* the same name.
* @group dfops
* @since 1.3.0
*/
Expand Down