import os
import pymongo
from loguru import logger
from typing import Union, List, Dict, Any, Tuple
class MongoClient:
def __init__(self, host="127.0.0.1", port=27017, db_name=None, user=None, pwd=None, uri=None):
try:
if uri: # 支持直接传 MongoDB URI
self.client = pymongo.MongoClient(uri)
elif user and pwd:
self.client = pymongo.MongoClient(
f"mongodb://{user}:{pwd}@{host}:{port}/{db_name}"
)
else:
self.client = pymongo.MongoClient(f"mongodb://{host}:{port}")
self.db = self.client[db_name] if db_name else None
logger.success("✅ MongoDB 连接成功!")
except Exception as e:
logger.error(f"❌ MongoDB 连接失败:{e}")
os._exit(1)
@staticmethod
def _serialize(doc):
"""将 ObjectId 转换成字符串"""
if isinstance(doc, list):
return [{**d, "_id": str(d["_id"])} for d in doc]
elif isinstance(doc, dict) and "_id" in doc:
doc["_id"] = str(doc["_id"])
return doc
def set_db(self, db_name: str):
"""切换数据库"""
self.db = self.client[db_name]
# ------------------- 增 -------------------
def insert(self, col: str, data: Union[Dict, List[Dict]]) -> Tuple[bool, Any]:
try:
if isinstance(data, dict):
res = self.db[col].insert_one(data)
return True, str(res.inserted_id)
else:
res = self.db[col].insert_many(data)
return True, [str(_id) for _id in res.inserted_ids]
except Exception as e:
logger.error(f"插入数据失败: {e}")
return False, e
# ------------------- 查 -------------------
def find(self, col: str, query: Dict = None, one=False, skip=0, limit=0, sort=None, to_str=True):
query = query or {}
try:
if one:
doc = self.db[col].find_one(query)
return True, self._serialize(doc) if to_str else doc
else:
cursor = self.db[col].find(query).skip(skip)
if limit:
cursor = cursor.limit(limit)
if sort:
cursor = cursor.sort(sort)
docs = list(cursor)
return True, self._serialize(docs) if to_str else docs
except Exception as e:
logger.error(f"查询失败: {e}")
return False, e
# ------------------- 改 -------------------
def update(self, col: str, query: Dict, update: Dict, many=False):
try:
if many:
res = self.db[col].update_many(query, update)
else:
res = self.db[col].update_one(query, update)
return True, res.modified_count
except Exception as e:
logger.error(f"更新失败: {e}")
return False, e
# ------------------- 删 -------------------
def delete(self, col: str, query: Dict, many=False):
try:
if many:
res = self.db[col].delete_many(query)
else:
res = self.db[col].delete_one(query)
return True, res.deleted_count
except Exception as e:
logger.error(f"删除失败: {e}")
return False, e
# ------------------- 聚合 -------------------
def aggregate(self, col: str, pipeline: List[Dict]):
try:
res = list(self.db[col].aggregate(pipeline))
return True, self._serialize(res)
except Exception as e:
logger.error(f"聚合失败: {e}")
return False, e
# ------------------- 统计 -------------------
def count(self, col: str, query: Dict = None):
query = query or {}
try:
return True, self.db[col].count_documents(query)
except Exception as e:
logger.error(f"统计失败: {e}")
return False, e
if __name__ == "__main__":
mongo = MongoClient(db_name="cup")
# 插入单条
mongo.insert("test", {"name": "test1", "value": 123})
# 插入多条
mongo.insert("test", [{"name": "test2"}, {"name": "test3"}])
# 查询所有
ok, data = mongo.find("test")
print(data)
# 查询一条
ok, doc = mongo.find("test", {"name": "test2"}, one=True)
print(doc)
# 更新
mongo.update("test", {"name": "test3"}, {"$set": {"value": 959}})
# 删除
mongo.delete("test", {"name": "test3"})