pymssql.connect(server='.', user='', password='', database='', timeout=0, login_timeout=60, charset='UTF-8', as_dict=False, host='', appname=None, port='1433', conn_properties, autocommit=False, tds_
http://pymssql.org/en/stable/ref/pymssql.html
"""
This is an effort to convert the pymssql low-level C module to Cython.
"""
#
# _mssql.pyx
#
# Copyright (C) 2003 Joon-cheol Park <jooncheol@gmail.com>
# 2008 Andrzej Kukula <akukula@gmail.com>
# 2009-2010 Damien Churchill <damoxc@gmail.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
# MA 02110-1301 USA
#
DEF PYMSSQL_DEBUG = 0
DEF PYMSSQL_DEBUG_ERRORS = 0
DEF PYMSSQL_CHARSETBUFSIZE = 100
DEF MSSQLDB_MSGSIZE = 1024
DEF PYMSSQL_MSGSIZE = (MSSQLDB_MSGSIZE * 8)
DEF EXCOMM = 9
# Provide constants missing in FreeTDS 0.82 so that we can build against it
DEF DBVERSION_71 = 5
DEF DBVERSION_72 = 6
ROW_FORMAT_TUPLE = 1
ROW_FORMAT_DICT = 2
cdef int _ROW_FORMAT_TUPLE = ROW_FORMAT_TUPLE
cdef int _ROW_FORMAT_DICT = ROW_FORMAT_DICT
from cpython cimport PY_MAJOR_VERSION, PY_MINOR_VERSION
from collections import Iterable
import os
import sys
import socket
import decimal
import binascii
import datetime
import re
import uuid
from sqlfront cimport *
from libc.stdio cimport fprintf, snprintf, stderr, FILE
from libc.string cimport strlen, strncpy, memcpy
from cpython cimport bool
from cpython.mem cimport PyMem_Malloc, PyMem_Free
from cpython.long cimport PY_LONG_LONG
from cpython.ref cimport Py_INCREF
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
cdef extern from "pymssql_version.h":
const char *PYMSSQL_VERSION
cdef extern from "cpp_helpers.h":
cdef bint FREETDS_SUPPORTS_DBSETLDBNAME
# Vars to store messages from the server in
cdef int _mssql_last_msg_no = 0
cdef int _mssql_last_msg_severity = 0
cdef int _mssql_last_msg_state = 0
cdef int _mssql_last_msg_line = 0
cdef char *_mssql_last_msg_str = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
_mssql_last_msg_str[0] = <char>0
cdef char *_mssql_last_msg_srv = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
_mssql_last_msg_srv[0] = <char>0
cdef char *_mssql_last_msg_proc = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
_mssql_last_msg_proc[0] = <char>0
IF PYMSSQL_DEBUG == 1:
cdef int _row_count = 0
cdef bytes HOSTNAME = socket.gethostname().encode('utf-8')
# List to store the connection objects in
cdef list connection_object_list = list()
# Store the 32bit int limit values
cdef int MAX_INT = 2147483647
cdef int MIN_INT = -2147483648
# Store the module version
__full_version__ = PYMSSQL_VERSION.decode('ascii')
__version__ = '.'.join(__full_version__.split('.')[:3]) # drop '.dev' from 'X.Y.Z.dev'
#############################
## DB-API type definitions ##
#############################
STRING = 1
BINARY = 2
NUMBER = 3
DATETIME = 4
DECIMAL = 5
##################
## DB-LIB types ##
##################
SQLBINARY = SYBBINARY
SQLBIT = SYBBIT
SQLBITN = 104
SQLCHAR = SYBCHAR
SQLDATETIME = SYBDATETIME
SQLDATETIM4 = SYBDATETIME4
SQLDATETIMN = SYBDATETIMN
SQLDECIMAL = SYBDECIMAL
SQLFLT4 = SYBREAL
SQLFLT8 = SYBFLT8
SQLFLTN = SYBFLTN
SQLIMAGE = SYBIMAGE
SQLINT1 = SYBINT1
SQLINT2 = SYBINT2
SQLINT4 = SYBINT4
SQLINT8 = SYBINT8
SQLINTN = SYBINTN
SQLMONEY = SYBMONEY
SQLMONEY4 = SYBMONEY4
SQLMONEYN = SYBMONEYN
SQLNUMERIC = SYBNUMERIC
SQLREAL = SYBREAL
SQLTEXT = SYBTEXT
SQLVARBINARY = SYBVARBINARY
SQLVARCHAR = SYBVARCHAR
SQLUUID = 36
#######################
## Exception classes ##
#######################
cdef extern from "pyerrors.h":
ctypedef class __builtin__.Exception [object PyBaseExceptionObject]:
pass
cdef class MSSQLException(Exception):
"""
Base exception class for the MSSQL driver.
"""
cdef class MSSQLDriverException(MSSQLException):
"""
Inherits from the base class and raised when an error is caused within
the driver itself.
"""
cdef class MSSQLDatabaseException(MSSQLException):
"""
Raised when an error occurs within the database.
"""
cdef readonly int number
cdef readonly int severity
cdef readonly int state
cdef readonly int line
cdef readonly char *text
cdef readonly char *srvname
cdef readonly char *procname
property message:
def __get__(self):
if self.procname:
return 'SQL Server message %d, severity %d, state %d, ' \
'procedure %s, line %d:\n%s' % (self.number,
self.severity, self.state, self.procname,
self.line, self.text)
else:
return 'SQL Server message %d, severity %d, state %d, ' \
'line %d:\n%s' % (self.number, self.severity,
self.state, self.line, self.text)
# Module attributes for configuring _mssql
login_timeout = 60
min_error_severity = 6
wait_callback = None
def set_wait_callback(a_callable):
global wait_callback
wait_callback = a_callable
# Buffer size for large numbers
DEF NUMERIC_BUF_SZ = 45
cdef bytes ensure_bytes(s, encoding='utf-8'):
try:
decoded = s.decode(encoding)
return decoded.encode(encoding)
except AttributeError:
return s.encode(encoding)
cdef void log(char * message, ...):
if PYMSSQL_DEBUG == 1:
fprintf(stderr, "+++ %s\n", message)
if PY_MAJOR_VERSION == '3':
string_types = str,
else:
string_types = basestring,
###################
## Error Handler ##
###################
cdef int err_handler(DBPROCESS *dbproc, int severity, int dberr, int oserr,
char *dberrstr, char *oserrstr) with gil:
cdef char *mssql_lastmsgstr
cdef int *mssql_lastmsgno
cdef int *mssql_lastmsgseverity
cdef int *mssql_lastmsgstate
cdef int _min_error_severity = min_error_severity
cdef char mssql_message[PYMSSQL_MSGSIZE]
if severity < _min_error_severity:
return INT_CANCEL
if dberrstr == NULL:
dberrstr = ''
if oserrstr == NULL:
oserrstr = ''
IF PYMSSQL_DEBUG == 1 or PYMSSQL_DEBUG_ERRORS == 1:
fprintf(stderr, "\n*** err_handler(dbproc = %p, severity = %d, " \
"dberr = %d, oserr = %d, dberrstr = '%s', oserrstr = '%s'); " \
"DBDEAD(dbproc) = %d\n", <void *>dbproc, severity, dberr,
oserr, dberrstr, oserrstr, DBDEAD(dbproc));
fprintf(stderr, "*** previous max severity = %d\n\n",
_mssql_last_msg_severity);
mssql_lastmsgstr = _mssql_last_msg_str
mssql_lastmsgno = &_mssql_last_msg_no
mssql_lastmsgseverity = &_mssql_last_msg_severity
mssql_lastmsgstate = &_mssql_last_msg_state
for conn in connection_object_list:
if dbproc != (<MSSQLConnection>conn).dbproc:
continue
mssql_lastmsgstr = (<MSSQLConnection>conn).last_msg_str
mssql_lastmsgno = &(<MSSQLConnection>conn).last_msg_no
mssql_lastmsgseverity = &(<MSSQLConnection>conn).last_msg_severity
mssql_lastmsgstate = &(<MSSQLConnection>conn).last_msg_state
if DBDEAD(dbproc):
log("+++ err_handler: dbproc is dead; killing conn...\n")
conn.mark_disconnected()
break
if severity > mssql_lastmsgseverity[0]:
mssql_lastmsgseverity[0] = severity
mssql_lastmsgno[0] = dberr
mssql_lastmsgstate[0] = oserr
if oserr != DBNOERR and oserr != 0:
if severity == EXCOMM:
snprintf(
mssql_message, sizeof(mssql_message),
'%sDB-Lib error message %d, severity %d:\n%s\nNet-Lib error during %s (%d)\n',
mssql_lastmsgstr, dberr, severity, dberrstr, oserrstr, oserr)
else:
snprintf(
mssql_message, sizeof(mssql_message),
'%sDB-Lib error message %d, severity %d:\n%s\nOperating System error during %s (%d)\n',
mssql_lastmsgstr, dberr, severity, dberrstr, oserrstr, oserr)
else:
snprintf(
mssql_message, sizeof(mssql_message),
'%sDB-Lib error message %d, severity %d:\n%s\n',
mssql_lastmsgstr, dberr, severity, dberrstr)
strncpy(mssql_lastmsgstr, mssql_message, PYMSSQL_MSGSIZE)
mssql_lastmsgstr[ PYMSSQL_MSGSIZE - 1 ] = '\0'
return INT_CANCEL
#####################
## Message Handler ##
#####################
cdef int msg_handler(DBPROCESS *dbproc, DBINT msgno, int msgstate,
int severity, char *msgtext, char *srvname, char *procname,
LINE_T line) with gil:
cdef int *mssql_lastmsgno
cdef int *mssql_lastmsgseverity
cdef int *mssql_lastmsgstate
cdef int *mssql_lastmsgline
cdef char *mssql_lastmsgstr
cdef char *mssql_lastmsgsrv
cdef char *mssql_lastmsgproc
cdef int _min_error_severity = min_error_severity
cdef MSSQLConnection conn = None
IF PYMSSQL_DEBUG == 1:
fprintf(stderr, "\n+++ msg_handler(dbproc = %p, msgno = %d, " \
"msgstate = %d, severity = %d, msgtext = '%s', " \
"srvname = '%s', procname = '%s', line = %d)\n",
dbproc, msgno, msgstate, severity, msgtext, srvname,
procname, line);
fprintf(stderr, "+++ previous max severity = %d\n\n",
_mssql_last_msg_severity);
for cnx in connection_object_list:
if (<MSSQLConnection>cnx).dbproc != dbproc:
continue
conn = <MSSQLConnection>cnx
break
if conn is not None and conn.msghandler is not None:
conn.msghandler(msgstate, severity, srvname, procname, line, msgtext)
if severity < _min_error_severity:
return INT_CANCEL
if conn is not None:
mssql_lastmsgstr = conn.last_msg_str
mssql_lastmsgsrv = conn.last_msg_srv
mssql_lastmsgproc = conn.last_msg_proc
mssql_lastmsgno = &conn.last_msg_no
mssql_lastmsgseverity = &conn.last_msg_severity
mssql_lastmsgstate = &conn.last_msg_state
mssql_lastmsgline = &conn.last_msg_line
else:
mssql_lastmsgstr = _mssql_last_msg_str
mssql_lastmsgsrv = _mssql_last_msg_srv
mssql_lastmsgproc = _mssql_last_msg_proc
mssql_lastmsgno = &_mssql_last_msg_no
mssql_lastmsgseverity = &_mssql_last_msg_severity
mssql_lastmsgstate = &_mssql_last_msg_state
mssql_lastmsgline = &_mssql_last_msg_line
# Calculate the maximum severity of all messages in a row
# Fill the remaining fields as this is going to raise the exception
if severity > mssql_lastmsgseverity[0]:
mssql_lastmsgseverity[0] = severity
mssql_lastmsgno[0] = msgno
mssql_lastmsgstate[0] = msgstate
mssql_lastmsgline[0] = line
strncpy(mssql_lastmsgstr, msgtext, PYMSSQL_MSGSIZE)
strncpy(mssql_lastmsgsrv, srvname, PYMSSQL_MSGSIZE)
strncpy(mssql_lastmsgproc, procname, PYMSSQL_MSGSIZE)
return 0
cdef int db_sqlexec(DBPROCESS *dbproc):
cdef RETCODE rtc
# The dbsqlsend function sends Transact-SQL statements, stored in the
# command buffer of the DBPROCESS, to SQL Server.
#
# It does not wait for a response. This gives us an opportunity to do other
# things while waiting for the server response.
#
# After dbsqlsend returns SUCCEED, dbsqlok must be called to verify the
# accuracy of the command batch. Then dbresults can be called to process
# the results.
with nogil:
rtc = dbsqlsend(dbproc)
if rtc != SUCCEED:
return rtc
# If we've reached here, dbsqlsend didn't fail so the query is in progress.
# Wait for results to come back and return the return code, optionally
# calling wait_callback first...
return db_sqlok(dbproc)
cdef int db_sqlok(DBPROCESS *dbproc):
cdef RETCODE rtc
# If there is a wait callback, call it with the file descriptor we're
# waiting on.
# The wait_callback is a good place to do things like yield to another
# gevent greenlet -- e.g.: gevent.socket.wait_read(read_fileno)
if wait_callback:
read_fileno = dbiordesc(dbproc)
wait_callback(read_fileno)
# dbsqlok following dbsqlsend is the equivalent of dbsqlexec. This function
# must be called after dbsqlsend returns SUCCEED. When dbsqlok returns,
# then dbresults can be called to process the results.
with nogil:
rtc = dbsqlok(dbproc)
return rtc
cdef void clr_err(MSSQLConnection conn):
if conn != None:
conn.last_msg_no = 0
conn.last_msg_severity = 0
conn.last_msg_state = 0
conn.last_msg_str[0] = 0
else:
_mssql_last_msg_no = 0
_mssql_last_msg_severity = 0
_mssql_last_msg_state = 0
_mssql_last_msg_str[0] = 0
cdef RETCODE db_cancel(MSSQLConnection conn):
cdef RETCODE rtc
if conn == None:
return SUCCEED
if conn.dbproc == NULL:
return SUCCEED
with nogil:
rtc = dbcancel(conn.dbproc);
conn.clear_metadata()
return rtc
##############################
## MSSQL Row Iterator Class ##
##############################
cdef class MSSQLRowIterator:
def __init__(self, connection, int row_format):
self.conn = connection
self.row_format = row_format
def __iter__(self):
return self
def __next__(self):
assert_connected(self.conn)
clr_err(self.conn)
return self.conn.fetch_next_row(1, self.row_format)
############################
## MSSQL Connection Class ##
############################
cdef class MSSQLConnection:
property charset:
"""
The current encoding in use.
"""
def __get__(self):
if strlen(self._charset):
return self._charset.decode('ascii') if PY_MAJOR_VERSION == 3 else self._charset
return None
property connected:
"""
True if the connection to a database is open.
"""
def __get__(self):
return self._connected
property identity:
"""
Returns identity value of the last inserted row. If the previous
operation did not involve inserting a row into a table with an
identity column, None is returned.
** Usage **
>>> conn.execute_non_query("INSERT INTO table (name) VALUES ('John')")
>>> print 'Last inserted row has ID = %s' % conn.identity
Last inserted row has ID = 178
"""
def __get__(self):
return self.execute_scalar('SELECT SCOPE_IDENTITY()')
property query_timeout:
"""
A
"""
def __get__(self):
return self._query_timeout
def __set__(self, value):
cdef int val = int(value)
cdef RETCODE rtc
if val < 0:
raise ValueError("The 'query_timeout' attribute must be >= 0.")
# XXX: Currently this will set it application wide :-(
rtc = dbsettime(val)
check_and_raise(rtc, self)
# if all is fine then set our attribute
self._query_timeout = val
property rows_affected:
"""
Number of rows affected by last query. For SELECT statements this
value is only meaningful after reading all rows.
"""
def __get__(self):
return self._rows_affected
property tds_version:
"""
Returns what TDS version the connection is using.
"""
def __get__(self):
cdef int version = dbtds(self.dbproc)
if version == 11:
return 7.3
elif version == 10:
return 7.2
elif version == 9:
return 8.0 # Actually 7.1, return 8.0 to keep backward compatibility
elif version == 8:
return 7.0
elif version == 6:
return 5.0
elif version == 4:
return 4.2
def __cinit__(self):
log("_mssql.MSSQLConnection.__cinit__()")
self._connected = 0
self._charset = <char *>PyMem_Malloc(PYMSSQL_CHARSETBUFSIZE)
self._charset[0] = <char>0
self.last_msg_str = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
self.last_msg_str[0] = <char>0
self.last_msg_srv = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
self.last_msg_srv[0] = <char>0
self.last_msg_proc = <char *>PyMem_Malloc(PYMSSQL_MSGSIZE)
self.last_msg_proc[0] = <char>0
self.column_names = None
self.column_types = None
def __init__(self, server="localhost", user="sa", password="",
charset='UTF-8', database='', appname=None, port='1433', tds_version='7.1', conn_properties=None):
log("_mssql.MSSQLConnection.__init__()")
cdef LOGINREC *login
cdef RETCODE rtc
cdef char *_charset
# support MS methods of connecting locally
instance = ""
if "\\" in server:
server, instance = server.split("\\")
if server in (".", "(local)"):
server = "localhost"
server = server + "\\" + instance if instance else server
login = dblogin()
if login == NULL:
raise MSSQLDriverException("dblogin() failed")
appname = appname or "pymssql=%s" % __full_version__
# For Python 3, we need to convert unicode to byte strings
cdef bytes user_bytes = user.encode('utf-8')
cdef char *user_cstr = user_bytes
cdef bytes password_bytes = password.encode('utf-8')
cdef char *password_cstr = password_bytes
cdef bytes appname_bytes = appname.encode('utf-8')
cdef char *appname_cstr = appname_bytes
DBSETLUSER(login, user_cstr)
DBSETLPWD(login, password_cstr)
DBSETLAPP(login, appname_cstr)
DBSETLVERSION(login, _tds_ver_str_to_constant(tds_version))
# add the port to the server string if it doesn't have one already and
# if we are not using an instance
if ':' not in server and not instance:
server = '%s:%s' % (server, port)
# override the HOST to be the portion without the server, otherwise
# FreeTDS chokes when server still has the port definition.
# BUT, a patch on the mailing list fixes the need for this. I am
# leaving it here just to remind us how to fix the problem if the bug
# doesn't get fixed for a while. But if it does get fixed, this code
# can be deleted.
# patch: http://lists.ibiblio.org/pipermail/freetds/2011q2/026997.html
#if ':' in server:
# os.environ['TDSHOST'] = server.split(':', 1)[0]
#else:
# os.environ['TDSHOST'] = server
# Add ourselves to the global connection list
connection_object_list.append(self)
cdef bytes charset_bytes
# Set the character set name
if charset:
charset_bytes = charset.encode('utf-8')
_charset = charset_bytes
strncpy(self._charset, _charset, PYMSSQL_CHARSETBUFSIZE)
DBSETLCHARSET(login, self._charset)
# For Python 3, we need to convert unicode to byte strings
cdef bytes dbname_bytes
cdef char *dbname_cstr
# Put the DB name in the login LOGINREC because it helps with connections to Azure
if database:
if FREETDS_SUPPORTS_DBSETLDBNAME:
dbname_bytes = database.encode('ascii')
dbname_cstr = dbname_bytes
DBSETLDBNAME(login, dbname_cstr)
else:
log("_mssql.MSSQLConnection.__init__(): Warning: This version of FreeTDS doesn't support selecting the DB name when setting up the connection. This will keep connections to Azure from working.")
# Set the login timeout
# XXX: Currently this will set it application wide :-(
dbsetlogintime(login_timeout)
cdef bytes server_bytes = server.encode('utf-8')
cdef char *server_cstr = server_bytes
# Connect to the server
with nogil:
self.dbproc = dbopen(login, server_cstr)
# Frees the login record, can be called immediately after dbopen.
dbloginfree(login)
if self.dbproc == NULL:
log("_mssql.MSSQLConnection.__init__() -> dbopen() returned NULL")
connection_object_list.remove(self)
maybe_raise_MSSQLDatabaseException(None)
raise MSSQLDriverException("Connection to the database failed for an unknown reason.")
self._connected = 1
if conn_properties is None:
conn_properties = \
"SET ARITHABORT ON;" \
"SET CONCAT_NULL_YIELDS_NULL ON;" \
"SET ANSI_NULLS ON;" \
"SET ANSI_NULL_DFLT_ON ON;" \
"SET ANSI_PADDING ON;" \
"SET ANSI_WARNINGS ON;" \
"SET ANSI_NULL_DFLT_ON ON;" \
"SET CURSOR_CLOSE_ON_COMMIT ON;" \
"SET QUOTED_IDENTIFIER ON;" \
"SET TEXTSIZE 2147483647;" # http://msdn.microsoft.com/en-us/library/aa259190%28v=sql.80%29.aspx
elif isinstance(conn_properties, Iterable) and not isinstance(conn_properties, string_types):
conn_properties = ' '.join(conn_properties)
cdef bytes conn_props_bytes
cdef char *conn_props_cstr
if conn_properties:
log("_mssql.MSSQLConnection.__init__() -> dbcmd() setting connection values")
# Set connection properties, some reasonable values are used by
# default but they can be customized
conn_props_bytes = conn_properties.encode(charset)
conn_props_cstr = conn_props_bytes
dbcmd(self.dbproc, conn_props_bytes)
rtc = db_sqlexec(self.dbproc)
if (rtc == FAIL):
raise MSSQLDriverException("Could not set connection properties")
db_cancel(self)
clr_err(self)
if database:
self.select_db(database)
def __dealloc__(self):
log("_mssql.MSSQLConnection.__dealloc__()")
self.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def __iter__(self):
assert_connected(self)
clr_err(self)
return MSSQLRowIterator(self, ROW_FORMAT_DICT)
cpdef set_msghandler(self, object handler):
"""
set_msghandler(handler) -- set the msghandler for the connection
This function allows setting a msghandler for the connection to
allow a client to gain access to the messages returned from the
server.
"""
self.msghandler = handler
cpdef cancel(self):
"""
cancel() -- cancel all pending results.
This function cancels all pending results from the last SQL operation.
It can be called more than once in a row. No exception is raised in
this case.
"""
log("_mssql.MSSQLConnection.cancel()")
cdef RETCODE rtc
assert_connected(self)
clr_err(self)
rtc = db_cancel(self)
check_and_raise(rtc, self)
cdef void clear_metadata(self):
log("_mssql.MSSQLConnection.clear_metadata()")
self.column_names = None
self.column_types = None
self.num_columns = 0
self.last_dbresults = 0
def close(self):
"""
close() -- close connection to an MS SQL Server.
This function tries to close the connection and free all memory used.
It can be called more than once in a row. No exception is raised in
this case.
"""
log("_mssql.MSSQLConnection.close()")
if self == None:
return None
if not self._connected:
return None
clr_err(self)
with nogil:
dbclose(self.dbproc)
self.mark_disconnected()
def mark_disconnected(self):
log("_mssql.MSSQLConnection.mark_disconnected()")
self.dbproc = NULL
self._connected = 0
PyMem_Free(self.last_msg_proc)
PyMem_Free(self.last_msg_srv)
PyMem_Free(self.last_msg_str)
PyMem_Free(self._charset)
connection_object_list.remove(self)
cdef object convert_db_value(self, BYTE *data, int dbtype, int length):
log("_mssql.MSSQLConnection.convert_db_value()")
cdef char buf[NUMERIC_BUF_SZ] # buffer in which we store text rep of bug nums
cdef int converted_length
cdef long prevPrecision
cdef BYTE precision
cdef DBDATEREC di
cdef DBDATETIME dt
cdef DBCOL dbcol
IF PYMSSQL_DEBUG == 1:
sys.stderr.write("convert_db_value: dbtype = %d; length = %d\n" % (dbtype, length))
if dbtype == SQLBIT:
return bool(<int>(<DBBIT *>data)[0])
elif dbtype == SQLINT1:
return int(<int>(<DBTINYINT *>data)[0])
elif dbtype == SQLINT2:
return int(<int>(<DBSMALLINT *>data)[0])
elif dbtype == SQLINT4:
return int(<int>(<DBINT *>data)[0])
elif dbtype == SQLINT8:
return long(<PY_LONG_LONG>(<PY_LONG_LONG *>data)[0])
elif dbtype == SQLFLT4:
return float(<float>(<DBREAL *>data)[0])
elif dbtype == SQLFLT8:
return float(<double>(<DBFLT8 *>data)[0])
elif dbtype in (SQLMONEY, SQLMONEY4, SQLNUMERIC, SQLDECIMAL):
dbcol.SizeOfStruct = sizeof(dbcol)
if dbtype in (SQLMONEY, SQLMONEY4):
precision = 4
else:
precision = 0
converted_length = dbconvert(self.dbproc, dbtype, data, -1, SQLCHAR,
<BYTE *>buf, NUMERIC_BUF_SZ)
with decimal.localcontext() as ctx:
# Python 3 doesn't like decimal.localcontext() with prec == 0
ctx.prec = precision if precision > 0 else 1
return decimal.Decimal(_remove_locale(buf, converted_length).decode(self._charset))
elif dbtype == SQLDATETIM4:
dbconvert(self.dbproc, dbtype, data, -1, SQLDATETIME,
<BYTE *>&dt, -1)
dbdatecrack(self.dbproc, &di, <DBDATETIME *><BYTE *>&dt)
return datetime.datetime(di.year, di.month, di.day,
di.hour, di.minute, di.second, di.millisecond * 1000)
elif dbtype == SQLDATETIME:
dbdatecrack(self.dbproc, &di, <DBDATETIME *>data)
return datetime.datetime(di.year, di.month, di.day,
di.hour, di.minute, di.second, di.millisecond * 1000)
elif dbtype in (SQLVARCHAR, SQLCHAR, SQLTEXT):
if strlen(self._charset):
return (<char *>data)[:length].decode(self._charset)
else:
return (<char *>data)[:length]
elif dbtype == SQLUUID:
return uuid.UUID(bytes_le=(<char *>data)[:length])
else:
return (<char *>data)[:length]
cdef int convert_python_value(self, object value, BYTE **dbValue,
int *dbtype, int *length) except -1:
log("_mssql.MSSQLConnection.convert_python_value()")
cdef int *intValue
cdef double *dblValue
cdef float *fltValue
cdef PY_LONG_LONG *longValue
cdef char *strValue
cdef char *tmp
cdef BYTE *binValue
cdef DBTYPEINFO decimal_type_info
IF PYMSSQL_DEBUG == 1:
sys.stderr.write("convert_python_value: value = %r; dbtype = %d" % (value, dbtype[0]))
if value is None:
dbValue[0] = <BYTE *>NULL
return 0
if dbtype[0] in (SQLBIT, SQLBITN):
intValue = <int *>PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
dbValue[0] = <BYTE *><DBBIT *>intValue
return 0
if dbtype[0] == SQLINTN:
dbtype[0] = SQLINT4
if dbtype[0] in (SQLINT1, SQLINT2, SQLINT4):
if value > MAX_INT:
raise MSSQLDriverException('value cannot be larger than %d' % MAX_INT)
elif value < MIN_INT:
raise MSSQLDriverException('value cannot be smaller than %d' % MIN_INT)
intValue = <int *>PyMem_Malloc(sizeof(int))
intValue[0] = <int>value
if dbtype[0] == SQLINT1:
dbValue[0] = <BYTE *><DBTINYINT *>intValue
return 0
if dbtype[0] == SQLINT2:
dbValue[0] = <BYTE *><DBSMALLINT *>intValue
return 0
if dbtype[0] == SQLINT4:
dbValue[0] = <BYTE *><DBINT *>intValue
return 0
if dbtype[0] == SQLINT8:
longValue = <PY_LONG_LONG *>PyMem_Malloc(sizeof(PY_LONG_LONG))
longValue[0] = <PY_LONG_LONG>value
dbValue[0] = <BYTE *>longValue
return 0
if dbtype[0] in (SQLFLT4, SQLREAL):
fltValue = <float *>PyMem_Malloc(sizeof(float))
fltValue[0] = <float>value
dbValue[0] = <BYTE *><DBREAL *>fltValue
return 0
if dbtype[0] == SQLFLT8:
dblValue = <double *>PyMem_Malloc(sizeof(double))
dblValue[0] = <double>value
dbValue[0] = <BYTE *><DBFLT8 *>dblValue
return 0
if dbtype[0] in (SQLDATETIM4, SQLDATETIME):
if type(value) not in (datetime.date, datetime.datetime):
raise TypeError('value can only be a date or datetime')
value = value.strftime('%Y-%m-%d %H:%M:%S.') + \
"%03d" % (value.microsecond // 1000)
value = value.encode(self.charset)
dbtype[0] = SQLCHAR
if dbtype[0] in (SQLNUMERIC, SQLDECIMAL):
# There seems to be no harm in setting precision higher than
# necessary
decimal_type_info.precision = 33
# Figure out `scale` - number of digits after decimal point
decimal_type_info.scale = abs(value.as_tuple().exponent)
# Need this to prevent Cython error:
# "Obtaining 'BYTE *' from temporary Python value"
# bytes_value = bytes(str(value), encoding="ascii")
bytes_value = unicode(value).encode("ascii")
decValue = <DBDECIMAL *>PyMem_Malloc(sizeof(DBDECIMAL))
length[0] = dbconvert_ps(
self.dbproc,
SQLCHAR,
bytes_value,
-1,
dbtype[0],
<BYTE *>decValue,
sizeof(DBDECIMAL),
&decimal_type_info,
)
dbValue[0] = <BYTE *>decValue
IF PYMSSQL_DEBUG == 1:
fprintf(stderr, "convert_python_value: Converted value to DBDECIMAL with length = %d\n", length[0])
for i in range(0, 35):
fprintf(stderr, "convert_python_value: dbValue[0][%d] = %d\n", i, dbValue[0][i])
return 0
if dbtype[0] in (SQLMONEY, SQLMONEY4, SQLNUMERIC, SQLDECIMAL):
if type(value) in (int, long, bytes):
value = decimal.Decimal(value)
if type(value) not in (decimal.Decimal, float):
raise TypeError('value can only be a Decimal')
value = str(value)
dbtype[0] = SQLCHAR
if dbtype[0] in (SQLVARCHAR, SQLCHAR, SQLTEXT):
if not hasattr(value, 'startswith'):
raise TypeError('value must be a string type')
if strlen(self._charset) > 0 and type(value) is unicode:
value = value.encode(self.charset)
strValue = <char *>PyMem_Malloc(len(value) + 1)
tmp = value
strncpy(strValue, tmp, len(value) + 1)
strValue[ len(value) ] = '\0';
dbValue[0] = <BYTE *>strValue
return 0
if dbtype[0] in (SQLBINARY, SQLVARBINARY, SQLIMAGE):
if type(value) is not str:
raise TypeError('value can only be str')
binValue = <BYTE *>PyMem_Malloc(len(value))
memcpy(binValue, <char *>value, len(value))
length[0] = len(value)
dbValue[0] = <BYTE *>binValue
return 0
if dbtype[0] == SQLUUID:
binValue = <BYTE *>PyMem_Malloc(16)
memcpy(binValue, <char *>value.bytes_le, 16)
length[0] = 16
dbValue[0] = <BYTE *>binValue
return 0
# No conversion was possible so raise an error
raise MSSQLDriverException('Unable to convert value')
cpdef execute_non_query(self, query_string, params=None):
"""
execute_non_query(query_string, params=None)
This method sends a query to the MS SQL Server to which this object
instance is connected. After completion, its results (if any) are
discarded. An exception is raised on failure. If there are any pending
results or rows prior to executing this command, they are silently
discarded. This method accepts Python formatting. Please see
execute_query() for more details.
This method is useful for INSERT, UPDATE, DELETE and for Data
Definition Language commands, i.e. when you need to alter your database
schema.
After calling this method, rows_affected property contains number of
rows affected by the last SQL command.
"""
log("_mssql.MSSQLConnection.execute_non_query() BEGIN")
cdef RETCODE rtc
self.format_and_run_query(query_string, params)
with nogil:
dbresults(self.dbproc)
self._rows_affected = dbcount(self.dbproc)
rtc = db_cancel(self)
check_and_raise(rtc, self)
log("_mssql.MSSQLConnection.execute_non_query() END")
cpdef execute_query(self, query_string, params=None):
"""
execute_query(query_string, params=None)
This method sends a query to the MS SQL Server to which this object
instance is connected. An exception is raised on failure. If there
are pending results or rows prior to executing this command, they
are silently discarded. After calling this method you may iterate
over the connection object to get rows returned by the query.
You can use Python formatting here and all values get properly
quoted:
conn.execute_query('SELECT * FROM empl WHERE id=%d', 13)
conn.execute_query('SELECT * FROM empl WHERE id IN (%s)', ((5,6),))
conn.execute_query('SELECT * FROM empl WHERE name=%s', 'John Doe')
conn.execute_query('SELECT * FROM empl WHERE name LIKE %s', 'J%')
conn.execute_query('SELECT * FROM empl WHERE name=%(name)s AND \
city=%(city)s', { 'name': 'John Doe', 'city': 'Nowhere' } )
conn.execute_query('SELECT * FROM cust WHERE salesrep=%s \
AND id IN (%s)', ('John Doe', (1,2,3)))
conn.execute_query('SELECT * FROM empl WHERE id IN (%s)',\
(tuple(xrange(4)),))
conn.execute_query('SELECT * FROM empl WHERE id IN (%s)',\
(tuple([3,5,7,11]),))
This method is intented to be used on queries that return results,
i.e. SELECT. After calling this method AND reading all rows from,
result rows_affected property contains number of rows returned by
last command (this is how MS SQL returns it).
"""
log("_mssql.MSSQLConnection.execute_query() BEGIN")
self.format_and_run_query(query_string, params)
self.get_result()
log("_mssql.MSSQLConnection.execute_query() END")
cpdef execute_row(self, query_string, params=None):
"""
execute_row(query_string, params=None)
This method sends a query to the MS SQL Server to which this object
instance is connected, then returns first row of data from result.
An exception is raised on failure. If there are pending results or
rows prior to executing this command, they are silently discarded.
This method accepts Python formatting. Please see execute_query()
for details.
This method is useful if you want just a single row and don't want
or don't need to iterate, as in:
conn.execute_row('SELECT * FROM employees WHERE id=%d', 13)
This method works exactly the same as 'iter(conn).next()'. Remaining
rows, if any, can still be iterated after calling this method.
"""
log("_mssql.MSSQLConnection.execute_row()")
self.format_and_run_query(query_string, params)
return self.fetch_next_row(0, ROW_FORMAT_DICT)
cpdef execute_scalar(self, query_string, params=None):
"""
execute_scalar(query_string, params=None)
This method sends a query to the MS SQL Server to which this object
instance is connected, then returns first column of first row from
result. An exception is raised on failure. If there are pending
results or rows prior to executing this command, they are silently
discarded.
This method accepts Python formatting. Please see execute_query()
for details.
This method is useful if you want just a single value, as in:
conn.execute_scalar('SELECT COUNT(*) FROM employees')
This method works in the same way as 'iter(conn).next()[0]'.
Remaining rows, if any, can still be iterated after calling this
method.
"""
cdef RETCODE rtc
log("_mssql.MSSQLConnection.execute_scalar()")
self.format_and_run_query(query_string, params)
self.get_result()
with nogil:
rtc = dbnextrow(self.dbproc)
self._rows_affected = dbcount(self.dbproc)
if rtc == NO_MORE_ROWS:
self.clear_metadata()
self.last_dbresults = 0
return None
return self.get_row(rtc, ROW_FORMAT_TUPLE)[0]
cdef fetch_next_row(self, int throw, int row_format):
cdef RETCODE rtc
log("_mssql.MSSQLConnection.fetch_next_row() BEGIN")
try:
self.get_result()
if self.last_dbresults == NO_MORE_RESULTS:
log("_mssql.MSSQLConnection.fetch_next_row(): NO MORE RESULTS")
self.clear_metadata()
if throw:
raise StopIteration
return None
with nogil:
rtc = dbnextrow(self.dbproc)
check_cancel_and_raise(rtc, self)
if rtc == NO_MORE_ROWS:
log("_mssql.MSSQLConnection.fetch_next_row(): NO MORE ROWS")
self.clear_metadata()
# 'rows_affected' is nonzero only after all records are read
self._rows_affected = dbcount(self.dbproc)
if throw:
raise StopIteration
return None
return self.get_row(rtc, row_format)
finally:
log("_mssql.MSSQLConnection.fetch_next_row() END")
cdef format_and_run_query(self, query_string, params=None):
"""
This is a helper function, which does most of the work needed by any
execute_*() function. It returns NULL on error, None on success.
"""
cdef RETCODE rtc
# For Python 3, we need to convert unicode to byte strings
cdef bytes query_string_bytes
cdef char *query_string_cstr
log("_mssql.MSSQLConnection.format_and_run_query() BEGIN")
try:
# Cancel any pending results
self.cancel()
if params:
query_string = self.format_sql_command(query_string, params)
# For Python 3, we need to convert unicode to byte strings
query_string_bytes = ensure_bytes(query_string, self.charset)
query_string_cstr = query_string_bytes
log(query_string_cstr)
if self.debug_queries:
sys.stderr.write("#%s#\n" % query_string)
# Prepare the query buffer
dbcmd(self.dbproc, query_string_cstr)
# Execute the query
rtc = db_sqlexec(self.dbproc)
check_cancel_and_raise(rtc, self)
finally:
log("_mssql.MSSQLConnection.format_and_run_query() END")
cdef format_sql_command(self, format, params=None):
log("_mssql.MSSQLConnection.format_sql_command()")
return _substitute_params(format, params, self.charset)
def get_header(self):
"""
get_header() -- get the Python DB-API compliant header information.
This method is infrastructure and doesn't need to be called by your
code. It returns a list of 7-element tuples describing the current
result header. Only name and DB-API compliant type is filled, rest
of the data is None, as permitted by the specs.
"""
cdef int col
log("_mssql.MSSQLConnection.get_header() BEGIN")
try:
self.get_result()
if self.num_columns == 0:
log("_mssql.MSSQLConnection.get_header(): num_columns == 0")
return None
header_tuple = []
for col in xrange(1, self.num_columns + 1):
col_name = self.column_names[col - 1]
col_type = self.column_types[col - 1]
header_tuple.append((col_name, col_type, None, None, None, None, None))
return tuple(header_tuple)
finally:
log("_mssql.MSSQLConnection.get_header() END")
def get_iterator(self, int row_format):
"""
get_iterator(row_format) -- allows the format of the iterator to be specified
While the iter(conn) call will always return a dictionary, this
method allows the return type of the row to be specified.
"""
assert_connected(self)
clr_err(self)
return MSSQLRowIterator(self, row_format)
cdef get_result(self):
cdef int coltype
cdef char log_message[200]
log("_mssql.MSSQLConnection.get_result() BEGIN")
try:
if self.last_dbresults:
log("_mssql.MSSQLConnection.get_result(): last_dbresults == True, return None")
return None
self.clear_metadata()
# Since python doesn't have a do/while loop do it this way
while True:
with nogil:
self.last_dbresults = dbresults(self.dbproc)
self.num_columns = dbnumcols(self.dbproc)
if self.last_dbresults != SUCCEED or self.num_columns > 0:
break
check_cancel_and_raise(self.last_dbresults, self)
self._rows_affected = dbcount(self.dbproc)
if self.last_dbresults == NO_MORE_RESULTS:
self.num_columns = 0
log("_mssql.MSSQLConnection.get_result(): NO_MORE_RESULTS, return None")
return None
self.num_columns = dbnumcols(self.dbproc)
snprintf(log_message, sizeof(log_message), "_mssql.MSSQLConnection.get_result(): num_columns = %d", self.num_columns)
log_message[ sizeof(log_message) - 1 ] = '\0'
log(log_message)
column_names = list()
column_types = list()
for col in xrange(1, self.num_columns + 1):
col_name = dbcolname(self.dbproc, col)
if not col_name:
self.num_columns -= 1
return None
column_name = col_name.decode(self._charset)
column_names.append(column_name)
coltype = dbcoltype(self.dbproc, col)
column_types.append(get_api_coltype(coltype))
self.column_names = tuple(column_names)
self.column_types = tuple(column_types)
finally:
log("_mssql.MSSQLConnection.get_result() END")
cdef get_row(self, int row_info, int row_format):
cdef DBPROCESS *dbproc = self.dbproc
cdef int col
cdef int col_type
cdef int len
cdef BYTE *data
cdef tuple trecord
cdef dict drecord
log("_mssql.MSSQLConnection.get_row()")
if PYMSSQL_DEBUG == 1:
global _row_count
_row_count += 1
if row_format == _ROW_FORMAT_TUPLE:
trecord = PyTuple_New(self.num_columns)
elif row_format == _ROW_FORMAT_DICT:
drecord = dict()
for col in xrange(1, self.num_columns + 1):
with nogil:
data = get_data(dbproc, row_info, col)
col_type = get_type(dbproc, row_info, col)
len = get_length(dbproc, row_info, col)
if data == NULL:
value = None
else:
IF PYMSSQL_DEBUG == 1:
global _row_count
fprintf(stderr, 'Processing row %d, column %d,' \
'Got data=%x, coltype=%d, len=%d\n', _row_count, col,
data, col_type, len)
value = self.convert_db_value(data, col_type, len)
if row_format == _ROW_FORMAT_TUPLE:
Py_INCREF(value)
PyTuple_SetItem(trecord, col - 1, value)
elif row_format == _ROW_FORMAT_DICT:
name = self.column_names[col - 1]
drecord[col - 1] = value
if name:
drecord[name] = value
if row_format == _ROW_FORMAT_TUPLE:
return trecord
elif row_format == _ROW_FORMAT_DICT:
return drecord
def init_procedure(self, procname):
"""
init_procedure(procname) -- creates and returns a MSSQLStoredProcedure
object.
This methods initializes a stored procedure or function on the server
and creates a MSSQLStoredProcedure object that allows parameters to
be bound.
"""
log("_mssql.MSSQLConnection.init_procedure()")
return MSSQLStoredProcedure(procname.encode(self.charset), self)
def nextresult(self):
"""
nextresult() -- move to the next result, skipping all pending rows.
This method fetches and discards any rows remaining from the current
resultset, then it advances to the next (if any) resultset. Returns
True if the next resultset is available, otherwise None.
"""
cdef RETCODE rtc
log("_mssql.MSSQLConnection.nextresult()")
assert_connected(self)
clr_err(self)
rtc = dbnextrow(self.dbproc)
check_cancel_and_raise(rtc, self)
while rtc != NO_MORE_ROWS:
rtc = dbnextrow(self.dbproc)
check_cancel_and_raise(rtc, self)
self.last_dbresults = 0
self.get_result()
if self.last_dbresults != NO_MORE_RESULTS:
return 1
def select_db(self, dbname):
"""
select_db(dbname) -- Select the current database.
This function selects the given database. An exception is raised on
failure.
"""
cdef RETCODE rtc
log("_mssql.MSSQLConnection.select_db()")
# For Python 3, we need to convert unicode to byte strings
cdef bytes dbname_bytes = dbname.encode('ascii')
cdef char *dbname_cstr = dbname_bytes
dbuse(self.dbproc, dbname_cstr)
##################################
## MSSQL Stored Procedure Class ##
##################################
cdef class MSSQLStoredProcedure:
property connection:
"""The underlying MSSQLConnection object."""
def __get__(self):
return self.conn
property name:
"""The name of the procedure that this object represents."""
def __get__(self):
return self.procname
property parameters:
"""The parameters that have been bound to this procedure."""
def __get__(self):
return self.params
def __init__(self, bytes name, MSSQLConnection connection):
cdef RETCODE rtc
log("_mssql.MSSQLStoredProcedure.__init__()")
# We firstly want to check if tdsver is >= 7 as anything less
# doesn't support remote procedure calls.
if connection.tds_version < 7:
raise MSSQLDriverException("Stored Procedures aren't "
"supported with a TDS version less than 7.")
self.conn = connection
self.dbproc = connection.dbproc
self.procname = name
self.params = dict()
self.output_indexes = list()
self.param_count = 0
self.had_positional = False
with nogil:
rtc = dbrpcinit(self.dbproc, self.procname, 0)
check_cancel_and_raise(rtc, self.conn)
def __dealloc__(self):
cdef _mssql_parameter_node *n
cdef _mssql_parameter_node *p
log("_mssql.MSSQLStoredProcedure.__dealloc__()")
n = self.params_list
p = NULL
while n != NULL:
PyMem_Free(n.value)
p = n
n = n.next
PyMem_Free(p)
def bind(self, object value, int dbtype, str param_name=None,
int output=False, int null=False, int max_length=-1):
"""
bind(value, data_type, param_name = None, output = False,
null = False, max_length = -1) -- bind a parameter
This method binds a parameter to the stored procedure.
"""
cdef int length = -1
cdef RETCODE rtc
cdef BYTE status
cdef BYTE *data
cdef bytes param_name_bytes
cdef char *param_name_cstr
cdef _mssql_parameter_node *pn
log("_mssql.MSSQLStoredProcedure.bind()")
# Set status according to output being True or False
status = DBRPCRETURN if output else <BYTE>0
# Convert the PyObject to the db type
self.conn.convert_python_value(value, &data, &dbtype, &length)
# We support nullable parameters by just not binding them
if dbtype in (SQLINTN, SQLBITN) and data == NULL:
return
# Store the converted parameter in our parameter list so we can
# free() it later.
if data != NULL:
pn = <_mssql_parameter_node *>PyMem_Malloc(sizeof(_mssql_parameter_node))
if pn == NULL:
raise MSSQLDriverException('Out of memory')
pn.next = self.params_list
pn.value = data
self.params_list = pn
# We may need to set the data length depending on the type being
# passed to the server here.
if dbtype in (SQLVARCHAR, SQLCHAR, SQLTEXT, SQLBINARY,
SQLVARBINARY, SQLIMAGE):
if null or data == NULL:
length = 0
if not output:
max_length = -1
# only set the length for strings, binary may contain NULLs
elif dbtype in (SQLVARCHAR, SQLCHAR, SQLTEXT):
length = strlen(<char *>data)
else:
# Fixed length data type
if null or (output and dbtype not in (SQLDECIMAL, SQLNUMERIC)):
length = 0
max_length = -1
# Add some monkey fixing for nullable bit types
if dbtype == SQLBITN:
if output:
max_length = 1
length = 0
else:
length = 1
if status != DBRPCRETURN:
max_length = -1
if param_name:
param_name_bytes = param_name.encode('ascii')
param_name_cstr = param_name_bytes
if self.had_positional:
raise MSSQLDriverException('Cannot bind named parameter after positional')
else:
param_name_cstr = ''
self.had_positional = True
IF PYMSSQL_DEBUG == 1:
sys.stderr.write(
"\n--- rpc_bind(name = '%s', status = %d, "
"max_length = %d, data_type = %d, data_length = %d\n"
% (param_name, status, max_length, dbtype, length)
)
with nogil:
rtc = dbrpcparam(self.dbproc, param_name_cstr, status, dbtype,
max_length, length, data)
check_cancel_and_raise(rtc, self.conn)
# Store the value in the parameters dictionary for returning
# later, by name if that has been supplied.
if param_name:
self.params[param_name] = value
self.params[self.param_count] = value
if output:
self.output_indexes.append(self.param_count)
self.param_count += 1
def execute(self):
cdef RETCODE rtc
cdef int output_count, i, type, length
cdef char *param_name_bytes
cdef BYTE *data
log("_mssql.MSSQLStoredProcedure.execute()")
# Cancel any pending results as this throws a server error
# otherwise.
db_cancel(self.conn)
# Send the RPC request
with nogil:
rtc = dbrpcsend(self.dbproc)
check_cancel_and_raise(rtc, self.conn)
# Wait for results to come back and return the return code, optionally
# calling wait_callback first...
rtc = db_sqlok(self.dbproc)
check_cancel_and_raise(rtc, self.conn)
# Need to call this regardless of whether or not there are output
# parameters in order for the return status to be correct.
output_count = dbnumrets(self.dbproc)
# If there are any output parameters then we are going to want to
# set the values in the parameters dictionary.
if output_count:
for i in xrange(1, output_count + 1):
with nogil:
type = dbrettype(self.dbproc, i)
param_name_bytes = dbretname(self.dbproc, i)
length = dbretlen(self.dbproc, i)
data = dbretdata(self.dbproc, i)
value = self.conn.convert_db_value(data, type, length)
if strlen(param_name_bytes):
param_name = param_name_bytes.decode('utf-8')
self.params[param_name] = value
self.params[self.output_indexes[i-1]] = value
# Get the return value from the procedure ready for return.
return dbretstatus(self.dbproc)
cdef int check_and_raise(RETCODE rtc, MSSQLConnection conn) except 1:
if rtc == FAIL:
return maybe_raise_MSSQLDatabaseException(conn)
elif get_last_msg_str(conn):
return maybe_raise_MSSQLDatabaseException(conn)
cdef int check_cancel_and_raise(RETCODE rtc, MSSQLConnection conn) except 1:
if rtc == FAIL:
db_cancel(conn)
return maybe_raise_MSSQLDatabaseException(conn)
elif get_last_msg_str(conn):
return maybe_raise_MSSQLDatabaseException(conn)
cdef char *get_last_msg_str(MSSQLConnection conn):
return conn.last_msg_str if conn != None else _mssql_last_msg_str
cdef char *get_last_msg_srv(MSSQLConnection conn):
return conn.last_msg_srv if conn != None else _mssql_last_msg_srv
cdef char *get_last_msg_proc(MSSQLConnection conn):
return conn.last_msg_proc if conn != None else _mssql_last_msg_proc
cdef int get_last_msg_no(MSSQLConnection conn):
return conn.last_msg_no if conn != None else _mssql_last_msg_no
cdef int get_last_msg_severity(MSSQLConnection conn):
return conn.last_msg_severity if conn != None else _mssql_last_msg_severity
cdef int get_last_msg_state(MSSQLConnection conn):
return conn.last_msg_state if conn != None else _mssql_last_msg_state
cdef int get_last_msg_line(MSSQLConnection conn):
return conn.last_msg_line if conn != None else _mssql_last_msg_line
cdef int maybe_raise_MSSQLDatabaseException(MSSQLConnection conn) except 1:
if get_last_msg_severity(conn) < min_error_severity:
return 0
error_msg = get_last_msg_str(conn)
if len(error_msg) == 0:
error_msg = b"Unknown error"
ex = MSSQLDatabaseException((get_last_msg_no(conn), error_msg))
(<MSSQLDatabaseException>ex).text = error_msg
(<MSSQLDatabaseException>ex).srvname = get_last_msg_srv(conn)
(<MSSQLDatabaseException>ex).procname = get_last_msg_proc(conn)
(<MSSQLDatabaseException>ex).number = get_last_msg_no(conn)
(<MSSQLDatabaseException>ex).severity = get_last_msg_severity(conn)
(<MSSQLDatabaseException>ex).state = get_last_msg_state(conn)
(<MSSQLDatabaseException>ex).line = get_last_msg_line(conn)
db_cancel(conn)
clr_err(conn)
raise ex
cdef void assert_connected(MSSQLConnection conn) except *:
log("_mssql.assert_connected()")
if not conn.connected:
raise MSSQLDriverException("Not connected to any MS SQL server")
cdef inline BYTE *get_data(DBPROCESS *dbproc, int row_info, int col) nogil:
return dbdata(dbproc, col) if row_info == REG_ROW else \
dbadata(dbproc, row_info, col)
cdef inline int get_type(DBPROCESS *dbproc, int row_info, int col) nogil:
return dbcoltype(dbproc, col) if row_info == REG_ROW else \
dbalttype(dbproc, row_info, col)
cdef inline int get_length(DBPROCESS *dbproc, int row_info, int col) nogil:
return dbdatlen(dbproc, col) if row_info == REG_ROW else \
dbadlen(dbproc, row_info, col)
######################
## Helper Functions ##
######################
cdef int get_api_coltype(int coltype):
if coltype in (SQLBIT, SQLINT1, SQLINT2, SQLINT4, SQLINT8, SQLINTN,
SQLFLT4, SQLFLT8, SQLFLTN):
return NUMBER
elif coltype in (SQLMONEY, SQLMONEY4, SQLMONEYN, SQLNUMERIC,
SQLDECIMAL):
return DECIMAL
elif coltype in (SQLDATETIME, SQLDATETIM4, SQLDATETIMN):
return DATETIME
elif coltype in (SQLVARCHAR, SQLCHAR, SQLTEXT):
return STRING
else:
return BINARY
cdef char *_remove_locale(char *s, size_t buflen):
cdef char c
cdef char *stripped = s
cdef int i, x = 0, last_sep = -1
for i, c in enumerate(s[0:buflen]):
if c in (',', '.'):
last_sep = i
for i, c in enumerate(s[0:buflen]):
if (c >= '0' and c <= '9') or c in ('+', '-'):
stripped[x] = c
x += 1
elif i == last_sep:
stripped[x] = c
x += 1
stripped[x] = 0
return stripped
def remove_locale(bytes value):
cdef char *s = <char*>value
cdef size_t l = strlen(s)
return _remove_locale(s, l)
cdef int _tds_ver_str_to_constant(verstr) except -1:
"""
http://www.freetds.org/userguide/choosingtdsprotocol.htm
"""
if verstr == u'4.2':
return DBVERSION_42
if verstr == u'7.0':
return DBVERSION_70
if verstr == u'7.1':
return DBVERSION_71
if verstr == u'7.2':
return DBVERSION_72
if verstr == u'8.0':
return DBVERSION_80
raise MSSQLException('unrecognized tds version: %s' % verstr)
#######################
## Quoting Functions ##
#######################
cdef _quote_simple_value(value, charset='utf8'):
if value == None:
return b'NULL'
if isinstance(value, bool):
return '1' if value else '0'
if isinstance(value, float):
return repr(value).encode(charset)
if isinstance(value, (int, long, decimal.Decimal)):
return str(value).encode(charset)
if isinstance(value, uuid.UUID):
return _quote_simple_value(str(value))
if isinstance(value, unicode):
return ("N'" + value.replace("'", "''") + "'").encode(charset)
if isinstance(value, bytearray):
return b'0x' + binascii.hexlify(bytes(value))
if isinstance(value, (str, bytes)):
# see if it can be decoded as ascii if there are no null bytes
if b'\0' not in value:
try:
value.decode('ascii')
return b"'" + value.replace(b"'", b"''") + b"'"
except UnicodeDecodeError:
pass
# Python 3: handle bytes
# @todo - Marc - hack hack hack
if isinstance(value, bytes):
return b'0x' + binascii.hexlify(value)
# will still be string type if there was a null byte in it or if the
# decoding failed. In this case, just send it as hex.
if isinstance(value, str):
return '0x' + value.encode('hex')
if isinstance(value, datetime.datetime):
return "{ts '%04d-%02d-%02d %02d:%02d:%02d.%03d'}" % (
value.year, value.month, value.day,
value.hour, value.minute, value.second,
value.microsecond / 1000)
if isinstance(value, datetime.date):
return "{d '%04d-%02d-%02d'} " % (
value.year, value.month, value.day)
return None
cdef _quote_or_flatten(data, charset='utf8'):
result = _quote_simple_value(data, charset)
if result is not None:
return result
if not issubclass(type(data), (list, tuple)):
raise ValueError('expected a simple type, a tuple or a list')
quoted = []
for value in data:
value = _quote_simple_value(value, charset)
if value is None:
raise ValueError('found an unsupported type')
quoted.append(value)
return b'(' + b','.join(quoted) + b')'
# This function is supposed to take a simple value, tuple or dictionary,
# normally passed in via the params argument in the execute_* methods. It
# then quotes and flattens the arguments and returns then.
cdef _quote_data(data, charset='utf8'):
result = _quote_simple_value(data)
if result is not None:
return result
if issubclass(type(data), dict):
result = {}
for k, v in data.iteritems():
result[k] = _quote_or_flatten(v, charset)
return result
if issubclass(type(data), tuple):
result = []
for v in data:
result.append(_quote_or_flatten(v, charset))
return tuple(result)
raise ValueError('expected a simple type, a tuple or a dictionary.')
_re_pos_param = re.compile(br'(%([sd]))')
_re_name_param = re.compile(br'(%\(([^\)]+)\)(?:[sd]))')
cdef _substitute_params(toformat, params, charset):
if params is None:
return toformat
if not issubclass(type(params),
(bool, int, long, float, unicode, str, bytes, bytearray, dict, tuple,
datetime.datetime, datetime.date, dict, decimal.Decimal, uuid.UUID)):
raise ValueError("'params' arg (%r) can be only a tuple or a dictionary." % type(params))
if charset:
quoted = _quote_data(params, charset)
else:
quoted = _quote_data(params)
# positional string substitution now requires a tuple
if hasattr(quoted, 'startswith'):
quoted = (quoted,)
if isinstance(toformat, unicode):
toformat = toformat.encode(charset)
if isinstance(params, dict):
""" assume name based substitutions """
offset = 0
for match in _re_name_param.finditer(toformat):
param_key = match.group(2).decode(charset)
if not param_key in params:
raise ValueError('params dictionary did not contain value for placeholder: %s' % param_key)
# calculate string positions so we can keep track of the offset to
# be used in future substitutions on this string. This is
# necessary b/c the match start() and end() are based on the
# original string, but we modify the original string each time we
# loop, so we need to make an adjustment for the difference between
# the length of the placeholder and the length of the value being
# substituted
param_val = quoted[param_key]
param_val_len = len(param_val)
placeholder_len = len(match.group(1))
offset_adjust = param_val_len - placeholder_len
# do the string substitution
match_start = match.start(1) + offset
match_end = match.end(1) + offset
toformat = toformat[:match_start] + ensure_bytes(param_val) + toformat[match_end:]
# adjust the offset for the next usage
offset += offset_adjust
else:
""" assume position based substitutions """
offset = 0
for count, match in enumerate(_re_pos_param.finditer(toformat)):
# calculate string positions so we can keep track of the offset to
# be used in future substitutions on this string. This is
# necessary b/c the match start() and end() are based on the
# original string, but we modify the original string each time we
# loop, so we need to make an adjustment for the difference between
# the length of the placeholder and the length of the value being
# substituted
try:
param_val = quoted[count]
except IndexError:
raise ValueError('more placeholders in sql than params available')
param_val_len = len(param_val)
placeholder_len = 2
offset_adjust = param_val_len - placeholder_len
# do the string substitution
match_start = match.start(1) + offset
match_end = match.end(1) + offset
toformat = toformat[:match_start] + ensure_bytes(param_val) + toformat[match_end:]
#print(param_val, param_val_len, offset_adjust, match_start, match_end)
# adjust the offset for the next usage
offset += offset_adjust
return toformat
# We'll add these methods to the module to allow for unit testing of the
# underlying C methods.
def quote_simple_value(value):
return _quote_simple_value(value)
def quote_or_flatten(data):
return _quote_or_flatten(data)
def quote_data(data):
return _quote_data(data)
def substitute_params(toformat, params, charset='utf8'):
return _substitute_params(toformat, params, charset)
###########################
## Compatibility Aliases ##
###########################
def connect(*args, **kwargs):
return MSSQLConnection(*args, **kwargs)
MssqlDatabaseException = MSSQLDatabaseException
MssqlDriverException = MSSQLDriverException
MssqlConnection = MSSQLConnection
###########################
## Test Helper Functions ##
###########################
def test_err_handler(connection, int severity, int dberr, int oserr, dberrstr, oserrstr):
"""
Expose err_handler function and its side effects to facilitate testing.
"""
cdef DBPROCESS *dbproc = NULL
cdef char *dberrstrc = NULL
cdef char *oserrstrc = NULL
if dberrstr:
dberrstr_byte_string = dberrstr.encode('UTF-8')
dberrstrc = dberrstr_byte_string
if oserrstr:
oserrstr_byte_string = oserrstr.encode('UTF-8')
oserrstrc = oserrstr_byte_string
if connection:
dbproc = (<MSSQLConnection>connection).dbproc
results = (
err_handler(dbproc, severity, dberr, oserr, dberrstrc, oserrstrc),
get_last_msg_str(connection),
get_last_msg_no(connection),
get_last_msg_severity(connection),
get_last_msg_state(connection)
)
clr_err(connection)
return results
#####################
## Max Connections ##
#####################
def get_max_connections():
"""
Get maximum simultaneous connections db-lib will open to the server.
"""
return dbgetmaxprocs()
def set_max_connections(int limit):
"""
Set maximum simultaneous connections db-lib will open to the server.
:param limit: the connection limit
:type limit: int
"""
dbsetmaxprocs(limit)
cdef void init_mssql():
if dbinit() == FAIL:
raise MSSQLDriverException("dbinit() failed")
dberrhandle(err_handler)
dbmsghandle(msg_handler)
init_mssql()

浙公网安备 33010602011771号