Project 2:Autograd for Algebraic Expressions
Introduction
Provide a function expression containing multiple variables, which includes operators such as +,-,*,\,(),^ and functions like ln(x),log(a,b),exp(x),pow(a,b),cos(x),sin(x),tan(x). For each variable in the expression, calculate the corresponding partial derivative.
Algorithm Specification
Declaration: Here, we only consider expressions that consist solely of +,-,*,\,(,),^ and do not include functions, and it is not required that the final expression needs to be simplified.
First, consider how to differentiate a certain variable. For problems involving expressions, the usual approach is to convert the expression into a postfix syntax tree and then process it recursively. The differentiation operation can be handled recursively here. Take addition as an example, \((f(x) + g(x))' = f'(x) + g'(x)\), which means to find \((f(x) + g(x))'\), one can first calculate \(f'(x)\) and \(g'(x)\), and then add the two together. Subtraction, multiplication, division, and exponentiation can all be achieved through similar operations for differentiation.
Then the entire process can be broken down into the following steps:
- Convert the obtained infix expression into a postfix expression, and then build a tree to obtain the postfix syntax tree.
- Perform the derivative processing on the postfix syntax tree to obtain the derivative-postfix syntax tree.
- Convert the derived postfix syntax tree back into infix expression and output it.
Get tokens and build a syntax tree
In the first step, we need to analyze the expression and break it down to obtain the tokens. Then, using these tokens, we generate the postfix expression. In the expression, tokens are symbols, variables, and constants. At this point, constants can also be regarded as a string first, and then converted into integers during the subsequent tree construction process. During the process of obtaining the tokens, one important point to note is that - has two meanings: negative sign and subtraction sign. We need to determine whether it is a negative sign or a subtraction sign based on its position. The process of obtaining the tokens is relatively simple; we only need to scan the expression and combine adjacent numbers and characters together.
During the process of converting an infix expression to a postfix expression, we can employ the classic stack method. That is, the tokens obtained just now are sequentially pushed into the stack, while the variables and constants are directily popped out of the stack. The symbols are pushed and popped into and out of the stack according to the priority order. One important point to note during this process is that +,-,*,\ are all left-associative operators, but ^ is a right-associative operator. Therefore, there are subtle differences in the priority of entry and exit from the stack for these two types. The specific difference lies in that the former allows for the same priority levels, while the latter does not.
After obtaining the postfix expression, building the tree becomes relatively easy. All we need to do is to set the left and right parts corresponding to the symbol as the left and right children of the symbol. During this process, we can determine the type of each node on the tree, that is, whether this node is a constant node, a variable node or a symbolic node, and record the corresponding value, variable name or symbol.
Defferentiate
After obtaining the postfix syntax tree, we can then proceed with the most crucial differentiation process. We can constrcut a new tree while thaversing thr syntax tree, and this new tree is the postfix syntax tree after differentiation.
Addition and Subtraction
Specifically, for + and -, we only need to create a new symbolic node corresponding to addtion or subtraction, and then recursively calculate the left and right children of the original syntax tree and set the left and right children to this newly created symbolic node.
Multiplication and Division
For *, according to \(\mathrm{d}(u\times v)=\mathrm{d}u\times v+u\times \mathrm{d}v\), that is, first add two multiplication nodes and set the corresponding left and right children, then add an addition node to connect these two multiplication nodes together.
As for /, based on the formula \(\mathrm{d}(\frac{u}{v})=\frac{\mathrm{d}u\times v-u\times \mathrm{d}v}{v^2}\), corresponding nodes can be established.
Power and Ln()
And ^ is a bit more complicated. The derivative of ^ is based on \(\mathrm{d}(u^v)=v\times u^{v-1}\times \mathrm{d}u+\ln(u)\times u^v\times \mathrm{d}v\), and we notice that the function \(\ln\) appears in this formula, which means we need to incorporate \(\ln\) into the construction of the syntax tree. Specifically, \(\ln\) can be regared as a new opertor, and it is a unary operator. The part being taken for the logarithm will be regarded as the child of the \(\ln\) node. In this way, we also need to calculate the derivative of \(\ln\). This is relatively simple because \(\mathrm{d}(\ln(u))=\frac{\mathrm{d}u}{u}\), so we only need to adopt a method similiar to that used when dealing with addition, subtraction, multiplication and division.
Convert syntax tree to string
After obtaining the syntax tree after differentiation, all we need to do is convert it into infix expression. This step only requires a simple inorder traversal. The only thing to note is the addition of parentheses, and this requires determining the priority of the corresponding notes' left and right children to do so.
Currently, we have achieved the differentiation for a specific variable. However, the expression in the problem involves multiple variables. Therefore, we need to first list all the variable names(this step is quite simple as we have already obtained the tokens). Then, we will sequentially traverse each variable to obtain the corresponding derived syntax tree.
Testing Results
Simple 1
Input:
x+2*x+x/2
Output:
x: 1+(0*x+2*1)+(1*2-x*0)/(2*2)
Simple 2
Input:
a+b*c^d
Output:
a: 1+(0*c^d+b*(d*c^(d-1)*0+ln(c)*c^d*0))
b: 0+(1*c^d+b*(d*c^(d-1)*0+ln(c)*c^d*0))
c: 0+(0*c^d+b*(d*c^(d-1)*1+ln(c)*c^d*0))
d: 0+(0*c^d+b*(d*c^(d-1)*0+ln(c)*c^d*1))
Analysis and Comments
Throughout the entire process, we were constantly traversing the tree, so time complexity is \(\mathcal O(N)\). And all the space we need is only for storing the syntax tree (as well as some related tokens, variable names, etc.), so the space complexity is also \(\mathcal O(N)\).
To further improve this algorithm, we can take two approaches. On one hand, the implementation of the derivative function of the function is enhanced. On the other hand, the final function expression can be simplified. The former can be achieved by adding a new node type, namely the function node, to participate in the construction of the syntax tree and the derivative process. The latter can simplify and combine some parts that can be optimized and reduced during the conversion of the syntax tree into and expression.
Appendix: Source Code(C++ 13)
#include<bits/stdc++.h>
using namespace std;
string InfixExpr;
enum treetype {const_node,var_node,op_node};//Syntax tree node type
struct Node
{
treetype type;
int val;
string var;
char op;
Node *left,*right;
Node (treetype t,int v=0,const string &vvar="",char oop='\0')
: type(t),val(v),var(vvar),op(oop),left(nullptr),right(nullptr){}
};//Syntax tree node structure
bool IsOp(char c) {return c=='+'||c=='-'||c=='*'||c=='/'||c=='^';}//Determine whether a character is an operator
bool IsVar(char c) {return isalpha(c)&&islower(c);}//Determine whether a character is part of a variable name (lowercase letters)
int GetPriority(char op)//Obtain operator priority
{
if (op=='+'||op=='-') return 1;
if (op=='*'||op=='/') return 2;
if (op=='^') return 3;
return 0;
}
vector<string> InfixToToken(string &expr)//Convert the infix expression to a token sequence
{
vector<string> tokens;
int n=expr.length();
int i=0;
while (i<n)
{
char c=expr[i];
if (isspace(c)) {i++;continue;}//Skip the spaces
if (isdigit(c)||(c=='-'&&(i==0||expr[i-1]=='('||IsOp(expr[i-1]))))//Handling numbers (integers)
{
int j=i;
if (c=='-') j++;//Skip the minus sign
while (j<n&&isdigit(expr[j])) j++;
if (c=='-'&&j>i+1)//Check if it is indeed a minus sign (followed by a number)
{
tokens.push_back(expr.substr(i,j-i));
i=j;
}
else if (isdigit(c))
{
tokens.push_back(expr.substr(i,j-i));
i=j;
}
else
{//A single minus sign is treated as an operator.
tokens.push_back(string(1,c));
i++;
}
continue;
}
if (IsVar(c))// Handle variable (lowercase letter sequence)
{
int j=i;
while (j<n&&IsVar(expr[j])) j++;
tokens.push_back(expr.substr(i,j-i));
i=j;
continue;
}
if (IsOp(c)||c=='('||c==')')// Handling operators and parentheses
{
tokens.push_back(string(1,c));
i++;
continue;
}
i++;
}
return tokens;
}
vector<string> TokensToPostfix(vector<string>& tokens)//Token to postfix conversion
{
vector<string> postfix;
stack<string> Opstack;
for (string &token:tokens)
{
if (isdigit(token[0])||(token.length()>1&&token[0]=='-'&&isdigit(token[1]))||isalpha(token[0])) postfix.push_back(token);// If it is an operand (a number or a variable)
else if (token=="(") Opstack.push(token);// If it is a left parenthesis
else if (token==")")// If it is a right parenthesis
{
while (!Opstack.empty()&&Opstack.top()!="(")
{
postfix.push_back(Opstack.top());
Opstack.pop();
}
if (!Opstack.empty()) Opstack.pop();
}
else // If it is an operator
{
while (!Opstack.empty()&&Opstack.top()!="(")
{
char topOp=Opstack.top()[0];
char curOp=token[0];
// Handling operator precedence and associativity
int topPriority=GetPriority(topOp);
int curPriority=GetPriority(curOp);
if (curOp!='^')// Left-associative operator
{
if (topPriority>=curPriority)
{
postfix.push_back(Opstack.top());
Opstack.pop();
}
else break;
}
else// Right-associative operator
{
if (topPriority>curPriority)
{
postfix.push_back(Opstack.top());
Opstack.pop();
}
else break;
}
}
Opstack.push(token);
}
}
while (!Opstack.empty())// Pop out the remaining operator
{
postfix.push_back(Opstack.top());
Opstack.pop();
}
return postfix;
}
Node *CreateNode(string &token)// Create syntax tree node
{
if (isdigit(token[0])||(token[0]=='-'&&token.length()>1&&isdigit(token[1]))) return new Node(const_node,atoi(token.c_str()));// Check if it is a number (including negative numbers)
else if (isalpha(token[0])) return new Node(var_node,0,token);// Check if it is a variable
else return new Node(op_node,0,"",token[0]);// Otherwise, it is an operator
}
Node *BuildExpressTree(vector<string> &post)// Build a syntax tree from the postfix expression
{
stack<Node *>NodeStack;
for (string &token:post)
{
Node *NewNode=CreateNode(token);
if (IsOp(token[0]))// If it is an operator, the operands need to be popped out.
{
if (NodeStack.size()<2&&token=="-")// Handling the case of a unary minus sign
{
Node *right=NodeStack.top();
NodeStack.pop();
NewNode->right=right;
NodeStack.push(NewNode);
continue;
}
Node *right=NodeStack.top();
NodeStack.pop();
Node *left=NodeStack.top();
NodeStack.pop();
NewNode->left=left;NewNode->right=right;
}
NodeStack.push(NewNode);
}
return NodeStack.empty()?nullptr:NodeStack.top();
}
void GetAllVar(Node *root,set<string> &Vars)//Collect all the variables in the expression
{
if (!root) return;
if (root->type==var_node) Vars.insert(root->var);
else if (root->type==op_node)
{
GetAllVar(root->left,Vars);
GetAllVar(root->right,Vars);
}
}
Node *CopyTree(Node *root)// Copy the syntax tree
{
if (!root) return nullptr;
Node *NewNode=new Node(root->type,root->val,root->var,root->op);
NewNode->left=CopyTree(root->left);
NewNode->right=CopyTree(root->right);
return NewNode;
}
Node *CreateConstNode(int v) {return new Node(const_node,v);}
Node *CreateVarNode(string &var) {return new Node(var_node,0,var);}
Node *CreateOpNode(char op,Node *left,Node *right)
{
Node *node=new Node(op_node,0,"",op);
node->left=left;node->right=right;
return node;
}
Node *CreateLnNode(Node *arg)
{
Node *NewNode=new Node(op_node,0,"",'l');
NewNode->right=arg;
return NewNode;
}
Node *Differentiate(Node *root,string &var)//Derivation of expressions
{
if (!root) return nullptr;
if (root->type==const_node) return CreateConstNode(0);// The constant derivative is 0
if (root->type==var_node)// Variable differentiation
{
if (root->var==var) return CreateConstNode(1);
else return CreateConstNode(0);
}
if (root->type==op_node)// Operator Node
{
Node *u=root->left,*v=root->right;
switch(root->op)
{
case '+': return CreateOpNode('+',Differentiate(u,var),Differentiate(v,var));//(u+v)'=u'+v'
case '-': return CreateOpNode('-',Differentiate(u,var),Differentiate(v,var));//(u-v)'=u'-v'
case '*'://(u*v)'=u'*v+u*v'
{
Node *firstpart=CreateOpNode('*',Differentiate(u,var),CopyTree(v));
Node *secondpart=CreateOpNode('*',CopyTree(u),Differentiate(v,var));
return CreateOpNode('+',firstpart,secondpart);
}
case '/'://(u/v)'=(u'*v-u*v')/(v*v)
{
Node *firstpart=CreateOpNode('*',Differentiate(u,var),CopyTree(v));
Node *secondpart=CreateOpNode('*',CopyTree(u),Differentiate(v,var));
Node *numerator=CreateOpNode('-',firstpart,secondpart);
Node *denominator=CreateOpNode('*',CopyTree(v),CopyTree(v));
return CreateOpNode('/',numerator,denominator);
}
case '^'://(u^v)'=v*u^(v-1)*u'+ln(u)*u^v*v'
{
Node *MinusOne=CreateConstNode(1);
Node *v_MinusOne=CreateOpNode('-',CopyTree(v),MinusOne);
Node *u_pow_v_minusone=CreateOpNode('^',CopyTree(u),v_MinusOne);
Node *v_mul_u_v_minusone=CreateOpNode('*',CopyTree(v),u_pow_v_minusone);
Node *firstpart=CreateOpNode('*',v_mul_u_v_minusone,Differentiate(u,var));
Node *ln_u=CreateLnNode(CopyTree(u));
Node *u_pow_v=CreateOpNode('^',CopyTree(u),CopyTree(v));
Node *ln_mul_pow=CreateOpNode('*',ln_u,u_pow_v);
Node *secondpart=CreateOpNode('*',ln_mul_pow,Differentiate(v,var));
return CreateOpNode('+',firstpart,secondpart);
}
case 'l'://(ln(u))'=u'/u
{
Node *one=CreateConstNode(1);
Node *one_div=CreateOpNode('/',one,CopyTree(v));
return CreateOpNode('*',one_div,Differentiate(v,var));
}
}
}
return CreateConstNode(0);
}
string TreeToString(Node *root)// Convert syntax tree to string
{
if (!root) return "";
if (root->type==const_node) return to_string(root->val);
if (root->type==var_node) return root->var;
if (root->type==op_node)
{
if (root->op=='l') return "ln("+TreeToString(root->right)+")";// Handle the ln function
char op=root->op;
string LeftString,RightString;
if (root->left) LeftString=TreeToString(root->left);
if (root->right) RightString=TreeToString(root->right);
bool LeftNeedParen=false,RightNeedParen=false;
if (root->left&&root->left->type==op_node)
if ((op=='*'||op=='/'||op=='^')&&(root->left->op=='+'||root->left->op=='-')) LeftNeedParen=true;
if (root->right&&root->right->type==op_node)
{
if (op=='+'||op=='-')
{if (root->right->op=='+'||root->left->op=='-') RightNeedParen=true;}
else if (op=='*')
{if (root->right->op=='+'||root->left->op=='-') RightNeedParen=true;}
else if (op=='/')
{if (root->right->op=='+'||root->right->op=='-'||root->right->op=='*'||root->right->op=='/') RightNeedParen=true;}
else if (op=='^')
{if (root->right->op=='+'||root->right->op=='-'||root->right->op=='*'||root->right->op=='/') RightNeedParen=true;}
}
string result="";
if (LeftNeedParen) result+="("+LeftString+")";
else result+=LeftString;
result+=string(1,op);
if (RightNeedParen) result+="("+RightString+")";
else result+=RightString;
return result;
}
return "";
}
void DeleteTree(Node *root)
{
if (!root) return;
DeleteTree(root->left);
DeleteTree(root->right);
delete(root);
}
int main()
{
getline(cin,InfixExpr);
vector<string> tokens=InfixToToken(InfixExpr);
vector<string> postfix=TokensToPostfix(tokens);
Node *ExpressTree=BuildExpressTree(postfix);
if (!ExpressTree)
{
printf("Error\n");
return 0;
}
set<string> AllVariale;
GetAllVar(ExpressTree,AllVariale);
vector<string> SortedVar(AllVariale.begin(),AllVariale.end());
sort(SortedVar.begin(),SortedVar.end());// Sort the variables in alphabetical order
for (string &var:SortedVar)// Calculate the derivative for each variable and output the result
{
Node *DerivativeTree=Differentiate(ExpressTree,var);
string DerivativeStr=TreeToString(DerivativeTree);
cout<<var<<": "<<DerivativeStr<<endl;
DeleteTree(DerivativeTree);//Clear the memory
}
DeleteTree(ExpressTree);//Clear the memory
return 0;
}
Declaration
I hereby declare that all the work done in this project titled Autograd for Algebraic Expressions is of my indepedent effort.

浙公网安备 33010602011771号