Skip to content
Prev Previous commit
Next Next commit
simplify
  • Loading branch information
williebsweet committed Jul 30, 2025
commit 959530490e1b8093f3f46970df3b82d5754f684a
56 changes: 7 additions & 49 deletions sqlmesh/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,57 +1163,15 @@ def _fetchdf_athena_pandas_cursor(self, context: Context, sql: str) -> "pd.DataF
except ImportError as e:
raise MagicError(f"PyAthena with pandas support is required: {e}")

# Use SQLMesh's transpilation to convert SQL to Athena dialect
# This handles features like QUALIFY that need transpilation
try:
# Parse the SQL string into a SQLGlot expression first
from sqlmesh.core.dialect import parse
parsed_expressions = parse(sql, default_dialect=context.config.dialect)

# Get the first expression (should be a SELECT statement)
if parsed_expressions:
transpiled_sql = context.engine_adapter._to_sql(parsed_expressions[0], quote=False)
else:
raise ValueError("No valid SQL expressions found")

except Exception as e:
context.console.log_error(f"SQL transpilation failed: {e}")
# Fall back to the regular fetchdf method if transpilation fails
return context.fetchdf(sql)
conn_config = context.config.get_connection(context.config.default_connection)
connection_kwargs = {
k: v for k, v in conn_config.dict().items()
if k in conn_config._connection_kwargs_keys and v is not None
}
cursor = connect(cursor_class=PandasCursor, **connection_kwargs).cursor()
return cursor.execute(sql).as_pandas()

# Get the connection configuration for Athena
conn_config = context.config.get_connection(context.config.default_connection)

# Build connection kwargs using the same logic as SQLMesh
connection_kwargs = {
k: v for k, v in conn_config.dict().items()
if k in conn_config._connection_kwargs_keys and v is not None
}

# Create connection with PandasCursor specifically
try:
with connect(
cursor_class=PandasCursor,
**connection_kwargs
) as conn:
with conn.cursor() as cursor:
cursor.execute(transpiled_sql)

# PyAthena PandasCursor needs to be converted to DataFrame manually
# It returns data but we need to use pandas.DataFrame constructor
data = cursor.fetchall()

if data:
# Get column names from cursor description
columns = [desc[0] for desc in cursor.description] if cursor.description else None
df = pd.DataFrame(data, columns=columns)
else:
# Empty result set
columns = [desc[0] for desc in cursor.description] if cursor.description else []
df = pd.DataFrame(columns=columns)

return df

except Exception as e:
# Fall back to the regular fetchdf method if PandasCursor fails
context.console.log_error(f"PandasCursor failed, falling back to standard method: {e}")
Expand Down
Loading