From 824a5cdb434637ac57b72e2695f78407402d2aab Mon Sep 17 00:00:00 2001 From: Nicholas Orlowsky Date: Sat, 23 Dec 2023 12:03:25 -0500 Subject: [PATCH] logic exprs --- CHANGELOG.md | 8 + README.md | 14 +- squirrel-client/src/main.rs | 7 +- squirrel-client/test_queries.sql | 32 +++ squirrel-server/src/main.rs | 128 +++++++++-- squirrel-server/src/parser/command.rs | 311 +++++++++++++++++++++++++- 6 files changed, 474 insertions(+), 26 deletions(-) create mode 100644 CHANGELOG.md create mode 100644 squirrel-client/test_queries.sql diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0079cf5 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,8 @@ +# Changelog + +## 12/23/23 + +- Added LogicExpressions for creating and evaluating expressions +- Added LogicExpressions to SELECT and DELETE commands (i.e SELECT FROM WHERE and DELETE FROM WHERE) +- Updated SELECT formatting to pad table columns +- General system stability improvements to enhance the user's experience. diff --git a/README.md b/README.md index 01d14d7..b328c5e 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,19 @@ This is a SQL database written in Rust. It will be based off of (and hopefully b [x] DELETE command -[ ] WHERE clause for SELECT and DELETE +[x] WHERE clause for SELECT and DELETE + +[ ] Create squirrel-core library for shared code between client & server + +[ ] Update parser to use common logic to identify 'objects' (i.e function calls, column references, and variables) + +[ ] Move parsing to client + +[ ] Create better logging + +[ ] UPDATE command + +[ ] Prune deleted records from disk [ ] Primary Keys via B+ Tree diff --git a/squirrel-client/src/main.rs b/squirrel-client/src/main.rs index 556ade9..0dacab4 100644 --- a/squirrel-client/src/main.rs +++ b/squirrel-client/src/main.rs @@ -6,13 +6,15 @@ fn main() { match TcpStream::connect("localhost:5433") { Ok(mut stream) => { println!("Connected to Database"); - loop { print!("SQUIRREL: "); io::stdout().flush().unwrap(); let mut msg_str = String::new(); - std::io::stdin().read_line(&mut msg_str).unwrap(); + let bytes = std::io::stdin().read_line(&mut msg_str).unwrap(); + if bytes == 0 { + break; + } let msg = msg_str.as_bytes(); stream.write(msg).unwrap(); @@ -32,4 +34,5 @@ fn main() { println!("Failed to connect: {}", e); } } + println!("Goodbye!"); } diff --git a/squirrel-client/test_queries.sql b/squirrel-client/test_queries.sql new file mode 100644 index 0000000..78f0af9 --- /dev/null +++ b/squirrel-client/test_queries.sql @@ -0,0 +1,32 @@ +CREATE TABLE users (id int, first_name varchar 128, last_name varchar 128, address varchar 128, age int); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (1, "Alex", "Karev", "613 Harper Lane Seattle, Washington", 33); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (2, "Richard", "Hendricks", "5230 Newell Road Palo Alto, California", 24); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (3, "Dinesh", "Chugati", "5230 Newell Road Palo Alto, California", 23); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (4, "Michael", "Scott", "1725 Slough Avenue Scranton, Pennsylvania", 40); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (5, "Walter", "White", "308 Negra Arroyo Lane Albuquerque, New Mexico", 50); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (6, "Jerry", "Seinfeld", "129 West 81st Street Apartment 5A New York, New York", 38); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (7, "Elaine", "Benes", "162 Riverside Drive Apartment 3E New York, New York", 36); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (8, "George", "Costanza", "129 West 81st Street Apartment 4B New York, New York", 39); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (9, "Cosmo", "Kramer", "129 West 81st Street Apartment 5B New York, New York", 41); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (10, "Newman", "Newman", "The Postal Office, New York, New York", 45); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (11, "Frank", "Costanza", "329 West 81st Street Apartment 5A New York, New York", 68); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (12, "Estelle", "Costanza", "Del Boca Vista Phase III, Florida", 65); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (13, "Jesse", "Pinkman", "9809 Margo Street Albuquerque, New Mexico", 27); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (14, "Skyler", "White", "308 Negra Arroyo Lane Albuquerque, New Mexico", 42); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (15, "Saul", "Goodman", "160 Juan Tabo Boulevard, Suite 503 Albuquerque, New Mexico", 50); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (16, "Gus", "Fring", "Los Pollos Hermanos, Albuquerque, New Mexico", 48); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (17, "Hank", "Schrader", "4901 Cumbre del Sur Court Albuquerque, New Mexico", 43); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (18, "Wendy", "S.", "Riverside Motel, Albuquerque, New Mexico", 32); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (19, "Tuco", "Salamanca", "1230 Negra Arroyo Lane Albuquerque, New Mexico", 34); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (20, "Gale", "Boetticher", "308 Negra Arroyo Lane Albuquerque, New Mexico", 32); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (21, "Badger", "Mayhew", "RV, Somewhere in the Desert", 29); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (22, "Skinny", "Pete", "Apartment 23, Albuquerque, New Mexico", 30); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (23, "Meredith", "Grey", "Grey Sloan Memorial Hospital, Seattle, Washington", 35); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (24, "Derek", "Shepherd", "Grey Sloan Memorial Hospital, Seattle, Washington", 40); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (25, "Cristina", "Yang", "1234 Chief Webber's Apartment, Seattle, Washington", 32); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (26, "Miranda", "Bailey", "Grey Sloan Memorial Hospital, Seattle, Washington", 38); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (27, "Arizona", "Robbins", "567 Surgical Wing, Grey Sloan Memorial Hospital, Seattle, Washington", 37); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (28, "Callie", "Torres", "789 Orthopedic Wing, Grey Sloan Memorial Hospital, Seattle, Washington", 36); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (29, "Mark", "Sloan", "Grey Sloan Memorial Hospital, Seattle, Washington", 40); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (30, "April", "Kepner", "1010 Mercy West Hospital, Seattle, Washington", 30); +INSERT INTO users (id, first_name, last_name, address, age) VALUES (31, "Jackson", "Avery", "678 Plastic Surgery Wing, Grey Sloan Memorial Hospital, Seattle, Washington", 34); diff --git a/squirrel-server/src/main.rs b/squirrel-server/src/main.rs index 6f2c9fa..08a7a3c 100644 --- a/squirrel-server/src/main.rs +++ b/squirrel-server/src/main.rs @@ -4,12 +4,13 @@ use std::io::{BufRead, BufReader, Read, Write}; use std::net::{Shutdown, TcpListener, TcpStream}; use std::thread; use std::collections::HashMap; +use std::cmp; mod parser; pub use parser::command::Command; mod table; -use parser::command::{CreateCommand, InsertCommand, SelectCommand, DeleteCommand}; +use parser::command::{CreateCommand, InsertCommand, SelectCommand, DeleteCommand, LogicExpression, InsertItem, LogicValue}; pub use table::datatypes::Datatype; pub use table::table_definition::{ColumnDefinition, TableDefinition}; @@ -93,17 +94,40 @@ fn handle_delete(command: DeleteCommand) -> ::anyhow::Result { let mut buf: Vec = vec![0; tabledef.get_byte_size()]; let mut row_count: usize = 0; - while file.read_exact(buf.as_mut_slice()).is_ok() { - row_count += 1; - } - - let _ = fs::remove_file(format!("./data/blobs/{}", command.table_name)); - - let _ = fs::OpenOptions::new() + let mut new_file = fs::OpenOptions::new() .create(true) .write(true) .append(true) - .open(format!("./data/blobs/{}", command.table_name))?; + .open(format!("./data/blobs/{}_new", command.table_name))?; + + while file.read_exact(buf.as_mut_slice()).is_ok() { + let mut row_data: HashMap = HashMap::new(); + let mut idx: usize = 0; + if let Some(ref le) = command.logic_expression { + let mut logic_expr = le.clone(); + for col_def in &tabledef.column_defs { + let len = if col_def.length > 0 { + col_def.length + } else { + 1 + }; + let str_val = col_def.data_type.from_bytes(&buf[idx..(idx + len)])?; + idx += len; + row_data.insert(col_def.name.clone(), LogicValue::from_string(str_val)?); + } + idx = 0; + logic_expr.fill_values(row_data); + if !logic_expr.evaluate()? { + new_file.write_all(&buf)?; + continue; + } + } + row_count += 1; + } + new_file.flush()?; + + let _ = fs::remove_file(format!("./data/blobs/{}", command.table_name))?; + let _ = fs::rename(format!("./data/blobs/{}_new", command.table_name), format!("./data/blobs/{}", command.table_name))?; return Ok(format!("{} Rows Deleted", row_count)); } @@ -124,15 +148,10 @@ fn handle_select(command: SelectCommand) -> ::anyhow::Result { } } - response += "| "; - for col_name in &column_names { - response += format!("{} | ", col_name).as_str(); - } - response += "\n"; - response += "-----------\n"; let mut buf: Vec = vec![0; tabledef.get_byte_size()]; let mut table: HashMap> = HashMap::new(); + let mut longest_cols: HashMap = HashMap::new(); let mut num_rows: usize = 0; for col_name in &column_names { @@ -141,6 +160,26 @@ fn handle_select(command: SelectCommand) -> ::anyhow::Result { while file.read_exact(buf.as_mut_slice()).is_ok() { let mut idx: usize = 0; + let mut row_data: HashMap = HashMap::new(); + if let Some(ref le) = command.logic_expression { + let mut logic_expr = le.clone(); + for col_def in &tabledef.column_defs { + let len = if col_def.length > 0 { + col_def.length + } else { + 1 + }; + let str_val = col_def.data_type.from_bytes(&buf[idx..(idx + len)])?; + idx += len; + row_data.insert(col_def.name.clone(), LogicValue::from_string(str_val)?); + } + idx = 0; + logic_expr.fill_values(row_data); + if !logic_expr.evaluate()? { + continue; + } + } + for col_def in &tabledef.column_defs { let len = if col_def.length > 0 { col_def.length @@ -148,18 +187,36 @@ fn handle_select(command: SelectCommand) -> ::anyhow::Result { 1 }; if column_names.iter().any(|col_name| &col_def.name == col_name) { - let str_val = col_def.data_type.from_bytes(&buf[idx..(idx + len)])?; - table.get_mut(&col_def.name).unwrap().push(str_val); + let str_val = col_def.data_type.from_bytes(&buf[idx..(idx + len)])?.trim_matches(char::from(0)).to_string(); + table.get_mut(&col_def.name).unwrap().push(str_val.clone()); + longest_cols.entry(col_def.name.clone()).and_modify(|val| *val = cmp::max(*val, str_val.len())).or_insert(str_val.len()); } idx += len; } num_rows += 1; } + + + // construct table string + response += "| "; + for col_name in &column_names { + longest_cols.entry(col_name.clone()).and_modify(|val| *val = cmp::max(*val, col_name.len())).or_insert(col_name.len()); + response += format!("{:0width$} | ", col_name, width = longest_cols.get(col_name).unwrap()).as_str(); + } + let mut total_length: usize = 1; + for (col_name, max_len) in longest_cols.clone() { + total_length += max_len + 3; + } + response += "\n"; + for i in 0..total_length { + response += "-"; + } + response += "\n"; for i in 0..num_rows { response += "| "; for col_name in &column_names { - response += format!("{} | ", table.get(col_name).unwrap()[i]).as_str(); + response += format!("{:0width$} | ", table.get(col_name).unwrap()[i], width = longest_cols.get(col_name).unwrap()).as_str(); } response += "\n"; } @@ -261,6 +318,40 @@ fn main() -> std::io::Result<()> { Ok(()) } +#[test] +fn logical_expression() -> anyhow::Result<()> { + assert_eq!(Command::le_from_string(String::from("1 < 5")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("1 > 5")).unwrap().evaluate().unwrap(), false); + assert_eq!(Command::le_from_string(String::from("1 <= 5")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("1 >= 5")).unwrap().evaluate().unwrap(), false); + assert_eq!(Command::le_from_string(String::from("5 >= 5")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("5 <= 5")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("5 = 5")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("5 AND 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("5 OR 5")).unwrap().evaluate().is_ok(), false); + + assert_eq!(Command::le_from_string(String::from("'Test' = 'Test'")).unwrap().evaluate().unwrap(), true); + assert_eq!(Command::le_from_string(String::from("'Test' = 'Text'")).unwrap().evaluate().unwrap(), false); + assert_eq!(Command::le_from_string(String::from("'Test' <= 'Test'")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' >= 'Test'")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' < 'Test'")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' > 'Test'")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' AND 'Test'")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' OR 'Test'")).unwrap().evaluate().is_ok(), false); + + assert_eq!(Command::le_from_string(String::from("'Test' < 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' > 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' <= 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' >= 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' >= 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' <= 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' = 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' AND 5")).unwrap().evaluate().is_ok(), false); + assert_eq!(Command::le_from_string(String::from("'Test' OR 5")).unwrap().evaluate().is_ok(), false); + + Ok(()) +} + #[test] fn insert_statement() -> anyhow::Result<()> { let empty_statement = ""; @@ -335,5 +426,6 @@ fn insert_statement() -> anyhow::Result<()> { expected_output_comma ); + Ok(()) } diff --git a/squirrel-server/src/parser/command.rs b/squirrel-server/src/parser/command.rs index fe64142..bff00d5 100644 --- a/squirrel-server/src/parser/command.rs +++ b/squirrel-server/src/parser/command.rs @@ -1,4 +1,5 @@ use std::collections::{HashMap, HashSet}; +use std::mem; use crate::table::table_definition::ColumnDefinition; use crate::{Datatype, TableDefinition}; @@ -20,6 +21,7 @@ pub struct CreateCommand { #[derive(Debug, Eq, PartialEq)] pub struct DeleteCommand { pub table_name: String, + pub logic_expression: Option, } #[derive(Debug, Eq, PartialEq)] @@ -32,6 +34,7 @@ pub struct InsertCommand { pub struct SelectCommand { pub table_name: String, pub column_names: Vec, + pub logic_expression: Option, } #[derive(Debug, Eq, PartialEq)] @@ -40,6 +43,38 @@ pub struct InsertItem { pub column_value: String, } +#[derive(Debug, Eq, PartialEq, Clone)] +enum LogicalOperator { + Equal, + GreaterThan, + LessThan, + GreaterThanEqualTo, + LessThanEqualTo, + And, + Or, +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum LogicValue { + StringValue(String), + U8Value(u8), + BoolValue(bool), + ColumnName(String), +} + +#[derive(Debug, Eq, PartialEq)] +pub enum LogicSide { + // pub expression: LogicExpression, + Value(LogicValue) +} + +#[derive(Debug, Eq, PartialEq, Clone)] +pub struct LogicExpression { + pub left_hand: LogicValue, + pub right_hand: LogicValue, + pub operator: LogicalOperator +} + enum CreateParserState { Object, TableName, @@ -55,12 +90,14 @@ enum SelectParserState { ColumnName, ColumnNameCommaOrFrom, TableName, + WhereKeywordOrSemicolon, Semicolon, } enum DeleteParserState { FromKeyword, TableName, + WhereKeywordOrSemicolon, Semicolon, } @@ -77,8 +114,17 @@ enum InsertParserState { Semicolon, } +#[derive(Debug)] +enum LogicExpressionParserState { + NumberOrQuoteOrColname, + StringValue, + EndQuote, + Operator, +} + + pub fn tokenizer(text: String) -> Vec { - let parts = HashSet::from([' ', ',', ';', '(', ')']); + let parts = HashSet::from([' ', ',', ';', '(', ')', '\'']); let mut tokens: Vec = vec![]; let mut cur_str = String::new(); let mut in_quotes = false; @@ -100,10 +146,154 @@ pub fn tokenizer(text: String) -> Vec { cur_str.push(cur_char); } } + tokens.push(cur_str); tokens } +impl LogicValue { + pub fn from_string(string: String) -> ::anyhow::Result { + let test = string.parse::(); + match test { + Ok(u8_val) => { + return Ok(LogicValue::U8Value(u8_val)); + }, + Err(_) => { + let res = string.trim_matches(char::from(0)); + return Ok(LogicValue::StringValue(res.to_string())); + }, + } + } +} + +impl LogicExpression { + pub fn is_valid(&self) -> bool { + return mem::discriminant(&self.left_hand) == mem::discriminant(&self.right_hand); + } + + pub fn is_evaluatable(&self) -> bool { + return mem::discriminant(&self.left_hand) != mem::discriminant(&LogicValue::ColumnName(String::from(""))) && + mem::discriminant(&self.right_hand) != mem::discriminant(&LogicValue::ColumnName(String::from(""))); + } + + pub fn fill_values(&mut self, hmap: HashMap) -> ::anyhow::Result<()> { + for (name, value) in hmap { + if self.left_hand == LogicValue::ColumnName(name.clone()) { + self.left_hand = value.clone(); + } + if self.right_hand == LogicValue::ColumnName(name.clone()) { + self.right_hand = value.clone(); + } + } + Ok(()) + } + + pub fn evaluate(&self) -> ::anyhow::Result { + if !self.is_evaluatable() { + return Err(anyhow!("Logical expression has not been properly filled. (Do you have a typo in a column name?)")); + } + if !self.is_valid() { + return Err(anyhow!("Logical expression is comparing 2 differing datatypes")); + } + println!("{:?}", self); + match self.left_hand { + LogicValue::StringValue(_) => { + return self.evaluate_string(); + } + LogicValue::BoolValue(_) => { + return self.evaluate_bool(); + } + LogicValue::U8Value(_) => { + return self.evaluate_u8(); + } + LogicValue::ColumnName(_) => { + return Err(anyhow!("Cannot compare names of 2 columns, only values")); + } + } + } + + fn evaluate_string(&self) -> ::anyhow::Result { + match self.operator { + LogicalOperator::Equal => { + return Ok(self.left_hand == self.right_hand); + } + _ => { + return Err(anyhow!("Invalid operator for datatype varchar")); + } + } + } + + fn evaluate_bool(&self) -> ::anyhow::Result{ + match self.operator { + LogicalOperator::Equal => { + return Ok(self.left_hand == self.right_hand); + } + LogicalOperator::And => { + if let LogicValue::BoolValue(left) = self.left_hand { + if let LogicValue::BoolValue(right) = self.right_hand { + return Ok(left && right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + LogicalOperator::Or => { + if let LogicValue::BoolValue(left) = self.left_hand { + if let LogicValue::BoolValue(right) = self.right_hand { + return Ok(left || right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + _ => { + return Err(anyhow!("Invalid operator for datatype bool")); + } + } + } + + fn evaluate_u8(&self) -> ::anyhow::Result{ + match self.operator { + LogicalOperator::Equal => { + return Ok(self.left_hand == self.right_hand); + } + LogicalOperator::GreaterThan => { + if let LogicValue::U8Value(left) = self.left_hand { + if let LogicValue::U8Value(right) = self.right_hand { + return Ok(left > right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + LogicalOperator::LessThan => { + if let LogicValue::U8Value(left) = self.left_hand { + if let LogicValue::U8Value(right) = self.right_hand { + return Ok(left < right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + LogicalOperator::GreaterThanEqualTo => { + if let LogicValue::U8Value(left) = self.left_hand { + if let LogicValue::U8Value(right) = self.right_hand { + return Ok(left >= right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + LogicalOperator::LessThanEqualTo => { + if let LogicValue::U8Value(left) = self.left_hand { + if let LogicValue::U8Value(right) = self.right_hand { + return Ok(left <= right); + } + } + return Err(anyhow!("Mismatched datatypes")); + } + _ => { + return Err(anyhow!("Invalid operator for datatype integer")); + } + } + } +} + impl Command { fn parse_insert_command(tokens: &mut Vec) -> ::anyhow::Result { let mut state: InsertParserState = InsertParserState::IntoKeyword; @@ -218,12 +408,93 @@ impl Command { Err(anyhow!("Unexpected end of input")) } + fn parse_logic_expression(tokens: &mut Vec) -> ::anyhow::Result { + let mut state: LogicExpressionParserState = LogicExpressionParserState::NumberOrQuoteOrColname; + let mut left_hand: Option = None; + let mut right_hand: Option = None; + let mut operator: Option = None; + + while let Some(token) = &tokens.pop() { + match state { + LogicExpressionParserState::NumberOrQuoteOrColname => { + if token == "'" { + state = LogicExpressionParserState::StringValue; + } else { + let test = token.parse::(); + match test { + Ok(u8_val) => { + if left_hand.is_none() { + left_hand = Some(LogicValue::U8Value(u8_val)); + state = LogicExpressionParserState::Operator; + } else { + right_hand = Some(LogicValue::U8Value(u8_val)); + return Ok(LogicExpression {left_hand: left_hand.unwrap(), right_hand: right_hand.unwrap(), operator: operator.unwrap()}); + } + }, + Err(_) => { + if left_hand.is_none() { + left_hand = Some(LogicValue::ColumnName(token.to_string())); + state = LogicExpressionParserState::Operator; + } else { + right_hand = Some(LogicValue::ColumnName(token.to_string())); + return Ok(LogicExpression {left_hand: left_hand.unwrap(), right_hand: right_hand.unwrap(), operator: operator.unwrap()}); + } + }, + } + } + + } + LogicExpressionParserState::StringValue => { + let mut value: Option = None; + if token == "'" { + value = Some(LogicValue::StringValue("".to_string())); + } else { + value = Some(LogicValue::StringValue(token.to_string())); + } + if left_hand.is_none() { + left_hand = value; + } else { + right_hand = value; + } + state = LogicExpressionParserState::EndQuote; + } + LogicExpressionParserState::EndQuote => { + if token == "'" { + if right_hand.is_none() { + state = LogicExpressionParserState::Operator; + } else { + return Ok(LogicExpression {left_hand: left_hand.unwrap(), right_hand: right_hand.unwrap(), operator: operator.unwrap()}); + } + } else { + return Err(anyhow!("Expected end quote at or near {}", token)); + } + } + LogicExpressionParserState::Operator => { + operator = match token.as_str() { + "OR" => Some(LogicalOperator::Or), + "AND" => Some(LogicalOperator::And), + "=" => Some(LogicalOperator::Equal), + ">" => Some(LogicalOperator::GreaterThan), + "<" => Some(LogicalOperator::LessThan), + ">=" => Some(LogicalOperator::GreaterThanEqualTo), + "<=" => Some(LogicalOperator::LessThanEqualTo), + _ => return Err(anyhow!("Unknown operator {}", token)) + }; + state = LogicExpressionParserState::NumberOrQuoteOrColname; + } + } + } + + Err(anyhow!("Unexpected end of input")) + } + fn parse_select_command(tokens: &mut Vec) -> ::anyhow::Result { let mut state: SelectParserState = SelectParserState::ColumnName; // intermediate tmp vars let mut table_name = String::new(); let mut column_names: Vec = vec![]; + let mut logic_expression: Option = None; while let Some(token) = &tokens.pop() { match state { @@ -246,13 +517,24 @@ impl Command { } SelectParserState::TableName => { table_name = token.to_string(); - state = SelectParserState::Semicolon; + state = SelectParserState::WhereKeywordOrSemicolon; + } + SelectParserState::WhereKeywordOrSemicolon => { + if token == ";" { + return Ok(Command::Select(SelectCommand { table_name, column_names, logic_expression: None })); + } else if token == "WHERE" { + logic_expression = Some(Self::parse_logic_expression(tokens)?); + state = SelectParserState::Semicolon; + } else { + return Err(anyhow!("Expected semicolon at or near '{}'", token)); + + } } SelectParserState::Semicolon => { if token != ";" { return Err(anyhow!("Expected semicolon at or near '{}'", token)); } else { - return Ok(Command::Select(SelectCommand { table_name, column_names })); + return Ok(Command::Select(SelectCommand { table_name, column_names, logic_expression })); } } } @@ -267,6 +549,7 @@ impl Command { // intermediate tmp vars let mut table_name = String::new(); + let mut logic_expression: Option = None; while let Some(token) = &tokens.pop() { match state { @@ -279,13 +562,23 @@ impl Command { } DeleteParserState::TableName => { table_name = token.to_string(); - state = DeleteParserState::Semicolon; + state = DeleteParserState::WhereKeywordOrSemicolon; + } + DeleteParserState::WhereKeywordOrSemicolon => { + if token == ";" { + return Ok(Command::Delete(DeleteCommand { table_name, logic_expression: None })); + } else if token == "WHERE" { + logic_expression = Some(Self::parse_logic_expression(tokens)?); + state = DeleteParserState::Semicolon; + } else { + return Err(anyhow!("Expected semicolon at or near '{}'", token)); + } } DeleteParserState::Semicolon => { if token != ";" { return Err(anyhow!("Expected semicolon at or near '{}'", token)); } else { - return Ok(Command::Delete(DeleteCommand { table_name })); + return Ok(Command::Delete(DeleteCommand { table_name, logic_expression })); } } } @@ -400,4 +693,12 @@ impl Command { Err(anyhow!("Unexpected end of statement")) } + + pub fn le_from_string(command_str: String) -> ::anyhow::Result { + let mut tokens: Vec = tokenizer(command_str.clone()); + println!("{}", command_str); + println!("{:?}", tokens); + tokens.reverse(); + return Ok(Self::parse_logic_expression(&mut tokens)?); + } }