diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/client.rs | 58 | ||||
| -rw-r--r-- | src/command/mod.rs | 51 | ||||
| -rw-r--r-- | src/command/ping.rs | 23 | ||||
| -rw-r--r-- | src/command/track.rs | 133 | ||||
| -rw-r--r-- | src/handler/mod.rs | 67 | ||||
| -rw-r--r-- | src/http/github.rs | 68 | ||||
| -rw-r--r-- | src/http/mod.rs | 44 | ||||
| -rw-r--r-- | src/main.rs | 29 |
8 files changed, 473 insertions, 0 deletions
diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..a779a3b --- /dev/null +++ b/src/client.rs @@ -0,0 +1,58 @@ +use std::sync::Arc; + +use crate::{ + handler::Handler, + http::{self, HttpClientExt}, +}; + +use eyre::Result; +use serenity::prelude::{Client, GatewayIntents, TypeMapKey}; +use tracing::trace; + +/// Container for [http::Client] +pub struct SharedClient; + +impl TypeMapKey for SharedClient { + type Value = Arc<http::Client>; +} + +/// Fetch our bot token +fn token() -> Result<String> { + let token = std::env::var("DISCORD_BOT_TOKEN")?; + Ok(token) +} + +/// Create our client +#[tracing::instrument] +pub async fn get() -> Client { + let token = token().expect("Couldn't find token in environment! Is DISCORD_BOT_TOKEN set?"); + + let intents = GatewayIntents::default(); + trace!("Creating client"); + let client = Client::builder(token, intents) + .event_handler(Handler) + .await + .expect("Couldn't create a client!"); + + // add state stuff + { + let mut data = client.data.write().await; + trace!("Creating HTTP client"); + let http_client = <http::Client as HttpClientExt>::default(); + trace!("Inserting HTTP client into Discord client"); + data.insert::<SharedClient>(Arc::new(http_client)) + } + + let shard_manager = client.shard_manager.clone(); + + // gracefully shutdown on ctrl+c + tokio::spawn(async move { + #[cfg(target_family = "unix")] + tokio::signal::ctrl_c() + .await + .expect("Couldn't registrl ctrl+c handler!"); + shard_manager.shutdown_all().await; + }); + + client +} diff --git a/src/command/mod.rs b/src/command/mod.rs new file mode 100644 index 0000000..eda4167 --- /dev/null +++ b/src/command/mod.rs @@ -0,0 +1,51 @@ +use eyre::{OptionExt, Result}; +use serenity::builder::{ + CreateCommand, CreateInteractionResponse, CreateInteractionResponseMessage, +}; +use serenity::model::application::CommandInteraction; +use serenity::prelude::Context; +use tracing::instrument; + +use crate::client::SharedClient; + +mod ping; +mod track; + +macro_rules! cmd { + ($module: ident) => { + $module::register() + }; +} + +/// Return a list of all our [CreateCommand]s +pub fn to_vec() -> Vec<CreateCommand> { + vec![cmd!(ping), cmd!(track)] +} + +/// Dispatch our commands from a [CommandInteraction] +#[instrument(skip(ctx))] +pub async fn dispatch(ctx: &Context, command: &CommandInteraction) -> Result<()> { + let command_name = command.data.name.as_str(); + + // grab our http client from the aether + let http = { + let read = ctx.data.read().await; + read.get::<SharedClient>() + .ok_or_eyre("Couldn't get shared HTTP client! WHY??????")? + .clone() + }; + + match command_name { + "ping" => ping::respond(ctx, command).await?, + "track" => track::respond(ctx, &http, command).await?, + _ => { + let message = CreateInteractionResponseMessage::new().content(format!( + "It doesn't look like you can use `{command_name}`. Sorry :(" + )); + let response = CreateInteractionResponse::Message(message); + command.create_response(&ctx, response).await? + } + }; + + Ok(()) +} diff --git a/src/command/ping.rs b/src/command/ping.rs new file mode 100644 index 0000000..1b1b812 --- /dev/null +++ b/src/command/ping.rs @@ -0,0 +1,23 @@ +use eyre::Result; +use serenity::builder::{ + CreateCommand, CreateInteractionResponse, CreateInteractionResponseMessage, +}; +use serenity::model::application::{CommandInteraction, InstallationContext}; +use serenity::prelude::Context; +use tracing::{instrument, trace}; + +#[instrument] +pub async fn respond(ctx: &Context, command: &CommandInteraction) -> Result<()> { + trace!("Responding to ping command"); + let message = CreateInteractionResponseMessage::new().content("Pong!"); + let response = CreateInteractionResponse::Message(message); + command.create_response(&ctx, response).await?; + + Ok(()) +} + +pub fn register() -> CreateCommand { + CreateCommand::new("ping") + .description("Check if the bot is up") + .add_integration_type(InstallationContext::User) +} diff --git a/src/command/track.rs b/src/command/track.rs new file mode 100644 index 0000000..6217043 --- /dev/null +++ b/src/command/track.rs @@ -0,0 +1,133 @@ +use crate::http::{Client, GithubClientExt}; + +use eyre::Result; +use futures::future::try_join_all; +use serenity::builder::{CreateCommand, CreateCommandOption, CreateInteractionResponseFollowup}; +use serenity::model::application::{ + CommandInteraction, CommandOptionType, InstallationContext, ResolvedOption, ResolvedValue, +}; +use serenity::prelude::Context; +use tracing::{instrument, trace}; + +/// All of our tracked branches in nixpkgs +const BRANCHES: [&str; 8] = [ + "master", + "staging", + "nixos-unstable", + "nixos-unstable-small", + "nixos-24.05-small", + "release-24.05", + "nixos-23.11-small", + "release-23.11", +]; + +#[derive(Clone, Debug, Default)] +struct BranchStatus { + repo_owner: String, + repo_name: String, + name: String, +} + +impl BranchStatus { + fn new(repo_owner: String, repo_name: String, name: String) -> Self { + Self { + repo_owner, + repo_name, + name, + } + } + + /// Make a nice friendly string displaying if this branch has a PR merged into it + fn to_status_string(&self, has_pr: bool) -> String { + let emoji = if has_pr { "✅" } else { "❌" }; + format!("`{}` {emoji}", &self.name) + } + + /// Check if this branch has the specified pull request merged into it + #[instrument(skip(http))] + async fn has_pr(&self, http: &Client, pr: u64) -> Result<bool> { + let commit = http + .merge_commit_for( + &self.repo_owner, + &self.repo_name, + u64::try_from(pr).unwrap(), + ) + .await?; + + let has_pr = http + .is_commit_in_branch(&self.repo_owner, &self.repo_name, &self.name, &commit) + .await?; + + Ok(has_pr) + } +} + +/// async wrapper for [BranchStatus::to_status_string()] +#[instrument(skip(http))] +async fn collect_status( + http: &Client, + repo_owner: String, + repo_name: String, + branch: String, + pr: u64, +) -> Result<String> { + let status = BranchStatus::new(repo_owner, repo_name, branch); + let has_pr = status.has_pr(http, pr).await?; + let res = status.to_status_string(has_pr); + + Ok(res) +} + +#[instrument(skip_all)] +pub async fn respond(ctx: &Context, http: &Client, command: &CommandInteraction) -> Result<()> { + trace!("Responding to track command"); + + // this will probably take a while + command.defer(&ctx).await?; + + // TODO: make these configurable for nixpkgs forks...or other github repos ig + const REPO_OWNER: &str = "NixOS"; + const REPO_NAME: &str = "nixpkgs"; + + let options = command.data.options(); + + let response = if let Some(ResolvedOption { + value: ResolvedValue::Integer(pr), + .. + }) = options.first() + { + if *pr < 0 { + CreateInteractionResponseFollowup::new().content("PR numbers aren't negative...") + } else { + // TODO: this is gross + let statuses = try_join_all(BRANCHES.iter().map(|&branch| { + collect_status( + http, + REPO_OWNER.to_string(), + REPO_NAME.to_string(), + branch.to_string(), + u64::try_from(*pr).unwrap(), + ) + })) + .await?; + + CreateInteractionResponseFollowup::new().content(statuses.join("\n")) + } + } else { + CreateInteractionResponseFollowup::new().content("Please provide a valid commit!") + }; + + command.create_followup(&ctx, response).await?; + + Ok(()) +} + +pub fn register() -> CreateCommand { + CreateCommand::new("track") + .description("Track a nixpkgs PR") + .add_integration_type(InstallationContext::User) + .add_option( + CreateCommandOption::new(CommandOptionType::Integer, "pull_request", "PR to track") + .required(true), + ) +} diff --git a/src/handler/mod.rs b/src/handler/mod.rs new file mode 100644 index 0000000..47e2774 --- /dev/null +++ b/src/handler/mod.rs @@ -0,0 +1,67 @@ +use crate::command; + +use std::error::Error; + +use serenity::async_trait; +use serenity::builder::{CreateEmbed, CreateInteractionResponse, CreateInteractionResponseMessage}; +use serenity::model::{ + application::{Command, Interaction}, + colour::Colour, + gateway::Ready, +}; +use serenity::prelude::{Context, EventHandler}; +use tracing::{debug, error, info, instrument}; + +#[derive(Clone, Copy, Debug)] +pub struct Handler; + +impl Handler { + async fn register_commands(&self, ctx: &Context) -> Result<(), Box<dyn Error>> { + let commands = command::to_vec(); + let commands_len = commands.len(); + for command in commands { + Command::create_global_command(&ctx.http, command).await?; + } + + debug!("Registered {} commands", commands_len); + Ok(()) + } +} + +#[async_trait] +impl EventHandler for Handler { + #[instrument(skip_all)] + async fn interaction_create(&self, ctx: Context, interaction: Interaction) { + if let Interaction::Command(command) = interaction { + let command_name = &command.data.name; + debug!("Received command: {}", command_name); + + if let Err(why) = command::dispatch(&ctx, &command).await { + error!( + "Ran into an error while dispatching command {}:\n{why:?}", + command_name + ); + + let embed = CreateEmbed::new() + .title("An error occurred") + .description("Sorry about that!") + .color(Colour::RED); + let message = CreateInteractionResponseMessage::new().embed(embed); + let response = CreateInteractionResponse::Message(message); + + if let Err(why) = command.create_response(&ctx.http, response).await { + error!("Ran into an error while trying to recover from an error!\n{why:?}"); + } + } + } + } + + #[instrument(skip_all)] + async fn ready(&self, ctx: Context, ready: Ready) { + info!("Connected as {}!", ready.user.name); + + if let Err(why) = self.register_commands(&ctx).await { + error!("Couldn't register commands!\n{why:?}"); + }; + } +} diff --git a/src/http/github.rs b/src/http/github.rs new file mode 100644 index 0000000..bdb363e --- /dev/null +++ b/src/http/github.rs @@ -0,0 +1,68 @@ +use super::{Error, HttpClientExt}; + +use serde::Deserialize; + +const GITHUB_API: &str = "https://api.github.com"; + +/// Bad version of `/repos/{owner}/{repo}/{compare}/{ref}...{ref}` +#[derive(Deserialize)] +struct Compare { + status: String, + ahead_by: i32, +} + +/// Bad version of `/repos/{owner}/{repo}/pulls/{pull_number}` +#[derive(Deserialize)] +struct PullRequest { + merge_commit_sha: String, +} + +pub trait GithubClientExt { + /// Get the commit that merged [`pr`] in [`repo_owner`]/[`repo_name`] + async fn merge_commit_for( + &self, + repo_owner: &str, + repo_name: &str, + pr: u64, + ) -> Result<String, Error>; + + /// Check if [`commit`] is in [`branch`] of [`repo_owner`]/[`repo_name`] + async fn is_commit_in_branch( + &self, + repo_owner: &str, + repo_name: &str, + branch_name: &str, + commit: &str, + ) -> Result<bool, Error>; +} + +impl GithubClientExt for super::Client { + async fn merge_commit_for( + &self, + repo_owner: &str, + repo_name: &str, + pr: u64, + ) -> Result<String, Error> { + let url = format!("{GITHUB_API}/repos/{repo_owner}/{repo_name}/pulls/{pr}"); + let resp: PullRequest = self.get_json(&url).await?; + let merge_commit = resp.merge_commit_sha; + + Ok(merge_commit) + } + + async fn is_commit_in_branch( + &self, + repo_owner: &str, + repo_name: &str, + branch: &str, + commit: &str, + ) -> Result<bool, Error> { + let url = format!( + "https://api.github.com/repos/{repo_owner}/{repo_name}/compare/{branch}...{commit}" + ); + let resp: Compare = self.get_json(&url).await?; + let in_branch = resp.status != "diverged" && resp.ahead_by >= 0; + + Ok(in_branch) + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs new file mode 100644 index 0000000..fa60d67 --- /dev/null +++ b/src/http/mod.rs @@ -0,0 +1,44 @@ +use serde::de::DeserializeOwned; +use tracing::trace; + +mod github; + +pub use github::*; + +pub type Client = reqwest::Client; +pub type Response = reqwest::Response; +pub type Error = reqwest::Error; + +/// Fun trait for functions we use with [Client] +pub trait HttpClientExt { + fn default() -> Self; + async fn get_request(&self, url: &str) -> Result<Response, Error>; + async fn get_json<T: DeserializeOwned>(&self, url: &str) -> Result<T, Error>; +} + +impl HttpClientExt for Client { + fn default() -> Self { + reqwest::Client::builder() + .user_agent(format!( + "nixpkgs-tracker-bot/{}", + option_env!("CARGO_PKG_VERSION").unwrap_or_else(|| "development") + )) + .build() + .unwrap() + } + + async fn get_request(&self, url: &str) -> Result<Response, Error> { + trace!("Making GET request to {url}"); + + let resp = self.get(url).send().await?; + resp.error_for_status_ref()?; + + Ok(resp) + } + + async fn get_json<T: DeserializeOwned>(&self, url: &str) -> Result<T, Error> { + let resp = self.get_request(url).await?; + let json = resp.json().await?; + Ok(json) + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..3a604ba --- /dev/null +++ b/src/main.rs @@ -0,0 +1,29 @@ +use eyre::Result; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +mod client; +mod command; +mod handler; +mod http; + +fn init_logging() { + let fmt_layer = tracing_subscriber::fmt::layer().pretty(); + let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "nixpkgs_discord_tracker=info,warn".into()); + + tracing_subscriber::registry() + .with(fmt_layer) + .with(env_filter) + .init(); +} + +#[tokio::main] +async fn main() -> Result<()> { + dotenvy::dotenv().ok(); + init_logging(); + + let mut client = client::get().await; + client.start().await?; + + Ok(()) +} |
