OCS2::MPC 启动流程

1. 创建MPC_ROS_Interface接口,以sqpMpc为例

//自定义接口
LeggedRobotInterface interface(taskFile, urdfFile, referenceFile);
// 创建同步接口
auto gaitReceiverPtr =
      std::make_shared<GaitReceiver>(nodeHandle, interface.getSwitchedModelReferenceManagerPtr()->getGaitSchedule(), robotName);

// 创建参考轨迹接口
  auto rosReferenceManagerPtr = std::make_shared<RosReferenceManager>(robotName, interface.getReferenceManagerPtr());
  rosReferenceManagerPtr->subscribe(nodeHandle);

//设置MPC
SqpMpc mpc(interface.mpcSettings(), interface.sqpSettings(), interface.getOptimalControlProblem(), interface.getInitializer());
mpc.getSolverPtr()->setReferenceManager(rosReferenceManagerPtr);
mpc.getSolverPtr()->addSynchronizedModule(gaitReceiverPtr);

//启动节点
MPC_ROS_Interface mpcNode(mpc, robotName);

在GaitReceiver.cpp中,
订阅一个"robotName + _mpc_mode_schedule",用来接收步态数据
发布是在LeggedRobotGaitCommandNode.cpp中,读取gain.inf文件,选择步态发布出来

在RosReferenceManager.cpp中
订阅"robotName + _mode_schedule",这个应该没有用到
订阅"robotName + _mpc_target", 接收target并同步到referenceManager中
发布是在TargetTrajectoriesKeyboardPublisher.cpp中

2. MPC_ROS_Interface.cpp

robotName = "legged_robot"
会创建一个“robotName + mpc_observation”的topic订阅,MPC主要循环就在这个mpcObservationCallback回调函数里
发布是在MRT_ROS_Interface.cpp中

bool controllerIsUpdated = mpc_.run(currentObservation.time, currentObservation.state);

3. MPC_BASE.cpp

这里calculateController由具体求解器继承覆写,这里以sqp为例进入下一步

bool MPC_BASE::run(scalar_t currentTime, const vector_t& currentState)
{
  calculateController(currentTime, currentState, finalTime);
}

4. SqpMpc.h

这里覆写了calculateController

void calculateController(scalar_t initTime, const vector_t& initState, scalar_t finalTime) override {
    if (settings().coldStart_) {
      solverPtr_->reset();
    }
    solverPtr_->run(initTime, initState, finalTime);
  }

solverPtr新建了一个SqpSolver的求解器,它继承了SolverBase

5. SolverBase.cpp

这里面SqpSolver.cpp只覆写了runImpl,runImpl的讲解放在求解器详解里,这里不做展开

void SolverBase::run(scalar_t initTime, const vector_t& initState, scalar_t finalTime) {
  preRun(initTime, initState, finalTime);
  runImpl(initTime, initState, finalTime);
  postRun();
}
void SolverBase::preRun(scalar_t initTime, const vector_t& initState, scalar_t finalTime) {
  referenceManagerPtr_->preSolverRun(initTime, finalTime, initState);

  for (auto& module : synchronizedModules_) {
    module->preSolverRun(initTime, finalTime, initState, *referenceManagerPtr_);
  }
}

6. ReferenceManager.cpp

由5.中referenceManagerPtr_->preSolverRun()跳转进来
一般会有用户定义类去继承ReferenceManager

void ReferenceManager::preSolverRun(scalar_t initTime, scalar_t finalTime, const vector_t& initState) {
  targetTrajectories_.updateFromBuffer();
  modeSchedule_.updateFromBuffer();
  modifyReferences(initTime, finalTime, initState, targetTrajectories_.get(), modeSchedule_.get());
}

这里三个函数不再展开,举例见SwitchedModeleReferenceManager.cpp中

7. LoopshapingSynchronizedModule.cpp

由5.中module->preSolverRun()跳转进来

void LoopshapingSynchronizedModule::preSolverRun(scalar_t initTime, scalar_t finalTime, const vector_t& initState,
                                                 const ReferenceManagerInterface& referenceManager) {
  if (!synchronizedModulesPtrArray_.empty()) {
    const auto systemState = loopshapingDefinitionPtr_->getSystemState(initState);
    for (auto& module : synchronizedModulesPtrArray_) {
      module->preSolverRun(initTime, finalTime, systemState, referenceManager);
    }
  }
}

preSolverRun()一般会有上级继承并覆写,这里不再详解,举例见GaitReceiver.cpp中

8. SolverBase.cpp

postRun():

void SolverBase::postRun() {
  if (!synchronizedModules_.empty() || !solverObservers_.empty()) {
    const auto solution = primalSolution(getFinalTime());
    for (auto& module : synchronizedModules_) {
      module->postSolverRun(solution);
    }
    for (auto& observer : solverObservers_) {
      observer->extractTermConstraint(getOptimalControlProblem(), solution, getSolutionMetrics());
      observer->extractTermLagrangianMetrics(getOptimalControlProblem(), solution, getSolutionMetrics());
      if (getDualSolution() != nullptr) {
        observer->extractTermMultipliers(getOptimalControlProblem(), *getDualSolution());
      }
    }
  }
}

postSolverRun()见GaitReceiver.cpp中

9. SolverObserver.cpp

9.1 extractTermConstraint():

提取优化控制问题中与约束相关的数据,并通过constraintCallback_回调执行。

void SolverObserver::extractTermConstraint(const OptimalControlProblem& ocp, const PrimalSolution& primalSolution,
                                           const ProblemMetrics& problemMetrics) {
switch (type_) {
    case Type::Final: {
      const auto* termConstraintPtr = extractFinalTermConstraint(ocp, termName_, problemMetrics.final);
      termIsFound = termConstraintPtr != nullptr;
      if (termIsFound) {
        const scalar_array_t timeArray{primalSolution.timeTrajectory_.back()};
        const std::vector<std::reference_wrapper<const vector_t>> termConstraintArray{*termConstraintPtr};
        constraintCallback_(timeArray, termConstraintArray);
      }
      break;
    }
    case Type::PreJump: {
      std::vector<std::reference_wrapper<const vector_t>> termConstraintArray;
      termIsFound = extractPreJumpTermConstraint(ocp, termName_, problemMetrics.preJumps, termConstraintArray);
      if (termIsFound) {
        scalar_array_t timeArray(primalSolution.postEventIndices_.size());
        std::transform(primalSolution.postEventIndices_.cbegin(), primalSolution.postEventIndices_.cend(), timeArray.begin(),
                       [&](size_t postInd) -> scalar_t { return primalSolution.timeTrajectory_[postInd - 1]; });
        constraintCallback_(timeArray, termConstraintArray);
      }
      break;
    }
    case Type::Intermediate: {
      std::vector<std::reference_wrapper<const vector_t>> termConstraintArray;
      termIsFound = extractIntermediateTermConstraint(ocp, termName_, problemMetrics.intermediates, termConstraintArray);
      if (termIsFound) {
        constraintCallback_(primalSolution.timeTrajectory_, termConstraintArray);
      }
      break;
    }
}
  • Final:提取终端约束
ocp.finalEqualityConstraintPtr->getTermIndex(name, index)
ocp.finalInequalityConstraintPtr->getTermIndex(name, index)
return &metrics.stateIneqConstraint[index];
  • PreJump:提取跳跃前的约束
ocp.preJumpEqualityConstraintPtr->getTermIndex(name, index)
ocp.preJumpInequalityConstraintPtr->getTermIndex(name, index)
constraintArray.emplace_back(m.stateEqConstraint[index])
  • Intermediate: 提取中间约束
ocp.equalityConstraintPtr->getTermIndex(name, index)
ocp.stateEqualityConstraintPtr->getTermIndex(name, index)
ocp.inequalityConstraintPtr->getTermIndex(name, index)
ocp.stateInequalityConstraintPtr->getTermIndex(name, index)

9.2 extractTermLagrangianMetrics():

提取优化问题中与拉格朗日相关的数据,并通过lagrangianCallback_()回调执行

case Type::Final: {
      const auto* lagrangianMetricsPtr = extractFinalTermLagrangianMetrics(ocp, termName_, problemMetrics.final);
      termIsFound = lagrangianMetricsPtr != nullptr;
      if (termIsFound) {
        const scalar_array_t timeArray{primalSolution.timeTrajectory_.back()};
        const std::vector<LagrangianMetricsConstRef> termLagrangianMetricsArray{*lagrangianMetricsPtr};
        lagrangianCallback_(timeArray, termLagrangianMetricsArray);
      }
      break;
    }
    case Type::PreJump: {
      std::vector<LagrangianMetricsConstRef> termLagrangianMetricsArray;
      termIsFound = extractPreJumpTermLagrangianMetrics(ocp, termName_, problemMetrics.preJumps, termLagrangianMetricsArray);
      if (termIsFound) {
        scalar_array_t timeArray(primalSolution.postEventIndices_.size());
        std::transform(primalSolution.postEventIndices_.cbegin(), primalSolution.postEventIndices_.cend(), timeArray.begin(),
                       [&](size_t postInd) -> scalar_t { return primalSolution.timeTrajectory_[postInd - 1]; });
        lagrangianCallback_(timeArray, termLagrangianMetricsArray);
      }
      break;
    }
    case Type::Intermediate: {
      std::vector<LagrangianMetricsConstRef> termLagrangianMetricsArray;
      termIsFound = extractIntermediateTermLagrangianMetrics(ocp, termName_, problemMetrics.intermediates, termLagrangianMetricsArray);
      if (termIsFound) {
        lagrangianCallback_(primalSolution.timeTrajectory_, termLagrangianMetricsArray);
      }
      break;
    }
  • Final:
ocp.finalEqualityLagrangianPtr->getTermIndex(name, index)
ocp.finalInequalityLagrangianPtr->getTermIndex(name, index)
  • PreJump:
ocp.preJumpEqualityLagrangianPtr->getTermIndex(name, index)
ocp.preJumpInequalityLagrangianPtr->getTermIndex(name, index)
  • Intermediate:
ocp.equalityLagrangianPtr->getTermIndex(name, index)
ocp.stateEqualityLagrangianPtr->getTermIndex(name, index)
ocp.inequalityLagrangianPtr->getTermIndex(name, index)
ocp.stateInequalityLagrangianPtr->getTermIndex(name, index)

9.3 extractTermMultipliers():

提取优化问题中与乘子相关的数据,并通过multiplierCallback_()执行

case Type::Final: {
      const auto* multiplierPtr = extractFinalTermMultiplier(ocp, termName_, dualSolution.final);
      termIsFound = multiplierPtr != nullptr;
      if (termIsFound) {
        const scalar_array_t timeArray{dualSolution.timeTrajectory.back()};
        const std::vector<MultiplierConstRef> termMultiplierArray{*multiplierPtr};
        multiplierCallback_(timeArray, termMultiplierArray);
      }
      break;
    }
    case Type::PreJump: {
      std::vector<MultiplierConstRef> termMultiplierArray;
      termIsFound = extractPreJumpTermMultiplier(ocp, termName_, dualSolution.preJumps, termMultiplierArray);
      if (termIsFound) {
        scalar_array_t timeArray(dualSolution.postEventIndices.size());
        std::transform(dualSolution.postEventIndices.cbegin(), dualSolution.postEventIndices.cend(), timeArray.begin(),
                       [&](size_t postInd) -> scalar_t { return dualSolution.timeTrajectory[postInd - 1]; });
        multiplierCallback_(timeArray, termMultiplierArray);
      }
      break;
    }
    case Type::Intermediate: {
      std::vector<MultiplierConstRef> termMultiplierArray;
      termIsFound = extractIntermediateTermMultiplier(ocp, termName_, dualSolution.intermediates, termMultiplierArray);
      if (termIsFound) {
        multiplierCallback_(dualSolution.timeTrajectory, termMultiplierArray);
      }
      break;
    }
  • Final:
ocp.finalEqualityLagrangianPtr->getTermIndex(name, index)
ocp.finalInequalityLagrangianPtr->getTermIndex(name, index)
  • PreJump:
ocp.preJumpEqualityLagrangianPtr->getTermIndex(name, index)
ocp.preJumpInequalityLagrangianPtr->getTermIndex(name, index)
  • Intermediate:
ocp.equalityLagrangianPtr->getTermIndex(name, index)
ocp.stateEqualityLagrangianPtr->getTermIndex(name, index)
ocp.inequalityLagrangianPtr->getTermIndex(name, index)
ocp.stateInequalityLagrangianPtr->getTermIndex(name, index)
posted @ 2024-12-30 10:55  penuel  阅读(264)  评论(0)    收藏  举报