Skip to content

Commit 6d8aacd

Browse files
ovrvasilev-alex
andauthored
feat: Support expression in SET statement (apache#574)
Co-authored-by: Alex Vasilev <[email protected]>
1 parent eb7f1b0 commit 6d8aacd

7 files changed

Lines changed: 63 additions & 52 deletions

File tree

src/ast/mod.rs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,10 @@ impl fmt::Display for Expr {
545545
Expr::UnaryOp { op, expr } => {
546546
if op == &UnaryOperator::PGPostfixFactorial {
547547
write!(f, "{}{}", expr, op)
548-
} else {
548+
} else if op == &UnaryOperator::Not {
549549
write!(f, "{} {}", op, expr)
550+
} else {
551+
write!(f, "{}{}", op, expr)
550552
}
551553
}
552554
Expr::Cast { expr, data_type } => write!(f, "CAST({} AS {})", expr, data_type),
@@ -1100,7 +1102,7 @@ pub enum Statement {
11001102
local: bool,
11011103
hivevar: bool,
11021104
variable: ObjectName,
1103-
value: Vec<SetVariableValue>,
1105+
value: Vec<Expr>,
11041106
},
11051107
/// SET NAMES 'charset_name' [COLLATE 'collation_name']
11061108
///
@@ -2745,23 +2747,6 @@ impl fmt::Display for ShowStatementFilter {
27452747
}
27462748
}
27472749

2748-
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2749-
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
2750-
pub enum SetVariableValue {
2751-
Ident(Ident),
2752-
Literal(Value),
2753-
}
2754-
2755-
impl fmt::Display for SetVariableValue {
2756-
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
2757-
use SetVariableValue::*;
2758-
match self {
2759-
Ident(ident) => write!(f, "{}", ident),
2760-
Literal(literal) => write!(f, "{}", literal),
2761-
}
2762-
}
2763-
}
2764-
27652750
/// Sqlite specific syntax
27662751
///
27672752
/// https://sqlite.org/lang_conflict.html

src/dialect/mysql.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ impl Dialect for MySqlDialect {
2424
|| ('A'..='Z').contains(&ch)
2525
|| ch == '_'
2626
|| ch == '$'
27+
|| ch == '@'
2728
|| ('\u{0080}'..='\u{ffff}').contains(&ch)
2829
}
2930

src/parser.rs

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3751,22 +3751,12 @@ impl<'a> Parser<'a> {
37513751
} else if self.consume_token(&Token::Eq) || self.parse_keyword(Keyword::TO) {
37523752
let mut values = vec![];
37533753
loop {
3754-
let token = self.peek_token();
3755-
let value = match (self.parse_value(), token) {
3756-
(Ok(value), _) => SetVariableValue::Literal(value),
3757-
(Err(_), Token::Word(ident)) => SetVariableValue::Ident(ident.to_ident()),
3758-
(Err(_), Token::Minus) => {
3759-
let next_token = self.next_token();
3760-
match next_token {
3761-
Token::Word(ident) => SetVariableValue::Ident(Ident {
3762-
quote_style: ident.quote_style,
3763-
value: format!("-{}", ident.value),
3764-
}),
3765-
_ => self.expected("word", next_token)?,
3766-
}
3767-
}
3768-
(Err(_), unexpected) => self.expected("variable value", unexpected)?,
3754+
let value = if let Ok(expr) = self.parse_expr() {
3755+
expr
3756+
} else {
3757+
self.expected("variable value", self.peek_token())?
37693758
};
3759+
37703760
values.push(value);
37713761
if self.consume_token(&Token::Comma) {
37723762
continue;

tests/sqlparser_common.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ fn parse_select_count_wildcard() {
580580

581581
#[test]
582582
fn parse_select_count_distinct() {
583-
let sql = "SELECT COUNT(DISTINCT + x) FROM customer";
583+
let sql = "SELECT COUNT(DISTINCT +x) FROM customer";
584584
let select = verified_only_select(sql);
585585
assert_eq!(
586586
&Expr::Function(Function {
@@ -597,8 +597,8 @@ fn parse_select_count_distinct() {
597597
);
598598

599599
one_statement_parses_to(
600-
"SELECT COUNT(ALL + x) FROM customer",
601-
"SELECT COUNT(+ x) FROM customer",
600+
"SELECT COUNT(ALL +x) FROM customer",
601+
"SELECT COUNT(+x) FROM customer",
602602
);
603603

604604
let sql = "SELECT COUNT(ALL DISTINCT + x) FROM customer";
@@ -754,7 +754,7 @@ fn parse_compound_expr_2() {
754754
#[test]
755755
fn parse_unary_math() {
756756
use self::Expr::*;
757-
let sql = "- a + - b";
757+
let sql = "-a + -b";
758758
assert_eq!(
759759
BinaryOp {
760760
left: Box::new(UnaryOp {

tests/sqlparser_hive.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
//! Test SQL syntax specific to Hive. The parser based on the generic dialect
1616
//! is also tested (on the inputs it can handle).
1717
18-
use sqlparser::ast::{CreateFunctionUsing, Ident, ObjectName, SetVariableValue, Statement};
18+
use sqlparser::ast::{CreateFunctionUsing, Expr, Ident, ObjectName, Statement, UnaryOperator};
1919
use sqlparser::dialect::{GenericDialect, HiveDialect};
2020
use sqlparser::parser::ParserError;
2121
use sqlparser::test_utils::*;
@@ -220,14 +220,17 @@ fn set_statement_with_minus() {
220220
Ident::new("java"),
221221
Ident::new("opts")
222222
]),
223-
value: vec![SetVariableValue::Ident("-Xmx4g".into())],
223+
value: vec![Expr::UnaryOp {
224+
op: UnaryOperator::Minus,
225+
expr: Box::new(Expr::Identifier(Ident::new("Xmx4g")))
226+
}],
224227
}
225228
);
226229

227230
assert_eq!(
228231
hive().parse_sql_statements("SET hive.tez.java.opts = -"),
229232
Err(ParserError::ParserError(
230-
"Expected word, found: EOF".to_string()
233+
"Expected variable value, found: EOF".to_string()
231234
))
232235
)
233236
}

tests/sqlparser_mysql.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,26 @@ fn parse_use() {
251251
);
252252
}
253253

254+
#[test]
255+
fn parse_set_variables() {
256+
mysql_and_generic().verified_stmt("SET sql_mode = CONCAT(@@sql_mode, ',STRICT_TRANS_TABLES')");
257+
assert_eq!(
258+
mysql_and_generic().verified_stmt("SET LOCAL autocommit = 1"),
259+
Statement::SetVariable {
260+
local: true,
261+
hivevar: false,
262+
variable: ObjectName(vec!["autocommit".into()]),
263+
value: vec![Expr::Value(Value::Number(
264+
#[cfg(not(feature = "bigdecimal"))]
265+
"1".to_string(),
266+
#[cfg(feature = "bigdecimal")]
267+
bigdecimal::BigDecimal::from(1),
268+
false
269+
))],
270+
}
271+
);
272+
}
273+
254274
#[test]
255275
fn parse_create_table_auto_increment() {
256276
let sql = "CREATE TABLE foo (bar INT PRIMARY KEY AUTO_INCREMENT)";

tests/sqlparser_postgres.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
mod test_utils;
1919
use test_utils::*;
2020

21-
use sqlparser::ast::Value::Boolean;
2221
use sqlparser::ast::*;
2322
use sqlparser::dialect::{GenericDialect, PostgreSqlDialect};
2423
use sqlparser::parser::ParserError;
@@ -782,7 +781,10 @@ fn parse_set() {
782781
local: false,
783782
hivevar: false,
784783
variable: ObjectName(vec![Ident::new("a")]),
785-
value: vec![SetVariableValue::Ident("b".into())],
784+
value: vec![Expr::Identifier(Ident {
785+
value: "b".into(),
786+
quote_style: None
787+
})],
786788
}
787789
);
788790

@@ -793,9 +795,7 @@ fn parse_set() {
793795
local: false,
794796
hivevar: false,
795797
variable: ObjectName(vec![Ident::new("a")]),
796-
value: vec![SetVariableValue::Literal(Value::SingleQuotedString(
797-
"b".into()
798-
))],
798+
value: vec![Expr::Value(Value::SingleQuotedString("b".into()))],
799799
}
800800
);
801801

@@ -806,7 +806,13 @@ fn parse_set() {
806806
local: false,
807807
hivevar: false,
808808
variable: ObjectName(vec![Ident::new("a")]),
809-
value: vec![SetVariableValue::Literal(number("0"))],
809+
value: vec![Expr::Value(Value::Number(
810+
#[cfg(not(feature = "bigdecimal"))]
811+
"0".to_string(),
812+
#[cfg(feature = "bigdecimal")]
813+
bigdecimal::BigDecimal::from(0),
814+
false,
815+
))],
810816
}
811817
);
812818

@@ -817,7 +823,10 @@ fn parse_set() {
817823
local: false,
818824
hivevar: false,
819825
variable: ObjectName(vec![Ident::new("a")]),
820-
value: vec![SetVariableValue::Ident("DEFAULT".into())],
826+
value: vec![Expr::Identifier(Ident {
827+
value: "DEFAULT".into(),
828+
quote_style: None
829+
})],
821830
}
822831
);
823832

@@ -828,7 +837,7 @@ fn parse_set() {
828837
local: true,
829838
hivevar: false,
830839
variable: ObjectName(vec![Ident::new("a")]),
831-
value: vec![SetVariableValue::Ident("b".into())],
840+
value: vec![Expr::Identifier("b".into())],
832841
}
833842
);
834843

@@ -839,7 +848,10 @@ fn parse_set() {
839848
local: false,
840849
hivevar: false,
841850
variable: ObjectName(vec![Ident::new("a"), Ident::new("b"), Ident::new("c")]),
842-
value: vec![SetVariableValue::Ident("b".into())],
851+
value: vec![Expr::Identifier(Ident {
852+
value: "b".into(),
853+
quote_style: None
854+
})],
843855
}
844856
);
845857

@@ -859,7 +871,7 @@ fn parse_set() {
859871
Ident::new("reducer"),
860872
Ident::new("parallelism")
861873
]),
862-
value: vec![SetVariableValue::Literal(Boolean(false))],
874+
value: vec![Expr::Value(Value::Boolean(false))],
863875
}
864876
);
865877

@@ -1107,7 +1119,7 @@ fn parse_pg_unary_ops() {
11071119
];
11081120

11091121
for (str_op, op) in pg_unary_ops {
1110-
let select = pg().verified_only_select(&format!("SELECT {} a", &str_op));
1122+
let select = pg().verified_only_select(&format!("SELECT {}a", &str_op));
11111123
assert_eq!(
11121124
SelectItem::UnnamedExpr(Expr::UnaryOp {
11131125
op: op.clone(),

0 commit comments

Comments
 (0)