COM组件的类厂(COM技术内幕笔记之四)

    在上一篇中,介绍了怎么样用动态链接库去实现COM,但组件对我们来说仍是不透明的,我们需要知道实现组件DLL的位置,必须自己来加载组件的CreateInstance函数来获得组件的指针.在书中第一篇就曾经提到过:COM组件可以透明地在网络上(或本地)被重新分配位置,而不会影响本地客户程序.所以,由客户端来调用DLL并不是什么好主意.必须有一种更好的办法让组件的实现更透明,更灵活!
    于是,就引入了类厂的概念.什么是类厂,类厂也是一个接口,它的职责是帮我们创造组件的对象.并返回给客户程序一个接口的指针.每个组件都必须有一个与之相关的类厂,这个类厂知道怎么样创建组件.当客户请求一个组件对象的实例时,实际上这个请求交给了类厂,由类厂创建组件实例,然后把实例指针交给客户程序。这么说有点难明白.先看一个伪实例.
 1.实现二个接口IX,IY        (上二节中有详细介绍)
 2.实现一个组件CA,实现了IX,IY接口.    (上二节中有详细介绍)
 3.对于这个组件进行注册,把组件的信息加入到注册表中.
     实现DllRegisterServer和DllUnregisterServer函数.函数具体功能就是把本组件的CLSID,ProgID,DLL的位置放入注册表中.这样程序就可以通过查询注册表来获得组件的位置.
 4.创建本组件类厂的实例
class CFactory:public IClassFactory
{
 virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
 virtual ULONG   __stdcall AddRef();
 virtual ULONG   __stdcall Release();

 virtual HRESULT __stdcall CreateInstance(IUnknown* pUnknownOuter,
  const IID& iid,
  void** ppv);
}
 在类厂实例中,主要的功能就是CreateInstance了,这个函数就是创建组件的相应实例.看它的实现:
HRESULT __stdcall CFactory::CreateInstance(IUnknown* pUnknownOuter,const IID& iid,void** ppv)
{
   //...
 CA* pA = new CA;
 if(pA == NULL)
  return E_OUTOFMEMORY;
 HRESULT hr = pA->QueryInterface(iid,ppv);

 pA->Release();
 return hr;
}

 5.在这个组件的DLL中导出DllGetClassObject函数.这个函数的功能就是创建类厂的实例对象并查询接口.看其实现:
STDAPI DllGetClassObject(const CLSID& clsid,
       const IID& iid,
       void** ppv)
{
 //....
 CFactory* pFactory = new CFactory();

 if(pFactory == NULL)
  return E_OUTOFMEMORY;

 HRESULT hr = pFactory->QueryInterface(iid,ppv);
 pFactory->Release();
 return hr;
}

组件的实现差不多就这么多,下面在客户端怎么调用组件呢?这就需要用到COM函数库了,由COM函数库去查找注册表,调用组件的类厂,创建组件实例,返回接口.如下所示:
IUnknown* pUnk = NULL;
IX* iX = NULL;
CoInitialize(NULL);
CoCreateInstance(CLSID_Component1,CLSCTX_INPROC_SERVER,IID_IUnknown,(void**)&pUnk);
pUnk->QueryInterface(IID_IX,(void**)&iX);
pUnk->Release();
iX->Fx();
iX->Release();
CoUninitialize();

至于客户是通过CoCreateInstance怎么获得组件的类厂,创建组件实例的.下面摘录的一篇文章很清晰的说明了这一切:
-------------------------------------------------------------------------------------
这部分我们将构造一个创建COM组件的最小框架结构,然后看一看其内部处理流程是怎样的

COM组件的运行机制,即COM是怎么跑起来的。
    IUnknown *pUnk=NULL;
    IObject *pObject=NULL;
    CoInitialize(NULL);
    CoCreateInstance(CLSID_Object, CLSCTX_INPROC_SERVER, NULL, IID_IUnknown, (void**)&pUnk);
    pUnk->QueryInterface(IID_IOjbect, (void**)&pObject);
    pUnk->Release();
    pObject->Func();
    pObject->Release();
    CoUninitialize();
  CoCreateInstance身上,让我们来看看它内部做了一些什么事情。以下是它内部实现的一个伪代码:

    CoCreateInstance(....)
    {
      .......
      IClassFactory *pClassFactory=NULL;
      CoGetClassObject(CLSID_Object, CLSCTX_INPROC_SERVER, NULL, IID_IClassFactory, (void **)&pClassFactory);
      pClassFactory->CreateInstance(NULL, IID_IUnknown, (void**)&pUnk);
      pClassFactory->Release();
      ........
    }
 


  这段话的意思就是先得到类厂对象,再通过类厂创建组件从而得到IUnknown指针。
  继续深入一步,看看CoGetClassObject的内部伪码:

    CoGetClassObject(.....)
    {
      //通过查注册表CLSID_Object,得知组件DLL的位置、文件名
      //装入DLL库
      //使用函数GetProcAddress(...)得到DLL库中函数DllGetClassObject的函数指针。
      //调用DllGetClassObject
    } 


  DllGetClassObject是干什么的,它是用来获得类厂对象的。只有先得到类厂才能去创建组件.
  下面是DllGetClassObject的伪码:

   DllGetClassObject(...)
   {
      ......
      CFactory* pFactory= new CFactory; //类厂对象
      pFactory->QueryInterface(IID_IClassFactory, (void**)&pClassFactory);
      //查询IClassFactory指针
      pFactory->Release();
      ......
   }
CoGetClassObject的流程已经到此为止,现在返回CoCreateInstance,看看CreateInstance的伪码:
   CFactory::CreateInstance(.....)
   {
      ...........
      CObject *pObject = new CObject; //组件对象
      pObject->QueryInterface(IID_IUnknown, (void**)&pUnk);
      pObject->Release();
      ...........
   } 


  下图是从COM+技术内幕中COPY来的一个例图,从图中可以清楚的看到CoCreateInstance的整个流程。



接下来就写下完全的源代码,说明类厂的概念:

Component实现:(FacInterFace.dll)

//In FACE.H
#ifndef _IFACE_H
#define _IFACE_H

//
interfaces
interface IX:IUnknown
{
    
virtual void __stdcall Fx() = 0;
}
;

interface IY: IUnknown
{
    
virtual void __stdcall Fy() = 0;
}
;

interface IZ: IUnknown
{
    
virtual void __stdcall Fz() = 0;
}
;

//Forward references for GUIDs
extern
 "C"
{
    
extern const IID IID_IX;
    
extern const IID IID_IY;
    
extern const IID IID_IZ;
    
extern const CLSID CLSID_Component1;
}


extern
 "C
{
// {A33D4226-0F56-4e34-91F3-BF4F85761101}
static
 const IID IID_IX 
0xa33d42260xf560x4e340x910xf30xbf0x4f0x850x760x110x1 } };

// {41A5F090-B33A-4ae8-A1BB-EF2D0B4F8B0E}
static
 const IID IID_IY 
0x41a5f0900xb33a0x4ae80xa10xbb0xef0x2d0xb0x4f0x8b0xe } };

// {65411881-4E05-4b71-9CB5-943D5E0787C4}
static
 const IID IID_IZ 
0x654118810x4e050x4b710x9c0xb50x940x3d0x5e0x70x870xc4 } };
}


//组件的CLSID,每个组件都有唯一的CLSID,需要把此CLSID添加到注册表中去.如何添加,见Register.cpp文件.
// {282D8F98-BC89-43d5-9225-0B1BB479CBDE}
static
 const CLSID CLSID_Component1 
0x282d8f980xbc890x43d50x920x250xb0x1b0xb40x790xcb0xde } };

#endif


组件的注册:

//In Register.h
HRESULT RegisterServer(HMODULE hModule,               const CLSID& clsid,               const char* szFriendlyName,           const char* szVerIndProgID,           const char* szProgID);

HRESULT UnRegisterServer(
const CLSID& clsid,         const char* szVerIndProgID,         const char* szProgID);





//In Register.cpp
//此文件是如何注册组件的代码实现,是把CLSID,ProgID,Version,Dll位置添加到
//HKEY_CLASSES_ROOT/CLSID,HKEY_CLASSES_ROOT的子键中去.
#include <objbase.h>
#include 
<assert.h>
#include 
"Register.h"


//set the given key and its value;
BOOL setKeyAndValue(const char* pszPath,
                    
const char* szSubkey,
                    
const char* szValue);

//Convert a CLSID into a char string
void CLSIDtochar(const CLSID& clsid,
                 
char* szCLSID,
                 
int length);

//Delete szKeyChild and all of its descendents
LONG recursiveDeleteKey(HKEY hKeyParent,const char* szKeyChild);

//size of a CLSID as a string
const int CLSID_STRING_SIZE = 39;


//Register the component in the registry
HRESULT RegisterServer(HMODULE hModule,
                       
const CLSID& clsid,
                       
const char* szFriendlyName,
                       
const char* szVerIndProgID,
                       
const char* szProgID)
{
    
//Get the Server location
    char szModule[512];
    DWORD dwResult 
= ::GetModuleFileName(hModule,szModule,sizeof(szModule)/sizeof(char));
    assert(dwResult
!=0);

    
//Convert the CLSID into a char
    char szCLSID[CLSID_STRING_SIZE];
    CLSIDtochar(clsid,szCLSID,
sizeof(szCLSID));

    
//Build the key CLSID\\{}
    char szKey[64];
    strcpy(szKey,
"CLSID\\");
    strcat(szKey,szCLSID);

    
//Add the CLSID to the registry
    setKeyAndValue(szKey,NULL,szFriendlyName);

    
//Add the Server filename subkey under the CLSID key
    setKeyAndValue(szKey,"InprocServer32",szModule);

    setKeyAndValue(szKey,
"ProgID",szProgID);

    setKeyAndValue(szKey,
"VersionIndependentProgID",szVerIndProgID);

    
//Add the version-independent ProgID subkey under HKEY_CLASSES_ROOT
    setKeyAndValue(szVerIndProgID,NULL,szFriendlyName);
    setKeyAndValue(szVerIndProgID,
"CLSID",szCLSID);
    setKeyAndValue(szVerIndProgID,
"CurVer",szProgID);

    
//Add the versioned ProgID subkey under HKEY_CLASSES_ROOT
    setKeyAndValue(szProgID,NULL,szFriendlyName);
    setKeyAndValue(szProgID,
"CLSID",szCLSID);
    
return S_OK;
}


//
//Remove the component from the register
//
HRESULT UnRegisterServer(const CLSID& clsid,           // Class ID
                         const char* szVerIndProgID,   // Programmatic
                         const char* szProgID)           // IDs
{
    
//Convert the CLSID into a char.
    char szCLSID[CLSID_STRING_SIZE];
    CLSIDtochar(clsid,szCLSID,
sizeof(szCLSID));

    
//Build the key CLSID\\{}
    char szKey[64];
    strcpy(szKey,
"CLSID\\");
    strcat(szKey,szCLSID);

    
//Delete the CLSID key - CLSID\{}
    LONG lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szKey);
    assert((lResult 
== ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));

    
//Delete the version-independent ProgID Key
    lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szVerIndProgID);
    assert((lResult 
== ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));

    
//Delete the ProgID key.
    lResult = recursiveDeleteKey(HKEY_CLASSES_ROOT,szProgID);
    assert((lResult 
== ERROR_SUCCESS) || (lResult == ERROR_FILE_NOT_FOUND));

    
return S_OK;
}



//Convert a CLSID to a char string
void CLSIDtochar(const CLSID& clsid,
                 
char* szCLSID,
                 
int length)
{
    assert(length
>=CLSID_STRING_SIZE);

    
//Get CLSID
    LPOLESTR wszCLSID = NULL;
    HRESULT hr 
= StringFromCLSID(clsid,&wszCLSID);
    assert(SUCCEEDED(hr));

    
//Convert from wide characters to non_wide
    wcstombs(szCLSID,wszCLSID,length);
    
    
//Free memory
    CoTaskMemFree(wszCLSID);
}



//
// Delete a Key and all of its descendents
//
LONG recursiveDeleteKey(HKEY hKeyParent,const char* lpszKeyChild)
{
    
//Open the child.
    HKEY hKeyChild;
    LONG lRes 
= RegOpenKeyEx(hKeyParent,lpszKeyChild,0,KEY_ALL_ACCESS,&hKeyChild);

    
if(lRes != ERROR_SUCCESS)
        
return lRes;

    
//Enumerate all of the decendents of this child
    FILETIME time;
    
char szBuffer[256];
    DWORD dwSize 
= 256 ;
    
    
while(RegEnumKeyEx(hKeyChild,0,szBuffer,&dwSize,NULL,
        NULL,NULL,
&time) == S_OK)
    
{
        
//Delete the decendents of this child.
        lRes = recursiveDeleteKey(hKeyChild,szBuffer);
        
if(lRes != ERROR_SUCCESS)
        
{
            RegCloseKey(hKeyChild);
            
return lRes;
        }

        dwSize 
= 256;
    }

    RegCloseKey(hKeyChild);
    
return RegDeleteKey(hKeyParent,lpszKeyChild);
}


BOOL setKeyAndValue(
const char* szKey,
                    
const char* szSubkey,
                    
const char* szValue)
{
    HKEY hKey;
    
char szKeyBuf[1024];

    
//Copy keyname into buffer.
    strcpy(szKeyBuf,szKey);

    
//Add subkey name to buffer.
    if(szSubkey!=NULL)
    
{
        strcat(szKeyBuf,
"\\");
        strcat(szKeyBuf,szSubkey);
    }


    
// Create and open key and subkey.
    long lResult = RegCreateKeyEx(HKEY_CLASSES_ROOT ,
                                  szKeyBuf, 
                                  
0, NULL, REG_OPTION_NON_VOLATILE,
                                  KEY_ALL_ACCESS, NULL, 
                                  
&hKey, NULL) ;
    
if (lResult != ERROR_SUCCESS)

    
{
        
return FALSE ;
    }


    
// Set the Value.
    if (szValue != NULL)
    
{
        RegSetValueEx(hKey, NULL, 
0, REG_SZ, 
                      (BYTE 
*)szValue, 
                      strlen(szValue)
+1) ;
    }


    RegCloseKey(hKey) ;
    
return TRUE ;
}

 

组件的实现:

//CMPNT.cpp
//此文件是组件CA,组件类厂CFactory的实现,CA的实现与前面讲述的是一样的,关键在于多引入了
//一个CFactory,还有一个是全局函数DllGetClassObject,另外,除了要导出DllGetClassObject之
//外,还要导出三个函数,分别是DllCanUnloadNow / DllRegisterServer / DllUnregisterServer.
//还有一项工作就是在DllMain中保存模块的信息.
#include <iostream.h>
#include 
<objbase.h>

#include 
"..\MYIF2\IFACE.h"
#include 
"Register.h"


//#ifndef EXPORTAPI 
//#define EXPORTAPI extern "C" __declspec(dllexport)
//#endif

void trace(const char* msg){cout<<msg<<endl;}

//Gobal variables
static HMODULE g_hModule = NULL ; 
static long g_cComponents    =    0;            //Count of active components
static long g_cServerLocks    =    0;            //Count of locks


//Friendly name of component
const char g_szFriendlyName[] = "Inside COM.Chapter 7 Example";


//Version-independent ProgID
const char g_szVerIndProgID[] = "InsideCOM.Chap07";

//ProgID
const char g_szProgID[] = "InsideCOM.Chap07.1";


//Component
class CA:public IX,public IY
{
public:
    
//IUnknown
    virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
    
virtual ULONG   __stdcall AddRef();
    
virtual ULONG   __stdcall Release();

    
//Interface IX
    virtual void   __stdcall Fx(){cout<<"Fx function"<<endl;}

    
//Interface IY
    virtual void   __stdcall Fy(){cout<<"Fy function"<<endl;}

    
//Constructor
    CA();

    
//Destructor
    ~CA();

private:
    
long m_cRef;
}
;

        
CA::CA():m_cRef(
1)
{
    InterlockedIncrement(
&g_cComponents);
}


CA::
~CA()
{
    InterlockedDecrement(
&g_cComponents);
    trace(
"Component:\t\tDestory self");
}



//IUnknown implementation
HRESULT __stdcall CA::QueryInterface(const IID& iid,void** ppv)
{
    
if(iid == IID_IUnknown)
    
{
        
*ppv = static_cast<IX*>(this);
    }

    
else if(iid == IID_IX)
    
{
        
*ppv = static_cast<IX*>(this);
        trace(
"Component:\tReturn pointer to IX.");
    }

    
else if(iid == IID_IY)
    
{
        
*ppv = static_cast<IY*>(this);
        trace(
"Component:\tReturn pointer to IY.");
    }

    
else
    
{
        
*ppv = NULL;
        trace(
"Component:\tCannot Get pointer to IX/IY");
        
return E_NOINTERFACE;
    }

    reinterpret_cast
<IUnknown*>(*ppv)->AddRef();
    
return S_OK;
}


ULONG __stdcall CA::AddRef()
{
    
return InterlockedIncrement(&m_cRef);
}


ULONG __stdcall CA::Release()
{
    
if(InterlockedDecrement(&m_cRef)==0)
    
{
        delete 
this;
        
return 0;
    }

    
return m_cRef;
}



///////////////////////////////////////////
//class factory
///////////////////////////////////////////
class CFactory:public IClassFactory
{
public:
    
//IUnknown
    virtual HRESULT __stdcall QueryInterface(const IID& iid,void** ppv);
    
virtual ULONG   __stdcall AddRef();
    
virtual ULONG   __stdcall Release();

    
//Interface IClassFactory
    virtual HRESULT __stdcall CreateInstance(IUnknown* pUnknownOuter,
        
const IID& iid,
        
void** ppv);
    
virtual HRESULT __stdcall LockServer(BOOL bLock);

    
//Constructor
    CFactory():m_cRef(1){}
    
    
//Destructor
    ~CFactory() {trace("Class factory:\t\tDestory self.");}

private:
    
long m_cRef;
}
;

HRESULT __stdcall CFactory::QueryInterface(
const IID& iid,void** ppv)
{
    
if((iid == IID_IUnknown) || (iid == IID_IClassFactory))
    
{
        
*ppv= static_cast<IClassFactory*>(this);
    }

    
else
    
{
        
*ppv = NULL;
        
return E_NOINTERFACE;
    }

    reinterpret_cast
<IUnknown*>(*ppv)->AddRef();
    
return S_OK;
}


ULONG __stdcall CFactory::AddRef()
{
    
return InterlockedIncrement(&m_cRef);
}


ULONG __stdcall CFactory::Release()
{
    
if(InterlockedDecrement(&m_cRef)==0)
    
{
        delete 
this;
        
return 0;
    }

    
else
        
return m_cRef;
}


HRESULT __stdcall CFactory::CreateInstance(IUnknown
* pUnknownOuter,const IID& iid,void** ppv)
{
    trace(
"Class factory:\t\tCreate component.");

    
// Cannot aggregate.
    if (pUnknownOuter != NULL)
    
{
        
return CLASS_E_NOAGGREGATION ;
    }

    
//if(pUnknownOuter!=NULL)
    
//    return CLASS_E_NOA

    CA
* pA = new CA;
    
if(pA == NULL)
        
return E_OUTOFMEMORY;

    
//Get the request interface
    HRESULT hr = pA->QueryInterface(iid,ppv);

    pA
->Release();
    
return hr;
}


//LockServer
HRESULT __stdcall CFactory::LockServer(BOOL bLock)
{
    
if(bLock)
    
{
        InterlockedIncrement(
&g_cServerLocks);
    }

    
else
        InterlockedDecrement(
&g_cServerLocks);

    
return S_OK;
}




//Can Dll unload now?

int AddNum(int a,int b)
{
    
return a+b;
}


STDAPI DllCanUnloadNow()
{
    
if((g_cComponents ==0 ) && (g_cServerLocks==0))
        
return S_OK;
    
else
        
return S_FALSE;
}


STDAPI DllGetClassObject(
const CLSID& clsid,
                         
const IID& iid,
                         
void** ppv)
{
    trace(
"DllGetClassObject:\tCreate Class factory");

    
if(clsid != CLSID_Component1)
    
{
        
return CLASS_E_CLASSNOTAVAILABLE;
    }


    CFactory
* pFactory = new CFactory();

    
if(pFactory == NULL)
        
return E_OUTOFMEMORY;

    
//Get request interfaces
    HRESULT hr = pFactory->QueryInterface(iid,ppv);
    pFactory
->Release();
    
return hr;
}




//Server registration
STDAPI DllRegisterServer()
{
    
return RegisterServer(g_hModule,CLSID_Component1,g_szFriendlyName,
        g_szVerIndProgID,g_szProgID);
}


//Server unregistration
STDAPI DllUnregisterServer()
{
    
return UnRegisterServer(CLSID_Component1,g_szVerIndProgID,g_szProgID);
}


BOOL APIENTRY DllMain(HANDLE hModule,
                      DWORD dwReason,
                      
void* lpReserved)
{
    
if (dwReason == DLL_PROCESS_ATTACH)
    
{
        g_hModule 
= (HMODULE)hModule ;
    }

    
return TRUE ;
}



以上是组件的实现。下面是客户端的代码实现:
//In Client.cpp
int main()
{
    HRESULT hr;

    ::CoInitialize(NULL);
    trace(
"Call CoCreateInstance to Create");
    trace(
" componet and get interface IX");
    IX
* pIX = NULL;

              
    
                            
    hr 
= ::CoCreateInstance(CLSID_Component1,
        NULL,
        CLSCTX_INPROC_SERVER,
        IID_IX,
        (
void**)&pIX);

    
if(SUCCEEDED(hr))
    
{
        trace(
"Succeeded getting IX");
        pIX
->Fx();

        trace(
"Ask for Interface IY");
        IY
* pIY = NULL;
        hr 
= pIX->QueryInterface(IID_IY,(void**)&pIY);
        
if(SUCCEEDED(hr))
        
{
            trace(
"Succeeded getting IY");
            pIY
->Fy();
            pIY
->Release();
            trace(
"Release IY interface");
        }

        
else
        
{
            trace(
"Could not get interface IY.");
        }

        pIX
->Release();
    }

    
else
        cout
<<"Client: \t\tCould not create component hr="<<hex<<hr<<endl;

    CoUninitialize();
}


再在最后详述一篇,客户端调用CoCreateInstance,导致调用CoGetClassObject,CoGetClassObject通过查找注册表,得知DLL位置,文件名,然后调用DLL中DllGetClassObject,
DllGetClassObject的功能是返回CFactory的实例.
返回后,回到CoCreateInstance,通过CFactory的指针,调用
pClassFactory->CreaetInstance()创建组件实例.
这样就返回了组件实例的指针.
CoCreateInstace  -->  CoGetClassObject  --> DllGetClassObject --> Get CFactory*
                      <-------------------------------------------------------
                 -->  CFactory->CreateInstance(); --> Get IX* 
IX->Fx();
posted @ 2007-02-13 10:47  shipfi  阅读(7560)  评论(3编辑  收藏