Skip to content
56 changes: 47 additions & 9 deletions src/ast/ddl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3178,6 +3178,26 @@ impl Spanned for RenameTableNameKind {
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
/// Whether the syntax used for the trigger object (ROW or STATEMENT) is `FOR` or `FOR EACH`.
pub enum TriggerObjectKind {
/// The `FOR` syntax is used.
For(TriggerObject),
/// The `FOR EACH` syntax is used.
ForEach(TriggerObject),
}

impl Display for TriggerObjectKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TriggerObjectKind::For(obj) => write!(f, "FOR {obj}"),
TriggerObjectKind::ForEach(obj) => write!(f, "FOR EACH {obj}"),
}
}
}

#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
Expand All @@ -3199,6 +3219,23 @@ pub struct CreateTrigger {
///
/// [MsSql](https://learn.microsoft.com/en-us/sql/t-sql/statements/create-trigger-transact-sql?view=sql-server-ver16#arguments)
pub or_alter: bool,
/// True if this is a temporary trigger.
///
/// Examples:
///
/// ```sql
/// CREATE TEMP TRIGGER trigger_name
/// ```
///
/// or
///
/// ```sql
/// CREATE TEMPORARY TRIGGER trigger_name;
/// CREATE TEMP TRIGGER trigger_name;
/// ```
///
/// [SQLite](https://sqlite.org/lang_createtrigger.html#temp_triggers_on_non_temp_tables)
pub temporary: bool,
/// The `OR REPLACE` clause is used to re-create the trigger if it already exists.
///
/// Example:
Expand Down Expand Up @@ -3243,6 +3280,8 @@ pub struct CreateTrigger {
/// ```
pub period: TriggerPeriod,
/// Whether the trigger period was specified before the target table name.
/// This does not refer to whether the period is BEFORE, AFTER, or INSTEAD OF,
/// but rather the position of the period clause in relation to the table name.
///
/// ```sql
/// -- period_before_table == true: Postgres, MySQL, and standard SQL
Expand All @@ -3262,9 +3301,9 @@ pub struct CreateTrigger {
pub referencing: Vec<TriggerReferencing>,
/// This specifies whether the trigger function should be fired once for
/// every row affected by the trigger event, or just once per SQL statement.
pub trigger_object: TriggerObject,
/// Whether to include the `EACH` term of the `FOR EACH`, as it is optional syntax.
pub include_each: bool,
/// This is optional in some SQL dialects, such as SQLite, and if not specified, in
/// those cases, the implied default is `FOR EACH ROW`.
pub trigger_object: Option<TriggerObjectKind>,
/// Triggering conditions
pub condition: Option<Expr>,
/// Execute logic block
Expand All @@ -3281,6 +3320,7 @@ impl Display for CreateTrigger {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let CreateTrigger {
or_alter,
temporary,
or_replace,
is_constraint,
name,
Expand All @@ -3292,15 +3332,15 @@ impl Display for CreateTrigger {
referencing,
trigger_object,
condition,
include_each,
exec_body,
statements_as,
statements,
characteristics,
} = self;
write!(
f,
"CREATE {or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
"CREATE {temporary}{or_alter}{or_replace}{is_constraint}TRIGGER {name} ",
temporary = if *temporary { "TEMPORARY " } else { "" },
or_alter = if *or_alter { "OR ALTER " } else { "" },
or_replace = if *or_replace { "OR REPLACE " } else { "" },
is_constraint = if *is_constraint { "CONSTRAINT " } else { "" },
Expand Down Expand Up @@ -3332,10 +3372,8 @@ impl Display for CreateTrigger {
write!(f, " REFERENCING {}", display_separated(referencing, " "))?;
}

if *include_each {
write!(f, " FOR EACH {trigger_object}")?;
} else if exec_body.is_some() {
write!(f, " FOR {trigger_object}")?;
if let Some(trigger_object) = trigger_object {
write!(f, " {trigger_object}")?;
}
if let Some(condition) = condition {
write!(f, " WHEN {condition}")?;
Expand Down
4 changes: 2 additions & 2 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ pub use self::ddl::{
IdentityProperty, IdentityPropertyFormatKind, IdentityPropertyKind, IdentityPropertyOrder,
IndexColumn, IndexOption, IndexType, KeyOrIndexDisplay, NullsDistinctOption, Owner, Partition,
ProcedureParam, ReferentialAction, RenameTableNameKind, ReplicaIdentity, TableConstraint,
TagsColumnOption, UserDefinedTypeCompositeAttributeDef, UserDefinedTypeRepresentation,
ViewColumnDef,
TagsColumnOption, TriggerObjectKind, UserDefinedTypeCompositeAttributeDef,
UserDefinedTypeRepresentation, ViewColumnDef,
};
pub use self::dml::{Delete, Insert};
pub use self::operator::{BinaryOperator, UnaryOperator};
Expand Down
6 changes: 3 additions & 3 deletions src/dialect/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::ast::helpers::attached_token::AttachedToken;
use crate::ast::{
BeginEndStatements, ConditionalStatementBlock, ConditionalStatements, CreateTrigger,
GranteesType, IfStatement, Statement, TriggerObject,
GranteesType, IfStatement, Statement,
};
use crate::dialect::Dialect;
use crate::keywords::{self, Keyword};
Expand Down Expand Up @@ -254,6 +254,7 @@ impl MsSqlDialect {

Ok(CreateTrigger {
or_alter,
temporary: false,
or_replace: false,
is_constraint: false,
name,
Expand All @@ -263,8 +264,7 @@ impl MsSqlDialect {
table_name,
referenced_table_name: None,
referencing: Vec::new(),
trigger_object: TriggerObject::Statement,
include_each: false,
trigger_object: None,
condition: None,
exec_body: None,
statements_as: true,
Expand Down
47 changes: 32 additions & 15 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4750,9 +4750,9 @@ impl<'a> Parser<'a> {
} else if self.parse_keyword(Keyword::DOMAIN) {
self.parse_create_domain()
} else if self.parse_keyword(Keyword::TRIGGER) {
self.parse_create_trigger(or_alter, or_replace, false)
self.parse_create_trigger(temporary, or_alter, or_replace, false)
} else if self.parse_keywords(&[Keyword::CONSTRAINT, Keyword::TRIGGER]) {
self.parse_create_trigger(or_alter, or_replace, true)
self.parse_create_trigger(temporary, or_alter, or_replace, true)
} else if self.parse_keyword(Keyword::MACRO) {
self.parse_create_macro(or_replace, temporary)
} else if self.parse_keyword(Keyword::SECRET) {
Expand Down Expand Up @@ -5546,7 +5546,8 @@ impl<'a> Parser<'a> {
/// DROP TRIGGER [ IF EXISTS ] name ON table_name [ CASCADE | RESTRICT ]
/// ```
pub fn parse_drop_trigger(&mut self) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | SQLiteDialect | GenericDialect | MySqlDialect | MsSqlDialect)
{
self.prev_token();
return self.expected("an object type after DROP", self.peek_token());
}
Expand Down Expand Up @@ -5574,11 +5575,13 @@ impl<'a> Parser<'a> {

pub fn parse_create_trigger(
&mut self,
temporary: bool,
or_alter: bool,
or_replace: bool,
is_constraint: bool,
) -> Result<Statement, ParserError> {
if !dialect_of!(self is PostgreSqlDialect | GenericDialect | MySqlDialect | MsSqlDialect) {
if !dialect_of!(self is PostgreSqlDialect | SQLiteDialect | GenericDialect | MySqlDialect | MsSqlDialect)
{
self.prev_token();
return self.expected("an object type after CREATE", self.peek_token());
}
Expand All @@ -5605,14 +5608,27 @@ impl<'a> Parser<'a> {
}
}

self.expect_keyword_is(Keyword::FOR)?;
let include_each = self.parse_keyword(Keyword::EACH);
let trigger_object =
match self.expect_one_of_keywords(&[Keyword::ROW, Keyword::STATEMENT])? {
Keyword::ROW => TriggerObject::Row,
Keyword::STATEMENT => TriggerObject::Statement,
_ => unreachable!(),
};
let trigger_object = if self.parse_keyword(Keyword::FOR) {
let include_each = self.parse_keyword(Keyword::EACH);
let trigger_object =
match self.expect_one_of_keywords(&[Keyword::ROW, Keyword::STATEMENT])? {
Keyword::ROW => TriggerObject::Row,
Keyword::STATEMENT => TriggerObject::Statement,
_ => unreachable!(),
};

Some(if include_each {
TriggerObjectKind::ForEach(trigger_object)
} else {
TriggerObjectKind::For(trigger_object)
})
} else {
if !dialect_of!(self is SQLiteDialect ) {
self.expect_keyword_is(Keyword::FOR)?;
}
Comment on lines +5626 to +5628
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if !dialect_of!(self is SQLiteDialect ) {
self.expect_keyword_is(Keyword::FOR)?;
}
self.expect_keyword_is(Keyword::FOR)?;

Can we skip the dialect guard here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems this comment is unresolved? ideally we want to skip dialect_of checks where possible


None
};

let condition = self
.parse_keyword(Keyword::WHEN)
Expand All @@ -5627,8 +5643,9 @@ impl<'a> Parser<'a> {
statements = Some(self.parse_conditional_statements(&[Keyword::END])?);
}

Ok(Statement::CreateTrigger(CreateTrigger {
Ok(CreateTrigger {
or_alter,
temporary,
or_replace,
is_constraint,
name,
Expand All @@ -5639,13 +5656,13 @@ impl<'a> Parser<'a> {
referenced_table_name,
referencing,
trigger_object,
include_each,
condition,
exec_body,
statements_as: false,
statements,
characteristics,
}))
}
.into())
}

pub fn parse_trigger_period(&mut self) -> Result<TriggerPeriod, ParserError> {
Expand Down
4 changes: 2 additions & 2 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2386,6 +2386,7 @@ fn parse_create_trigger() {
create_stmt,
Statement::CreateTrigger(CreateTrigger {
or_alter: true,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("reminder1")]),
Expand All @@ -2395,8 +2396,7 @@ fn parse_create_trigger() {
table_name: ObjectName::from(vec![Ident::new("Sales"), Ident::new("Customer")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Statement,
include_each: false,
trigger_object: None,
condition: None,
exec_body: None,
statements_as: true,
Expand Down
4 changes: 2 additions & 2 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3924,6 +3924,7 @@ fn parse_create_trigger() {
create_stmt,
Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("emp_stamp")]),
Expand All @@ -3933,8 +3934,7 @@ fn parse_create_trigger() {
table_name: ObjectName::from(vec![Ident::new("emp")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand Down
24 changes: 12 additions & 12 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5673,6 +5673,7 @@ fn parse_create_simple_before_insert_trigger() {
let sql = "CREATE TRIGGER check_insert BEFORE INSERT ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_insert";
let expected = Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("check_insert")]),
Expand All @@ -5682,8 +5683,7 @@ fn parse_create_simple_before_insert_trigger() {
table_name: ObjectName::from(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand All @@ -5705,6 +5705,7 @@ fn parse_create_after_update_trigger_with_condition() {
let sql = "CREATE TRIGGER check_update AFTER UPDATE ON accounts FOR EACH ROW WHEN (NEW.balance > 10000) EXECUTE FUNCTION check_account_update";
let expected = Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("check_update")]),
Expand All @@ -5714,8 +5715,7 @@ fn parse_create_after_update_trigger_with_condition() {
table_name: ObjectName::from(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: Some(Expr::Nested(Box::new(Expr::BinaryOp {
left: Box::new(Expr::CompoundIdentifier(vec![
Ident::new("NEW"),
Expand Down Expand Up @@ -5744,6 +5744,7 @@ fn parse_create_instead_of_delete_trigger() {
let sql = "CREATE TRIGGER check_delete INSTEAD OF DELETE ON accounts FOR EACH ROW EXECUTE FUNCTION check_account_deletes";
let expected = Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("check_delete")]),
Expand All @@ -5753,8 +5754,7 @@ fn parse_create_instead_of_delete_trigger() {
table_name: ObjectName::from(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand All @@ -5776,6 +5776,7 @@ fn parse_create_trigger_with_multiple_events_and_deferrable() {
let sql = "CREATE CONSTRAINT TRIGGER check_multiple_events BEFORE INSERT OR UPDATE OR DELETE ON accounts DEFERRABLE INITIALLY DEFERRED FOR EACH ROW EXECUTE FUNCTION check_account_changes";
let expected = Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: true,
name: ObjectName::from(vec![Ident::new("check_multiple_events")]),
Expand All @@ -5789,8 +5790,7 @@ fn parse_create_trigger_with_multiple_events_and_deferrable() {
table_name: ObjectName::from(vec![Ident::new("accounts")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand All @@ -5816,6 +5816,7 @@ fn parse_create_trigger_with_referencing() {
let sql = "CREATE TRIGGER check_referencing BEFORE INSERT ON accounts REFERENCING NEW TABLE AS new_accounts OLD TABLE AS old_accounts FOR EACH ROW EXECUTE FUNCTION check_account_referencing";
let expected = Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("check_referencing")]),
Expand All @@ -5836,8 +5837,7 @@ fn parse_create_trigger_with_referencing() {
transition_relation_name: ObjectName::from(vec![Ident::new("old_accounts")]),
},
],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand Down Expand Up @@ -6132,6 +6132,7 @@ fn parse_trigger_related_functions() {
create_trigger,
Statement::CreateTrigger(CreateTrigger {
or_alter: false,
temporary: false,
or_replace: false,
is_constraint: false,
name: ObjectName::from(vec![Ident::new("emp_stamp")]),
Expand All @@ -6141,8 +6142,7 @@ fn parse_trigger_related_functions() {
table_name: ObjectName::from(vec![Ident::new("emp")]),
referenced_table_name: None,
referencing: vec![],
trigger_object: TriggerObject::Row,
include_each: true,
trigger_object: Some(TriggerObjectKind::ForEach(TriggerObject::Row)),
condition: None,
exec_body: Some(TriggerExecBody {
exec_type: TriggerExecBodyType::Function,
Expand Down
Loading