some code v3
// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use async_trait::async_trait; use datafusion::arrow::array::{UInt64Builder, UInt8Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::dataframe::DataFrame; use datafusion::datasource::{TableProvider, TableType}; use datafusion::error::Result; use datafusion::execution::context::{SessionState, TaskContext}; use datafusion::logical_plan::{provider_as_source, Expr, LogicalPlanBuilder}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryStream; use datafusion::physical_plan::{ project_schema, ExecutionPlan, SendableRecordBatchStream, Statistics, }; use datafusion::prelude::*; use std::any::Any; use std::collections::{BTreeMap, HashMap}; use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::time::timeout; /// This example demonstrates executing a simple query against a custom datasource #[tokio::main] async fn main() -> Result<()> { // create our custom datasource and adding some users let db = CustomDataSource::default(); db.populate_users(); let ctx = SessionContext::new(); ctx.register_table("abc", Arc::new(db))?; let df = ctx.sql("select * from abc").await?; df.show().await?; // println!("{:?}", df.s); // search_accounts(db.clone(), None, 3).await?; // search_accounts(db.clone(), Some(col("bank_account").gt(lit(8000u64))), 1).await?; // search_accounts(db.clone(), Some(col("bank_account").gt(lit(200u64))), 2).await?; Ok(()) } /// A User, with an id and a bank account #[derive(Clone, Debug)] struct User { id: u8, bank_account: u64, } /// A custom datasource, used to represent a datastore with a single index #[derive(Clone)] pub struct CustomDataSource { inner: Arc<Mutex<CustomDataSourceInner>>, } struct CustomDataSourceInner { data: HashMap<u8, User>, bank_account_index: BTreeMap<u64, u8>, } impl Debug for CustomDataSource { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.write_str("custom_db") } } impl CustomDataSource { pub(crate) async fn create_physical_plan( &self, projections: &Option<Vec<usize>>, schema: SchemaRef, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(Arc::new(CustomExec::new(projections, schema, self.clone()))) } pub(crate) fn populate_users(&self) { self.add_user(User { id: 1, bank_account: 9_000, }); self.add_user(User { id: 2, bank_account: 100, }); self.add_user(User { id: 3, bank_account: 1_000, }); } fn add_user(&self, user: User) { let mut inner = self.inner.lock().unwrap(); inner.bank_account_index.insert(user.bank_account, user.id); inner.data.insert(user.id, user); } } impl Default for CustomDataSource { fn default() -> Self { CustomDataSource { inner: Arc::new(Mutex::new(CustomDataSourceInner { data: Default::default(), bank_account_index: Default::default(), })), } } } #[async_trait] impl TableProvider for CustomDataSource { fn as_any(&self) -> &dyn Any { self } fn schema(&self) -> SchemaRef { SchemaRef::new(Schema::new(vec![ Field::new("id", DataType::UInt8, false), Field::new("bank_account", DataType::UInt64, true), ])) } fn table_type(&self) -> TableType { TableType::Base } async fn scan( &self, _state: &SessionState, projection: &Option<Vec<usize>>, // filters and limit can be used here to inject some push-down operations if needed _filters: &[Expr], _limit: Option<usize>, ) -> Result<Arc<dyn ExecutionPlan>> { return self.create_physical_plan(projection, self.schema()).await; } } #[derive(Debug, Clone)] struct CustomExec { db: CustomDataSource, projected_schema: SchemaRef, } impl CustomExec { fn new( projections: &Option<Vec<usize>>, schema: SchemaRef, db: CustomDataSource, ) -> Self { let projected_schema = project_schema(&schema, projections.as_ref()).unwrap(); Self { db, projected_schema, } } } impl ExecutionPlan for CustomExec { fn as_any(&self) -> &dyn Any { self } fn schema(&self) -> SchemaRef { self.projected_schema.clone() } fn output_partitioning(&self) -> datafusion::physical_plan::Partitioning { datafusion::physical_plan::Partitioning::UnknownPartitioning(1) } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { None } fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { vec![] } fn with_new_children( self: Arc<Self>, _: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(self) } fn execute( &self, _partition: usize, _context: Arc<TaskContext>, ) -> Result<SendableRecordBatchStream> { let users: Vec<User> = { let db = self.db.inner.lock().unwrap(); db.data.values().cloned().collect() }; let mut id_array = UInt8Builder::new(users.len()); let mut account_array = UInt64Builder::new(users.len()); for user in users { id_array.append_value(user.id); account_array.append_value(user.bank_account); } Ok(Box::pin(MemoryStream::try_new( vec![RecordBatch::try_new( self.projected_schema.clone(), vec![ Arc::new(id_array.finish()), Arc::new(account_array.finish()), ], )?], self.schema(), None, )?)) } fn statistics(&self) -> Statistics { todo!() } } // ================================================================================================================================ async fn search_accounts( db: CustomDataSource, filter: Option<Expr>, expected_result_length: usize, ) -> Result<()> { // create local execution context let ctx = SessionContext::new(); // create logical plan composed of a single TableScan let logical_plan = LogicalPlanBuilder::scan_with_filters( "accounts", provider_as_source(Arc::new(db)), None, vec![], ) .unwrap() .build() .unwrap(); let mut dataframe = DataFrame::new(ctx.state, &logical_plan) .select_columns(&["id", "bank_account"])?; if let Some(f) = filter { dataframe = dataframe.filter(f)?; } timeout(Duration::from_secs(10), async move { let result = dataframe.collect().await.unwrap(); let record_batch = result.get(0).unwrap(); assert_eq!(expected_result_length, record_batch.column(1).len()); dbg!(record_batch.columns()); }) .await .unwrap(); Ok(()) }