Compare commits

...

11 Commits

Author SHA1 Message Date
Erik Johnston
cdd83480dc Fix power levels and tests 2022-09-09 15:10:52 +01:00
Erik Johnston
e6a11bde38 Revert debug logging 2022-09-09 15:10:52 +01:00
Erik Johnston
310df25173 Fixup 2022-09-09 15:10:52 +01:00
Erik Johnston
df24732962 Experimental rules 2022-09-09 15:10:52 +01:00
Erik Johnston
bbba09e583 fix up 2022-09-09 15:10:52 +01:00
Erik Johnston
9e2712aa2f Fixup 2022-09-09 15:10:52 +01:00
Erik Johnston
88fddf991d Fix visibility 2022-09-09 15:10:52 +01:00
Erik Johnston
9f42e64145 Newsfile 2022-09-09 15:10:52 +01:00
Erik Johnston
32ea637ce7 stuff 2022-09-09 15:10:52 +01:00
Erik Johnston
a759cf7d82 wip 2022-09-09 15:10:52 +01:00
Erik Johnston
18b6ccadd2 Implement the push evaluator in Rust 2022-09-09 15:10:51 +01:00
14 changed files with 1166 additions and 46 deletions

1
.rustfmt.toml Normal file
View File

@@ -0,0 +1 @@
group_imports = "StdExternalCrate"

1
changelog.d/13720.misc Normal file
View File

@@ -0,0 +1 @@
Speed up push rule evaluation in rooms with large numbers of local users.

View File

@@ -18,4 +18,12 @@ crate-type = ["cdylib"]
name = "synapse.synapse_rust"
[dependencies]
pyo3 = { version = "0.16.5", features = ["extension-module", "macros", "abi3", "abi3-py37"] }
anyhow = "1.0.63"
lazy_static = "1.4.0"
log = "0.4.17"
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
pyo3-log = "0.7.0"
pythonize = "0.17.0"
regex = "1.6.0"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"

View File

@@ -1,5 +1,7 @@
use pyo3::prelude::*;
pub mod push;
/// Formats the sum of two numbers as string.
#[pyfunction]
#[pyo3(text_signature = "(a, b, /)")]
@@ -9,8 +11,12 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
/// The entry point for defining the Python module.
#[pymodule]
fn synapse_rust(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
pyo3_log::init();
m.add_function(wrap_pyfunction!(sum_as_string, m)?)?;
push::register_module(py, m)?;
Ok(())
}

306
rust/src/push/base_rules.rs Normal file
View File

@@ -0,0 +1,306 @@
//! Contains the definitions of the "base" push rules.
use std::borrow::Cow;
use std::collections::HashMap;
use lazy_static::lazy_static;
use serde_json::Value;
use crate::push::Action;
use crate::push::Condition;
use crate::push::EventMatchCondition;
use crate::push::PushRule;
use crate::push::SetTweak;
use crate::push::TweakValue;
const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("highlight"),
value: None,
other_keys: Value::Null,
});
const HIGHLIGHT_FALSE_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("highlight"),
value: Some(TweakValue::Other(Value::Bool(false))),
other_keys: Value::Null,
});
const SOUND_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("sound"),
value: Some(TweakValue::String(Cow::Borrowed("default"))),
other_keys: Value::Null,
});
const RING_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("sound"),
value: Some(TweakValue::String(Cow::Borrowed("ring"))),
other_keys: Value::Null,
});
pub const BASE_PREPEND_OVERRIDE_RULES: &[PushRule] = &[PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.master"),
priority_class: 5,
conditions: Cow::Borrowed(&[]),
actions: Cow::Borrowed(&[Action::DontNotify]),
default: true,
default_enabled: false,
}];
pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.suppress_notices"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("content.msgtype"),
pattern: Some(Cow::Borrowed("m.notice")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::DontNotify]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.invite_for_me"),
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.member")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("content.membership"),
pattern: Some(Cow::Borrowed("invite")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("state_key"),
pattern: None,
pattern_type: Some(Cow::Borrowed("user_id")),
}),
]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.member_event"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.member")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::DontNotify]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.contains_display_name"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::ContainsDisplayName]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.roomnotif"),
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::SenderNotificationPermission {
key: Cow::Borrowed("room"),
},
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("content.body"),
pattern: Some(Cow::Borrowed("@room")),
pattern_type: None,
}),
]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.tombstone"),
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.tombstone")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("state_key"),
pattern: Some(Cow::Borrowed("")),
pattern_type: None,
}),
]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.m.rule.reaction"),
priority_class: 5,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.reaction")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::DontNotify]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/override/.org.matrix.msc3786.rule.room.server_acl"),
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.server_acl")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("state_key"),
pattern: Some(Cow::Borrowed("")),
pattern_type: None,
}),
]),
actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
];
pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
rule_id: Cow::Borrowed("global/content/.m.rule.contains_user_name"),
priority_class: 4,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("content.body"),
pattern: None,
pattern_type: Some(Cow::Borrowed("user_localpart")),
})]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_ACTION, SOUND_ACTION]),
default: true,
default_enabled: true,
}];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.call.invite")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::Notify, RING_ACTION, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.room_one_to_one"),
priority_class: 1,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.message")),
pattern_type: None,
}),
Condition::RoomMemberCount {
is: Some(Cow::Borrowed("2")),
},
]),
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted_room_one_to_one"),
priority_class: 1,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.encrypted")),
pattern_type: None,
}),
Condition::RoomMemberCount {
is: Some(Cow::Borrowed("2")),
},
]),
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::RelationMatch {
rel_type: Cow::Borrowed("m.thread"),
sender: None,
sender_type: Some(Cow::Borrowed("user_id")),
}]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.message"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.message")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.encrypted"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("m.room.encrypted")),
pattern_type: None,
})]),
actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.im.vector.jitsi"),
priority_class: 1,
conditions: Cow::Borrowed(&[
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("type"),
pattern: Some(Cow::Borrowed("im.vector.modular.widgets")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("content.type"),
pattern: Some(Cow::Borrowed("jitsi")),
pattern_type: None,
}),
Condition::EventMatch(EventMatchCondition {
key: Cow::Borrowed("state_key"),
pattern: Some(Cow::Borrowed("*")),
pattern_type: None,
}),
]),
actions: Cow::Borrowed(&[HIGHLIGHT_FALSE_ACTION]),
default: true,
default_enabled: true,
},
];
lazy_static! {
pub static ref BASE_RULES_BY_ID: HashMap<&'static str, &'static PushRule> =
BASE_PREPEND_OVERRIDE_RULES
.iter()
.chain(BASE_APPEND_OVERRIDE_RULES.iter())
.chain(BASE_APPEND_CONTENT_RULES.iter())
.chain(BASE_APPEND_UNDERRIDE_RULES.iter())
.map(|rule| { (&*rule.rule_id, rule) })
.collect();
}

248
rust/src/push/evaluator.rs Normal file
View File

@@ -0,0 +1,248 @@
use std::collections::{BTreeMap, BTreeSet};
use anyhow::{Context, Error};
use log::{info, warn};
use pyo3::prelude::*;
use super::{
utils::{get_localpart_from_id, glob_to_regex, GlobMatchType},
Action, Condition, EventMatchCondition, FilteredPushRules, INEQUALITY_EXPR,
};
#[pyclass]
pub struct PushRuleEvaluator {
flattened_keys: BTreeMap<String, String>,
body: String,
room_member_count: u64,
notification_power_levels: BTreeMap<String, i64>,
relations: BTreeMap<String, BTreeSet<(String, String)>>,
relation_match_enabled: bool,
sender_power_level: i64,
}
#[pymethods]
impl PushRuleEvaluator {
#[new]
fn py_new(
flattened_keys: BTreeMap<String, String>,
room_member_count: u64,
sender_power_level: i64,
notification_power_levels: BTreeMap<String, i64>,
relations: BTreeMap<String, BTreeSet<(String, String)>>,
relation_match_enabled: bool,
) -> Result<Self, Error> {
let body = flattened_keys
.get("content.body")
.cloned()
.unwrap_or_default();
Ok(PushRuleEvaluator {
flattened_keys,
body,
room_member_count,
notification_power_levels,
relations,
relation_match_enabled,
sender_power_level,
})
}
fn run(
&self,
push_rules: &FilteredPushRules,
user_id: Option<&str>,
display_name: Option<&str>,
) -> Vec<Action> {
let mut actions = Vec::new();
'outer: for (push_rule, enabled) in push_rules.iter() {
if !enabled {
continue;
}
for condition in push_rule.conditions.iter() {
match self.match_condition(condition, user_id, display_name) {
Ok(true) => {}
Ok(false) => continue 'outer,
Err(err) => {
warn!("Condition match failed {err}");
continue 'outer;
}
}
}
actions.extend(
push_rule
.actions
.iter()
// .filter(|a| **a != Action::DontNotify)
.cloned(),
);
return actions;
}
actions
}
}
impl PushRuleEvaluator {
pub fn match_condition(
&self,
condition: &Condition,
user_id: Option<&str>,
display_name: Option<&str>,
) -> Result<bool, Error> {
let result = match condition {
Condition::EventMatch(event_match) => self.match_event_match(event_match, user_id)?,
Condition::ContainsDisplayName => {
if let Some(dn) = display_name {
let matcher = glob_to_regex(dn, GlobMatchType::Word)?;
matcher.is_match(&self.body)
} else {
false
}
}
Condition::RoomMemberCount { is } => {
if let Some(is) = is {
self.match_member_count(is)?
} else {
false
}
}
Condition::SenderNotificationPermission { key } => {
let required_level = self
.notification_power_levels
.get(key.as_ref())
.copied()
.unwrap_or(50);
info!(
"Power level {required_level} vs {}",
self.sender_power_level
);
self.sender_power_level >= required_level
}
Condition::RelationMatch {
rel_type,
sender,
sender_type,
} => {
if !self.relation_match_enabled {
return Ok(false);
}
let sender_pattern = if let Some(sender) = sender {
sender
} else if let Some(sender_type) = sender_type {
if sender_type == "user_id" {
if let Some(user_id) = user_id {
user_id
} else {
return Ok(false);
}
} else {
warn!("Unrecognized sender_type: {sender_type}");
return Ok(false);
}
} else {
warn!("relation_match condition missing sender or sender_type");
return Ok(false);
};
let relations = if let Some(relations) = self.relations.get(&**rel_type) {
relations
} else {
return Ok(false);
};
let sender_compiled_pattern = glob_to_regex(sender_pattern, GlobMatchType::Whole)?;
let rel_type_compiled_pattern = glob_to_regex(rel_type, GlobMatchType::Whole)?;
for (relation_sender, event_type) in relations {
if sender_compiled_pattern.is_match(&relation_sender)
&& rel_type_compiled_pattern.is_match(event_type)
{
return Ok(true);
}
}
false
}
};
Ok(result)
}
fn match_event_match(
&self,
event_match: &EventMatchCondition,
user_id: Option<&str>,
) -> Result<bool, Error> {
let pattern = if let Some(pattern) = &event_match.pattern {
pattern
} else if let Some(pattern_type) = &event_match.pattern_type {
let user_id = if let Some(user_id) = user_id {
user_id
} else {
return Ok(false);
};
match &**pattern_type {
"user_id" => user_id,
"user_localpart" => get_localpart_from_id(user_id)?,
_ => return Ok(false),
}
} else {
return Ok(false);
};
if event_match.key == "content.body" {
let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Word)?;
Ok(compiled_pattern.is_match(&self.body))
} else if let Some(value) = self.flattened_keys.get(&*event_match.key) {
let compiled_pattern = glob_to_regex(pattern, GlobMatchType::Whole)?;
Ok(compiled_pattern.is_match(value))
} else {
Ok(false)
}
}
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
let captures = INEQUALITY_EXPR.captures(is).context("bad is clause")?;
let ineq = captures.get(1).map(|m| m.as_str()).unwrap_or("==");
let rhs: u64 = captures
.get(2)
.context("missing number")?
.as_str()
.parse()?;
let matches = match ineq {
"" | "==" => self.room_member_count == rhs,
"<" => self.room_member_count < rhs,
">" => self.room_member_count > rhs,
">=" => self.room_member_count >= rhs,
"<=" => self.room_member_count <= rhs,
_ => false,
};
Ok(matches)
}
}
#[test]
fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string());
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
10,
0,
BTreeMap::new(),
BTreeMap::new(),
true,
)
.unwrap();
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
assert_eq!(result.len(), 3);
}

393
rust/src/push/mod.rs Normal file
View File

@@ -0,0 +1,393 @@
//! An implementation of Matrix push rules.
//!
//! The `Cow<_>` type is used extensively within this module to allow creating
//! the base rules as constants (in Rust constants can't require explicit
//! allocation atm).
use std::borrow::Cow;
use std::collections::{BTreeMap, HashMap};
use anyhow::{Context, Error};
use lazy_static::lazy_static;
use log::warn;
use pyo3::prelude::*;
use pythonize::pythonize;
use regex::Regex;
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use self::evaluator::PushRuleEvaluator;
mod base_rules;
mod evaluator;
mod utils;
lazy_static! {
static ref INEQUALITY_EXPR: Regex = Regex::new(r"^([=<>]*)([0-9]*)$").expect("valid regex");
static ref WORD_BOUNDARY_EXPR: Regex = Regex::new(r"\W*\b\W*").expect("valid regex");
static ref WILDCARD_RUN: Regex = Regex::new(r"([^\?\*]*)([\?\*]*)").expect("valid regex");
}
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "push")?;
child_module.add_class::<PushRule>()?;
child_module.add_class::<PushRules>()?;
child_module.add_class::<PushRuleEvaluator>()?;
child_module.add_class::<FilteredPushRules>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import push` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.push", child_module)?;
Ok(())
}
#[derive(Debug, Clone)]
#[pyclass(frozen)]
pub struct PushRule {
pub rule_id: Cow<'static, str>,
#[pyo3(get)]
pub priority_class: i32,
pub conditions: Cow<'static, [Condition]>,
pub actions: Cow<'static, [Action]>,
#[pyo3(get)]
pub default: bool,
#[pyo3(get)]
pub default_enabled: bool,
}
#[pymethods]
impl PushRule {
#[staticmethod]
pub fn from_db(
rule_id: String,
priority_class: i32,
conditions: &str,
actions: &str,
) -> Result<PushRule, Error> {
let conditions = serde_json::from_str(conditions).context("parsing conditions")?;
let actions = serde_json::from_str(actions).context("parsing actions")?;
Ok(PushRule {
rule_id: Cow::Owned(rule_id),
priority_class,
conditions,
actions,
default: false,
default_enabled: true,
})
}
#[getter]
fn rule_id(&self) -> &str {
&self.rule_id
}
#[getter]
fn actions(&self) -> Vec<Action> {
self.actions.clone().into_owned()
}
#[getter]
fn conditions(&self) -> Vec<Condition> {
self.conditions.clone().into_owned()
}
fn __repr__(&self) -> String {
format!(
"<PushRule rule_id={}, conditions={:?}, actions={:?}>",
self.rule_id, self.conditions, self.actions
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Action {
DontNotify,
Notify,
Coalesce,
SetTweak(SetTweak),
}
impl IntoPy<PyObject> for Action {
fn into_py(self, py: Python<'_>) -> PyObject {
pythonize(py, &self).expect("valid action")
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct SetTweak {
set_tweak: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none")]
value: Option<TweakValue>,
// This picks saves any other fields that may have been added as clients.
// These get added when we convert the `Action` to a python object.
#[serde(flatten)]
other_keys: Value,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum TweakValue {
String(Cow<'static, str>),
Other(Value),
}
impl Serialize for Action {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
Action::DontNotify => serializer.serialize_str("dont_notify"),
Action::Notify => serializer.serialize_str("notify"),
Action::Coalesce => serializer.serialize_str("coalesce"),
Action::SetTweak(tweak) => tweak.serialize(serializer),
}
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum ActionDeserializeHelper {
Str(String),
SetTweak(SetTweak),
}
impl<'de> Deserialize<'de> for Action {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let helper: ActionDeserializeHelper = Deserialize::deserialize(deserializer)?;
match helper {
ActionDeserializeHelper::Str(s) => match &*s {
"dont_notify" => Ok(Action::DontNotify),
"notify" => Ok(Action::Notify),
"coalesce" => Ok(Action::Coalesce),
_ => Err(D::Error::custom("unrecognized action")),
},
ActionDeserializeHelper::SetTweak(set_tweak) => Ok(Action::SetTweak(set_tweak)),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
#[serde(tag = "kind")]
pub enum Condition {
EventMatch(EventMatchCondition),
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
is: Option<Cow<'static, str>>,
},
SenderNotificationPermission {
key: Cow<'static, str>,
},
#[serde(rename = "org.matrix.msc3772.relation_match")]
RelationMatch {
rel_type: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none")]
sender: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
sender_type: Option<Cow<'static, str>>,
},
}
impl IntoPy<PyObject> for Condition {
fn into_py(self, py: Python<'_>) -> PyObject {
pythonize(py, &self).expect("valid condition")
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct EventMatchCondition {
key: Cow<'static, str>,
#[serde(skip_serializing_if = "Option::is_none")]
pattern: Option<Cow<'static, str>>,
#[serde(skip_serializing_if = "Option::is_none")]
pattern_type: Option<Cow<'static, str>>,
}
#[derive(Debug, Clone, Default)]
#[pyclass(frozen)]
struct PushRules {
overridden_base_rules: HashMap<Cow<'static, str>, PushRule>,
override_rules: Vec<PushRule>,
content: Vec<PushRule>,
room: Vec<PushRule>,
sender: Vec<PushRule>,
underride: Vec<PushRule>,
}
#[pymethods]
impl PushRules {
#[new]
fn new(rules: Vec<PushRule>) -> PushRules {
let mut push_rules: PushRules = Default::default();
for rule in rules {
if let Some(&o) = base_rules::BASE_RULES_BY_ID.get(&*rule.rule_id) {
push_rules.overridden_base_rules.insert(
rule.rule_id.clone(),
PushRule {
actions: rule.actions.clone(),
..o.clone()
},
);
continue;
}
match rule.priority_class {
5 => push_rules.override_rules.push(rule),
4 => push_rules.content.push(rule),
3 => push_rules.room.push(rule),
2 => push_rules.sender.push(rule),
1 => push_rules.underride.push(rule),
_ => {
warn!(
"Unrecognized priority class for rule {}: {}",
rule.rule_id, rule.priority_class
);
}
}
}
push_rules
}
fn rules(&self) -> Vec<PushRule> {
self.iter().cloned().collect()
}
}
impl PushRules {
pub fn iter(&self) -> impl Iterator<Item = &PushRule> {
base_rules::BASE_PREPEND_OVERRIDE_RULES
.iter()
.chain(self.override_rules.iter())
.chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
.chain(self.content.iter())
.chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
.chain(self.room.iter())
.chain(self.sender.iter())
.chain(self.underride.iter())
.chain(base_rules::BASE_APPEND_UNDERRIDE_RULES.iter())
.map(|rule| {
self.overridden_base_rules
.get(&*rule.rule_id)
.unwrap_or(rule)
})
}
}
#[derive(Debug, Clone, Default)]
#[pyclass(frozen)]
pub struct FilteredPushRules {
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3786_enabled: bool,
msc3772_enabled: bool,
}
#[pymethods]
impl FilteredPushRules {
#[new]
fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
msc3786_enabled: bool,
msc3772_enabled: bool,
) -> Self {
Self {
push_rules,
enabled_map,
msc3786_enabled,
msc3772_enabled,
}
}
fn rules(&self) -> Vec<(PushRule, bool)> {
self.iter().map(|(r, e)| (r.clone(), e)).collect()
}
}
impl FilteredPushRules {
fn iter(&self) -> impl Iterator<Item = (&PushRule, bool)> {
self.push_rules
.iter()
.filter(|rule| {
// Ignore disabled experimental push rules
if !self.msc3786_enabled
&& rule.rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
{
return false;
}
if !self.msc3772_enabled
&& rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
{
return false;
}
true
})
.map(|r| {
let enabled = *self
.enabled_map
.get(&*r.rule_id)
.unwrap_or(&r.default_enabled);
(r, enabled)
})
}
}
#[test]
fn split_string() {
let split_body: Vec<_> = WORD_BOUNDARY_EXPR
.split("this is. A. TEST!!")
.filter(|s| *s != "")
.collect();
assert_eq!(split_body, ["this", "is", "A", "TEST"]);
}
#[test]
fn test_erialize_condition() {
let condition = Condition::EventMatch(EventMatchCondition {
key: "content.body".into(),
pattern: Some("coffee".into()),
pattern_type: None,
});
let json = serde_json::to_string(&condition).unwrap();
assert_eq!(
json,
r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#
)
}
#[test]
fn test_deserialize_condition() {
let json = r#"{"kind":"event_match","key":"content.body","pattern":"coffee"}"#;
let _: Condition = serde_json::from_str(json).unwrap();
}
#[test]
fn test_deserialize_action() {
let _: Action = serde_json::from_str(r#""notify""#).unwrap();
let _: Action = serde_json::from_str(r#""dont_notify""#).unwrap();
let _: Action = serde_json::from_str(r#""coalesce""#).unwrap();
let _: Action = serde_json::from_str(r#"{"set_tweak": "highlight"}"#).unwrap();
}

115
rust/src/push/utils.rs Normal file
View File

@@ -0,0 +1,115 @@
use anyhow::bail;
use anyhow::Context;
use anyhow::Error;
use regex;
use regex::Regex;
use regex::RegexBuilder;
use super::WILDCARD_RUN;
pub(crate) fn get_localpart_from_id(id: &str) -> Result<&str, Error> {
let (localpart, _) = id
.split_once(':')
.with_context(|| format!("ID does not contain colon: {id}"))?;
// We need to strip off the first character, which is the ID type.
if localpart.is_empty() {
bail!("Invalid ID {id}");
}
Ok(&localpart[1..])
}
pub(crate) enum GlobMatchType {
Whole,
Word,
}
pub(crate) fn glob_to_regex(glob: &str, match_type: GlobMatchType) -> Result<Regex, Error> {
let mut chunks = Vec::new();
for captures in WILDCARD_RUN.captures_iter(glob) {
if let Some(chunk) = captures.get(1) {
chunks.push(regex::escape(chunk.as_str()));
}
if let Some(wildcards) = captures.get(2) {
if wildcards.as_str() == "" {
continue;
}
let question_marks = wildcards.as_str().chars().filter(|c| *c == '?').count();
if wildcards.as_str().contains('*') {
chunks.push(format!(".{{{question_marks},}}"));
} else {
chunks.push(format!(".{{{question_marks}}}"));
}
}
}
let joined = chunks.join("");
let regex_str = match match_type {
GlobMatchType::Whole => format!(r"\A{joined}\z"),
// `^|\W` and `\W|$` handle the case where `pattern` starts or ends with a non-word
// character.
GlobMatchType::Word => format!(r"(?:^|\W|\b){joined}(?:\b|\W|$)"),
};
Ok(RegexBuilder::new(&regex_str)
.case_insensitive(true)
.build()?)
}
#[test]
fn test_get_domain_from_id() {
get_localpart_from_id("").unwrap_err();
get_localpart_from_id(":").unwrap_err();
get_localpart_from_id(":asd").unwrap_err();
get_localpart_from_id("::as::asad").unwrap_err();
assert_eq!(get_localpart_from_id("@test:foo").unwrap(), "test");
assert_eq!(get_localpart_from_id("@:").unwrap(), "");
assert_eq!(get_localpart_from_id("@test:foo:907").unwrap(), "test");
}
#[test]
fn tset_glob() -> Result<(), Error> {
assert_eq!(
glob_to_regex("simple", GlobMatchType::Whole)?.as_str(),
r"\Asimple\z"
);
assert_eq!(
glob_to_regex("simple*", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{0,}\z"
);
assert_eq!(
glob_to_regex("simple?", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{1}\z"
);
assert_eq!(
glob_to_regex("simple?*?*", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{2,}\z"
);
assert_eq!(
glob_to_regex("simple???", GlobMatchType::Whole)?.as_str(),
r"\Asimple.{3}\z"
);
assert_eq!(
glob_to_regex("escape.", GlobMatchType::Whole)?.as_str(),
r"\Aescape\.\z"
);
assert_eq!(
glob_to_regex("simple", GlobMatchType::Word)?.as_str(),
r"\bsimple\b"
);
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("some simple."));
assert!(glob_to_regex("simple", GlobMatchType::Word)?.is_match("simple"));
Ok(())
}

View File

@@ -0,0 +1,47 @@
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
from synapse.types import JsonDict
class PushRule:
rule_id: str
priority_class: int
conditions: Sequence[Mapping[str, str]]
actions: Sequence[Union[Mapping[str, Any], str]]
default: bool
default_enabled: bool
@staticmethod
def from_db(
rule_id: str, priority_class: int, conditions: str, actions: str
) -> "PushRule": ...
class PushRules:
def __init__(self, rules: Collection[PushRule]): ...
def rules(self) -> Collection[PushRule]: ...
class FilteredPushRules:
def __init__(
self,
push_rules: PushRules,
enabled_map: Dict[str, bool],
msc3786_enabled: bool,
msc3772_enabled: bool,
): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
class PushRuleEvaluator:
def __init__(
self,
flattened_keys: Mapping[str, str],
room_member_count: int,
sender_power_level: int,
notification_power_levels: Mapping[str, int],
relations: Mapping[str, Set[Tuple[str, str]]],
relation_match_enabled: bool,
): ...
def run(
self,
push_rules: FilteredPushRules,
user_id: Optional[str],
display_name: Optional[str],
) -> Collection[dict]: ...

View File

@@ -34,16 +34,15 @@ from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.push.push_rule_evaluator import _flatten_dict
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
from .baserules import FilteredPushRules, PushRule
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -282,14 +281,16 @@ class BulkPushRuleEvaluator:
) = await self._get_power_levels_and_sender_level(event, context)
relations = await self._get_mutual_relations(
event, itertools.chain(*rules_by_user.values())
event, itertools.chain(*(r.rules() for r in rules_by_user.values()))
)
evaluator = PushRuleEvaluatorForEvent(
event,
logger.info("Flatten map: %s", _flatten_dict(event))
logger.info("room_member_count: %s", room_member_count)
evaluator = PushRuleEvaluator(
_flatten_dict(event),
room_member_count,
sender_power_level,
power_levels,
power_levels.get("notifications", {}),
relations,
self._relations_match_enabled,
)
@@ -333,17 +334,7 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
for rule, enabled in rules:
if not enabled:
continue
matches = evaluator.check_conditions(rule.conditions, uid, display_name)
if matches:
actions = [x for x in rule.actions if x != "dont_notify"]
if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
actions_by_user[uid] = evaluator.run(rules, uid, display_name)
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist

View File

@@ -16,10 +16,9 @@ import copy
from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.synapse_rust.push import FilteredPushRules, PushRule
from synapse.types import UserID
from .baserules import FilteredPushRules, PushRule
def format_push_rules_for_user(
user: UserID, ruleslist: FilteredPushRules
@@ -34,7 +33,7 @@ def format_push_rules_for_user(
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r, enabled in ruleslist:
for r, enabled in ruleslist.rules():
template_name = _priority_class_to_template_name(r.priority_class)
rulearray = rules["global"][template_name]

View File

@@ -30,9 +30,8 @@ from typing import (
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -51,6 +50,7 @@ from synapse.storage.util.id_generators import (
IdGenerator,
StreamIdGenerator,
)
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@@ -72,18 +72,26 @@ def _load_rules(
"""
ruleslist = [
PushRule(
PushRule.from_db(
rule_id=rawrule["rule_id"],
priority_class=rawrule["priority_class"],
conditions=db_to_json(rawrule["conditions"]),
actions=db_to_json(rawrule["actions"]),
conditions=rawrule["conditions"],
actions=rawrule["actions"],
)
for rawrule in rawrules
]
push_rules = compile_push_rules(ruleslist)
push_rules = PushRules(
ruleslist,
)
filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
# TODO: Experimental config
filtered_rules = FilteredPushRules(
push_rules,
enabled_map,
msc3786_enabled=experimental_config.msc3786_enabled,
msc3772_enabled=experimental_config.msc3772_enabled,
)
return filtered_rules
@@ -845,7 +853,7 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
for rule, enabled in user_push_rules:
for rule, enabled in user_push_rules.rules():
if not enabled:
continue

View File

@@ -15,11 +15,11 @@
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import AccountDataTypes
from synapse.push.baserules import PushRule
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
from synapse.server import HomeServer
from synapse.synapse_rust.push import PushRule
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -161,20 +161,15 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self._store.get_push_rules_for_user(self.user)
)
# Filter out default rules; we don't care
push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
push_rules = [
r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
]
# Check our rule made it
self.assertEqual(
push_rules,
[
PushRule(
rule_id="personal.override.rule1",
priority_class=5,
conditions=[],
actions=[],
)
],
push_rules,
)
self.assertEqual(len(push_rules), 1)
self.assertEqual(push_rules[0].rule_id, "personal.override.rule1")
self.assertEqual(push_rules[0].priority_class, 5)
self.assertEqual(push_rules[0].conditions, [])
self.assertEqual(push_rules[0].actions, [])
# Request the deactivation of our account
self._deactivate_my_account()
@@ -183,7 +178,9 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self._store.get_push_rules_for_user(self.user)
)
# Filter out default rules; we don't care
push_rules = [r for r, _ in filtered_push_rules if self._is_custom_rule(r)]
push_rules = [
r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)