Interfaces and data structures for implementing Dify model plugins, covering LLM, TextEmbedding, Rerank, Speech2Text, Text2Speech, and Moderation models
Every provider also implements validate_provider_credentials (provider-level auth) and, if the model is user-configurable, validate_credentials per model type.
Provider Implementation
Learn how to implement model provider classes for different AI service providers
Model Types
Implementation details for the five supported model types: LLM, Embedding, Rerank, Speech2Text, and Text2Speech
Data Structures
Comprehensive reference for all data structures used in the model API
Error Handling
Guidelines for proper error mapping and exception handling
def validate_provider_credentials(self, credentials: dict) -> None: """ Validate provider credentials by making a test API call Parameters: credentials: Provider credentials as defined in `provider_credential_schema` Raises: CredentialsValidateFailedError: If validation fails """ try: # Example implementation: validate using an LLM model instance model_instance = self.get_model_instance(ModelType.LLM) model_instance.validate_credentials( model="example-model", credentials=credentials ) except Exception as ex: logger.exception(f"Credential validation failed") raise CredentialsValidateFailedError(f"Invalid credentials: {str(ex)}")
Credential information as defined in the provider’s YAML configuration under provider_credential_schema, typically fields such as api_key and organization_id.
If validation fails, your implementation must raise a CredentialsValidateFailedError exception. This ensures proper error handling in the Dify UI.
For predefined model providers, implement a thorough validation method that verifies the credentials against your API. For custom model providers (where each model has its own credentials), a simplified implementation is sufficient.
def validate_credentials(self, model: str, credentials: dict) -> None: """ Validate that the provided credentials work with the specified model Parameters: model: The specific model identifier (e.g., "gpt-4") credentials: Authentication details for the model Raises: CredentialsValidateFailedError: If validation fails """ try: # Make a lightweight API call to verify credentials # Example: List available models or check account status response = self._api_client.validate_api_key(credentials["api_key"]) # Verify the specific model is available if applicable if model not in response.get("available_models", []): raise CredentialsValidateFailedError(f"Model {model} is not available") except ApiException as e: raise CredentialsValidateFailedError(str(e))
You can alternatively raise these standardized error types directly in your code instead of relying on the error mapping. This approach gives you more control over error messages.
def get_num_tokens( self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Calculate the number of tokens in the prompt """ # Convert prompt_messages to the format expected by the tokenizer text = self._convert_messages_to_text(prompt_messages) try: # Use the appropriate tokenizer for this model tokenizer = self._get_tokenizer(model) return len(tokenizer.encode(text)) except Exception: # Fall back to a generic tokenizer return self._get_num_tokens_by_gpt2(text)
If the model doesn’t provide a tokenizer, you can use the base class’s _get_num_tokens_by_gpt2(text) method for a reasonable approximation.
def get_customizable_model_schema( self, model: str, credentials: dict) -> Optional[AIModelEntity]: """ Get parameter schema for custom models """ # For fine-tuned models, you might return the base model's schema if model.startswith("ft:"): base_model = self._extract_base_model(model) return self._get_predefined_model_schema(base_model) # For standard models, return None to use the predefined schema return None
This method is only necessary for providers that support custom models. It allows custom models to inherit parameter rules from base models.
Text embedding models convert text into high-dimensional vectors that capture semantic meaning, which is useful for retrieval, similarity search, and classification.
To implement a Text Embedding provider, inherit from the __base.text_embedding_model.TextEmbeddingModel base class:
def get_num_tokens( self, model: str, credentials: dict, texts: list[str]) -> int: """ Calculate the number of tokens in the texts to be embedded """ # Join all texts to estimate token count combined_text = " ".join(texts) try: # Use the appropriate tokenizer for this model tokenizer = self._get_tokenizer(model) return len(tokenizer.encode(combined_text)) except Exception: # Fall back to a generic tokenizer return self._get_num_tokens_by_gpt2(combined_text)
For embedding models, accurate token counting is important for cost estimation, but not critical for functionality. The _get_num_tokens_by_gpt2 method provides a reasonable approximation for most models.
Reranking models help improve search quality by re-ordering a set of candidate documents based on their relevance to a query, typically after an initial retrieval phase.
To implement a Reranking provider, inherit from the __base.rerank_model.RerankModel base class:
def _invoke( self, model: str, credentials: dict, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, user: Optional[str] = None) -> RerankResult: """ Rerank documents based on relevance to the query """ # Set up API client with credentials client = self._get_client(credentials) # Prepare request data request_data = { "query": query, "documents": docs, } # Call reranking API endpoint response = client.rerank( model=model, **request_data, user=user ) # Process results ranked_results = [] for i, result in enumerate(response.results): # Create RerankDocument for each result doc = RerankDocument( index=result.document_index, # Original index in docs list text=docs[result.document_index], # Original text score=result.relevance_score # Relevance score ) ranked_results.append(doc) # Sort by score in descending order ranked_results.sort(key=lambda x: x.score, reverse=True) # Apply score threshold filtering if specified if score_threshold is not None: ranked_results = [doc for doc in ranked_results if doc.score >= score_threshold] # Apply top_n limit if specified if top_n is not None and top_n > 0: ranked_results = ranked_results[:top_n] return RerankResult( model=model, docs=ranked_results )
docs: List of RerankDocument objects with index, text, and score.
Reranking can be computationally expensive, especially with large document sets. Implement batching for large document collections to avoid timeouts or excessive resource consumption.
Speech-to-text models convert spoken language from audio files into written text, enabling applications like transcription services, voice commands, and accessibility features.
To implement a Speech-to-Text provider, inherit from the __base.speech2text_model.Speech2TextModel base class:
def _invoke( self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: """ Convert speech audio to text """ # Set up API client with credentials client = self._get_client(credentials) try: # Determine the file format file_format = self._detect_audio_format(file) # Prepare the file for API submission # Most APIs require either a file path or binary data audio_data = file.read() # Call the speech-to-text API response = client.audio.transcriptions.create( model=model, file=("audio.mp3", audio_data), # Adjust filename based on actual format user=user ) # Extract and return the transcribed text return response.text except Exception as e: # Map to appropriate error type self._handle_api_error(e) finally: # Reset file pointer for potential reuse file.seek(0)
Audio format detection is important for proper handling of different file types. Consider implementing a helper method to detect the format from the file header as shown in the example.
Some speech-to-text APIs have file size limitations. Consider implementing chunking for large audio files if necessary.
Text-to-speech models convert written text into natural-sounding speech, enabling applications such as voice assistants, screen readers, and audio content generation.
To implement a Text-to-Speech provider, inherit from the __base.text2speech_model.Text2SpeechModel base class:
def _invoke( self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> Union[bytes, Generator[bytes, None, None]]: """ Convert text to speech audio """ # Set up API client with credentials client = self._get_client(credentials) # Get voice settings based on model voice = self._get_voice_for_model(model) try: # Choose implementation based on streaming preference if streaming: return self._stream_audio( client=client, model=model, text=content_text, voice=voice, user=user ) else: return self._generate_complete_audio( client=client, model=model, text=content_text, voice=voice, user=user ) except Exception as e: self._handle_api_error(e)
Most text-to-speech APIs require you to specify a voice along with the model. Consider implementing a mapping between Dify’s model identifiers and the provider’s voice options.
Long text inputs may need to be chunked for better speech synthesis quality. Consider implementing text preprocessing to handle punctuation, numbers, and special characters properly.
Moderation models analyze content for potentially harmful, inappropriate, or unsafe material, helping maintain platform safety and content policies.
To implement a Moderation provider, inherit from the __base.moderation_model.ModerationModel base class:
def _invoke( self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool: """ Analyze text for harmful content Returns: bool: False if the text is safe, True if it contains harmful content """ # Set up API client with credentials client = self._get_client(credentials) try: # Call moderation API response = client.moderations.create( model=model, input=text, user=user ) # Check if any categories were flagged result = response.results[0] # Return True if flagged in any category, False if safe return result.flagged except Exception as e: # Log the error but default to safe if there's an API issue. # This is a conservative approach; production systems might want # different fallback behavior. logger.error(f"Moderation API error: {str(e)}") return False
Moderation is often used as a safety mechanism. Consider the implications of false negatives (letting harmful content through) versus false positives (blocking safe content) when implementing your solution.
Many moderation APIs provide detailed category scores rather than just a binary result. Consider extending this implementation to return more detailed information about specific categories of harmful content if your application needs it.
Base class for message content. It exists only for type declarations—do not instantiate it directly.
class PromptMessageContent(BaseModel): """ Model class for prompt message content. """ type: PromptMessageContentType data: str # Content data
Content currently supports two types, text and image, and a single message can combine text with multiple images. Instantiate TextPromptMessageContent and ImagePromptMessageContent instead.
class TextPromptMessageContent(PromptMessageContent): """ Model class for text prompt message content. """ type: PromptMessageContentType = PromptMessageContentType.TEXT
When a message combines text and images, wrap the text in this entity and add it to the content list.
class ImagePromptMessageContent(PromptMessageContent): """ Model class for image prompt message content. """ class DETAIL(Enum): LOW = 'low' HIGH = 'high' type: PromptMessageContentType = PromptMessageContentType.IMAGE detail: DETAIL = DETAIL.LOW # Resolution
When a message combines text and images, wrap each image in this entity and add it to the content list. data accepts an image URL or a base64-encoded image string.
Base class for all role-specific messages. It exists only for type declarations—do not instantiate it directly.
class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ role: PromptMessageRole # Message role content: Optional[str | list[PromptMessageContent]] = None # Either a string or a content list; the list form supports multimodal input, see PromptMessageContent name: Optional[str] = None # Optional name
Represents a model response, typically used for few-shot examples or chat history input.
class AssistantPromptMessage(PromptMessage): """ Model class for assistant prompt message. """ class ToolCall(BaseModel): """ Model class for assistant prompt message tool call. """ class ToolCallFunction(BaseModel): """ Model class for assistant prompt message tool call function. """ name: str # Tool name arguments: str # Tool parameters id: str # Tool call ID; only meaningful for OpenAI tool calls. Uniquely identifies one invocation, since the same tool can be called multiple times type: str # Defaults to "function" function: ToolCallFunction # Tool call information role: PromptMessageRole = PromptMessageRole.ASSISTANT tool_calls: list[ToolCall] = [] # Model's tool call results (only returned when tools are passed in and the model decides to call them)
tool_calls holds the tool calls the model returns when the request includes tools.
Represents a tool message, which passes a tool’s execution result back to the model for next-step planning.
class ToolPromptMessage(PromptMessage): """ Model class for tool prompt message. """ role: PromptMessageRole = PromptMessageRole.TOOL tool_call_id: str # Tool call ID; if the provider doesn't support OpenAI tool calls, you can pass the tool name instead
Pass the tool’s execution result through the inherited content field.
class LLMResult(BaseModel): """ Model class for llm result. """ model: str # Model actually used prompt_messages: list[PromptMessage] # Prompt message list message: AssistantPromptMessage # Reply message usage: LLMUsage # Token usage and cost information system_fingerprint: Optional[str] = None # Request fingerprint; see OpenAI's parameter definition
The incremental delta within each chunk of a streaming response.
class LLMResultChunkDelta(BaseModel): """ Model class for llm result chunk delta. """ index: int # Sequence number message: AssistantPromptMessage # Reply message usage: Optional[LLMUsage] = None # Token usage and cost information; only returned in the last chunk finish_reason: Optional[str] = None # Completion reason; only returned in the last chunk
class LLMResultChunk(BaseModel): """ Model class for llm result chunk. """ model: str # Model actually used prompt_messages: list[PromptMessage] # Prompt message list system_fingerprint: Optional[str] = None # Request fingerprint; see OpenAI's parameter definition delta: LLMResultChunkDelta # Content changes in this chunk
class LLMUsage(ModelUsage): """ Model class for llm usage. """ prompt_tokens: int # Tokens used by the prompt prompt_unit_price: Decimal # Prompt unit price prompt_price_unit: Decimal # Prompt price unit, i.e., the number of tokens the unit price applies to prompt_price: Decimal # Prompt cost completion_tokens: int # Tokens used by the completion completion_unit_price: Decimal # Completion unit price completion_price_unit: Decimal # Completion price unit, i.e., the number of tokens the unit price applies to completion_price: Decimal # Completion cost total_tokens: int # Total tokens used total_price: Decimal # Total cost currency: str # Currency unit latency: float # Request latency in seconds
class TextEmbeddingResult(BaseModel): """ Model class for text embedding result. """ model: str # Model actually used embeddings: list[list[float]] # Embedding vectors, in the same order as the input texts usage: EmbeddingUsage # Usage information
class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. """ tokens: int # Tokens used total_tokens: int # Total tokens used unit_price: Decimal # Unit price price_unit: Decimal # Price unit, i.e., the number of tokens the unit price applies to total_price: Decimal # Total cost currency: str # Currency unit latency: float # Request latency in seconds
class RerankResult(BaseModel): """ Model class for rerank result. """ model: str # Model actually used docs: list[RerankDocument] # List of reranked documents
class RerankDocument(BaseModel): """ Model class for rerank document. """ index: int # Index in the original docs list text: str # Document text score: float # Relevance score