import numpy as np
import matplotlib.pyplot as plt
# ---------------------------------------
# 定义 Swish 和 SwiGLU 激活函数
# ---------------------------------------
def swish(x):
"""Swish activation: x * sigmoid(x)"""
return x / (1 + np.exp(-x))
def swiglu(x):
"""Simplified 1D SwiGLU: x * Swish(x)"""
return x * swish(x)
# ---------------------------------------
# 生成数据
# ---------------------------------------
x = np.linspace(-6, 6, 400)
y_swish = swish(x)
y_swiglu = swiglu(x)
# ---------------------------------------
# 绘制曲线
# ---------------------------------------
plt.figure(figsize=(8, 5))
plt.plot(x, y_swish, label="Swish", color="orange", linewidth=2)
plt.plot(x, y_swiglu, label="SwiGLU", color="blue", linewidth=2)
plt.axhline(0, color="black", linestyle="--", linewidth=0.8)
plt.axvline(0, color="black", linestyle="--", linewidth=0.8)
plt.title("Swish vs. SwiGLU Activation Functions (1D)", fontsize=14)
plt.xlabel("Input (x)", fontsize=12)
plt.ylabel("Output", fontsize=12)
plt.legend()
plt.grid(True, linestyle="--", alpha=0.6)
plt.tight_layout()
plt.show()
![image]()