Listing 2: CSslCredentials class

... // other code not shown

// definition

class CSslCredentials

{

friend class CSslProvider;

protected:

CSslCredentials();

~CSslCredentials();

public:

HRESULT Obtain(BOOL bAsServer, LPCTSTR strPrincipal,

LPCTSTR strStore, BOOL bMutualAuth = FALSE,

DWORD dwProtoFlags = 0);

HRESULT Obtain(BOOL bAsServer, LPCTSTR strPrincipal,

HCERTSTORE hStore, BOOL bMutualAuth = FALSE,

DWORD dwProtoFlags = 0);

HRESULT Obtain(DWORD dwProtoFlags = 0);

HRESULT Attach(CredHandle hCredentials, BOOL bAsServer,

BOOL bMutualAuth = FALSE, DWORD dwProtoFlags = 0);

HRESULT Detach();

BOOL IsServer() {return m_bServer;}

BOOL MutualAuthRequired() {return m_bMutualAuth;}

DWORD GetProtoFlags() {return m_dwProtoFlags;}

BOOL HasCredHandle()

{

return SecIsValidHandle(&m_hCredentials);

}

PCredHandle GetHandle()

{

return (HasCredHandle() ? &m_hCredentials : NULL);

}

void CleanUp();

private:

HRESULT OpenSysStore(LPCTSTR strStore, HCERTSTORE& hStore);

HRESULT FindCertificate(HCERTSTORE hStore,

LPCTSTR strPrincipal, PCCERT_CONTEXT& certContext);

HRESULT ObtainImpl(PCCERT_CONTEXT certContext,

BOOL bAsServer, BOOL bMutualAuth, DWORD dwProtoFlags);

private:

CredHandle m_hCredentials;

BOOL m_bServer;

BOOL m_bMutualAuth;

DWORD m_dwProtoFlags;

HCERTSTORE m_hStore;

};

... // other code not shown

// implementation

HRESULT CSslCredentials::ObtainImpl(PCCERT_CONTEXT certContext,

BOOL bAsServer,

BOOL bMutualAuth,

DWORD dwProtoFlags)

{

// Build Schannel credentials structure.

SCHANNEL_CRED credSchannel = {0};

credSchannel.dwVersion = SCHANNEL_CRED_VERSION;

credSchannel.grbitEnabledProtocols = dwProtoFlags;

if (certContext != NULL)

{

credSchannel.cCreds = 1;

credSchannel.paCred = &certContext;

}

// Create SSL credentials.

TimeStamp tsExpires;

HRESULT hr = CSspiLib::AcquireCredentialsHandle(

NULL,

UNISP_NAME,

(bAsServer ?

SECPKG_CRED_INBOUND :

SECPKG_CRED_OUTBOUND),

NULL,

&credSchannel,

NULL,

NULL,

&m_hCredentials,

&tsExpires);

// See, if we have succeeded

if (SUCCEEDED(hr))

{

m_bServer = bAsServer;

if (m_bServer)

m_bMutualAuth = bMutualAuth;

m_dwProtoFlags = dwProtoFlags;

}

else

TRACE(_T("Line: %d. Error: 0x%08X\n"), __LINE__, hr);

return hr;

}

... // other code not shown

HRESULT CSslCredentials::OpenSysStore(LPCTSTR strStore,

HCERTSTORE& hStore)

{

CT2CA storeName(strStore);

m_hStore = ::CertOpenSystemStore(0, storeName);

if (m_hStore == NULL)

{

DWORD dwErrCode = ::GetLastError();

return HRESULT_FROM_WIN32(dwErrCode);

}

return S_OK;

}

HRESULT CSslCredentials::FindCertificate(

HCERTSTORE hStore,

LPCTSTR strPrincipal,

PCCERT_CONTEXT& certContext)

{

USES_CONVERSION;

CT2CA hostName(strPrincipal);

ASSERT(m_hStore != NULL);

certContext = ::CertFindCertificateInStore(

m_hStore,

X509_ASN_ENCODING | PKCS_7_ASN_ENCODING,

0,

CERT_FIND_SUBJECT_STR_A,

hostName,

NULL);

if (certContext == NULL)

{

DWORD dwErrCode = ::GetLastError();

return HRESULT_FROM_WIN32(dwErrCode);

}

return S_OK;

}