【Naive Splay Template】

写小作业的时候重新复习了一下splay
只支持插入,删除,查k大,查节点数。没有迭代器。
T类型需要重载==和<,要调用拷贝构造函数。

template<class T>
class Splay {
private:
	struct node {
		T v;
		node *ch[2], *fa;
		int size;
		node(const T &a) : size(1), v(a), ch{nullptr, nullptr}, fa(nullptr) {};
		void setc(node *r, int c) {
			ch[c] = r;
			if (r != nullptr) r->fa = this;
		}
		int pl() {
			if (fa != nullptr) return fa->ch[1] == this;
			else return 0;
		}
		void count() {
			size = 1;
			if (ch[0] != nullptr) size += ch[0]->size;
			if (ch[1] != nullptr) size += ch[1]->size;
		}
	};
	node *root;

	void release(node *r) {
		if (r == nullptr) return;
		release(r->ch[0]);
		release(r->ch[1]);
		delete r;
	}

public:
	Splay() : root(nullptr) {}

	~Splay() {
		release(root);
	}

private:
	void rotate(node *r) {
		node *f = r->fa;
		int c = r->pl();
		if (f == root) r->fa = nullptr, root = r;
		else f->fa->setc(r, f->pl());
		f->setc(r->ch[c ^ 1], c);
		r->setc(f, c ^ 1);
		f->count();
	}

	void splay(node *r, node *tar = nullptr) {
		for(; r->fa != tar; rotate(r))
			if (r->fa->fa != tar) rotate(r->fa->pl() == r->pl() ? r->fa : r);
		r->count();
	}

	void rm(node *);

public:
	int size();
	
	void remove(const T &);

	void insert(const T &);

	T *kth(int);
};

template<class T>
int Splay<T>::size() {
	if (root == nullptr) return 0;
	else return root->size;
}

template<class T>
void Splay<T>::rm(node *r) {
	node *f = nullptr;
	if (r->ch[0] == nullptr && r->ch[1] == nullptr) {
		if (r == root) root = nullptr;
		else {
			f = r->fa;
			r->fa->setc(nullptr, r->pl());
			delete r;
		}
	} else if (r->ch[0] == nullptr || r->ch[1] == nullptr) {
		int c = r->ch[0] == nullptr;
		node *t = r->ch[c];
		while (t->ch[c ^ 1] != nullptr) t = t->ch[c ^ 1];
		splay(t, r->fa);
		r->fa->setc(nullptr, c ^ 1);
		f = r->fa;
		delete r;
	} else {
		node *h = r->ch[0], *t = r->ch[1];
		while (h->ch[1] != nullptr)
			h = h->ch[1];
		while (t->ch[0] != nullptr) t = t->ch[0];
		splay(h, r->fa);
		splay(t, h);
		t->setc(nullptr, 0);
		delete r;
		f = t;
	}

	while (f != nullptr) {
		f->count();
		f = f->fa;
	}
}

template<class T>
void Splay<T>::remove(const T &a) {
	node *r = root;
	while (r != nullptr) {
		if (r->v == a) {rm(r); break;}
		if (a < r->v) r = r->ch[0];
		else r = r->ch[1];
	}
}

template<class T>
void Splay<T>::insert(const T &a) {
	node *r = root;
	node *n = new node(a);
	if (root == nullptr) {
		root = n;
		return;
	}
	while (r != nullptr) {
		if (r->v == a) {delete n; break;}
		if (a < r->v) {
			if (r->ch[0] == nullptr) {
				r->setc(n, 0);
				splay(n);
				break;
			}
			else
				r = r->ch[0];
		} else {
			if (r->ch[1] == nullptr) {
				r->setc(n, 1);
				splay(n);
				break;
			}
			else
				r = r->ch[1];
		}
	}
}

template<class T>
T *Splay<T>::kth(int k) {
	node *r = root;
	while (r != nullptr) {
		int l_size = 0;
		if (r->ch[0] != nullptr) l_size = r->ch[0]->size;
		if (l_size >= k) r = r->ch[0];
		else if (l_size + 1 == k) return &r->v;
		else {
			k -= (l_size + 1);
			r = r->ch[1];
		}
	}
	return nullptr;
}
posted @ 2019-03-24 17:09  abclzr  阅读(211)  评论(0编辑  收藏  举报