Skip to content

Commit

Permalink
refactor: use arcstr except for IDs (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
mmta authored Mar 26, 2024
1 parent 2fee2b0 commit 6a233c0
Show file tree
Hide file tree
Showing 13 changed files with 257 additions and 230 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ axum-extra = { version = "0.9.2", features = ["erased-json"] }
ratelimit = "0.9.0"
atomic-counter = "1.0.1"
tower-http = { version = "0.5.1", features = ["timeout", "fs" ] , default-features = false }
arcstr = { version = "1.1.5", features = ["serde"] }

[dev-dependencies]
table-test = "0.2.1"
Expand Down
11 changes: 9 additions & 2 deletions server/benches/generator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::ops::Range;

use arcstr::ArcStr;
use dsiem::rule::{SIDPair, TaxoPair};
use rand::{distributions::Alphanumeric, Rng};

Expand Down Expand Up @@ -27,10 +28,16 @@ pub fn gen_taxopairs(correct_pair: TaxoPair, n: u64) -> Vec<TaxoPair> {
} else {
// use one of the correct product for the 2nd half
let pos = rand::thread_rng().gen_range(0..correct_pair.product.len());
vec![correct_pair.product[pos].clone()]
vec![correct_pair.product[pos].to_string()]
};
let category: String = rand::thread_rng().sample_iter(&Alphanumeric).take(10).map(char::from).collect();
pairs.push(TaxoPair { product, category });
let mut p: Vec<ArcStr> = vec![];

for i in product {
let a = ArcStr::from(i.as_str());
p.push(a)
}
pairs.push(TaxoPair { product: p, category: category.into() });
}
pairs.push(correct_pair);
pairs
Expand Down
3 changes: 1 addition & 2 deletions server/benches/quick_checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ pub fn quick_check_taxo_rule_with_rayon(pairs: &[TaxoPair], e: &NormalizedEvent)
}

fn bench_quick_check_taxo_rule(c: &mut Criterion) {
let correct_pair =
TaxoPair { product: vec!["Suricata".to_string(), "Snort".to_string()], category: "Firewall".to_string() };
let correct_pair = TaxoPair { product: vec!["Suricata".into(), "Snort".into()], category: "Firewall".into() };
let event = NormalizedEvent {
product: correct_pair.product[1].clone(),
category: correct_pair.category.clone(),
Expand Down
102 changes: 56 additions & 46 deletions server/src/backlog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::{
};

use anyhow::{anyhow, Result};
use arcstr::ArcStr;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use serde::Deserialize;
Expand All @@ -34,8 +35,8 @@ pub mod manager;

#[derive(Serialize, Deserialize, Debug, Clone, Hash, PartialEq, Eq)]
pub struct CustomData {
pub label: String,
pub content: String,
pub label: ArcStr,
pub content: ArcStr,
}

#[derive(Serialize)]
Expand All @@ -61,19 +62,19 @@ pub enum BacklogState {
pub struct Backlog {
#[serde(rename(serialize = "alarm_id", deserialize = "alarm_id"))]
pub id: String,
pub title: String,
pub status: String,
pub tag: String,
pub kingdom: String,
pub category: String,
pub title: ArcStr,
pub status: ArcStr,
pub tag: ArcStr,
pub kingdom: ArcStr,
pub category: ArcStr,
pub created_time: AtomicI64,
pub update_time: AtomicI64,
pub risk: AtomicU8,
pub risk_class: Mutex<String>,
pub risk_class: Mutex<ArcStr>,
pub rules: Vec<DirectiveRule>,
pub src_ips: Mutex<HashSet<IpAddr>>,
pub dst_ips: Mutex<HashSet<IpAddr>>,
pub networks: Mutex<HashSet<String>>,
pub networks: Mutex<HashSet<ArcStr>>,
#[serde(skip_serializing_if = "is_locked_data_empty")]
#[serde(default)]
pub intel_hits: Mutex<HashSet<IntelResult>>,
Expand Down Expand Up @@ -179,8 +180,8 @@ pub struct BacklogOpt {
pub bp_tx: Sender<()>,
pub delete_tx: Option<Sender<()>>, // allow late initialization
pub min_alarm_lifetime: i64,
pub default_status: String,
pub default_tag: String,
pub default_status: ArcStr,
pub default_tag: ArcStr,
pub med_risk_min: u8,
pub med_risk_max: u8,
pub intel_private_ip: bool,
Expand All @@ -194,7 +195,7 @@ impl Backlog {
title: o.directive.name.clone(),
kingdom: o.directive.kingdom.clone(),
category: o.directive.category.clone(),
status: o.default_status.to_owned(),
status: o.default_status.clone(),
tag: o.default_tag.to_owned(),
intel_private_ip: o.intel_private_ip,
current_stage: AtomicU8::new(1),
Expand All @@ -219,15 +220,15 @@ impl Backlog {
} else {
v.src_ip.to_string()
};
backlog.title = backlog.title.replace("SRC_IP", &src);
backlog.title = backlog.title.replace("SRC_IP", &src).into();
}
if backlog.title.contains("DST_IP") {
let dst = if let Some(hostname) = backlog.assets.get_name(&v.dst_ip) {
hostname
} else {
v.dst_ip.to_string()
};
backlog.title = backlog.title.replace("DST_IP", &dst);
backlog.title = backlog.title.replace("DST_IP", &dst).into();
}

backlog.rules = o.directive.init_backlog_rules(v.as_ref());
Expand Down Expand Up @@ -317,15 +318,15 @@ impl Backlog {
// - status, kingdom, tag, category, created_time are empty;
let mut backlog = Backlog {
id: (*running.id).to_string(),
title: (*running.title).to_string(),
status: (*running.status).to_string(),
kingdom: (*running.kingdom).to_string(),
category: (*running.category).to_string(),
tag: (*running.tag).to_string(),
title: running.title.clone(),
status: running.status.clone(),
kingdom: running.kingdom.clone(),
category: running.category.clone(),
tag: running.tag.clone(),
created_time: AtomicI64::new(running.created_time.load(Relaxed)),
update_time: AtomicI64::new(running.update_time.load(Relaxed)),
risk: AtomicU8::new(running.risk.load(Relaxed)),
risk_class: (*running.risk_class.lock()).to_string().into(),
risk_class: (*running.risk_class.lock()).clone().into(),
rules: running.rules.clone(),
src_ips: (*running.src_ips.lock()).clone().into(),
dst_ips: (*running.dst_ips.lock()).clone().into(),
Expand Down Expand Up @@ -580,7 +581,7 @@ impl Backlog {
fn set_rule_status(&self, status: &str) -> Result<()> {
let curr_rule = self.current_rule()?;
let mut w = curr_rule.status.lock();
*w = status.to_owned();
*w = status.into();
Ok(())
}
fn set_rule_endtime(&self, t: DateTime<Utc>) -> Result<()> {
Expand Down Expand Up @@ -619,11 +620,11 @@ impl Backlog {
let mut w = self.risk_class.lock();
let risk = self.risk.load(Relaxed);
*w = if risk < self.med_risk_min {
"Low".to_string()
"Low".into()
} else if risk >= self.med_risk_min && risk <= self.med_risk_max {
"Medium".to_string()
"Medium".into()
} else {
"High".to_string()
"High".into()
};
}

Expand Down Expand Up @@ -716,13 +717,22 @@ impl Backlog {
{
let mut w = self.custom_data.lock();
if !event.custom_data1.is_empty() {
w.insert(CustomData { label: event.custom_label1.clone(), content: event.custom_data1.clone() });
w.insert(CustomData {
label: event.custom_label1.clone().into(),
content: event.custom_data1.clone().into(),
});
}
if !event.custom_data2.is_empty() {
w.insert(CustomData { label: event.custom_label2.clone(), content: event.custom_data2.clone() });
w.insert(CustomData {
label: event.custom_label2.clone().into(),
content: event.custom_data2.clone().into(),
});
}
if !event.custom_data3.is_empty() {
w.insert(CustomData { label: event.custom_label3.clone(), content: event.custom_data3.clone() });
w.insert(CustomData {
label: event.custom_label3.clone().into(),
content: event.custom_data3.clone().into(),
});
}
}

Expand Down Expand Up @@ -770,7 +780,7 @@ impl Backlog {
for ip in r.iter() {
if let Some(v) = self.assets.get_asset_networks(ip) {
for x in v {
w.insert(x);
w.insert(x.into());
}
}
}
Expand Down Expand Up @@ -960,8 +970,8 @@ mod test {
bp_tx,
delete_tx: Some(mgr_delete_tx),
min_alarm_lifetime: 0,
default_status: "Open".to_string(),
default_tag: "Identified Threat".to_string(),
default_status: "Open".into(),
default_tag: "Identified Threat".into(),
med_risk_min: 3,
med_risk_max: 6,
intel_private_ip: true,
Expand Down Expand Up @@ -1013,7 +1023,7 @@ mod test {
// but different title
let b = Backlog::new(&get_opt()).unwrap();
let mut saveable = Backlog::saveable_version(Arc::new(b));
saveable.title = "foo".to_string();
saveable.title = "foo".into();
let res = Backlog::runnable_version(get_opt(), saveable);
assert!(res.unwrap_err().to_string().contains("different title detected"));

Expand All @@ -1022,7 +1032,7 @@ mod test {
let b = Backlog::new(&get_opt()).unwrap();
let mut saveable = Backlog::saveable_version(Arc::new(b));
for rule in saveable.rules.iter_mut() {
rule.status = Arc::new(Mutex::new("finished".to_string()));
rule.status = Arc::new(Mutex::new("finished".into()));
}
let res = Backlog::runnable_version(get_opt(), saveable);
assert!(res.unwrap_err().to_string().contains("skipping this backlog"));
Expand Down Expand Up @@ -1058,12 +1068,12 @@ mod test {
dst_ip: "192.168.0.2".parse().unwrap(),
src_port: 31337,
dst_port: 80,
custom_label1: "label".to_string(),
custom_data1: "data".to_string(),
custom_label2: "label".to_string(),
custom_data2: "data".to_string(),
custom_label3: "label".to_string(),
custom_data3: "data".to_string(),
custom_label1: "label".into(),
custom_data1: "data".into(),
custom_label2: "label".into(),
custom_data2: "data".into(),
custom_label3: "label".into(),
custom_data3: "data".into(),
rcvd_time: now - 10000,
..Default::default()
};
Expand All @@ -1077,8 +1087,8 @@ mod test {
event: Some(Arc::new(evt_cloned.clone())),
bp_tx,
delete_tx: Some(mgr_delete_tx),
default_status: "Open".to_string(),
default_tag: "Identified Threat".to_string(),
default_status: "Open".into(),
default_tag: "Identified Threat".into(),
min_alarm_lifetime: 0,
med_risk_min: 3,
med_risk_max: 6,
Expand Down Expand Up @@ -1181,8 +1191,8 @@ mod test {
event: Some(Arc::new(evt_cloned.clone())),
bp_tx,
delete_tx: Some(mgr_delete_tx),
default_status: "Open".to_string(),
default_tag: "Identified Threat".to_string(),
default_status: "Open".into(),
default_tag: "Identified Threat".into(),
min_alarm_lifetime: 0,
med_risk_min: 3,
med_risk_max: 5,
Expand Down Expand Up @@ -1265,8 +1275,8 @@ mod test {
let evt = NormalizedEvent {
plugin_id: 1337,
plugin_sid: 1,
custom_label1: "label".to_string(),
custom_data1: "data".to_string(),
custom_label1: "label".into(),
custom_data1: "data".into(),
..Default::default()
};

Expand All @@ -1279,8 +1289,8 @@ mod test {
event: Some(Arc::new(evt.clone())),
bp_tx,
delete_tx: Some(mgr_delete_tx),
default_status: "Open".to_string(),
default_tag: "Identified Threat".to_string(),
default_status: "Open".into(),
default_tag: "Identified Threat".into(),
min_alarm_lifetime: 0,
med_risk_min: 3,
med_risk_max: 6,
Expand Down
5 changes: 3 additions & 2 deletions server/src/bin/dsiem-backend.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{process::ExitCode, sync::Arc, thread, time::Duration};

use anyhow::{anyhow, Result};
use arcstr::ArcStr;
use clap::{arg, command, Args, Parser, Subcommand};
use dsiem::{
asset::NetworkAssets,
Expand Down Expand Up @@ -80,7 +81,7 @@ struct ServeArgs {
value_delimiter = ',',
default_value = "Open,In-Progress,Closed"
)]
status: Vec<String>,
status: Vec<ArcStr>,
/// Alarm tags to use, the first one will be assigned to new alarms
#[arg(
short('t'),
Expand All @@ -91,7 +92,7 @@ struct ServeArgs {
value_delimiter = ',',
default_value = "Identified Threat,False Positive,Valid Threat,Security Incident"
)]
tags: Vec<String>,
tags: Vec<ArcStr>,
/// Minimum alarm risk value to be classified as Medium risk. Lower value
/// than this will be classified as Low risk
#[arg(long = "med_risk_min", value_name = "2 to 8", env = "DSIEM_MEDRISKMIN", default_value_t = 3)]
Expand Down
Loading

0 comments on commit 6a233c0

Please sign in to comment.