春招Coding总结
写在前面
练习coding旨在复习,而非学习。
通过短时间练习,提高手感及套路化思维,因为面试大家问的主要还是dp,贪心,以及具体问题具体分析的题目,代码量通常不会很长,一般属于知道解法秒ac,不知道解法无限超时和爆内存的题。
至少过一遍剑指offer(约4天),我遇到的原题大半来自于此处。
leetcode随意,旨在复习而非提高自己解题水平。
推荐一个解题技巧的repo,通读一遍(刷完约1周)绝对有所帮助。
至少熟练LCS,LIS,股票买卖,区间问题,堆,链表(部分反转,环等),非递归遍历,排序,topk,字典序,滑动窗口,双指针等内容,的确容易遇到,无论是在笔试还是面试中。
个人笔记
不太全面,仅供参考。归根到底,纸上得来终觉浅,绝知此事要躬行。
凑零钱问题
给你 k 种⾯值的硬币,⾯值分别为 c1, c2 ... ck ,每种硬币的数量⽆限,再给⼀个总⾦额 amount ,问你最少需要⼏枚硬币凑出这个⾦额,如果不可能凑出,算法返回 -1 。
class MinCoin:
def __init__(self,coins,amount):
self.coins,self.amount = coins,amount
def ans(self):
cache = [0]
for i in range(1,self.amount+1):
cache.append(min([cache[i-coin]+1 if i-coin >= 0 else float('inf') for coin in self.coins]))
return cache[self.amount] if cache[self.amount] != float('inf') else -1
check = MinCoin([1,2,5],11)
print(check.ans())
最⻓递增⼦序列
给定无序数组,求最长上升子序列。
class LIS: # O(n^2)
def __init__(self,arr):
self.arr = arr
def ans(self):
if len(self.arr) <= 1:
return len(self.arr)
result = [1] * len(self.arr)
for i,x in enumerate(self.arr[1:],start=1):
result[i] = max([result[j]+1 if self.arr[i] > self.arr[j] else 1 for j in range(i)])
return max(result)
check = LIS([10,9,2,5,3,7,101,18])
print(check.ans())
class LIS: # O(nlogn)
def __init__(self,s):
self.s = s
def ans(self):
s = self.s
bucket = []
result = []
from bisect import bisect_left as bs
for i,x in enumerate(s):
idx = bs(bucket,x)
if idx >= len(bucket):
bucket.append(x)
result.append(result[idx-1] + [x] if idx > 0 else [x])
else:
bucket[idx] = x
result[idx] = result[idx-1] + [x] if idx > 0 else [x]
return len(bucket),result[-1],result
check = LIS([10,9,2,5,3,7,18,14,19])
print(check.ans())
最长回文子序列
子序列与子串区别在于可以不连续 将此题变形为最长公共子序列(回文子串不可变形为最长公共子串,如aacxycaa)
class LCS:
def __init__(self, s):
self.s = s
def ans(self):
dp = [[0 for __ in range(len(self.s) + 1)] for _ in range(len(self.s) + 1)]
s1, s2 = self.s, self.s[::-1]
for i, ti in enumerate(s1, start=1):
for j, tj in enumerate(s2, start=1):
dp[i][j] = (1 + dp[i - 1][j - 1]) if ti == tj \
else max(dp[i - 1][j - 1], dp[i][j - 1], dp[i - 1][j])
return dp[len(self.s)][len(self.s)]
check = LCS("bxaby")
print(check.ans())
# LCS可变形为LIS,LIS通过二分查找可以降低复杂度为O(nlogn)
# 注意算法可能会退化,因此字符串要尽可能不一样更好
class LCS:
def __init__(self, s1, s2):
self.s1, self.s2 = s1, s2
def ans(self):
s1, s2 = self.s1, self.s2
from collections import defaultdict
hashtable = defaultdict(lambda: [])
for i, ti in enumerate(s2):
hashtable[ti].insert(0, i)
s = [hashtable[token] for token in s1]
from functools import reduce
s = reduce(lambda x, y: x + y, s) # flatten
def LIS(arr):
from bisect import bisect_left as bs
if len(arr) <= 1:
return len(arr)
bucket = []
for i, ti in enumerate(arr):
idx = bs(bucket, ti)
if idx >= len(bucket):
bucket.append(ti)
else:
bucket[idx] = ti
return len(bucket)
return LIS(s)
check = LCS("bxaby", "ybaxb")
print(check.ans())
最长回文子串
直接解或DP是O(n^2),这里给出O(n)的Manacher算法,主体思路是在已判定的回文串内自然左右是回文的,下一次比较就可以直接从边界开始。
class Manacher:
def __init__(self, s):
self.s = s
def ans(self):
# 预处理字符串为奇数,如 ab -> [SOS] #a#b# [EOS}
self.s = "$#" + "".join([x + "#" for x in self.s]) + "&"
# 要找到最长的回文子串span,自然就要记录中心mx(max idx)
mx, p = 1, [1 for _ in self.s]
for idx, token in enumerate(self.s[1:-1], start=1):
if idx < mx + p[mx]: # the span
p[idx] = min(mx + p[mx] - idx, p[2 * mx - idx])
else:
p[idx] = 1
# 有两种情况idx的回文span可以扩展,显然else肯定是要扩展的
# if里面 如果两者span都在边界上也是可以扩展的
while self.s[idx + p[idx]] == self.s[idx - p[idx]]:
p[idx] += 1
maxlen, result = max(p), ""
maxidx = p.index(maxlen)
result = self.s[maxidx - maxlen + 1:maxidx + maxlen].replace("#", "")
maxlen = len(result)
return maxlen, result
check = Manacher("aacxycaa")
print(check.ans())
字符串匹配(KMP)
应该注意到即使全部回退,最差情况下也只会平局遍历pattern 2遍,这个画出实际例子更好理解 因此复杂度是O(M+N)
class KMP:
def __init__(self, text: str, pattern: str):
self.text, self.pattern = text, pattern
def __str__(self):
return "Pattern:{} \nText:{}".format(self.pattern, self.text)
def brute(self):
text, pattern = self.text, self.pattern
for ti in range(len(text) - len(pattern) + 1):
if text[ti:ti + len(pattern)] == pattern:
return ti
return -1
def genNext(self):
# next[i]:0-i 包含i 作为子字符串,其前缀与后缀的最大重叠长度
# 为什么是长度? 因为kmp比较中我们比较的是失配的单个字符
# 直接使用长度访问pattern比较的就是那个失配的单个字符
pattern = self.pattern
# Brute Ways to Generate Next Array
# self.next = [0 for _ in pattern]
# for i in range(len(pattern)):
# for j in range(i, 0, -1):
# if pattern[:j] == pattern[i + 1 - j:i + 1]:
# self.next[i] = j
# return self.next
# DP Ways to Generate Next Array
self.next = [0]
cur,past = 1,0
while cur < len(pattern):
if pattern[cur] == pattern[past]:
cur,past = cur+1,past+1
self.next.append(past)
elif past:
past = self.next[past-1]
else:
self.next.append(0)
cur += 1
return self.next
def kmp(self):
if not hasattr(self, "next"):
self.next = self.genNext()
next = self.next
text, pattern = self.text, self.pattern
ti, pi = 0, 0
while ti < len(text):
if text[ti] == pattern[pi]:
ti, pi = ti + 1, pi + 1
elif pi:
pi = next[pi - 1]
else:
ti += 1
if pi == len(pattern):
return ti - len(pattern)
return -1
kmp = KMP("abcababbabcaabcabd", "abcaabca")
print(kmp.brute())
print(kmp.kmp())
编辑距离&输出编辑过程
插入|删除|替换
class EditDistance:
def __init__(self, s1, s2):
self.s1, self.s2 = s1, s2
def ans(self):
distance = [[0 for __ in range(len(self.s2)+1)] for _ in range(len(self.s1) + 1)]
ptr = [[(0,0) for __ in range(len(self.s2)+1)] for _ in range(len(self.s1) + 1)]
for i in range(len(self.s1) + 1):
distance[i][0] = i
ptr[i][0] = (i - 1, 0)
for j in range(len(self.s2) + 1):
distance[0][j] = j
ptr[0][j] = (0, j - 1)
for i, ti in enumerate(self.s1, start=1):
for j, tj in enumerate(self.s2, start=1):
# f(i-1,j-1) f(i-1,j) f(i,j-1)
if ti == tj:
distance[i][j] = distance[i - 1][j - 1]
ptr[i][j] = (i - 1, j - 1)
else:
distance[i][j] = 1 + min(distance[i - 1][j - 1], distance[i - 1][j], distance[i][j - 1])
if distance[i - 1][j - 1] == distance[i][j] - 1:
ptr[i][j] = (i - 1, j - 1)
elif distance[i - 1][j] == distance[i][j] - 1:
ptr[i][j] = (i - 1, j)
else:
ptr[i][j] = (i, j - 1)
process = []
s1, s2 = self.s1, self.s2
i, j = len(self.s1), len(self.s2)
while not (i == 0 and j == 0):
ni, nj = ptr[i][j]
if (ni, nj) == (i - 1, j - 1):
if s1[i - 1] != s2[j - 1]:
ns1 = s1[:i - 1] + s2[j - 1] + s1[i:]
process.append("REPLACE|"+s1 + "->" + ns1)
s1 = ns1
elif (ni, nj) == (i - 1, j):
ns1 = s1[:i - 1] + s1[i:]
process.append("DELETE|"+s1 + "->" + ns1)
s1 = ns1
else:
ns1 = s1[:i] + s2[j - 1] + s1[i:]
process.append("INSERT|"+s1 + "->" + ns1)
s1 = ns1
i, j = ni, nj
return distance[len(self.s1)][len(self.s2)], process
check = EditDistance("intention", "interesting")
print(check.ans())
高楼扔鸡蛋
若⼲层楼,若⼲个鸡蛋,让你算出最少的尝试次数,找到鸡蛋恰好摔不碎的那层楼。
# O(KN^2) O(KNlogN)
class BrokenEgg:
def __init__(self, N, K):
self.N, self.K = N, K
self.cache = {}
def ans(self):
return self.dp(self.N, self.K)
def dp(self, n, k):
if n <= 1 or k == 1:
return n
if (n, k) in self.cache:
return self.cache[(n, k)]
# self.cache[(n, k)] = min([max(self.dp(i - 1, k - 1), self.dp(n - i, k)) + 1 for i in range(1, n + 1)])
l, r = 1, n
res = float('inf')
while l <= r: # 这里和二分的区别在于要找最小值,因此左右两侧也需要比较一下看哪个更小
mid = (l + r) // 2
broken = self.dp(mid - 1, k - 1) # inc
not_broken = self.dp(n - mid, k) # dec
if broken > not_broken:
r = mid - 1
res = min(res,broken + 1)
elif broken < not_broken:
l = mid + 1
res = min(res,not_broken + 1)
else:
res = broken + 1
break
self.cache[(n,k)] = res
return self.cache[(n, k)]
check = BrokenEgg(100, 2)
res = check.ans()
print(res)
# O(KN)
class BrokenEgg:
def __init__(self, N, K):
self.N, self.K = N, K
self.cache = {}
def ans(self):
# k eggs, m max throw
# dp[k][m] = dp[k-1][m-1] + dp[k][m-1] + 1
# 巧妙地避开了遍历楼层的循环,因为无论哪层鸡蛋碎了,其dp方程都是一致的
dp = [[0 for __ in range(self.N+1)] for __ in range(self.K+1)]
m = 0
while dp[self.K][m] < self.N:
m += 1
for k in range(1,self.K+1):
dp[k][m] = dp[k-1][m-1] + dp[k][m-1] + 1
return m
⽯头游戏
你和你的朋友⾯前有⼀排⽯头堆,⽤⼀个数组 piles 表⽰,pilesi 表⽰第 i堆⽯⼦有多少个。 你们轮流拿⽯头,⼀次拿⼀堆,但是只能拿⾛最左边或者最右边的⽯头堆。所有⽯头被拿完后,谁拥有的⽯头多,谁获胜。 ⽯头的堆数可以是任意正整数,⽯头的总数也可以是任意正整数,这样就能打破先⼿必胜的局⾯了。 ⽐如有三堆⽯头 piles = 1, 100, 3 ,先⼿不管拿 1 还是 3,能够决定胜负的 100 都会被后⼿拿⾛,后⼿会获胜。 假设两⼈都很聪明,请你设计⼀个算法,返回先⼿和后⼿的最后得分(⽯头总数)之差。 ⽐如上⾯那个例⼦,先⼿能获得 4 分,后⼿会获得 100 分,你的算法应该返回 -96。
class DivideStone:
def __init__(self, w):
self.w = w
def ans(self):
w = self.w
dp = [[[0, 0] for __ in w] for _ in w]
for i, wi in enumerate(w):
dp[i][i] = [w[i], 0]
for height in range(1, len(w)):
for i in range(len(w) - height):
i, wi, j, wj = i, w[i], i + height, w[i + height]
first_left = wi + dp[i + 1][j][1]
first_right = wj + dp[i][j - 1][1]
if first_left > first_right:
dp[i][j][1] = dp[i + 1][j][0]
dp[i][j][0] = first_left
else:
dp[i][j][1] = dp[i][j - 1][0]
dp[i][j][0] = first_right
return dp[0][len(w) - 1]
check = DivideStone([3, 9,1, 2])
print(check.ans())
区间覆盖问题
区间覆盖问题如分配教室,最多参与几个活动等。
典型的贪心:
- 对于分配教室,我们需要将所有活动都进行分配,因此按开始时间顺次排列即可。 严格来说这个属于最多有几个区间重叠的问题,需要用堆来进行判断区间是否结束。 典型的例子比如最大并发数量。
- 对于区间覆盖,求的是最多有几个区间不重叠,我们关注的是重叠问题,因此按结束时间顺次取得即可。 区间覆盖的反面问题就是最少去掉多少区间使得区间不重叠,或者说多个区间最少取几个点使得每个区间都有点。(leetcode 435,452)。
class Overlap:
def __init__(self,arr):
self.arr = arr
self.leetcode = 435
def ans(self):
arr = self.arr
if len(arr) <= 1:
return len(arr)
arr = sorted(arr,key = lambda x:x[1])
cnt,last_end = 0,arr[0][1]
for x in arr[1:]:
if x[0] < last_end:
cnt += 1
last_end = x[1]
return cnt
check = Overlap([ [1,2], [2,3], [3,4], [1,3] ])
print(check.ans())
class LeastArrow:
def __init__(self,arr):
self.arr = arr
self.leetcode = 452
def ans(self):
arr = self.arr
if len(arr) <= 1:
return len(arr)
arr = sorted(arr,key=lambda x:x[1])
cnt,last_end = 1,arr[0][1]
for x in arr[1:]:
if x[0] > last_end:
cnt += 1
last_end = x[1]
return cnt
check = LeastArrow([[10,16], [2,8], [1,6], [7,12]])
print(check.ans())
堆
class MyHeap:
def __init__(self, arr=None):
self.arr = arr or []
self.length = len(self.arr)
self.build()
def build(self):
arr, length = self.arr, self.length
for idx in reversed(range(length)):
self.shiftDown(idx)
def top(self):
return self.arr[0]
def pop(self):
if self.length == 0:
return None
if self.length == 1:
self.length -= 1
return self.arr.pop(0)
self.length -= 1
result = self.arr[0]
self.arr[0] = self.arr.pop(-1)
self.shiftDown(0)
return result
def push(self, x):
self.arr.append(x)
self.length += 1
cur = self.length - 1
while cur:
father = (cur - 1) // 2
if self.arr[father] > self.arr[cur]:
self.swap(cur, father)
cur = father
else:
break
def empty(self):
return self.length == 0
def shiftDown(self, idx):
arr, length = self.arr, self.length
l_idx, r_idx = 2 * idx + 1, 2 * idx + 2
l_val, r_val = arr[l_idx] if l_idx < length else float('inf'), \
arr[r_idx] if r_idx < length else float('inf')
if min(arr[idx], l_val, r_val) == l_val:
self.swap(idx, l_idx)
self.shiftDown(l_idx)
elif min(arr[idx], l_val, r_val) == r_val:
self.swap(idx, r_idx)
self.shiftDown(r_idx)
def swap(self, src, dst):
tmp = self.arr[src]
self.arr[src] = self.arr[dst]
self.arr[dst] = tmp
# self.arr[src] = self.arr[src] ^ self.arr[dst]
# self.arr[dst] = self.arr[src] ^ self.arr[dst]
# self.arr[src] = self.arr[src] ^ self.arr[dst]
# self.arr[src],self.arr[dst] = self.arr[dst],self.arr[src]
# heap = MyHeap([5, 4, -9, 2, 0.1, -10, 999, 0, -9])
heap = MyHeap()
for x in [5, 4, -9, 2, 0.1, -10, 999, 0, -9]:
heap.push(x)
print(heap.arr)
while not heap.empty():
print(heap.pop())
PriorityQueue/Sort/TopK
# 排序模板,TopK模板,Python 堆模板
# Quick Sort
arr = [1, 3, 5, 1, -9, 3, -4, 10, 2]
print("Array:", arr)
import heapq
class Heap:
def __init__(self):
self.heap = []
def __len__(self):
return len(self.heap)
def empty(self):
return len(self.heap) == 0
def pop(self):
key, val = heapq.heappop(self.heap)
return val
def top(self):
return self.heap[0][1]
class MinHeap(Heap):
def push(self, x):
heapq.heappush(self.heap, (x, x))
class MaxHeap(Heap):
def push(self, x):
heapq.heappush(self.heap, (-x, x))
class Sort:
def __init__(self, arr):
self.arr = arr
@staticmethod
def qsort(arr):
def qsort_i_j(i, j):
if i >= j:
return
b, e = i, j
val = arr[i]
while i != j:
while arr[j] >= val and i < j:
j -= 1
arr[i] = arr[j]
while arr[i] <= val and i < j:
i += 1
arr[j] = arr[i]
arr[i] = val
qsort_i_j(b, i - 1)
qsort_i_j(i + 1, e)
return
qsort_i_j(0, len(arr) - 1)
return arr
# print("Sorted Array:",Sort.qsort(arr))
class TopK:
def __init__(self, arr, k):
self.arr, self.k = arr, k
@staticmethod
def by_heap(arr, k):
heap = MinHeap()
for i, ti in enumerate(arr[:k]):
heap.push(ti)
if i == k - 1:
break
for i, ti in enumerate(arr[k:]):
if heap.top() < ti:
heap.pop()
heap.push(ti)
return heap.top()
@staticmethod
def by_qsort(arr, k):
k = len(arr) - k # qsort idx from 0
origin = arr[:]
def qsort(i, j):
if i == j:
return arr[i]
if j < 0:
return arr[0]
if i > len(arr) - 1:
return arr[-1]
b, e = i, j
val = arr[i]
while i != j:
while arr[j] >= val and i < j:
j -= 1
arr[i] = arr[j]
while arr[i] <= val and i < j:
i += 1
arr[j] = arr[i]
arr[i] = val
if i == k:
return val
elif i < k:
return qsort(i + 1, e)
else:
return qsort(b, i - 1)
ans = qsort(0, len(arr) - 1)
arr = origin
return ans
print(TopK.by_heap(arr, k=3))
print(TopK.by_qsort(arr, k=3))
递归&非递归遍历二叉树
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
def __str__(self):
return self.val
# Tree Structure:
# 0
# / \
# 1 2
# / \ / \
# 3 4 5 6
# / \ \
# 7 8 9
class Tree:
def __init__(self):
nodes = [TreeNode(i) for i in range(0, 10)]
nodes[0].left = nodes[1]
nodes[0].right = nodes[2]
nodes[1].left = nodes[3]
nodes[1].right = nodes[4]
nodes[2].left = nodes[5]
nodes[2].right = nodes[6]
nodes[4].left = nodes[7]
nodes[4].right = nodes[8]
nodes[5].right = nodes[9]
self.root = nodes[0]
def recur_pre(self, root: TreeNode):
if root is None:
return []
return [root.val] + self.recur_pre(root.left) + self.recur_pre(root.right)
def recur_mid(self, root: TreeNode):
if root is None:
return []
return self.recur_mid(root.left) + [root.val] + self.recur_mid(root.right)
def recur_post(self, root: TreeNode):
if root is None:
return []
return self.recur_post(root.left) + self.recur_post(root.right) + [root.val]
def stk_pre(self, root: TreeNode):
stk, res = [], []
cur = root
while cur or len(stk):
while cur:
res.append(cur)
stk.append(cur)
cur = cur.left
if len(stk):
cur = stk.pop(-1).right
return [line.val for line in res]
def stk_mid(self, root: TreeNode):
stk, res = [], []
cur = root
while cur or len(stk):
if cur:
stk.append(cur)
cur = cur.left
else:
res.append(stk.pop(-1))
cur = res[-1].right
return [line.val for line in res]
def stk_post(self, root: TreeNode):
stk, res = [], []
cur = root
while cur or len(stk):
while cur:
stk.append((cur, 0))
cur = cur.left
if len(stk):
cur, times = stk.pop(-1)
if times == 0:
stk.append((cur, 1))
cur = cur.right
else:
res.append(cur)
cur = None
return [line.val for line in res]
check = Tree()
print("Pre Order:", check.recur_pre(check.root), check.recur_pre(check.root) == check.stk_pre(check.root))
print("Mid Order:", check.recur_mid(check.root), check.recur_mid(check.root) == check.stk_mid(check.root))
print("Post Order:", check.recur_post(check.root), check.recur_post(check.root) == check.stk_post(check.root))
字典序问题
建议使用trie树来解题 这里举出两个例子:输出字典序和输出第m个字典序
from functools import reduce
flatten = lambda z: list(reduce(lambda x, y: x + y, z))
valid = [str(i) for i in range(10)] # valid chars in permutation,such as [0-9],[a-z]
N = str(101) # permutations under this string(include itself)
def getPermutation(prefix): # O(n)
if (len(prefix) == len(N) and prefix > N) or \
len(prefix) > len(N):
return []
ans = [prefix]
for c in valid:
ans += getPermutation(prefix + c)
return ans
def getPermutationLength(prefix): # O(depth = log n)
if (len(prefix) == len(N) and prefix > N) or \
len(prefix) > len(N):
return 0
if len(prefix) == len(N):
return 1
ratio, div = len(valid), len(N) - len(prefix)
# 1, len(valid),len(valid)**2,...
# (ratio**div-1)/(ratio-1)
if prefix < N[:len(prefix)]:
return (ratio ** (div + 1) - 1) // (ratio - 1)
ans = (ratio ** div - 1) // (ratio - 1)
if prefix > N[:len(prefix)]:
return ans
for i, c in enumerate(reversed(N[len(prefix):])):
ans += ((ratio ** i) * valid.index(c))
ans += 1 # include N
return ans
def getKthPermutation(k):
if k <= 0:
return None
total_permutation = sum([getPermutationLength(x) for x in valid[1:]])
if k > total_permutation:
return None
idx = 1 # idx = 0
ans = valid[idx]
while k > 0:
cur = getPermutationLength(ans)
if cur < k:
k, idx = k - cur, idx + 1
ans = ans[:-1] + valid[idx]
else:
k, idx = k - 1, 0
if k == 0:
return ans
ans = ans + valid[idx]
return ans
print("============Get Permutations==============")
permutations = flatten([getPermutation(x) for x in valid[1:]])
print(permutations)
print("============Test getPermutationLength==============")
for x in valid[1:]:
print("prefix:", x, "golden:", len(getPermutation(x)), "predict:", getPermutationLength(x))
print("============Test getKthPermutation==============")
for k in [1, 2, 45, 56, 78, 99, 100, 101]:
golden = permutations[k - 1] if 1 <= k <= len(permutations) else None
print("Kth:", k, "golden:", golden, "predict:", getKthPermutation(k))
链表操作
主要集中在递归非递归实现反转链表,指定区间反转链表,k个一组反转链表等
class Node:
def __init__(self, x):
self.val = x
self.next = None
def __str__(self):
ans, ptr = [str(self.val)], self.next
while ptr is not None:
ans.append(str(ptr.val))
ptr = ptr.next
return "->".join(ans)
def getNodeList():
nodes = [Node(x) for x in range(1, 10)]
for i in range(len(nodes) - 1):
nodes[i].next = nodes[i + 1]
root = nodes[0]
return root
notReverseHead = None
def reverseNodeListN(root: Node, begin: int):
global notReverseHead
if begin == 1 or root.next is None:
notReverseHead = root.next
return root
head = reverseNodeListN(root.next, begin - 1)
root.next.next = root
root.next = notReverseHead
return head
def reverseNodeList(root: Node, begin: int, end: int):
if root is None or root.next is None:
return root
if begin <= 1:
return reverseNodeListN(root, end)
root.next = reverseNodeList(root.next, begin - 1, end - 1)
return root
def reverseNodeListByLoop(root: Node, begin: int, end: int):
if root is None or root.next is None:
return root
begin, end = begin - 1, end - 1 # just for handy
prev, cur, nxt = None, root, root.next
while nxt and begin:
begin, end = begin - 1, end - 1
prev, cur, nxt = cur, nxt, nxt.next
if begin:
return root
reverse_root = prev
while nxt and end:
prev, cur, nxt = cur, nxt, nxt.next
cur.next = prev
end = end - 1
if reverse_root is None:
root.next = nxt
return cur
reverse_root.next.next = nxt
reverse_root.next = cur
return root
def reverseKGroup(root: Node, k: int):
ptr = root
idx = 0
while idx < k:
idx += 1
if ptr is None:
return root
ptr = ptr.next
def rev(x: Node, k: int):
idx = 1
prev, cur, nxt = None, x, x.next
while idx < k:
idx += 1
cur.next = prev
prev, cur, nxt = cur, nxt, nxt.next
cur.next = prev
return cur, nxt
head, nxt = rev(root, k)
root.next = reverseKGroup(nxt, k)
return head
n1,n2,n3 = getNodeList(),getNodeList(),getNodeList()
print(" origin :", n1)
n1 = reverseNodeList(n1, 2, 8)
n2 = reverseNodeListByLoop(n2, 3, 7)
n3 = reverseKGroup(n3, k=4)
print("reversed n1: {}\nreversed n2: {}\nreversed n3: {}".format(n1,n2,n3))
K-Means
谱聚类那一块也可以看看,原理和GCN有关联。
谱聚类SVD分解后,最后还是要过一遍K-Means,就很玄学。
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
K = 3
c1, c2, c3 = np.random.randn(200, 2) + [1, 1], \
np.random.randn(200, 2) + [4, 4], \
np.random.randn(200, 2) + [7, 1]
data = np.concatenate((c1, c2, c3), axis=0)
train, test = train_test_split(data, test_size=0.3)
def init_centers(k):
batch, dim = train.shape
mean, std = np.mean(train, axis=0, keepdims=True), np.std(train, axis=0, keepdims=True)
centers = np.random.randn(k, dim) * std + mean
return centers
centers = init_centers(k=K)
dist = np.zeros((train.shape[0], K))
iter, move = 0, 0.0
cluster = np.zeros(train.shape[0])
max_iter, min_move = 64, 1e-3
while iter < max_iter and move > min_move:
iter += 1
for i in range(K):
dist[:, i] = np.linalg.norm(train - centers[i], axis=1)
cluster = np.argmin(dist, axis=1)
last_center = deepcopy(centers)
for i in range(K):
centers[i] = np.mean(train[cluster == i], axis=0)
move = np.linalg.norm(centers - last_center)
plt.scatter(train[:, 0], train[:, 1], alpha=0.5, c=cluster)
plt.scatter(centers[:, 0], centers[:, 1], marker='*', c='k')
plt.show()

浙公网安备 33010602011771号