Skip to main content

Introduction

This page is the technical reference for the interfaces and data structures you implement when integrating AI models with Dify through a model plugin.
Before diving into this API reference, we recommend reading Model Design Rules for the conceptual model and Creating a New Model Provider for a step-by-step walkthrough.

Quick Decision: Which Method Do I Implement?

If your model is a…Implement
Chat/completion LLMLargeLanguageModel._invoke, _get_num_tokens
Embedding modelTextEmbeddingModel._invoke, _get_num_tokens
Rerank modelRerankModel._invoke
Speech-to-textSpeech2TextModel._invoke
Text-to-speechText2SpeechModel._invoke
ModerationModerationModel._invoke
Every provider also implements validate_provider_credentials (provider-level auth) and, if the model is user-configurable, validate_credentials per model type.

Provider Implementation

Learn how to implement model provider classes for different AI service providers

Model Types

Implementation details for the five supported model types: LLM, Embedding, Rerank, Speech2Text, and Text2Speech

Data Structures

Comprehensive reference for all data structures used in the model API

Error Handling

Guidelines for proper error mapping and exception handling

Model Provider

Every model provider must inherit from the __base.model_provider.ModelProvider base class and implement the credential validation interface.

Provider Credential Validation

def validate_provider_credentials(self, credentials: dict) -> None:
    """
    Validate provider credentials by making a test API call
    
    Parameters:
        credentials: Provider credentials as defined in `provider_credential_schema`
        
    Raises:
        CredentialsValidateFailedError: If validation fails
    """
    try:
        # Example implementation: validate using an LLM model instance
        model_instance = self.get_model_instance(ModelType.LLM)
        model_instance.validate_credentials(
            model="example-model", 
            credentials=credentials
        )
    except Exception as ex:
        logger.exception(f"Credential validation failed")
        raise CredentialsValidateFailedError(f"Invalid credentials: {str(ex)}")
credentials
dict
Credential information as defined in the provider’s YAML configuration under provider_credential_schema, typically fields such as api_key and organization_id.
If validation fails, your implementation must raise a CredentialsValidateFailedError exception. This ensures proper error handling in the Dify UI.
For predefined model providers, implement a thorough validation method that verifies the credentials against your API. For custom model providers (where each model has its own credentials), a simplified implementation is sufficient.

Models

Dify supports five distinct model types, each with its own interface. All model types share the common requirements below.

Common Interfaces

Every model implementation, regardless of type, must implement these two fundamental methods:

1. Model Credential Validation

def validate_credentials(self, model: str, credentials: dict) -> None:
    """
    Validate that the provided credentials work with the specified model
    
    Parameters:
        model: The specific model identifier (e.g., "gpt-4")
        credentials: Authentication details for the model
        
    Raises:
        CredentialsValidateFailedError: If validation fails
    """
    try:
        # Make a lightweight API call to verify credentials
        # Example: List available models or check account status
        response = self._api_client.validate_api_key(credentials["api_key"])
        
        # Verify the specific model is available if applicable
        if model not in response.get("available_models", []):
            raise CredentialsValidateFailedError(f"Model {model} is not available")
            
    except ApiException as e:
        raise CredentialsValidateFailedError(str(e))
model
string
required
The specific model identifier to validate (e.g., “gpt-4”, “claude-3-opus”)
credentials
dict
required
Credential information as defined in the provider’s configuration

2. Error Mapping

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
    """
    Map provider-specific exceptions to standardized Dify error types
    
    Returns:
        Dictionary mapping Dify error types to lists of provider exception types
    """
    return {
        InvokeConnectionError: [
            requests.exceptions.ConnectionError,
            requests.exceptions.Timeout,
            ConnectionRefusedError
        ],
        InvokeServerUnavailableError: [
            ServiceUnavailableError,
            HTTPStatusError
        ],
        InvokeRateLimitError: [
            RateLimitExceededError,
            QuotaExceededError
        ],
        InvokeAuthorizationError: [
            AuthenticationError,
            InvalidAPIKeyError,
            PermissionDeniedError
        ],
        InvokeBadRequestError: [
            InvalidRequestError,
            ValidationError
        ]
    }
InvokeConnectionError
class
Network connection failures, timeouts
InvokeServerUnavailableError
class
Service provider is down or unavailable
InvokeRateLimitError
class
Rate limits or quota limits reached
InvokeAuthorizationError
class
Authentication or permission issues
InvokeBadRequestError
class
Invalid parameters or requests
You can alternatively raise these standardized error types directly in your code instead of relying on the error mapping. This approach gives you more control over error messages.

LLM Implementation

To implement a Large Language Model provider, inherit from the __base.large_language_model.LargeLanguageModel base class and implement these methods:

1. Model Invocation

This core method handles both streaming and non-streaming API calls to language models.
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    prompt_messages: list[PromptMessage], 
    model_parameters: dict,
    tools: Optional[list[PromptMessageTool]] = None, 
    stop: Optional[list[str]] = None,
    stream: bool = True, 
    user: Optional[str] = None
) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]:
    """
    Invoke the language model
    """
    # Prepare API parameters
    api_params = self._prepare_api_parameters(
        model, 
        credentials, 
        prompt_messages, 
        model_parameters,
        tools, 
        stop
    )
    
    try:
        # Choose between streaming and non-streaming implementation
        if stream:
            return self._invoke_stream(model, api_params, user)
        else:
            return self._invoke_sync(model, api_params, user)
            
    except Exception as e:
        # Map errors using the error mapping property
        self._handle_api_error(e)

# Helper methods for streaming and non-streaming calls
def _invoke_stream(self, model, api_params, user):
    # Implement streaming call and yield chunks
    pass
    
def _invoke_sync(self, model, api_params, user):
    # Implement synchronous call and return complete result
    pass
model
string
required
Model identifier (e.g., “gpt-4”, “claude-3”)
credentials
dict
required
Authentication credentials for the API
prompt_messages
list[PromptMessage]
required
Message list in Dify’s standardized format:
  • For completion models: include a single UserPromptMessage.
  • For chat models: include SystemPromptMessage, UserPromptMessage, AssistantPromptMessage, and ToolPromptMessage as needed.
model_parameters
dict
required
Model-specific parameters (temperature, top_p, etc.) as defined in the model’s YAML configuration
tools
list[PromptMessageTool]
Tool definitions for function calling capabilities
stop
list[string]
Stop sequences that will halt model generation when encountered
stream
boolean
default:true
Whether to return a streaming response
user
string
User identifier for API monitoring
stream=True
Generator[LLMResultChunk, None, None]
A generator yielding chunks of the response as they become available
stream=False
LLMResult
A complete response object with the full generated text
We recommend implementing separate helper methods for streaming and non-streaming calls to keep your code organized and maintainable.

2. Token Counting

def get_num_tokens(
    self, 
    model: str, 
    credentials: dict, 
    prompt_messages: list[PromptMessage],
    tools: Optional[list[PromptMessageTool]] = None
) -> int:
    """
    Calculate the number of tokens in the prompt
    """
    # Convert prompt_messages to the format expected by the tokenizer
    text = self._convert_messages_to_text(prompt_messages)
    
    try:
        # Use the appropriate tokenizer for this model
        tokenizer = self._get_tokenizer(model)
        return len(tokenizer.encode(text))
    except Exception:
        # Fall back to a generic tokenizer
        return self._get_num_tokens_by_gpt2(text)
If the model doesn’t provide a tokenizer, you can use the base class’s _get_num_tokens_by_gpt2(text) method for a reasonable approximation.

3. Custom Model Schema (Optional)

def get_customizable_model_schema(
    self, 
    model: str, 
    credentials: dict
) -> Optional[AIModelEntity]:
    """
    Get parameter schema for custom models
    """
    # For fine-tuned models, you might return the base model's schema
    if model.startswith("ft:"):
        base_model = self._extract_base_model(model)
        return self._get_predefined_model_schema(base_model)
    
    # For standard models, return None to use the predefined schema
    return None
This method is only necessary for providers that support custom models. It allows custom models to inherit parameter rules from base models.

TextEmbedding Implementation

Text embedding models convert text into high-dimensional vectors that capture semantic meaning, which is useful for retrieval, similarity search, and classification.
To implement a Text Embedding provider, inherit from the __base.text_embedding_model.TextEmbeddingModel base class:

1. Core Embedding Method

def _invoke(
    self, 
    model: str, 
    credentials: dict,
    texts: list[str], 
    user: Optional[str] = None
) -> TextEmbeddingResult:
    """
    Generate embedding vectors for multiple texts
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Handle batching if needed
    batch_size = self._get_batch_size(model)
    all_embeddings = []
    total_tokens = 0
    start_time = time.time()
    
    # Process in batches to avoid API limits
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        
        # Make API call to the embeddings endpoint
        response = client.embeddings.create(
            model=model,
            input=batch,
            user=user
        )
        
        # Extract embeddings from response
        batch_embeddings = [item.embedding for item in response.data]
        all_embeddings.extend(batch_embeddings)
        
        # Track token usage
        total_tokens += response.usage.total_tokens
    
    # Calculate usage metrics
    elapsed_time = time.time() - start_time
    usage = self._create_embedding_usage(
        model=model,
        tokens=total_tokens,
        latency=elapsed_time
    )
    
    return TextEmbeddingResult(
        model=model,
        embeddings=all_embeddings,
        usage=usage
    )
model
string
required
Embedding model identifier
credentials
dict
required
Authentication credentials for the embedding service
texts
list[string]
required
List of text inputs to embed
user
string
User identifier for API monitoring
TextEmbeddingResult
object
required
A structured response containing:
  • model: The model used for embedding.
  • embeddings: Embedding vectors in the same order as the input texts.
  • usage: Metadata about token usage and costs.

2. Token Counting Method

def get_num_tokens(
    self, 
    model: str, 
    credentials: dict, 
    texts: list[str]
) -> int:
    """
    Calculate the number of tokens in the texts to be embedded
    """
    # Join all texts to estimate token count
    combined_text = " ".join(texts)
    
    try:
        # Use the appropriate tokenizer for this model
        tokenizer = self._get_tokenizer(model)
        return len(tokenizer.encode(combined_text))
    except Exception:
        # Fall back to a generic tokenizer
        return self._get_num_tokens_by_gpt2(combined_text)
For embedding models, accurate token counting is important for cost estimation, but not critical for functionality. The _get_num_tokens_by_gpt2 method provides a reasonable approximation for most models.

Rerank Implementation

Reranking models help improve search quality by re-ordering a set of candidate documents based on their relevance to a query, typically after an initial retrieval phase.
To implement a Reranking provider, inherit from the __base.rerank_model.RerankModel base class:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    query: str, 
    docs: list[str], 
    score_threshold: Optional[float] = None, 
    top_n: Optional[int] = None,
    user: Optional[str] = None
) -> RerankResult:
    """
    Rerank documents based on relevance to the query
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Prepare request data
    request_data = {
        "query": query,
        "documents": docs,
    }
    
    # Call reranking API endpoint
    response = client.rerank(
        model=model,
        **request_data,
        user=user
    )
    
    # Process results
    ranked_results = []
    for i, result in enumerate(response.results):
        # Create RerankDocument for each result
        doc = RerankDocument(
            index=result.document_index,  # Original index in docs list
            text=docs[result.document_index],  # Original text
            score=result.relevance_score  # Relevance score
        )
        ranked_results.append(doc)
    
    # Sort by score in descending order
    ranked_results.sort(key=lambda x: x.score, reverse=True)
    
    # Apply score threshold filtering if specified
    if score_threshold is not None:
        ranked_results = [doc for doc in ranked_results if doc.score >= score_threshold]
    
    # Apply top_n limit if specified
    if top_n is not None and top_n > 0:
        ranked_results = ranked_results[:top_n]
    
    return RerankResult(
        model=model,
        docs=ranked_results
    )
model
string
required
Reranking model identifier
credentials
dict
required
Authentication credentials for the API
query
string
required
The search query text
docs
list[string]
required
List of document texts to be reranked
score_threshold
float
Minimum score a document must reach to be included in the results
top_n
int
Maximum number of results to return
user
string
User identifier for API monitoring
RerankResult
object
required
A structured response containing:
  • model: The model used for reranking.
  • docs: List of RerankDocument objects with index, text, and score.
Reranking can be computationally expensive, especially with large document sets. Implement batching for large document collections to avoid timeouts or excessive resource consumption.

Speech2Text Implementation

Speech-to-text models convert spoken language from audio files into written text, enabling applications like transcription services, voice commands, and accessibility features.
To implement a Speech-to-Text provider, inherit from the __base.speech2text_model.Speech2TextModel base class:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    file: IO[bytes], 
    user: Optional[str] = None
) -> str:
    """
    Convert speech audio to text
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    try:
        # Determine the file format
        file_format = self._detect_audio_format(file)
        
        # Prepare the file for API submission
        # Most APIs require either a file path or binary data
        audio_data = file.read()
        
        # Call the speech-to-text API
        response = client.audio.transcriptions.create(
            model=model,
            file=("audio.mp3", audio_data),  # Adjust filename based on actual format
            user=user
        )
        
        # Extract and return the transcribed text
        return response.text
        
    except Exception as e:
        # Map to appropriate error type
        self._handle_api_error(e)
        
    finally:
        # Reset file pointer for potential reuse
        file.seek(0)
model
string
required
Speech-to-text model identifier
credentials
dict
required
Authentication credentials for the API
file
IO[bytes]
required
Binary file object containing the audio to transcribe
user
string
User identifier for API monitoring
text
string
required
The transcribed text from the audio file
Audio format detection is important for proper handling of different file types. Consider implementing a helper method to detect the format from the file header as shown in the example.
Some speech-to-text APIs have file size limitations. Consider implementing chunking for large audio files if necessary.

Text2Speech Implementation

Text-to-speech models convert written text into natural-sounding speech, enabling applications such as voice assistants, screen readers, and audio content generation.
To implement a Text-to-Speech provider, inherit from the __base.text2speech_model.Text2SpeechModel base class:
def _invoke(
    self, 
    model: str, 
    credentials: dict, 
    content_text: str, 
    streaming: bool,
    user: Optional[str] = None
) -> Union[bytes, Generator[bytes, None, None]]:
    """
    Convert text to speech audio
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    # Get voice settings based on model
    voice = self._get_voice_for_model(model)
    
    try:
        # Choose implementation based on streaming preference
        if streaming:
            return self._stream_audio(
                client=client,
                model=model,
                text=content_text,
                voice=voice,
                user=user
            )
        else:
            return self._generate_complete_audio(
                client=client,
                model=model,
                text=content_text,
                voice=voice,
                user=user
            )
    except Exception as e:
        self._handle_api_error(e)
model
string
required
Text-to-speech model identifier
credentials
dict
required
Authentication credentials for the API
content_text
string
required
Text content to be converted to speech
streaming
boolean
required
Whether to return streaming audio or complete file
user
string
User identifier for API monitoring
streaming=True
Generator[bytes, None, None]
A generator yielding audio chunks as they become available
streaming=False
bytes
Complete audio data as bytes
Most text-to-speech APIs require you to specify a voice along with the model. Consider implementing a mapping between Dify’s model identifiers and the provider’s voice options.
Long text inputs may need to be chunked for better speech synthesis quality. Consider implementing text preprocessing to handle punctuation, numbers, and special characters properly.

Moderation Implementation

Moderation models analyze content for potentially harmful, inappropriate, or unsafe material, helping maintain platform safety and content policies.
To implement a Moderation provider, inherit from the __base.moderation_model.ModerationModel base class:
def _invoke(
    self, 
    model: str, 
    credentials: dict,
    text: str, 
    user: Optional[str] = None
) -> bool:
    """
    Analyze text for harmful content
    
    Returns:
        bool: False if the text is safe, True if it contains harmful content
    """
    # Set up API client with credentials
    client = self._get_client(credentials)
    
    try:
        # Call moderation API
        response = client.moderations.create(
            model=model,
            input=text,
            user=user
        )
        
        # Check if any categories were flagged
        result = response.results[0]
        
        # Return True if flagged in any category, False if safe
        return result.flagged
        
    except Exception as e:
        # Log the error but default to safe if there's an API issue.
        # This is a conservative approach; production systems might want
        # different fallback behavior.
        logger.error(f"Moderation API error: {str(e)}")
        return False
model
string
required
Moderation model identifier
credentials
dict
required
Authentication credentials for the API
text
string
required
Text content to be analyzed
user
string
User identifier for API monitoring
result
boolean
required
Boolean indicating content safety:
  • False: The content is safe.
  • True: The content contains harmful material.
Moderation is often used as a safety mechanism. Consider the implications of false negatives (letting harmful content through) versus false positives (blocking safe content) when implementing your solution.
Many moderation APIs provide detailed category scores rather than just a binary result. Consider extending this implementation to return more detailed information about specific categories of harmful content if your application needs it.

Entities

PromptMessageRole

The role of a message in a conversation.
class PromptMessageRole(Enum):
    """
    Enum class for prompt message.
    """
    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"
    TOOL = "tool"

PromptMessageContentType

The type of message content: plain text or image.
class PromptMessageContentType(Enum):
    """
    Enum class for prompt message content type.
    """
    TEXT = 'text'
    IMAGE = 'image'

PromptMessageContent

Base class for message content. It exists only for type declarations—do not instantiate it directly.
class PromptMessageContent(BaseModel):
    """
    Model class for prompt message content.
    """
    type: PromptMessageContentType
    data: str  # Content data
Content currently supports two types, text and image, and a single message can combine text with multiple images. Instantiate TextPromptMessageContent and ImagePromptMessageContent instead.

TextPromptMessageContent

class TextPromptMessageContent(PromptMessageContent):
    """
    Model class for text prompt message content.
    """
    type: PromptMessageContentType = PromptMessageContentType.TEXT
When a message combines text and images, wrap the text in this entity and add it to the content list.

ImagePromptMessageContent

class ImagePromptMessageContent(PromptMessageContent):
    """
    Model class for image prompt message content.
    """
    class DETAIL(Enum):
        LOW = 'low'
        HIGH = 'high'

    type: PromptMessageContentType = PromptMessageContentType.IMAGE
    detail: DETAIL = DETAIL.LOW  # Resolution
When a message combines text and images, wrap each image in this entity and add it to the content list. data accepts an image URL or a base64-encoded image string.

PromptMessage

Base class for all role-specific messages. It exists only for type declarations—do not instantiate it directly.
class PromptMessage(ABC, BaseModel):
    """
    Model class for prompt message.
    """
    role: PromptMessageRole  # Message role
    content: Optional[str | list[PromptMessageContent]] = None  # Either a string or a content list; the list form supports multimodal input, see PromptMessageContent
    name: Optional[str] = None  # Optional name

UserPromptMessage

Represents a user message.
class UserPromptMessage(PromptMessage):
    """
    Model class for user prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.USER

AssistantPromptMessage

Represents a model response, typically used for few-shot examples or chat history input.
class AssistantPromptMessage(PromptMessage):
    """
    Model class for assistant prompt message.
    """
    class ToolCall(BaseModel):
        """
        Model class for assistant prompt message tool call.
        """
        class ToolCallFunction(BaseModel):
            """
            Model class for assistant prompt message tool call function.
            """
            name: str  # Tool name
            arguments: str  # Tool parameters

        id: str  # Tool call ID; only meaningful for OpenAI tool calls. Uniquely identifies one invocation, since the same tool can be called multiple times
        type: str  # Defaults to "function"
        function: ToolCallFunction  # Tool call information

    role: PromptMessageRole = PromptMessageRole.ASSISTANT
    tool_calls: list[ToolCall] = []  # Model's tool call results (only returned when tools are passed in and the model decides to call them)
tool_calls holds the tool calls the model returns when the request includes tools.

SystemPromptMessage

Represents a system message, typically used to set system instructions for the model.
class SystemPromptMessage(PromptMessage):
    """
    Model class for system prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.SYSTEM

ToolPromptMessage

Represents a tool message, which passes a tool’s execution result back to the model for next-step planning.
class ToolPromptMessage(PromptMessage):
    """
    Model class for tool prompt message.
    """
    role: PromptMessageRole = PromptMessageRole.TOOL
    tool_call_id: str  # Tool call ID; if the provider doesn't support OpenAI tool calls, you can pass the tool name instead
Pass the tool’s execution result through the inherited content field.

PromptMessageTool

class PromptMessageTool(BaseModel):
    """
    Model class for prompt message tool.
    """
    name: str  # Tool name
    description: str  # Tool description
    parameters: dict  # Tool parameters dict

LLMResult

class LLMResult(BaseModel):
    """
    Model class for llm result.
    """
    model: str  # Model actually used
    prompt_messages: list[PromptMessage]  # Prompt message list
    message: AssistantPromptMessage  # Reply message
    usage: LLMUsage  # Token usage and cost information
    system_fingerprint: Optional[str] = None  # Request fingerprint; see OpenAI's parameter definition

LLMResultChunkDelta

The incremental delta within each chunk of a streaming response.
class LLMResultChunkDelta(BaseModel):
    """
    Model class for llm result chunk delta.
    """
    index: int  # Sequence number
    message: AssistantPromptMessage  # Reply message
    usage: Optional[LLMUsage] = None  # Token usage and cost information; only returned in the last chunk
    finish_reason: Optional[str] = None  # Completion reason; only returned in the last chunk

LLMResultChunk

A single chunk in a streaming response.
class LLMResultChunk(BaseModel):
    """
    Model class for llm result chunk.
    """
    model: str  # Model actually used
    prompt_messages: list[PromptMessage]  # Prompt message list
    system_fingerprint: Optional[str] = None  # Request fingerprint; see OpenAI's parameter definition
    delta: LLMResultChunkDelta  # Content changes in this chunk

LLMUsage

class LLMUsage(ModelUsage):
    """
    Model class for llm usage.
    """
    prompt_tokens: int  # Tokens used by the prompt
    prompt_unit_price: Decimal  # Prompt unit price
    prompt_price_unit: Decimal  # Prompt price unit, i.e., the number of tokens the unit price applies to
    prompt_price: Decimal  # Prompt cost
    completion_tokens: int  # Tokens used by the completion
    completion_unit_price: Decimal  # Completion unit price
    completion_price_unit: Decimal  # Completion price unit, i.e., the number of tokens the unit price applies to
    completion_price: Decimal  # Completion cost
    total_tokens: int  # Total tokens used
    total_price: Decimal  # Total cost
    currency: str  # Currency unit
    latency: float  # Request latency in seconds

TextEmbeddingResult

class TextEmbeddingResult(BaseModel):
    """
    Model class for text embedding result.
    """
    model: str  # Model actually used
    embeddings: list[list[float]]  # Embedding vectors, in the same order as the input texts
    usage: EmbeddingUsage  # Usage information

EmbeddingUsage

class EmbeddingUsage(ModelUsage):
    """
    Model class for embedding usage.
    """
    tokens: int  # Tokens used
    total_tokens: int  # Total tokens used
    unit_price: Decimal  # Unit price
    price_unit: Decimal  # Price unit, i.e., the number of tokens the unit price applies to
    total_price: Decimal  # Total cost
    currency: str  # Currency unit
    latency: float  # Request latency in seconds

RerankResult

class RerankResult(BaseModel):
    """
    Model class for rerank result.
    """
    model: str  # Model actually used
    docs: list[RerankDocument]  # List of reranked documents

RerankDocument

class RerankDocument(BaseModel):
    """
    Model class for rerank document.
    """
    index: int  # Index in the original docs list
    text: str  # Document text
    score: float  # Relevance score

Edit this page | Report an issue