从4x4变换矩阵中提取旋转轴

  1 #include <cassert>
  2 #include <cmath>
  3 #include <iostream>
  4 #include <vector>
  5 
  6 #ifndef M_PI
  7 #define M_PI 3.1415926
  8 #endif
  9 // 定义精度常量
 10 const double EPS = 1e-8;
 11 
 12 // 三维向量结构
 13 struct Vector3 {
 14     double x, y, z;
 15 
 16     Vector3(double x = 0, double y = 0, double z = 0) : x(x), y(y), z(z) {}
 17 
 18     // 向量加法
 19     Vector3 operator+(const Vector3 &other) const {
 20         return Vector3(x + other.x, y + other.y, z + other.z);
 21     }
 22 
 23     // 向量减法
 24     Vector3 operator-(const Vector3 &other) const {
 25         return Vector3(x - other.x, y - other.y, z - other.z);
 26     }
 27 
 28     // 向量数乘
 29     Vector3 operator*(double s) const { return Vector3(x * s, y * s, z * s); }
 30 
 31     // 点积
 32     double dot(const Vector3 &other) const {
 33         return x * other.x + y * other.y + z * other.z;
 34     }
 35 
 36     // 叉积
 37     Vector3 cross(const Vector3 &other) const {
 38         return Vector3(y * other.z - z * other.y, z * other.x - x * other.z,
 39                        x * other.y - y * other.x);
 40     }
 41 
 42     // 向量归一化
 43     Vector3 normalize() const {
 44         double len = std::sqrt(x * x + y * y + z * z);
 45         if (len > EPS) {
 46             return Vector3(x / len, y / len, z / len);
 47         }
 48         return *this;
 49     }
 50 
 51     // 打印向量
 52     void print(const std::string &name = "") const {
 53         if (!name.empty()) {
 54             std::cout << name << ": ";
 55         }
 56         std::cout << "(" << x << ", " << y << ", " << z << ")" << std::endl;
 57     }
 58 
 59     // 计算向量范数平方
 60     double normSquared() const { return x * x + y * y + z * z; }
 61 };
 62 
 63 // 3x3矩阵结构
 64 struct Matrix3x3 {
 65     double m[3][3];
 66 
 67     Matrix3x3() {
 68         // 初始化为单位矩阵
 69         for (int i = 0; i < 3; i++) {
 70             for (int j = 0; j < 3; j++) {
 71                 m[i][j] = (i == j) ? 1.0 : 0.0;
 72             }
 73         }
 74     }
 75 
 76     // 矩阵与向量乘法
 77     Vector3 multiply(const Vector3 &v) const {
 78         return Vector3(m[0][0] * v.x + m[0][1] * v.y + m[0][2] * v.z,
 79                        m[1][0] * v.x + m[1][1] * v.y + m[1][2] * v.z,
 80                        m[2][0] * v.x + m[2][1] * v.y + m[2][2] * v.z);
 81     }
 82 
 83     // 矩阵减法
 84     Matrix3x3 operator-(const Matrix3x3 &other) const {
 85         Matrix3x3 result;
 86         for (int i = 0; i < 3; i++) {
 87             for (int j = 0; j < 3; j++) {
 88                 result.m[i][j] = m[i][j] - other.m[i][j];
 89             }
 90         }
 91         return result;
 92     }
 93 
 94     // 计算矩阵Frobenius范数平方
 95     double normSquared() const {
 96         double sum = 0;
 97         for (int i = 0; i < 3; i++) {
 98             for (int j = 0; j < 3; j++) {
 99                 sum += m[i][j] * m[i][j];
100             }
101         }
102         return sum;
103     }
104 
105     // 打印矩阵
106     void print(const std::string &name = "") const {
107         if (!name.empty()) {
108             std::cout << name << ":" << std::endl;
109         }
110         for (int i = 0; i < 3; i++) {
111             for (int j = 0; j < 3; j++) {
112                 std::cout << m[i][j] << "\t";
113             }
114             std::cout << std::endl;
115         }
116     }
117 };
118 
119 // 4x4矩阵结构
120 struct Matrix4x4 {
121     double m[4][4];
122 
123     Matrix4x4() {
124         // 初始化为单位矩阵
125         for (int i = 0; i < 4; i++) {
126             for (int j = 0; j < 4; j++) {
127                 m[i][j] = (i == j) ? 1.0 : 0.0;
128             }
129         }
130     }
131 
132     // 从4x4矩阵中提取3x3旋转矩阵
133     Matrix3x3 getRotationMatrix() const {
134         Matrix3x3 result;
135         for (int i = 0; i < 3; i++) {
136             for (int j = 0; j < 3; j++) {
137                 result.m[i][j] = m[i][j];
138             }
139         }
140         return result;
141     }
142 
143     // 从4x4矩阵中提取平移向量
144     Vector3 getTranslationVector() const {
145         return Vector3(m[0][3], m[1][3], m[2][3]);
146     }
147 
148     // 打印矩阵
149     void print(const std::string &name = "") const {
150         if (!name.empty()) {
151             std::cout << name << ":" << std::endl;
152         }
153         for (int i = 0; i < 4; i++) {
154             for (int j = 0; j < 4; j++) {
155                 std::cout << m[i][j] << "\t";
156             }
157             std::cout << std::endl;
158         }
159     }
160 };
161 
162 // 使用叉积法直接计算旋转轴方向
163 Vector3 computeRotationAxisDirection(const Matrix3x3 &R) {
164     // 计算R-I
165     Matrix3x3 R_minus_I = R - Matrix3x3();
166 
167     // 获取列向量
168     Vector3 col0(R_minus_I.m[0][0], R_minus_I.m[1][0], R_minus_I.m[2][0]);
169     Vector3 col1(R_minus_I.m[0][1], R_minus_I.m[1][1], R_minus_I.m[2][1]);
170     Vector3 col2(R_minus_I.m[0][2], R_minus_I.m[1][2], R_minus_I.m[2][2]);
171 
172     // 计算列向量的范数平方
173     double n0 = col0.normSquared();
174     double n1 = col1.normSquared();
175     double n2 = col2.normSquared();
176 
177     Vector3 axis_dir;
178     double max_norm = 0;
179     bool found = false;
180 
181     // 尝试所有列组合的叉积
182     if (n0 > EPS && n1 > EPS) {
183         axis_dir = col0.cross(col1);
184         double norm_sq = axis_dir.normSquared();
185         if (norm_sq > max_norm) {
186             max_norm = norm_sq;
187             found = true;
188         }
189     }
190 
191     if (n0 > EPS && n2 > EPS) {
192         Vector3 candidate = col0.cross(col2);
193         double norm_sq = candidate.normSquared();
194         if (norm_sq > max_norm) {
195             axis_dir = candidate;
196             max_norm = norm_sq;
197             found = true;
198         }
199     }
200 
201     if (n1 > EPS && n2 > EPS) {
202         Vector3 candidate = col1.cross(col2);
203         double norm_sq = candidate.normSquared();
204         if (norm_sq > max_norm) {
205             axis_dir = candidate;
206             max_norm = norm_sq;
207             found = true;
208         }
209     }
210 
211     // 如果没有找到有效的叉积,使用默认方向
212     if (!found || max_norm < EPS) {
213         // 可能无旋转或旋转角很小,返回默认方向
214         return Vector3(1, 0, 0);
215     }
216 
217     return axis_dir.normalize();
218 }
219 
220 // 从4x4变换矩阵中提取旋转轴方向向量
221 Vector3 extractRotationAxisDirection(const Matrix4x4 &transform) {
222     // 提取旋转矩阵R
223     Matrix3x3 R = transform.getRotationMatrix();
224 
225     // 检查是否无旋转
226     if ((R - Matrix3x3()).normSquared() < EPS) {
227         return Vector3(1, 0, 0); // 默认方向
228     }
229 
230     return computeRotationAxisDirection(R);
231 }
232 
233 // 从4x4变换矩阵中提取不过原点的旋转轴(返回轴上一点和方向向量)
234 bool extractRotationAxis(const Matrix4x4 &transform, Vector3 &axis_point,
235                          Vector3 &axis_dir) {
236     // 提取旋转矩阵R和平移向量t
237     Matrix3x3 R = transform.getRotationMatrix();
238     Vector3 t = transform.getTranslationVector();
239 
240     // 检查是否无旋转
241     Matrix3x3 R_minus_I = R - Matrix3x3();
242     if (R_minus_I.normSquared() < EPS) {
243         axis_dir = Vector3(1, 0, 0);
244         axis_point = Vector3(0, 0, 0);
245         // 检查平移是否为零
246         if (t.normSquared() > EPS) {
247             // 纯平移变换,没有固定旋转轴
248             return false;
249         }
250         return true;
251     }
252 
253     // 计算旋转轴方向向量
254     axis_dir = computeRotationAxisDirection(R);
255 
256     // 构造与旋转轴方向正交的基
257     Vector3 u = axis_dir.normalize();
258     Vector3 u1;
259 
260     // 选择一个与u不平行的基础向量
261     if (std::abs(u.x) < 0.9) {
262         u1 = Vector3(1, 0, 0);
263     } else {
264         u1 = Vector3(0, 1, 0);
265     }
266 
267     // Gram-Schmidt正交化
268     u1 = u1 - u * u.dot(u1);
269     u1 = u1.normalize();
270     Vector3 u2 = u.cross(u1).normalize();
271 
272     // 计算(R-I)在正交基上的投影
273     Vector3 A_u1 = R_minus_I.multiply(u1);
274     Vector3 A_u2 = R_minus_I.multiply(u2);
275 
276     // 构造2x2方程组系数
277     double a11 = u1.dot(A_u1);
278     double a12 = u1.dot(A_u2);
279     double a21 = u2.dot(A_u1);
280     double a22 = u2.dot(A_u2);
281 
282     // 构造右端项: -t 在正交基上的投影
283     double b1 = -u1.dot(t);
284     double b2 = -u2.dot(t);
285 
286     // 解2x2方程组: [a11 a12; a21 a22] * [x; y] = [b1; b2]
287     double det = a11 * a22 - a12 * a21;
288     if (std::abs(det) < EPS) {
289         // 方程组奇异,使用原点作为轴上一点
290         axis_point = Vector3(0, 0, 0);
291     } else {
292         double x = (a22 * b1 - a12 * b2) / det;
293         double y = (a11 * b2 - a21 * b1) / det;
294         axis_point = u1 * x + u2 * y;
295     }
296 
297     return true;
298 }
299 
300 // 测试函数:创建一个绕指定轴旋转的4x4变换矩阵
301 Matrix4x4 createRotationAroundAxis(const Vector3 &axis_point,
302                                    const Vector3 &axis_dir, double angle) {
303     Matrix4x4 mat;
304 
305     // 轴方向归一化
306     Vector3 u = axis_dir.normalize();
307 
308     double cos_theta = std::cos(angle);
309     double sin_theta = std::sin(angle);
310     double one_minus_cos = 1.0 - cos_theta;
311 
312     // 设置旋转矩阵部分
313     mat.m[0][0] = cos_theta + u.x * u.x * one_minus_cos;
314     mat.m[0][1] = u.x * u.y * one_minus_cos - u.z * sin_theta;
315     mat.m[0][2] = u.x * u.z * one_minus_cos + u.y * sin_theta;
316 
317     mat.m[1][0] = u.y * u.x * one_minus_cos + u.z * sin_theta;
318     mat.m[1][1] = cos_theta + u.y * u.y * one_minus_cos;
319     mat.m[1][2] = u.y * u.z * one_minus_cos - u.x * sin_theta;
320 
321     mat.m[2][0] = u.z * u.x * one_minus_cos - u.y * sin_theta;
322     mat.m[2][1] = u.z * u.y * one_minus_cos + u.x * sin_theta;
323     mat.m[2][2] = cos_theta + u.z * u.z * one_minus_cos;
324 
325     // 计算平移部分
326     Vector3 R_p0(mat.m[0][0] * axis_point.x + mat.m[0][1] * axis_point.y +
327                          mat.m[0][2] * axis_point.z,
328                  mat.m[1][0] * axis_point.x + mat.m[1][1] * axis_point.y +
329                          mat.m[1][2] * axis_point.z,
330                  mat.m[2][0] * axis_point.x + mat.m[2][1] * axis_point.y +
331                          mat.m[2][2] * axis_point.z);
332 
333     Vector3 t = axis_point - R_p0;
334     mat.m[0][3] = t.x;
335     mat.m[1][3] = t.y;
336     mat.m[2][3] = t.z;
337 
338     // 最后一行保持不变
339     mat.m[3][0] = 0;
340     mat.m[3][1] = 0;
341     mat.m[3][2] = 0;
342     mat.m[3][3] = 1;
343 
344     return mat;
345 }
346 
347 // 计算点到直线的距离平方
348 double pointToLineDistanceSquared(const Vector3 &point,
349                                   const Vector3 &line_point,
350                                   const Vector3 &line_dir) {
351     Vector3 v = point - line_point;
352     Vector3 cross = v.cross(line_dir);
353     return cross.normSquared() / line_dir.normSquared();
354 }
355 
356 int main0() {
357     // 测试用例1:一般情况
358     {
359         std::cout << "===== 测试用例1: 一般情况 =====" << std::endl;
360         // 定义旋转轴:过点(1, 2, 3),方向向量为(1, 1, 1)
361         Vector3 original_point(17.0710678118655, 0.0, -1.21320343559642);
362         Vector3 original_dir(0.0, -1.0, 0.0);
363         double angle = M_PI / 4; // 45度旋转
364 
365         // 创建变换矩阵
366         Matrix4x4 transform =
367                 createRotationAroundAxis(original_point, original_dir, angle);
368         std::cout << "创建的变换矩阵:" << std::endl;
369         transform.print();
370 
371 
372         // 提取旋转轴
373         Vector3 extracted_point, extracted_dir;
374         bool success =
375                 extractRotationAxis(transform, extracted_point, extracted_dir);
376 
377         if (success) {
378             std::cout << "\n提取的旋转轴信息:" << std::endl;
379             extracted_point.print("轴上一点");
380             extracted_dir.print("方向向量");
381 
382             // 验证方向向量是否正确(应该与原方向共线)
383             Vector3 normalized_original = original_dir.normalize();
384             Vector3 normalized_extracted = extracted_dir.normalize();
385             double dot_product = normalized_original.dot(normalized_extracted);
386             std::cout << "\n方向向量点积(应接近±1):" << std::abs(dot_product)
387                       << std::endl;
388 
389             // 验证提取的点是否在原轴上
390             double distance_sq = pointToLineDistanceSquared(
391                     extracted_point, original_point, normalized_original);
392             std::cout << "提取点到原轴的距离平方(应接近0):" << distance_sq
393                       << std::endl;
394         } else {
395             std::cout << "\n提取旋转轴失败!" << std::endl;
396         }
397         std::cout << "\n\n";
398     }
399 
400     // 测试用例2:轴过原点
401    /* {
402         std::cout << "===== 测试用例2: 轴过原点 =====" << std::endl;
403         // 定义旋转轴:过点(0, 0, 0),方向向量为(0, 1, 0)
404         Vector3 original_point(0.0, 0.0, 0.0);
405         Vector3 original_dir(0.0, 1.0, 0.0);
406         double angle = M_PI / 3; // 60度旋转
407 
408         // 创建变换矩阵
409         Matrix4x4 transform =
410                 createRotationAroundAxis(original_point, original_dir, angle);
411         std::cout << "创建的变换矩阵:" << std::endl;
412         transform.print();
413 
414         // 提取旋转轴
415         Vector3 extracted_point, extracted_dir;
416         bool success =
417                 extractRotationAxis(transform, extracted_point, extracted_dir);
418 
419         if (success) {
420             std::cout << "\n提取的旋转轴信息:" << std::endl;
421             extracted_point.print("轴上一点");
422             extracted_dir.print("方向向量");
423 
424             // 验证方向向量是否正确
425             Vector3 normalized_original = original_dir.normalize();
426             Vector3 normalized_extracted = extracted_dir.normalize();
427             double dot_product = normalized_original.dot(normalized_extracted);
428             std::cout << "\n方向向量点积(应接近±1):" << std::abs(dot_product)
429                       << std::endl;
430 
431             // 验证提取的点是否在轴上(应接近原点)
432             double distance_to_origin = extracted_point.normSquared();
433             std::cout << "提取点到原点的距离平方(应接近0):"
434                       << distance_to_origin << std::endl;
435         } else {
436             std::cout << "\n提取旋转轴失败!" << std::endl;
437         }
438         std::cout << "\n\n";
439     }
440 
441     // 测试用例3:无旋转(纯平移)
442     {
443         std::cout << "===== 测试用例3: 纯平移变换 =====" << std::endl;
444         Matrix4x4 transform;
445         // 设置平移
446         transform.m[0][3] = 1.0;
447         transform.m[1][3] = 2.0;
448         transform.m[2][3] = 3.0;
449 
450         std::cout << "创建的变换矩阵:" << std::endl;
451         transform.print();
452 
453         // 提取旋转轴
454         Vector3 extracted_point, extracted_dir;
455         bool success =
456                 extractRotationAxis(transform, extracted_point, extracted_dir);
457 
458         if (success) {
459             std::cout << "\n提取的旋转轴信息:" << std::endl;
460             extracted_point.print("轴上一点");
461             extracted_dir.print("方向向量");
462             std::cout << "\n注意:纯平移变换没有固定旋转轴,返回默认值"
463                       << std::endl;
464         } else {
465             std::cout << "\n提取旋转轴失败(符合预期)!" << std::endl;
466         }
467         std::cout << "\n\n";
468     }
469 
470     // 测试用例4:无旋转无平移(单位矩阵)
471     {
472         std::cout << "===== 测试用例4: 单位矩阵 =====" << std::endl;
473         Matrix4x4 transform;
474 
475         std::cout << "创建的变换矩阵:" << std::endl;
476         transform.print();
477 
478         // 提取旋转轴
479         Vector3 extracted_point, extracted_dir;
480         bool success =
481                 extractRotationAxis(transform, extracted_point, extracted_dir);
482 
483         if (success) {
484             std::cout << "\n提取的旋转轴信息:" << std::endl;
485             extracted_point.print("轴上一点");
486             extracted_dir.print("方向向量");
487             std::cout << "\n注意:单位矩阵返回默认旋转轴" << std::endl;
488         } else {
489             std::cout << "\n提取旋转轴失败!" << std::endl;
490         }
491     }*/
492 
493     return 0;
494 }

 

posted @ 2025-08-05 14:13  禅元天道  阅读(9)  评论(0)    收藏  举报