Design By Contract in python

Arthur Gordon arthur.gordon at thales-esecurity.com
Tue Jul 24 06:23:10 EDT 2001


An implementation of Design by Contract in Python, this uses the
document strings to hold the pre and post conditions as per the paper
http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html

Enjoy

-Arthur

--------------- cut here -----------------------------
"""FILENAME:
        DesignByContract.py
   DESCRIPTION:
        An implementation of Design by Contract in Python 2.1 
        based on a paper by R. Plosch @ Johannes Kepler Universitiat
Linz.
        which can be found at:
http://www.swe.uni-linz.ac.at/publications/abstract/TR-SE-97.24.html
   AUTHOR:
       Arthur Gordon <arthur.gordon at thales-esecurity.com>
   HISTORY:
       18-Jul-01: Added support for old and param namespace.
       17-Jul-01: Added PrintDbCTrace - which prints out the stack
upto the DbC classes, no further.
       12-Jul-01: First attempt.
"""
import types
import sys
import traceback
import copy

if sys.hexversion <0x200000:
    print 'DesignByContract.py requires Python version 2.0 or
greater.'
    sys.exit(0)

KEYWORD_INVARIANT   = 'invar:'
KEYWORD_TYPE        = 'type:'
KEYWORD_PARAMETER   = 'param:'
KEYWORD_REQUIRE     = 'require:'
KEYWORD_ENSURE      = 'ensure:'

LEN_KEYWORD_INVARIANT   = len(KEYWORD_INVARIANT)
LEN_KEYWORD_TYPE        = len(KEYWORD_TYPE)
LEN_KEYWORD_PARAMETER   = len(KEYWORD_PARAMETER)
LEN_KEYWORD_REQUIRE     = len(KEYWORD_REQUIRE)
LEN_KEYWORD_ENSURE      = len(KEYWORD_ENSURE)

TRUE = 1==1

""" Used for String to Type conversion.
    asType ={   'NoneType':types.NoneType, ...}
"""
asType ={}
for ty in dir(types):                                           #
Create table of types
    if type(eval('types.'+ty)) == types.TypeType:
        asType[ty] = eval('types.'+ty)


def checkType(param, typeAsString):
    """ Checks the param against the string description of the type.
        Special case any accepts all types.
    """
    if typeAsString == 'AnyType':           # accept any
        return TRUE        
    else:
        return type(param) == asType[typeAsString]           # Check
type is correct
""" Violation handling
    --------------------------------------------------------------------------
    Examples for catching exceptions.

    try:
       TestStack()
        pass
    except DbCFormatError,msg:
        PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg),"FORMAT
ERROR:")
    except DbCViolationError,msg:
        PrintDbCTrace(traceback.extract_tb(sys.exc_info()[2]),str(msg))


"""
class DbCViolationError(Exception):
    """ Error: Contract Violation """
    pass

class DbCFormatError(Exception):
    """ Error: Incorrect Contract Format """
    pass


def FormatDbCTrace(trace, line, msg= "VIOLATION:"):
    """ Formats the trace without going back into the DbC classes """
    retval = []
    retval.append('='*60)
    retval.append(msg+line)
    retval.append('='*60)
    FIELD_FILENAME  = 0
    FIELD_LINENO    = 1
    trace_slice = 0                             # Find where trace
back to.
    while not trace[trace_slice][FIELD_FILENAME].endswith('DesignByContract.py')
or\
              trace[trace_slice][FIELD_LINENO] >DBC_LAST_LINE:
        trace_slice += 1                                
    retval.extend( traceback.format_list(trace[0:trace_slice]))
    return retval

def PrintDbCTrace(trace, msg, str = "VIOLATION:"):
    for line in  FormatDbCTrace(trace,msg,str):
        print line

def LogDbCTrace(trace, line, fname='violation.txt'):
    f = open(fname,'w+')
    f.writelines(FormatDbCTrace(trace,line))
    f.close()    

def contractViolationBehaviour(object,line):
    """	Options write to screen 
            write to log file
            raise Exception
    """
    PrintDbCTrace(traceback.extract_stack(),line)  # Pass the stack
trace
    #raise DbCViolationError(line)                            # Raise
exception
    #LogDbCTrace(traceback.extract_stack(),line)    # Write to log
file
    



""" --------------------------------------------------------------------------
"""
def getParameters(orignal_line):
    line = orignal_line
    comment_start = line.find('#')        
    if comment_start != -1:                                   #remove
trailing comments
        line = line[:comment_start]
    line =line.split(',')                           # split by ,
    line = map(lambda x:str(x).split(':'),line)     # split by :
    for pair in line:
        if len(pair) != 2:                          # should be two
            raise DbCFormatError(orignal_line)
        pair[0] = pair[0].strip()
        pair[1] = pair[1].strip()
    return line

""" print "getParameters(' list: ListType, another:IntType       #
Should work')"
    print getParameters(' list: ListType, another:IntType       #
Should work')
"""



DBC_CLASS_WRAPPER_PRIVATE_VARS =
['_wrapped_object','_violationProc','_docstr']
class DbCClassWrapper:
    """ Wrapper for classes to check Contracts. """
    
    def __init__(self, object, violationProc =
contractViolationBehaviour):
        """ param: object:InstanceType
        
            Note the leading underscores are required on variable
names as we are
            sharing name space with the wrapped object
            We should check the class invaraints as we have just been
called.
        """
        self._wrapped_object = object                               #
Reference to wrapped instance
        self._violationProc = violationProc                         #
Who you going to call?
        self._wrapped_object._dbc_wrapper = self                    #
Attach ourselves to the Wrapped object
        if not self._wrapped_object.__doc__:
            self._docstr = None
        else:
            self._docstr = self._wrapped_object.__doc__.splitlines()  
 # Turn it into a list
            self._docstr = map(lambda x:x.strip(),self._docstr)       
 # remove leading/trailing whitespace
            self.checkClassInvariants()                               
 # Since we did not check at __init__
            
                            
    def __getattr__(self, name):
        """ param: name:StringType """
        if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:		                  
         # Access private vars.
            attribute = self.__dict__[name]
        else:
            attribute = getattr(self._wrapped_object, name)
        if type(attribute) == types.MethodType:			            # If it
is a method...
            attribute = DbCClassWrapper.DbCMethodWrapper(
                           self._wrapped_object, attribute)         #
...Wrap it
        return attribute
    
    def __setattr__(self,name,val):
        """ param: name:StringType, val:AnyType """
        if name in DBC_CLASS_WRAPPER_PRIVATE_VARS:                    
               # Access private vars.
            self.__dict__[name]= val
        else:
            setattr(self._wrapped_object, name, val)
            self.checkClassInvariants()

    def checkClassInvariants(self):        
        """ Check the invariants for the wrapped class
            Protects itself against self.docstr == None
        """
        if self._docstr:    
            for line in self._docstr:
                if line.startswith(KEYWORD_INVARIANT):                
 # line starts invariant:
                    if not eval(line[LEN_KEYWORD_INVARIANT:]):
                        self._violationProc(self._wrapped_object,line)
                            
                elif line.startswith(KEYWORD_TYPE):
                    for p in getParameters(line[LEN_KEYWORD_TYPE:]):
                        arg = getattr(self._wrapped_object, p[0])
                        if not checkType(arg,p[1]):
                           
self._violationProc(self._wrapped_object,line)

    
                    
    class DbCMethodWrapper:         
        """ A Wrapper for methods """
        def __init__(self, object, method):
            """ param: object:InstanceType, method:MethodType  """
            self.wrapped_object = object
            self.method = method
            if method.__doc__:                                      #
has documentation string
                self.docstr = method.__doc__.splitlines()             
 # Turn it into a list
                self.docstr = map(lambda x:x.strip(),self.docstr)     
 # remove leading/trailing whitespace
            else:
                self.docstr = None
            self.violationProc =
self.wrapped_object._dbc_wrapper._violationProc
            self.private_vars = {}
            
        def __call__(self,*args):
            """ param: args:AnyType
                Only bother to call pre and post condition if you have
a doc string for this method
                However we can still check the class invariants at the
end.
            """
            if not self.docstr:
                retval = apply(self.method,args)			                #
Call the method
            else:    
                self.private_vars = self.getWrappedClassPrivateVars() 
     # shallow-copy
                self.checkBefore(args)                                
     # Pre-conditions
                self.copyPrivateVarsToOld()                           
     # Copy members to old
                #self.displayPrivateVars()
                retval = apply(self.method,args)			                #
Call the method

                self.private_vars.update(self.getWrappedClassPrivateVars())
                self.checkAfter()					                        #
Post-conditions
                del self.private_vars
                self.private_vars = {}
            self.wrapped_object._dbc_wrapper.checkClassInvariants()   
     #always check class invars !
            return retval

        def getWrappedClassPrivateVars(self):
            retval = {}
            for key in self.wrapped_object.__dict__.keys():
                if not key in
['docstr','_dbc_wrapper','__builtins__']: # dont whant to copy these
objects
                    retval[key] = self.wrapped_object.__dict__[key]
            return retval
        
        class DbCDictWrapper:
            """ class that allows us to use the old. and para.
notation in the DbC expressions."""
            pass
        
        def copyPrivateVarsToOld(self):
            """ Copy methods private vars to old """    
            if self.private_vars:                                   #
may not have any
                old =
DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
                if self.private_vars.has_key('__builtins__'):
                    del self.private_vars['__builtins__']
                for key in self.private_vars.keys():
                    if not key in ['param']:                        #
dont copy param over
                       
setattr(old,key,copy.deepcopy(self.private_vars[key]))     # deep-copy
                self.private_vars['old'] = old

        def displayPrivateVars(self):
            """ for debug """
            print '-'*60
            for i in self.private_vars.keys():
                if not i in ['__builtins__']:
                    print '[',i,'] =',self.private_vars[i]
                    if i in ['param','old']:
                            print '  ',self.private_vars[i].__dict__
            print '-'*60

        def checkBefore(self,*args):
            """ param: args:AnyType
                Pre-conditions 
                Check param, require, invars
                require: self.docstr != None
            """
            args = args[0]                                          #
unpack args from tuple
            param = DbCClassWrapper.DbCMethodWrapper.DbCDictWrapper()
            self.private_vars['param'] = param
            for line in self.docstr:				                # For each
line...
                if line.startswith(KEYWORD_PARAMETER):
                    parameters =
getParameters(line[LEN_KEYWORD_PARAMETER:])
                    #print 'parameters',parameters
                    for i in range(len(parameters)):                  
         # for each param type pair
                        #if checkType(type(args[i]) !=
asType[parameters[i][1]]:           # Check type is correct
                        if not checkType(args[i],parameters[i][1]):   
       # Check type is correct
                           
self.violationProc(self.wrapped_object,line)
                       
setattr(self.private_vars['param'],parameters[i][0],args[i])        #
assign to params
        
                elif line.startswith(KEYWORD_REQUIRE):
                    if not
eval(line[LEN_KEYWORD_REQUIRE:],self.private_vars):
                        self.violationProc(self.wrapped_object,line)
                
                elif line.startswith(KEYWORD_INVARIANT): 
                    if not eval(line[LEN_KEYWORD_INVARIANT:],
self.private_vars):
                        self.violationProc(self.wrapped_object,line)
            
                
        def checkAfter(self):
            """ Post-conditions 
                This checks the ensure and method invars
                require: self.docstr != None
            """
            for line in self.docstr:
                if line.startswith(KEYWORD_ENSURE):
                    if not eval(line[LEN_KEYWORD_ENSURE:],
self.private_vars):
                        self.displayPrivateVars()
            
                       
self.violationProc(self.wrapped_object,':"'+line+'"')
                
                elif line.startswith(KEYWORD_INVARIANT):
                    if not
eval(line[LEN_KEYWORD_INVARIANT:],self.private_vars):
                       
self.violationProc(self.wrapped_object,':"'+line+'"')
            
            
            
DBC_LAST_LINE = 297     # Don not dump the stack any further than this
line number for DbCErrors.

    
if __name__ == '__main__':
    class Stack:
        """ A simple test class
            invar: list != None                         # Should work
            type: list: ListType                        # Should work
        """
        def __init__(self):
            self.list = []
            
        def ppush(self,val):
            """ param:   val:IntType                        # Should
work
                require: param.val >1                       # Should
work
                invar:  list != None                        # Should
work
                ensure: len(list) == len(old.list)+1        # Should
work
                """
            self.list.append(val)
        
        def ppop(self):
            """ invar:  list != None                        # Should
work
                ensure: len(list) == len(old.list)-1        # Should
work
                """
    
            return self.list.pop()

        def inc(self,val):
            """ param:  val:IntType
                require: another > 1                     # Should pass
                ensure: old.another > 100                     # Should
fail
                ensure: another == old.another + 1                    
# Should pass
                ensure: another == old.another + 10                   
 # Should fail
                """
            self.another = self.another+ 1
               
    import unittest    
    class DesignByContractTests(unittest.TestCase):
        def setUp(self):
            def onContractViolation(object,line):
                raise DbCViolationError(line)               # Raise
exception
            aStack = Stack()                                # Call the
constructor for class
            self.ws = DbCClassWrapper(aStack,onContractViolation)     
   # And wrap it
    	    
        def checkPushValidValue(self):
            self.ws.ppush(10)
            
        def checkPushInValidType(self):
            self.failUnlessRaises(DbCViolationError,self.ws.ppush,'asd')

        def checkPushInValidValue(self):
            self.failUnlessRaises(DbCViolationError,self.ws.ppush,-1)

        def checkPopValid(self):
            self.ws.ppush(10)
            assert self.ws.ppop() == 10
                    
            
    suite = unittest.makeSuite(DesignByContractTests,'check')
    retval = unittest.TextTestRunner().run(suite)



More information about the Python-list mailing list