1 #include <cassert>
2 #include <cmath>
3 #include <iostream>
4 #include <vector>
5
6 #ifndef M_PI
7 #define M_PI 3.1415926
8 #endif
9 // 定义精度常量
10 const double EPS = 1e-8;
11
12 // 三维向量结构
13 struct Vector3 {
14 double x, y, z;
15
16 Vector3(double x = 0, double y = 0, double z = 0) : x(x), y(y), z(z) {}
17
18 // 向量加法
19 Vector3 operator+(const Vector3 &other) const {
20 return Vector3(x + other.x, y + other.y, z + other.z);
21 }
22
23 // 向量减法
24 Vector3 operator-(const Vector3 &other) const {
25 return Vector3(x - other.x, y - other.y, z - other.z);
26 }
27
28 // 向量数乘
29 Vector3 operator*(double s) const { return Vector3(x * s, y * s, z * s); }
30
31 // 点积
32 double dot(const Vector3 &other) const {
33 return x * other.x + y * other.y + z * other.z;
34 }
35
36 // 叉积
37 Vector3 cross(const Vector3 &other) const {
38 return Vector3(y * other.z - z * other.y, z * other.x - x * other.z,
39 x * other.y - y * other.x);
40 }
41
42 // 向量归一化
43 Vector3 normalize() const {
44 double len = std::sqrt(x * x + y * y + z * z);
45 if (len > EPS) {
46 return Vector3(x / len, y / len, z / len);
47 }
48 return *this;
49 }
50
51 // 打印向量
52 void print(const std::string &name = "") const {
53 if (!name.empty()) {
54 std::cout << name << ": ";
55 }
56 std::cout << "(" << x << ", " << y << ", " << z << ")" << std::endl;
57 }
58
59 // 计算向量范数平方
60 double normSquared() const { return x * x + y * y + z * z; }
61 };
62
63 // 3x3矩阵结构
64 struct Matrix3x3 {
65 double m[3][3];
66
67 Matrix3x3() {
68 // 初始化为单位矩阵
69 for (int i = 0; i < 3; i++) {
70 for (int j = 0; j < 3; j++) {
71 m[i][j] = (i == j) ? 1.0 : 0.0;
72 }
73 }
74 }
75
76 // 矩阵与向量乘法
77 Vector3 multiply(const Vector3 &v) const {
78 return Vector3(m[0][0] * v.x + m[0][1] * v.y + m[0][2] * v.z,
79 m[1][0] * v.x + m[1][1] * v.y + m[1][2] * v.z,
80 m[2][0] * v.x + m[2][1] * v.y + m[2][2] * v.z);
81 }
82
83 // 矩阵减法
84 Matrix3x3 operator-(const Matrix3x3 &other) const {
85 Matrix3x3 result;
86 for (int i = 0; i < 3; i++) {
87 for (int j = 0; j < 3; j++) {
88 result.m[i][j] = m[i][j] - other.m[i][j];
89 }
90 }
91 return result;
92 }
93
94 // 计算矩阵Frobenius范数平方
95 double normSquared() const {
96 double sum = 0;
97 for (int i = 0; i < 3; i++) {
98 for (int j = 0; j < 3; j++) {
99 sum += m[i][j] * m[i][j];
100 }
101 }
102 return sum;
103 }
104
105 // 打印矩阵
106 void print(const std::string &name = "") const {
107 if (!name.empty()) {
108 std::cout << name << ":" << std::endl;
109 }
110 for (int i = 0; i < 3; i++) {
111 for (int j = 0; j < 3; j++) {
112 std::cout << m[i][j] << "\t";
113 }
114 std::cout << std::endl;
115 }
116 }
117 };
118
119 // 4x4矩阵结构
120 struct Matrix4x4 {
121 double m[4][4];
122
123 Matrix4x4() {
124 // 初始化为单位矩阵
125 for (int i = 0; i < 4; i++) {
126 for (int j = 0; j < 4; j++) {
127 m[i][j] = (i == j) ? 1.0 : 0.0;
128 }
129 }
130 }
131
132 // 从4x4矩阵中提取3x3旋转矩阵
133 Matrix3x3 getRotationMatrix() const {
134 Matrix3x3 result;
135 for (int i = 0; i < 3; i++) {
136 for (int j = 0; j < 3; j++) {
137 result.m[i][j] = m[i][j];
138 }
139 }
140 return result;
141 }
142
143 // 从4x4矩阵中提取平移向量
144 Vector3 getTranslationVector() const {
145 return Vector3(m[0][3], m[1][3], m[2][3]);
146 }
147
148 // 打印矩阵
149 void print(const std::string &name = "") const {
150 if (!name.empty()) {
151 std::cout << name << ":" << std::endl;
152 }
153 for (int i = 0; i < 4; i++) {
154 for (int j = 0; j < 4; j++) {
155 std::cout << m[i][j] << "\t";
156 }
157 std::cout << std::endl;
158 }
159 }
160 };
161
162 // 使用叉积法直接计算旋转轴方向
163 Vector3 computeRotationAxisDirection(const Matrix3x3 &R) {
164 // 计算R-I
165 Matrix3x3 R_minus_I = R - Matrix3x3();
166
167 // 获取列向量
168 Vector3 col0(R_minus_I.m[0][0], R_minus_I.m[1][0], R_minus_I.m[2][0]);
169 Vector3 col1(R_minus_I.m[0][1], R_minus_I.m[1][1], R_minus_I.m[2][1]);
170 Vector3 col2(R_minus_I.m[0][2], R_minus_I.m[1][2], R_minus_I.m[2][2]);
171
172 // 计算列向量的范数平方
173 double n0 = col0.normSquared();
174 double n1 = col1.normSquared();
175 double n2 = col2.normSquared();
176
177 Vector3 axis_dir;
178 double max_norm = 0;
179 bool found = false;
180
181 // 尝试所有列组合的叉积
182 if (n0 > EPS && n1 > EPS) {
183 axis_dir = col0.cross(col1);
184 double norm_sq = axis_dir.normSquared();
185 if (norm_sq > max_norm) {
186 max_norm = norm_sq;
187 found = true;
188 }
189 }
190
191 if (n0 > EPS && n2 > EPS) {
192 Vector3 candidate = col0.cross(col2);
193 double norm_sq = candidate.normSquared();
194 if (norm_sq > max_norm) {
195 axis_dir = candidate;
196 max_norm = norm_sq;
197 found = true;
198 }
199 }
200
201 if (n1 > EPS && n2 > EPS) {
202 Vector3 candidate = col1.cross(col2);
203 double norm_sq = candidate.normSquared();
204 if (norm_sq > max_norm) {
205 axis_dir = candidate;
206 max_norm = norm_sq;
207 found = true;
208 }
209 }
210
211 // 如果没有找到有效的叉积,使用默认方向
212 if (!found || max_norm < EPS) {
213 // 可能无旋转或旋转角很小,返回默认方向
214 return Vector3(1, 0, 0);
215 }
216
217 return axis_dir.normalize();
218 }
219
220 // 从4x4变换矩阵中提取旋转轴方向向量
221 Vector3 extractRotationAxisDirection(const Matrix4x4 &transform) {
222 // 提取旋转矩阵R
223 Matrix3x3 R = transform.getRotationMatrix();
224
225 // 检查是否无旋转
226 if ((R - Matrix3x3()).normSquared() < EPS) {
227 return Vector3(1, 0, 0); // 默认方向
228 }
229
230 return computeRotationAxisDirection(R);
231 }
232
233 // 从4x4变换矩阵中提取不过原点的旋转轴(返回轴上一点和方向向量)
234 bool extractRotationAxis(const Matrix4x4 &transform, Vector3 &axis_point,
235 Vector3 &axis_dir) {
236 // 提取旋转矩阵R和平移向量t
237 Matrix3x3 R = transform.getRotationMatrix();
238 Vector3 t = transform.getTranslationVector();
239
240 // 检查是否无旋转
241 Matrix3x3 R_minus_I = R - Matrix3x3();
242 if (R_minus_I.normSquared() < EPS) {
243 axis_dir = Vector3(1, 0, 0);
244 axis_point = Vector3(0, 0, 0);
245 // 检查平移是否为零
246 if (t.normSquared() > EPS) {
247 // 纯平移变换,没有固定旋转轴
248 return false;
249 }
250 return true;
251 }
252
253 // 计算旋转轴方向向量
254 axis_dir = computeRotationAxisDirection(R);
255
256 // 构造与旋转轴方向正交的基
257 Vector3 u = axis_dir.normalize();
258 Vector3 u1;
259
260 // 选择一个与u不平行的基础向量
261 if (std::abs(u.x) < 0.9) {
262 u1 = Vector3(1, 0, 0);
263 } else {
264 u1 = Vector3(0, 1, 0);
265 }
266
267 // Gram-Schmidt正交化
268 u1 = u1 - u * u.dot(u1);
269 u1 = u1.normalize();
270 Vector3 u2 = u.cross(u1).normalize();
271
272 // 计算(R-I)在正交基上的投影
273 Vector3 A_u1 = R_minus_I.multiply(u1);
274 Vector3 A_u2 = R_minus_I.multiply(u2);
275
276 // 构造2x2方程组系数
277 double a11 = u1.dot(A_u1);
278 double a12 = u1.dot(A_u2);
279 double a21 = u2.dot(A_u1);
280 double a22 = u2.dot(A_u2);
281
282 // 构造右端项: -t 在正交基上的投影
283 double b1 = -u1.dot(t);
284 double b2 = -u2.dot(t);
285
286 // 解2x2方程组: [a11 a12; a21 a22] * [x; y] = [b1; b2]
287 double det = a11 * a22 - a12 * a21;
288 if (std::abs(det) < EPS) {
289 // 方程组奇异,使用原点作为轴上一点
290 axis_point = Vector3(0, 0, 0);
291 } else {
292 double x = (a22 * b1 - a12 * b2) / det;
293 double y = (a11 * b2 - a21 * b1) / det;
294 axis_point = u1 * x + u2 * y;
295 }
296
297 return true;
298 }
299
300 // 测试函数:创建一个绕指定轴旋转的4x4变换矩阵
301 Matrix4x4 createRotationAroundAxis(const Vector3 &axis_point,
302 const Vector3 &axis_dir, double angle) {
303 Matrix4x4 mat;
304
305 // 轴方向归一化
306 Vector3 u = axis_dir.normalize();
307
308 double cos_theta = std::cos(angle);
309 double sin_theta = std::sin(angle);
310 double one_minus_cos = 1.0 - cos_theta;
311
312 // 设置旋转矩阵部分
313 mat.m[0][0] = cos_theta + u.x * u.x * one_minus_cos;
314 mat.m[0][1] = u.x * u.y * one_minus_cos - u.z * sin_theta;
315 mat.m[0][2] = u.x * u.z * one_minus_cos + u.y * sin_theta;
316
317 mat.m[1][0] = u.y * u.x * one_minus_cos + u.z * sin_theta;
318 mat.m[1][1] = cos_theta + u.y * u.y * one_minus_cos;
319 mat.m[1][2] = u.y * u.z * one_minus_cos - u.x * sin_theta;
320
321 mat.m[2][0] = u.z * u.x * one_minus_cos - u.y * sin_theta;
322 mat.m[2][1] = u.z * u.y * one_minus_cos + u.x * sin_theta;
323 mat.m[2][2] = cos_theta + u.z * u.z * one_minus_cos;
324
325 // 计算平移部分
326 Vector3 R_p0(mat.m[0][0] * axis_point.x + mat.m[0][1] * axis_point.y +
327 mat.m[0][2] * axis_point.z,
328 mat.m[1][0] * axis_point.x + mat.m[1][1] * axis_point.y +
329 mat.m[1][2] * axis_point.z,
330 mat.m[2][0] * axis_point.x + mat.m[2][1] * axis_point.y +
331 mat.m[2][2] * axis_point.z);
332
333 Vector3 t = axis_point - R_p0;
334 mat.m[0][3] = t.x;
335 mat.m[1][3] = t.y;
336 mat.m[2][3] = t.z;
337
338 // 最后一行保持不变
339 mat.m[3][0] = 0;
340 mat.m[3][1] = 0;
341 mat.m[3][2] = 0;
342 mat.m[3][3] = 1;
343
344 return mat;
345 }
346
347 // 计算点到直线的距离平方
348 double pointToLineDistanceSquared(const Vector3 &point,
349 const Vector3 &line_point,
350 const Vector3 &line_dir) {
351 Vector3 v = point - line_point;
352 Vector3 cross = v.cross(line_dir);
353 return cross.normSquared() / line_dir.normSquared();
354 }
355
356 int main0() {
357 // 测试用例1:一般情况
358 {
359 std::cout << "===== 测试用例1: 一般情况 =====" << std::endl;
360 // 定义旋转轴:过点(1, 2, 3),方向向量为(1, 1, 1)
361 Vector3 original_point(17.0710678118655, 0.0, -1.21320343559642);
362 Vector3 original_dir(0.0, -1.0, 0.0);
363 double angle = M_PI / 4; // 45度旋转
364
365 // 创建变换矩阵
366 Matrix4x4 transform =
367 createRotationAroundAxis(original_point, original_dir, angle);
368 std::cout << "创建的变换矩阵:" << std::endl;
369 transform.print();
370
371
372 // 提取旋转轴
373 Vector3 extracted_point, extracted_dir;
374 bool success =
375 extractRotationAxis(transform, extracted_point, extracted_dir);
376
377 if (success) {
378 std::cout << "\n提取的旋转轴信息:" << std::endl;
379 extracted_point.print("轴上一点");
380 extracted_dir.print("方向向量");
381
382 // 验证方向向量是否正确(应该与原方向共线)
383 Vector3 normalized_original = original_dir.normalize();
384 Vector3 normalized_extracted = extracted_dir.normalize();
385 double dot_product = normalized_original.dot(normalized_extracted);
386 std::cout << "\n方向向量点积(应接近±1):" << std::abs(dot_product)
387 << std::endl;
388
389 // 验证提取的点是否在原轴上
390 double distance_sq = pointToLineDistanceSquared(
391 extracted_point, original_point, normalized_original);
392 std::cout << "提取点到原轴的距离平方(应接近0):" << distance_sq
393 << std::endl;
394 } else {
395 std::cout << "\n提取旋转轴失败!" << std::endl;
396 }
397 std::cout << "\n\n";
398 }
399
400 // 测试用例2:轴过原点
401 /* {
402 std::cout << "===== 测试用例2: 轴过原点 =====" << std::endl;
403 // 定义旋转轴:过点(0, 0, 0),方向向量为(0, 1, 0)
404 Vector3 original_point(0.0, 0.0, 0.0);
405 Vector3 original_dir(0.0, 1.0, 0.0);
406 double angle = M_PI / 3; // 60度旋转
407
408 // 创建变换矩阵
409 Matrix4x4 transform =
410 createRotationAroundAxis(original_point, original_dir, angle);
411 std::cout << "创建的变换矩阵:" << std::endl;
412 transform.print();
413
414 // 提取旋转轴
415 Vector3 extracted_point, extracted_dir;
416 bool success =
417 extractRotationAxis(transform, extracted_point, extracted_dir);
418
419 if (success) {
420 std::cout << "\n提取的旋转轴信息:" << std::endl;
421 extracted_point.print("轴上一点");
422 extracted_dir.print("方向向量");
423
424 // 验证方向向量是否正确
425 Vector3 normalized_original = original_dir.normalize();
426 Vector3 normalized_extracted = extracted_dir.normalize();
427 double dot_product = normalized_original.dot(normalized_extracted);
428 std::cout << "\n方向向量点积(应接近±1):" << std::abs(dot_product)
429 << std::endl;
430
431 // 验证提取的点是否在轴上(应接近原点)
432 double distance_to_origin = extracted_point.normSquared();
433 std::cout << "提取点到原点的距离平方(应接近0):"
434 << distance_to_origin << std::endl;
435 } else {
436 std::cout << "\n提取旋转轴失败!" << std::endl;
437 }
438 std::cout << "\n\n";
439 }
440
441 // 测试用例3:无旋转(纯平移)
442 {
443 std::cout << "===== 测试用例3: 纯平移变换 =====" << std::endl;
444 Matrix4x4 transform;
445 // 设置平移
446 transform.m[0][3] = 1.0;
447 transform.m[1][3] = 2.0;
448 transform.m[2][3] = 3.0;
449
450 std::cout << "创建的变换矩阵:" << std::endl;
451 transform.print();
452
453 // 提取旋转轴
454 Vector3 extracted_point, extracted_dir;
455 bool success =
456 extractRotationAxis(transform, extracted_point, extracted_dir);
457
458 if (success) {
459 std::cout << "\n提取的旋转轴信息:" << std::endl;
460 extracted_point.print("轴上一点");
461 extracted_dir.print("方向向量");
462 std::cout << "\n注意:纯平移变换没有固定旋转轴,返回默认值"
463 << std::endl;
464 } else {
465 std::cout << "\n提取旋转轴失败(符合预期)!" << std::endl;
466 }
467 std::cout << "\n\n";
468 }
469
470 // 测试用例4:无旋转无平移(单位矩阵)
471 {
472 std::cout << "===== 测试用例4: 单位矩阵 =====" << std::endl;
473 Matrix4x4 transform;
474
475 std::cout << "创建的变换矩阵:" << std::endl;
476 transform.print();
477
478 // 提取旋转轴
479 Vector3 extracted_point, extracted_dir;
480 bool success =
481 extractRotationAxis(transform, extracted_point, extracted_dir);
482
483 if (success) {
484 std::cout << "\n提取的旋转轴信息:" << std::endl;
485 extracted_point.print("轴上一点");
486 extracted_dir.print("方向向量");
487 std::cout << "\n注意:单位矩阵返回默认旋转轴" << std::endl;
488 } else {
489 std::cout << "\n提取旋转轴失败!" << std::endl;
490 }
491 }*/
492
493 return 0;
494 }