use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use crate::{
config::SessionConfig,
memory_pool::MemoryPool,
registry::FunctionRegistry,
runtime_env::{RuntimeConfig, RuntimeEnv},
};
use datafusion_common::{plan_datafusion_err, DataFusionError, Result};
use datafusion_expr::planner::ExprPlanner;
use datafusion_expr::{AggregateUDF, ScalarUDF, WindowUDF};
#[derive(Debug)]
pub struct TaskContext {
session_id: String,
task_id: Option<String>,
session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
runtime: Arc<RuntimeEnv>,
}
impl Default for TaskContext {
fn default() -> Self {
let runtime = RuntimeEnv::new(RuntimeConfig::new())
.expect("default runtime created successfully");
Self {
session_id: "DEFAULT".to_string(),
task_id: None,
session_config: SessionConfig::new(),
scalar_functions: HashMap::new(),
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
runtime: Arc::new(runtime),
}
}
}
impl TaskContext {
pub fn new(
task_id: Option<String>,
session_id: String,
session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
runtime: Arc<RuntimeEnv>,
) -> Self {
Self {
task_id,
session_id,
session_config,
scalar_functions,
aggregate_functions,
window_functions,
runtime,
}
}
pub fn session_config(&self) -> &SessionConfig {
&self.session_config
}
pub fn session_id(&self) -> String {
self.session_id.clone()
}
pub fn task_id(&self) -> Option<String> {
self.task_id.clone()
}
pub fn memory_pool(&self) -> &Arc<dyn MemoryPool> {
&self.runtime.memory_pool
}
pub fn runtime_env(&self) -> Arc<RuntimeEnv> {
Arc::clone(&self.runtime)
}
pub fn with_session_config(mut self, session_config: SessionConfig) -> Self {
self.session_config = session_config;
self
}
pub fn with_runtime(mut self, runtime: Arc<RuntimeEnv>) -> Self {
self.runtime = runtime;
self
}
}
impl FunctionRegistry for TaskContext {
fn udfs(&self) -> HashSet<String> {
self.scalar_functions.keys().cloned().collect()
}
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>> {
let result = self.scalar_functions.get(name);
result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no UDF named \"{name}\" in the TaskContext")
})
}
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
let result = self.aggregate_functions.get(name);
result.cloned().ok_or_else(|| {
plan_datafusion_err!("There is no UDAF named \"{name}\" in the TaskContext")
})
}
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
let result = self.window_functions.get(name);
result.cloned().ok_or_else(|| {
DataFusionError::Internal(format!(
"There is no UDWF named \"{name}\" in the TaskContext"
))
})
}
fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
udaf.aliases().iter().for_each(|alias| {
self.aggregate_functions
.insert(alias.clone(), Arc::clone(&udaf));
});
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}
fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
udwf.aliases().iter().for_each(|alias| {
self.window_functions
.insert(alias.clone(), Arc::clone(&udwf));
});
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
udf.aliases().iter().for_each(|alias| {
self.scalar_functions
.insert(alias.clone(), Arc::clone(&udf));
});
Ok(self.scalar_functions.insert(udf.name().into(), udf))
}
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
vec![]
}
}
#[cfg(test)]
mod tests {
use super::*;
use datafusion_common::{
config::{ConfigExtension, ConfigOptions, Extensions},
extensions_options,
};
extensions_options! {
struct TestExtension {
value: usize, default = 42
}
}
impl ConfigExtension for TestExtension {
const PREFIX: &'static str = "test";
}
#[test]
fn task_context_extensions() -> Result<()> {
let runtime = Arc::new(RuntimeEnv::default());
let mut extensions = Extensions::new();
extensions.insert(TestExtension::default());
let mut config = ConfigOptions::new().with_extensions(extensions);
config.set("test.value", "24")?;
let session_config = SessionConfig::from(config);
let task_context = TaskContext::new(
Some("task_id".to_string()),
"session_id".to_string(),
session_config,
HashMap::default(),
HashMap::default(),
HashMap::default(),
runtime,
);
let test = task_context
.session_config()
.options()
.extensions
.get::<TestExtension>();
assert!(test.is_some());
assert_eq!(test.unwrap().value, 24);
Ok(())
}
}