diff --git a/bot/src/cli.rs b/bot/src/cli.rs index 1427e08..4d029b3 100644 --- a/bot/src/cli.rs +++ b/bot/src/cli.rs @@ -10,4 +10,12 @@ pub struct CliArgs { /// Redis stream name #[arg(short = 't', long)] pub redis_stream_name: String, + + /// Redis consumer group name + #[arg(short = 'c', long)] + pub redis_consumer_group: String, + + /// The current consumer name + #[arg(short = 'n', long)] + pub redis_consumer_name: String, } diff --git a/bot/src/main.rs b/bot/src/main.rs index 7209312..b241c81 100644 --- a/bot/src/main.rs +++ b/bot/src/main.rs @@ -1,6 +1,8 @@ use crate::cli::CliArgs; use clap::Parser; -use log::{error, info}; +use infrastructure::RedisService; +use log::{error, info, warn}; +use post::NewsPost; use signal_hook::consts::{SIGINT, SIGTERM}; use signal_hook::iterator::Signals; use std::sync::atomic::{AtomicBool, Ordering}; @@ -11,7 +13,7 @@ mod cli; //noinspection DuplicatedCode /// Sets up a signal handler in a separate thread to handle SIGINT and SIGTERM signals. -fn setup_graceful_shutdown(running: Arc) { +fn setup_graceful_shutdown(running: &Arc) { let r = running.clone(); thread::spawn(move || { let signals = Signals::new([SIGINT, SIGTERM]); @@ -36,7 +38,40 @@ async fn main() -> Result<(), anyhow::Error> { // Graceful shutdown. let running = Arc::new(AtomicBool::new(true)); - setup_graceful_shutdown(running); + setup_graceful_shutdown(&running); + // Redis setup + let mut redis_service = RedisService::new(&args.redis_connection_string).await; + + // Create consumer group for stream. + let result = redis_service + .create_group(&args.redis_stream_name, &args.redis_consumer_group, 0) + .await; + if let Err(err) = result { + warn!("{}", err); + } + + // Read from stream + while running.load(Ordering::SeqCst) { + match redis_service + .read_stream::( + &args.redis_stream_name, + &args.redis_consumer_group, + &args.redis_consumer_name, + 5000, + ) + .await + { + Ok(data) => { + // TODO: Implement + dbg!(data); + } + Err(err) => { + error!("error reading stream: {err}") + } + } + } + + info!("Stopping the program"); Ok(()) } diff --git a/infrastructure/Cargo.toml b/infrastructure/Cargo.toml index e666094..fed53ea 100644 --- a/infrastructure/Cargo.toml +++ b/infrastructure/Cargo.toml @@ -8,11 +8,12 @@ tokio = { version = "1", features = ["full"] } # Note: This appears unused by the RustRover analyzer, but it works. # If in the future it stops working for whatever reason because the dependency has # the same name as the module, then we can try to rename it using `package`. -redis = { version = "0.27.6", features = ["tokio-comp"] } +redis = { version = "0.27.6", features = ["tokio-comp", "streams"] } md5 = "0.7.0" serde = { version = "1.0.216", features = ["derive"] } serde_json = "1.0.134" log = "0.4.22" +anyhow = "1.0.95" [dev-dependencies] rand = "0.8.5" diff --git a/infrastructure/src/redis.rs b/infrastructure/src/redis.rs index 4792828..220b699 100644 --- a/infrastructure/src/redis.rs +++ b/infrastructure/src/redis.rs @@ -1,7 +1,10 @@ +use anyhow::anyhow; use log::error; use redis::aio::MultiplexedConnection; -use redis::{AsyncCommands, RedisError}; -use serde::Serialize; +use redis::streams::StreamReadReply; +use redis::Value::BulkString; +use redis::{AsyncCommands, RedisError, RedisResult}; +use serde::{Deserialize, Serialize}; pub struct RedisService { multiplexed_connection: MultiplexedConnection, @@ -53,6 +56,85 @@ impl RedisService { }; true } + + /// Creates a group for the given stream that consumes from the specified starting id. + pub async fn create_group( + &mut self, + stream_name: &str, + group_name: &str, + starting_id: u32, + ) -> Result<(), anyhow::Error> { + redis::cmd("XGROUP") + .arg("CREATE") + .arg(stream_name) + .arg(group_name) + .arg(starting_id) + .exec_async(&mut self.multiplexed_connection) + .await + .map_err(|e| { + anyhow!("failed to create group {group_name} for stream {stream_name}: {e}") + }) + } + + /// Reads a stream from Redis and in a blocking fashion. + /// + /// Messages are acknowledged automatically when read. + /// + /// stream_name - is the name of the stream + /// consumer_group - is the name of the consumer group + /// consumer_name - is the name of the current consumer + /// block_timeout - is the timeout in milliseconds to block for messages. + pub async fn read_stream( + &mut self, + stream_name: &str, + consumer_group: &str, + consumer_name: &str, + block_timeout: u32, + ) -> Result + where + T: for<'a> Deserialize<'a>, + { + let result: RedisResult = redis::cmd("XREADGROUP") + .arg("GROUP") + .arg(consumer_group) + .arg(consumer_name) + .arg("BLOCK") + .arg(block_timeout) + .arg("COUNT") + .arg(1) + .arg("NOACK") + .arg("STREAMS") + .arg(stream_name) + .arg(">") + .query_async(&mut self.multiplexed_connection) + .await; + + match result { + Ok(data) => { + if data.keys.is_empty() { + return Err(anyhow!("read stream entry with empty keys")); + } + if data.keys[0].ids.is_empty() { + return Err(anyhow!("read stream entry with empty ids")); + } + let stream = data.keys[0].ids[0].map.get("data"); + if let Some(BulkString(data)) = stream { + let string_data = std::str::from_utf8(data); + return match string_data { + Ok(string_data) => { + let deserialized_data: T = serde_json::from_str(string_data)?; + Ok(deserialized_data) + } + Err(err) => Err(anyhow!("can't convert data to string: {err}")), + }; + } + Err(anyhow!( + "invalid type read from streams, expected BulkString" + )) + } + Err(err) => Err(err.into()), + } + } } #[cfg(test)] @@ -140,4 +222,40 @@ mod tests { assert_eq!(stream_length, Ok(1)); cleanup(&mut service).await; } + + #[tokio::test] + #[serial] + async fn test_redis_service_read() -> Result<(), anyhow::Error> { + // Setup + let random_stream_name = Alphanumeric.sample_string(&mut rand::thread_rng(), 6); + + let mut service = RedisService::new(REDIS_CONNECTION_STRING).await; + let post = NewsPost { + image: Some(String::from("i")), + title: Some(String::from("t")), + summary: Some(String::from("s")), + link: Some(String::from("l")), + author: Some(String::from("a")), + }; + let _ = service.publish(&random_stream_name, &post).await; + + // Test + service + .create_group(&random_stream_name, &random_stream_name, 0) + .await?; + let result = service + .read_stream::( + &random_stream_name, + &random_stream_name, + &random_stream_name, + 10_000, + ) + .await?; + + assert_eq!(result, post); + + // Assert + cleanup(&mut service).await; + Ok(()) + } } diff --git a/post/src/lib.rs b/post/src/lib.rs index 43e3b2d..fdf1a30 100644 --- a/post/src/lib.rs +++ b/post/src/lib.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; /// NewsPost represents a news post. -#[derive(Serialize, Deserialize, Debug, Clone)] +#[derive(Serialize, Deserialize, Debug, Clone, PartialOrd, PartialEq)] pub struct NewsPost { /// A URL containing the image of the post. pub image: Option,