Logo Search packages:      
Sourcecode: matplotlib version File versions  Download package

trait_handlers.py

#-------------------------------------------------------------------------------
#
#  Define the base 'TraitHandler' class and a standard set of TraitHandler
#  subclasses for use with the 'traits' package.
#
#  A trait handler mediates the assignment of values to object traits. It
#  verifies (through the TraitHandler's 'validate' method) that a specified
#  value is consistent with the object trait, and generates a 'TraitError'
#  exception if not.
#
#  Written by: David C. Morrill
#
#  Date: 06/21/2002
#
#  Refactored into a separate module: 07/04/2003
#
#  Symbols defined: TraitHandler
#                   TraitRange
#                   TraitType
#                   TraitString
#                   TraitInstance
#                   TraitFunction
#                   TraitEnum
#                   TraitPrefixList
#                   TraitMap
#                   TraitPrefixMap
#                   TraitCompound
#                   TraitList
#                   TraitDict
#
#  (c) Copyright 2002, 2003 by Enthought, Inc.
#
#-------------------------------------------------------------------------------
 
#-------------------------------------------------------------------------------
#  Imports:
#-------------------------------------------------------------------------------
           
import sys           
import re
import copy
           
from types        import StringType, InstanceType, TypeType, FunctionType, \
                         MethodType
from ctraits      import CTraitMethod                            
from trait_base   import strx, SequenceTypes, Undefined, TypeTypes, ClassTypes,\
                         CoercableTypes, class_of, enumerate           
from trait_errors import TraitError
                 
# Patched by 'traits.py' once class is defined!     
Trait   = Event = None  
warning = None

#-------------------------------------------------------------------------------
#  Constants:
#-------------------------------------------------------------------------------

RangeTypes    = ( int, float )

CallableTypes = ( FunctionType, MethodType, CTraitMethod )

#-------------------------------------------------------------------------------
#  Forward references:
#-------------------------------------------------------------------------------

trait_from = None  # Patched by 'traits.py' when real 'trait_from' is defined
          
#-------------------------------------------------------------------------------
#  'TraitHandler' class (base class for all trait handlers):
#-------------------------------------------------------------------------------

class TraitHandler ( object ):
    
    default_value_type = -1
    has_items          = False
    is_mapped          = False
    editor             = None
    
    __traits_metadata__ = {
        'type': 'trait'
    }

    def validate ( self, object, name, value ):
        raise TraitError, ( 
              "The '%s' trait of %s instance has an unknown type. "
              "Contact the developer to correct the problem." % (
              name, class_of( object ) ) )
 
    def error ( self, object, name, value ):
        raise TraitError, ( object, name, self.info(), value )
        
    def arg_error ( self, method, arg_num, object, name, value ):
        raise TraitError, ("The '%s' parameter (argument %d) of the %s method "
                           "of %s instance must be %s, but a value of %s was "
                           "specified." % ( name, arg_num, method.tm_name,
                           class_of( object ), self.info(), value ) )
        
    def keyword_error ( self, method, object, name, value ):
        raise TraitError, ("The '%s' keyword argument of the %s method of "
                           "%s instance must be %s, but a value of %s was "
                           "specified." % ( name, method.tm_name,
                           class_of( object ), self.info(), value ) )
        
    def missing_arg_error ( self, method, arg_num, object, name ):
        raise TraitError, ("The '%s' parameter (argument %d) of the %s method "
                           "of %s instance must be specified, but was omitted."
                           % ( name, arg_num, method.tm_name, 
                               class_of( object ) ) )
        
    def dup_arg_error ( self, method, arg_num, object, name ):
        raise TraitError, ("The '%s' parameter (argument %d) of the %s method "
                           "of %s instance was specified as both a positional "
                           "and keyword value."
                           % ( name, arg_num, method.tm_name, 
                               class_of( object ) ) )
        
    def return_error ( self, method, object, value ):
        raise TraitError, ("The result of the %s method of %s instance must "
                           "be %s, but a value of %s was returned." % ( 
                           method.tm_name, class_of( object ), self.info(), 
                           value ) )
 
    def info ( self ):
        return 'a legal value'
 
    def repr ( self, value ):
        if type( value ) is InstanceType:
            return 'class '  + value.__class__.__name__
        return repr( value )
               
    def get_editor ( self, trait ):
        if self.editor is None:
            from matplotlib.enthought.traits.ui import TextEditor
            self.editor = TextEditor()
        return self.editor
        
    def metadata ( self ):
        return getattr( self, '__traits_metadata__', {} )

#-------------------------------------------------------------------------------
#  'TraitRange' class:
#-------------------------------------------------------------------------------

class TraitRange ( TraitHandler ):

    def __init__ ( self, low = None, high = None ):
        self.low  = low
        self.high = high
        vtype     = type( high )
        if (low is not None) and (vtype is not float):
            vtype = type( low )
        if vtype not in RangeTypes:
            raise TraitError, ("TraitRange can only be use for int or float "
                               "values, but a value of type %s was specified." %
                               vtype)
        if vtype is float:
            self.validate = self.float_validate
            kind           = 4
            self.type_desc = 'a floating point number'
            if low is not None:
                low = float( low )
            if high is not None:
                high = float( high )
        else:
            self.validate = self.int_validate
            kind = 3
            self.type_desc = 'an integer'
            if low is not None:
                low = int( low )
            if high is not None:
                high = int( high )
        self.fast_validate = ( kind, low, high )
 
    def float_validate ( self, object, name, value ):
        try:
            if (isinstance( value, RangeTypes ) and
                ((self.low  is None) or (self.low  <= value)) and 
                ((self.high is None) or (self.high >= value))):
               return float( value )
        except:
            pass
        self.error( object, name, self.repr( value ) )
 
    def int_validate ( self, object, name, value ):
        try:
            if (isinstance( value, int ) and
                ((self.low  is None) or (self.low  <= value)) and 
                ((self.high is None) or (self.high >= value))):
               return value
        except:
            pass
        self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        if self.low is None:
            if self.high is None:
                return self.type_desc
            return '%s <= %s' % ( self.type_desc, self.high )
        elif self.high is None:
            return  '%s >= %s' % ( self.type_desc, self.low )
        return '%s in the range from %s to %s' % (
               self.type_desc, self.low, self.high )
               
    def get_editor ( self, trait ):
        auto_set = trait.auto_set
        if auto_set is None:
            auto_set = True
        from matplotlib.enthought.traits.ui import RangeEditor
        return RangeEditor( self, 
                            cols       = trait.cols or 3, 
                            auto_set   = auto_set,
                            enter_set  = trait.enter_set or False,
                            low_label  = trait.low  or '',
                            high_label = trait.high or '' )

#-------------------------------------------------------------------------------
#  'TraitString' class:
#-------------------------------------------------------------------------------

class TraitString ( TraitHandler ):

    def __init__ ( self, minlen = 0, maxlen = sys.maxint, regex = '' ):
        self.minlen = max( 0, minlen )
        self.maxlen = max( self.minlen, maxlen )
        self.regex  = regex
        self._init()
        
    def _init ( self ):
        if self.regex != '':
            self.match = re.compile( self.regex ).match
            if (self.minlen == 0) and (self.maxlen == sys.maxint):
                self.validate = self.validate_regex
        elif (self.minlen == 0) and (self.maxlen == sys.maxint):
            self.validate = self.validate_str
        else:
            self.validate = self.validate_len
 
    def validate ( self, object, name, value ):
        try:
            value = strx( value )
            if ((self.minlen <= len( value ) <= self.maxlen) and 
                (self.match( value ) is not None)):
                return value
        except:
            pass
        self.error( object, name, self.repr( value ) ) 
 
    def validate_str ( self, object, name, value ):
        try:
            return strx( value )
        except:
            pass
        self.error( object, name, self.repr( value ) ) 
 
    def validate_len ( self, object, name, value ):
        try:
            value = strx( value )
            if self.minlen <= len( value ) <= self.maxlen:
                return value
        except:
            pass
        self.error( object, name, self.repr( value ) ) 
 
    def validate_regex ( self, object, name, value ):
        try:
            value = strx( value )
            if self.match( value ) is not None:
                return value
        except:
            pass
        self.error( object, name, self.repr( value ) ) 
 
    def info ( self ):
        msg = ''
        if (self.minlen != 0) and (self.maxlen != sys.maxint):
            msg = ' between %d and %d characters long' % ( 
                  self.minlen, self.maxlen )
        elif self.maxlen != sys.maxint:
            msg = ' <= %d characters long' % self.maxlen
        elif self.minlen != 0:
            msg = ' >= %d characters long' % self.minlen
        if self.regex != '':
            if msg != '':
                msg += ' and'
            msg += (" matching the pattern '%s'" % self.regex)
        return 'a string' + msg
    
    def __getstate__ ( self ):
        result = self.__dict__.copy()
        for name in [ 'validate', 'match' ]:
            if name in result:
                del result[ name ]
        return result
        
    def __setstate__ ( self, state ):
        self.__dict__.update( state )
        self._init()

#-------------------------------------------------------------------------------
#  'TraitType' class:
#-------------------------------------------------------------------------------

class TraitType ( TraitHandler ):

    def __init__ ( self, aType ):
        if not isinstance( aType, TypeType ):
            aType = type( aType )
        self.aType = aType
        try:
            self.fast_validate = CoercableTypes[ aType ]
        except:
            self.fast_validate = ( 11, aType )
 
    def validate ( self, object, name, value ):
        fv = self.fast_validate
        tv = type( value )
        
        # If the value is already the desired type, then return it:
        if tv is fv[1]:
            return value
            
        # Else see if it is one of the coercable types:
        for typei in fv[2:]:
            if tv is typei:
                # Return the coerced value:
                return fv[1]( value )
                
        # Otherwise, raise an exception:
        if tv is InstanceType:
            kind = class_of( value )
        else:
            kind = repr( value )
        self.error( object, name, '%s (i.e. %s)' % ( str( tv )[1:-1], kind ) )
 
    def info ( self ):
        return 'a value of %s' % str( self.aType )[1:-1]
               
    def get_editor ( self, trait ):
        
        # Make the special case of a 'bool' type use the boolean editor:
        if self.aType is bool:
            if self.editor is None:
                from matplotlib.enthought.traits.ui import BooleanEditor
                self.editor = BooleanEditor()
            return self.editor
            
        # Otherwise, map all other types to a text editor:
        auto_set = trait.auto_set
        if auto_set is None:
            auto_set = True
        from matplotlib.enthought.traits.ui import TextEditor
        return TextEditor( auto_set  = auto_set,
                           enter_set = trait.enter_set or False,
                           evaluate  = self.fast_validate[1] )

#-------------------------------------------------------------------------------
#  'TraitCastType' class:
#-------------------------------------------------------------------------------

class TraitCastType ( TraitType ):

    def __init__ ( self, aType ):
        if not isinstance( aType, TypeType ):
            aType = type( aType )
        self.aType = aType
        self.fast_validate = ( 12, aType )
 
    def validate ( self, object, name, value ):
        
        # If the value is already the desired type, then return it:
        if type( value ) is self.aType:
            return value
            
        # Else try to cast it to the specified type:
        try:
            return self.aType( value )
        except:
            # Otherwise, raise an exception:
            tv = type( value )
            if tv is InstanceType:
                kind = class_of( value )
            else:
                kind = repr( value )
            self.error( object, name, '%s (i.e. %s)' % ( 
                                      str( tv )[1:-1], kind ) )

#-------------------------------------------------------------------------------
#  'ThisClass' class:
#-------------------------------------------------------------------------------

class ThisClass ( TraitHandler ):

    def __init__ ( self, or_none = 0 ):
        if or_none is None:
            self.allow_none()
            self.fast_validate = ( 2, None )
        else:
            self.fast_validate = ( 2, )
 
    def validate ( self, object, name, value ):
        if isinstance( value, object.__class__ ):
            return value
        self.validate_failed( object, name, value )
 
    def validate_none ( self, object, name, value ):
        if isinstance( value, object.__class__ ) or (value is None):
            return value
        self.validate_failed( object, name, value )
 
    def info ( self ):
        return 'an instance of the same type as the receiver'
 
    def info_none ( self ):
        return 'an instance of the same type as the receiver or None'
 
    def validate_failed ( self, object, name, value ):
        kind = type( value )
        if kind is InstanceType:
            msg = 'class %s' % value.__class__.__name__
        else:
            msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
        self.error( object, name, msg )
               
    def get_editor ( self, trait ):
        if self.editor is None:
            from matplotlib.enthought.traits.ui import InstanceEditor
            self.editor = InstanceEditor( label = trait.label or '',
                                          view  = trait.view  or '',
                                          kind  = trait.kind  or 'live' )
        return self.editor
        
class TraitThisClass ( ThisClass ):

    def __init__ ( self, or_none = 0 ):
        warning( "Use of 'TraitThisClass' is deprecated, use 'ThisClass' "
                 "instead." )
        ThisClass.__init__( self, or_none )  

#-------------------------------------------------------------------------------
#  'TraitInstance' class:
#-------------------------------------------------------------------------------

class TraitInstance ( ThisClass ):

    def __init__ ( self, aClass, or_none = False, module = '' ):
        if aClass is None:
            aClass, or_none = or_none, aClass
        self.or_none = (or_none != False)
        self.module  = module
        if type( aClass ) is str:
            self.aClass = aClass
        else:
            if not isinstance( aClass, ClassTypes ):
                aClass = aClass.__class__
            self.aClass = aClass
            self.set_fast_validate()
            
    def allow_none ( self ):
        self.or_none = True
        if hasattr( self, 'fast_validate' ):
            self.set_fast_validate()
        
    def set_fast_validate ( self ):
        fast_validate = [ 1, self.aClass ]
        if self.or_none:
            fast_validate = [ 1, None, self.aClass ]
        if self.aClass in TypeTypes:
            fast_validate[0] = 0
        self.fast_validate = tuple( fast_validate )
 
    def validate ( self, object, name, value ):
        if type( self.aClass ) is str:
            self.resolve_class( object, name, value )
        if (isinstance( value, self.aClass ) or 
            (self.or_none and (value is None))):
            return value
        self.validate_failed( object, name, value )
 
    def info ( self ):
        aClass = self.aClass
        if type( aClass ) is not str:
            aClass = aClass.__name__
        result = class_of( aClass )
        if self.or_none is None:
            return result + ' or None'
        return result
        
    def resolve_class ( self, object, name, value ):
        aClass = self.find_class()
        if aClass is None:
            self.validate_failed( object, name, value )
        self.aClass = aClass
        self.set_fast_validate()
        trait = object.base_trait( name )
        trait.validate( self.fast_validate )
        
    def find_class ( self ):
        module = self.module
        aClass = self.aClass
        col    = aClass.rfind( '.' )
        if col >= 0:
            module = aClass[ : col ]
            aClass = aClass[ col + 1: ]
        theClass = getattr( sys.modules.get( module ), aClass, None )
        if (theClass is None) and (col >= 0):
            try:
                theClass = getattr( __import__( module ), aClass, None )
            except:
                pass
        return theClass            
        
    def create_default_value ( self, *args, **kw ):
        aClass = self.aClass        
        if type( aClass ) is str:
            aClass = self.find_class()
            if aClass is None:
                raise TraitError, 'Unable to locate class: ' + self.aClass
        return aClass( *args, **kw )

#-------------------------------------------------------------------------------
#  'TraitClass' class:
#-------------------------------------------------------------------------------

class TraitClass ( TraitHandler ):

    def __init__ ( self, aClass ):
        if type( aClass ) is InstanceType:
            aClass = aClass.__class__
        self.aClass = aClass
 
    def validate ( self, object, name, value ):
        try:
            if type( value ) == StringType:
                value = value.strip()
                col   = value.rfind( '.' )
                if col >= 0:
                    module_name = value[:col]
                    class_name  = value[col + 1:]
                    module      = sys.modules.get( module_name )
                    if module is None:
                        exec( 'import ' + module_name )
                        module = sys.modules[ module_name ]
                    value = getattr( module, class_name )
                else:
                    value = globals().get( value ) 
            
            if issubclass( value, self.aClass ):
                return value
        except:
            pass
 
        self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        return 'a subclass of ' + self.aClass.__name__ 

#-------------------------------------------------------------------------------
#  'TraitFunction' class:
#-------------------------------------------------------------------------------

class TraitFunction ( TraitHandler ):

    def __init__ ( self, aFunc ):
        if not isinstance( aFunc, CallableTypes ):
            raise TraitError, "Argument must be callable."
        self.aFunc = aFunc
        self.fast_validate = ( 13, aFunc )
 
    def validate ( self, object, name, value ):
       try:
           return self.aFunc( object, name, value )
       except TraitError:
           self.error( object, name, self.repr( value ) ) 
 
    def info ( self ):
        try:
            return self.aFunc.info
        except:
            if self.aFunc.__doc__:
                return self.aFunc.__doc__
            return 'a legal value'

#-------------------------------------------------------------------------------
#  'TraitEnum' class:
#-------------------------------------------------------------------------------

class TraitEnum ( TraitHandler ):

    def __init__ ( self, *values ):
        if (len( values ) == 1) and (type( values[0] ) in SequenceTypes):
            values = values[0]
        self.values        = tuple( values )
        self.fast_validate = ( 5, self.values )
 
    def validate ( self, object, name, value ):
        if value in self.values:
            return value
        self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        return ' or '.join( [ repr( x ) for x in self.values ] )
               
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import EnumEditor
        return EnumEditor( values = self, 
                           cols   = trait.cols or 3  )

#-------------------------------------------------------------------------------
#  'TraitPrefixList' class:
#-------------------------------------------------------------------------------

class TraitPrefixList ( TraitHandler ):

    def __init__ ( self, *values ):
        if (len( values ) == 1) and (type( values[0] ) in SequenceTypes):
            values = values[0]
        self.values  = values[:]
        self.values_ = values_ = {}
        for key in values:
            values_[ key ] = key
        self.fast_validate = ( 10, values_, self.validate )
 
    def validate ( self, object, name, value ):
        try:
            if not self.values_.has_key( value ):
                match = None
                n     = len( value )
                for key in self.values:
                    if value == key[:n]:
                        if match is not None:
                           match = None
                           break
                        match = key
                if match is None:
                    self.error( object, name, self.repr( value ) )
                self.values_[ value ] = match
            return self.values_[ value ]
        except:
            self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        return (' or '.join( [ repr( x ) for x in self.values ] ) + 
                ' (or any unique prefix)')
               
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import EnumEditor
        return EnumEditor( values = self, 
                           cols   = trait.cols or 3  )
    
    def __getstate__ ( self ):
        result = self.__dict__.copy()
        if 'fast_validate' in result:
            del result[ 'fast_validate' ]
        return result

#-------------------------------------------------------------------------------
#  'TraitMap' class:
#-------------------------------------------------------------------------------

class TraitMap ( TraitHandler ):
    
    is_mapped = True

    def __init__ ( self, map ):
        self.map = map
        self.fast_validate = ( 6, map )
 
    def validate ( self, object, name, value ):
        try:
            if self.map.has_key( value ):
                return value
        except:
            pass
        self.error( object, name, self.repr( value ) )
        
    def mapped_value ( self, value ):
        return self.map[ value ]
 
    def post_setattr ( self, object, name, value ):
        try:
            setattr( object, name + '_', self.mapped_value( value ) )
        except:
            # We don't need a fancy error message, because this exception 
            # should always be caught by a TraitCompound handler:
            raise TraitError, 'Unmappable'
 
    def info ( self ):
        keys = [ repr( x ) for x in self.map.keys() ]
        keys.sort()
        return ' or '.join( keys )
               
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import EnumEditor
        return EnumEditor( values = self, 
                           cols   = trait.cols or 3  )

#-------------------------------------------------------------------------------
#  'TraitPrefixMap' class:
#-------------------------------------------------------------------------------

class TraitPrefixMap ( TraitMap ):

    def __init__ ( self, map ):
        self.map  = map
        self._map = _map = {}
        for key in map.keys():
            _map[ key ] = key
        self.fast_validate = ( 10, _map, self.validate )
 
    def validate ( self, object, name, value ):
        try:
            if not self._map.has_key( value ):
                match = None
                n     = len( value )
                for key in self.map.keys():
                    if value == key[:n]:
                        if match is not None:
                           match = None
                           break
                        match = key
                if match is None:
                    self.error( object, name, self.repr( value ) )
                self._map[ value ] = match
            return self._map[ value ]
        except:
            self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        return TraitMap.info( self ) + ' (or any unique prefix)'

#-------------------------------------------------------------------------------
#  'TraitCompound' class:
#-------------------------------------------------------------------------------

class TraitCompound ( TraitHandler ):

    def __init__ ( self, *handlers ):
        if (len( handlers ) == 1) and (type( handlers[0] ) in SequenceTypes):
            handlers = handlers[0]
        self.handlers       = handlers
        mapped_handlers     = []
        post_setattrs       = []
        self.validates      = validates = []
        self.slow_validates = slow_validates = []
        fast_validates      = []
        self.reversable     = True
        for handler in handlers:
            if hasattr( handler, 'fast_validate' ):
                validates.append( handler.validate )
                fv = handler.fast_validate
                if fv[0] == 7:
                    # If this is a nested complex fast validator, expand its
                    # contents and adds its list to our list:
                    fast_validates.extend( fv[1] )
                else:
                    # Else just add the entire validator to the list:
                    fast_validates.append( fv )
            else:
                slow_validates.append( handler.validate )
            if hasattr( handler, 'post_setattr' ):
                post_setattrs.append( handler.post_setattr )
            if handler.is_mapped:
                self.is_mapped = True
                mapped_handlers.append( handler )
                self.mapped_handlers = mapped_handlers
            else:
                self.reversable = False
            if handler.has_items:
                self.has_items = True
                
        # If there are any fast validators, then we create a 'complex' fast
        # validator that composites them:
        if len( fast_validates ) > 0:
            # If there are any 'slow' validators, add a special handler at
            # the end of the fast validator list to handle them:
            if len( slow_validates ) > 0:
                fast_validates.append( ( 8, self ) )
            # Create the 'complex' fast validator:
            self.fast_validate = ( 7, tuple( fast_validates ) )
            
        if len( post_setattrs ) > 0:
            self.post_setattrs = post_setattrs
            self.post_setattr  = self._post_setattr
               
    def validate ( self, object, name, value ):
        for validate in self.validates:
            try:
               return validate( object, name, value )
            except TraitError:
               pass
        return self.slow_validate( object, name, value )
               
    def slow_validate ( self, object, name, value ):
        for validate in self.slow_validates:
            try:
               return validate( object, name, value )
            except TraitError:
               pass
        self.error( object, name, self.repr( value ) )
 
    def info ( self ):
        return ' or '.join( [ x.info() for x in self.handlers ] )
        
    def mapped_value ( self, value ):
        for handler in self.mapped_handlers:
            try:
                return handler.mapped_value( value )
            except:
                pass
        return value
 
    def _post_setattr ( self, object, name, value ):
        for post_setattr in self.post_setattrs:
            try:
                post_setattr( object, name, value )
                return
            except TraitError:
               pass
        setattr( object, name + '_', value )
               
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import TextEditor, CompoundEditor
        
        the_editors = [ x.get_editor( trait ) for x in self.handlers ] 
        text_editor = TextEditor()
        count       = 0
        editors     = []
        for editor in the_editors:
            if isinstance( text_editor, editor.__class__ ):
                count += 1
                if count > 1:
                    continue
            editors.append( editor )
        return CompoundEditor( editors = editors )

#-------------------------------------------------------------------------------
#  'TraitComplex' class:  (deprecated, will be removed)
#-------------------------------------------------------------------------------

class TraitComplex ( TraitCompound ):

    def __init__ ( self, *handlers ):
        warning( 
           "'Use of 'TraitComplex' is deprecated, use 'TraitCompound' instead" )
        TraitCompound.__init__( self, *handlers )

#-------------------------------------------------------------------------------
#  'TraitTuple' class:
#-------------------------------------------------------------------------------

class TraitTuple ( TraitHandler ):
    
    def __init__ ( self, *args ):
        self.traits = tuple( [ trait_from( arg ) for arg in args ] )
        self.fast_validate = ( 9, self.traits )
 
    def validate ( self, object, name, value ):
        try:
            if isinstance( value, tuple ):
                traits = self.traits
                if len( value ) == len( traits ):
                    values = []
                    for i, trait in enumerate( traits ):
                        values.append( trait.handler.validate( object, name,
                                                               value[i] ) )
                    return tuple( values )
        except:
            pass
        self.error( object, name, self.repr( value ) ) 
 
    def info ( self ):
        return 'a tuple of the form: (%s)' % (', '.join( 
               [ self._trait_info( trait ) for trait in self.traits ] ))
               
    def _trait_info ( self, trait ):
        handler = trait.handler
        if handler is None:
            return 'any value'
        else:
            return handler.info()
            
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import TupleEditor
        return TupleEditor( traits = self.traits,
                            labels = trait.labels or [],
                            cols   = trait.cols   or 1  )
   
#-------------------------------------------------------------------------------
#  'TraitAny' class: (This class is deprecated)
#-------------------------------------------------------------------------------

class TraitAny ( TraitHandler ):

    def validate ( self, object, name, value ):
        return value

    def info ( self ):
        return 'any value'
   
#-------------------------------------------------------------------------------
#  'TraitCallable' class:
#-------------------------------------------------------------------------------

class TraitCallable ( TraitHandler ):

   def validate ( self, object, name, value ):
       if callable( value ):
           return value
       self.error( object, name, self.repr( value ) ) 

   def info ( self ):
       return 'a callable value'
       
#-------------------------------------------------------------------------------
#  'TraitListEvent' class:
#-------------------------------------------------------------------------------
       
class TraitListEvent:
    
    #---------------------------------------------------------------------------
    #  Initialize the object:
    #---------------------------------------------------------------------------
    
    def __init__ ( self, index, removed = None, added = None ):
        self.index = index
        if removed is None:
            self.removed = []
        else:
            self.removed = removed
        if added is None:
            self.added = []
        else:
            self.added = added

#-------------------------------------------------------------------------------
#  'TraitList' class:
#-------------------------------------------------------------------------------

class TraitList ( TraitHandler ):

    info_trait         = None
    default_value_type = 5
    _items_event       = None 
    
    def __init__ ( self, trait = None, minlen = 0, maxlen = sys.maxint,
                         has_items = True ):
        self.item_trait = trait_from( trait )
        self.minlen     = max( 0, minlen )
        self.maxlen     = max( minlen, maxlen )
        self.has_items  = has_items
 
    def validate ( self, object, name, value ):
        if (isinstance( value, list ) and
           (self.minlen <= len( value ) <= self.maxlen)):
            return TraitListObject( self, object, name, value )
        self.error( object, name, self.repr( value ) ) 
 
    def info ( self ):
        if self.minlen == 0:
            if self.maxlen == sys.maxint:
                size = 'items'
            else:
                size = 'at most %d items' % self.maxlen
        else:
            if self.maxlen == sys.maxint:
                size = 'at least %d items' % self.minlen
            else:
                size = 'from %d to %d items' % ( 
                       self.minlen, self.maxlen )
        handler = self.item_trait.handler
        if handler is None:
            info = ''
        else:
            info = ' which are %s' % handler.info()
        return 'a list of %s%s' % ( size, info ) 
               
    def get_editor ( self, trait ):
        from matplotlib.enthought.traits.ui import ListEditor
        return ListEditor( trait_handler = self, 
                           rows          = trait.rows or 5,
                           use_notebook  = trait.use_notebook is True,
                           page_name     = trait.page_name or '' )
        
    def items_event ( self ):
        if TraitList._items_event is None:
            TraitList._items_event = Event( TraitListEvent, is_base = False )
        return TraitList._items_event
       
#-------------------------------------------------------------------------------
#  'TraitListObject' class:
#-------------------------------------------------------------------------------

class TraitListObject ( list ):
    
    def __init__ ( self, trait, object, name, value ):
        self.trait      = trait
        self.object     = object
        self.name       = name
        self.name_items = None
        if trait.has_items:
            self.name_items = name + '_items'
          
        # Do the validated 'setslice' assignment without raising an 
        # 'items_changed' event:
        if trait.minlen <= len( value ) <= trait.maxlen:
            try:
                handler = trait.item_trait.handler
                if handler is not None:
                    validate = handler.validate
                    value    = [ validate( object, name, val ) for val in value]
                list.__setslice__( self, 0, 0, value )
                return
            except TraitError, excp:
                excp.set_prefix( 'Each element of the' )
                raise excp
        self.len_error( len( value ) )
        
    def __deepcopy__ ( self, memo ):
        id_self = id( self )
        if id_self in memo:
            return memo[ id_self ]
        memo[ id_self ] = result = TraitListObject( self.trait, self.object, 
                        self.name, [ copy.deepcopy( x, memo ) for x in self ]  )          
        return result

    def __setitem__ ( self, key, value ):
        try:
            removed = [ self[ key ] ]
        except:
            pass
        try:
            handler = self.trait.item_trait.handler
            if handler is not None:
                value = handler.validate( self.object, self.name, value )
            list.__setitem__( self, key, value )
            if self.name_items is not None:
                if key < 0:
                    key = len( self ) + key
                setattr( self.object, self.name_items, 
                         TraitListEvent( key, removed, [ value ] ) )
        except TraitError, excp:
            excp.set_prefix( 'Each element of the' )
            raise excp
 
    def __setslice__ ( self, i, j, values ):
        try:
            delta = len( values ) - (min( j, len( self ) ) - max( 0, i ))
        except:
            raise TypeError, 'must assign sequence (not "%s") to slice' % ( 
                             values.__class__.__name__ )
        if self.trait.minlen <= (len(self) + delta) <= self.trait.maxlen:
            try:
                object  = self.object
                name    = self.name
                trait   = self.trait.item_trait
                removed = self[ i: j ]
                handler = trait.handler
                if handler is not None:
                    validate = handler.validate
                    values   = [ validate( object, name, value )
                                 for value in values ]
                list.__setslice__( self, i, j, values )
                if self.name_items is not None:
                   setattr( self.object, self.name_items, 
                            TraitListEvent( max( 0, i ), removed, values ) )
                return
            except TraitError, excp:
                excp.set_prefix( 'Each element of the' )
                raise excp
        self.len_error( len( self ) + delta )
    
    def __delitem__ ( self, key ):
        if self.trait.minlen <= (len( self ) - 1):
            try:
                removed = [ self[ key ] ]
            except:
                pass
            list.__delitem__( self, key )
            if self.name_items is not None:
                if key < 0:
                    key = len( self ) + key + 1
                setattr( self.object, self.name_items, 
                         TraitListEvent( key, removed ) )
            return
        self.len_error( len( self ) - 1 )
    
    def __delslice__ ( self, i, j ):
        delta = min( j, len( self ) ) - max( 0, i )
        if self.trait.minlen <= (len( self ) - delta):
            removed = self[ i: j ]
            list.__delslice__( self, i, j )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 
                         TraitListEvent( max( 0, i ), removed ) )
            return
        self.len_error( len( self ) - delta )
        
    def append ( self, value ):
        if self.trait.minlen <= (len( self ) + 1) <= self.trait.maxlen:
            try:
                handler = self.trait.item_trait.handler
                if handler is not None:
                    value = handler.validate( self.object, self.name, value )
                list.append( self, value )
                if self.name_items is not None:
                    setattr( self.object, self.name_items, 
                            TraitListEvent( len( self ) - 1, None, [ value ] ) )
                return
            except TraitError, excp:
                excp.set_prefix( 'Each element of the' )
                raise excp
        self.len_error( len( self ) + 1 )
        
    def insert ( self, index, value ):
        if self.trait.minlen <= (len( self ) + 1) <= self.trait.maxlen:
            try:
                handler = self.trait.item_trait.handler
                if handler is not None:
                    value = handler.validate( self.object, self.name, value )
                list.insert( self, index, value )
                if self.name_items is not None:
                    if index < 0:
                        index = len( self ) + index - 1
                    setattr( self.object, self.name_items, 
                             TraitListEvent( index, None, [ value ] ) )
                return
            except TraitError, excp:
                excp.set_prefix( 'Each element of the' )
                raise excp
        self.len_error( len( self ) + 1 )
        
    def extend ( self, xlist ):
        try:
            len_xlist = len( xlist )
        except:
            raise TypeError, "list.extend() argument must be iterable"
        if (self.trait.minlen <= (len( self ) + len_xlist) <= 
            self.trait.maxlen):
            object  = self.object
            name    = self.name
            handler = self.trait.item_trait.handler
            try:
                if handler is not None:
                    validate = handler.validate
                    xlist    = [ validate( object, name, value )
                                 for value in xlist ]
                list.extend( self, xlist )
                if self.name_items is not None:
                    setattr( self.object, self.name_items, 
                             TraitListEvent( len( self ) - len( xlist ), None, 
                                             xlist ) )
                return
            except TraitError, excp:
                excp.set_prefix( 'The elements of the' )
                raise excp
        self.len_error( len( self ) + len( xlist ) )
        
    def remove ( self, value ):
        if self.trait.minlen < len( self ):
            try:
                index   = self.index( value )
                removed = [ self[ index ] ]
            except:
                pass
            list.remove( self, value )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 
                         TraitListEvent( index, removed ) )
        else:
            self.len_error( len( self ) - 1 )
        
    def len_error ( self, len ):
        raise TraitError, ( "The '%s' trait of %s instance must be %s, "
                  "but you attempted to change its length to %d element%s." % (
                  self.name, class_of( self.object ), self.trait.info(),
                  len, 's'[ len == 1: ] ) )
                  
    def sort ( self, cmpfunc = None ):
        removed = self[:]
        list.sort( self, cmpfunc )
        if self.name_items is not None:
            setattr( self.object, self.name_items, 
                     TraitListEvent( 0, removed, self[:] ) )
                  
    def reverse ( self ):
        removed = self[:]
        if len( self ) > 1:
            list.reverse( self )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 
                         TraitListEvent( 0, removed, self[:] ) )
                
    def pop ( self, *args ):
        if self.trait.minlen < len( self ):
            if len( args ) > 0:
                index = args[0]
            else:
                index = -1
            try:
                removed = [ self[ index ] ]
            except:
                pass
            result = list.pop( self, *args )
            if self.name_items is not None:
                if index < 0:
                    index = len( self ) + index + 1
                setattr( self.object, self.name_items, 
                         TraitListEvent( index, removed ) )
            return result
        else:
            self.len_error( len( self ) - 1 )
    
    def __getstate__ ( self ):
        result = self.__dict__.copy()
        del result[ 'trait' ]
        return result
        
    def __setstate__ ( self, state ):
        self.__dict__.update( state )
        self.trait = self.object._trait( self.name, 0 ).handler

#-------------------------------------------------------------------------------
#  'TraitDict' class:
#-------------------------------------------------------------------------------

class TraitDict ( TraitHandler ):
    
    info_trait         = None
    default_value_type = 6
    _items_event       = None
 
    def __init__ ( self, key_trait = None, value_trait = None, 
                         has_items = True ):
        self.key_trait   = trait_from( key_trait )
        self.value_trait = trait_from( value_trait )
        self.has_items   = has_items
 
    def validate ( self, object, name, value ):
        if isinstance( value, dict ):
            return TraitDictObject( self, object, name, value )
        self.error( object, name, self.repr( value ) ) 
 
    def info ( self ):
        extra   = ''
        handler = self.key_trait.handler
        if handler is not None:
            extra = ' with keys which are %s' % handler.info()
        handler = self.value_trait.handler
        if handler is not None:
            if extra == '':
                extra = ' with'
            else:
                extra += ' and'
            extra += ' values which are %s' % handler.info()
        return 'a dictionary%s' % extra
               
    def get_editor ( self, trait ):
        if self.editor is None:
            from matplotlib.enthought.traits.ui import TextEditor
            self.editor = TextEditor( evaluate = eval )
        return self.editor
        
    def items_event ( self ):
        if TraitDict._items_event is None:
            TraitDict._items_event = Event( 0, is_base = False )
        return TraitDict._items_event

#-------------------------------------------------------------------------------
#  'TraitDictObject' class:
#-------------------------------------------------------------------------------

class TraitDictObject ( dict ):
    
    def __init__ ( self, trait, object, name, value ):
        self.trait      = trait
        self.object     = object
        self.name       = name
        self.name_items = None
        if trait.has_items:
            self.name_items = name + '_items'
        if len( value ) > 0:
            dict.update( self, value )
        
    def __setitem__ ( self, key, value ):
        try:
            handler = self.trait.key_trait.handler
            if handler is not None:
                key = self.trait.key_trait.handler.validate( 
                                           self.object, self.name, key )
        except TraitError, excp:
            excp.set_prefix( 'Each key of the' )
            raise excp
        except AttributeError:
            # This is to handle the fact that on unpickling a serialized
            # dictionary, the items are simply stored in the dictionary before
            # '__setstate__' is called:
            dict.__setitem__( self, key, value ) 
            return
        try:                                          
            handler = self.trait.value_trait.handler
            if handler is not None:
                value = handler.validate( self.object, self.name, value )
            dict.__setitem__( self, key, value ) 
            if self.name_items is not None:
                setattr( self.object, self.name_items, 0 )
        except TraitError, excp:
            excp.set_prefix( 'Each value of the' )
            raise excp
    
    def __delitem__ ( self, key ):
        dict.__delitem__( self, key )
        if self.name_items is not None:
            setattr( self.object, self.name_items, 0 )
           
    def clear ( self ):
        if len( self ) > 0:
            dict.clear( self )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 0 )
        
    def update ( self, dic ):
        if len( dic ) > 0:
            dict.update( self, dic )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 0 )
      
    def setdefault ( self, key, value = None ):
        if self.has_key( key ):
            return self[ key ]
        self[ key ] = value
        return value
      
    def pop ( self, key, value = Undefined ):
        if (value is Undefined) or self.has_key( key ):
            result = dict.pop( key )
            if self.name_items is not None:
                setattr( self.object, self.name_items, 0 )
            return result
        return value
           
    def popitem ( self ):
        result = dict.popitem( self )
        if self.name_items is not None:
            setattr( self.object, self.name_items, 0 )
        return result
    
    def __getstate__ ( self ):
        result = self.__dict__.copy()
        del result[ 'trait' ]
        return result
        
    def __setstate__ ( self, state ):
        self.__dict__.update( state )
        self.trait = self.object._trait( self.name, 0 ).handler
            
#-------------------------------------------------------------------------------
#  Tell the C-based traits module about 'TraitListObject' and 'TraitDictObject': 
#-------------------------------------------------------------------------------
        
import ctraits
ctraits._list_classes( TraitListObject, TraitDictObject )


Generated by  Doxygen 1.6.0   Back to index