Files
mistral-go-sdk/audio_api.go
vikingowl 776b693f2d chore: move module path to github.com/VikingOwl91/mistral-go-sdk
Public discoverability on pkg.go.dev. Also fixes stream tool call
test fixture to match real Mistral API responses (finish_reason, usage).
2026-04-03 12:01:11 +02:00

222 lines
6.2 KiB
Go

package mistral
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/VikingOwl91/mistral-go-sdk/audio"
)
// Transcribe sends an audio file for transcription.
// If file is non-nil, it's uploaded via multipart. Otherwise, FileURL or FileID must be set.
func (c *Client) Transcribe(ctx context.Context, req *audio.TranscriptionRequest, filename string, file io.Reader) (*audio.TranscriptionResponse, error) {
if file != nil {
fields := map[string]string{"model": req.Model}
if req.Language != nil {
fields["language"] = *req.Language
}
if req.FileID != nil {
fields["file_id"] = *req.FileID
}
if req.Diarize {
fields["diarize"] = "true"
}
fields["stream"] = "false"
var resp audio.TranscriptionResponse
if err := c.doMultipart(ctx, "/v1/audio/transcriptions", filename, file, fields, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// JSON body request (file_url or file_id)
var resp audio.TranscriptionResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/transcriptions", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// TranscribeStream sends an audio file for streaming transcription.
func (c *Client) TranscribeStream(ctx context.Context, req *audio.TranscriptionRequest, filename string, file io.Reader) (*AudioStream, error) {
fields := map[string]string{"model": req.Model, "stream": "true"}
if req.Language != nil {
fields["language"] = *req.Language
}
if req.Diarize {
fields["diarize"] = "true"
}
resp, err := c.doMultipartStream(ctx, "/v1/audio/transcriptions", filename, file, fields)
if err != nil {
return nil, err
}
return newAudioStream(resp.Body), nil
}
// AudioStream wraps the generic Stream to provide typed audio events.
type AudioStream struct {
stream *Stream[json.RawMessage]
event audio.StreamEvent
err error
}
func newAudioStream(body readCloser) *AudioStream {
return &AudioStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *AudioStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := audio.UnmarshalStreamEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *AudioStream) Current() audio.StreamEvent { return s.event }
// Err returns any error encountered during streaming.
func (s *AudioStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *AudioStream) Close() error { return s.stream.Close() }
// Speech sends a text-to-speech request and returns the full response.
func (c *Client) Speech(ctx context.Context, req *audio.SpeechRequest) (*audio.SpeechResponse, error) {
var resp audio.SpeechResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/speech", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// SpeechStream sends a text-to-speech request and returns a stream of audio events.
func (c *Client) SpeechStream(ctx context.Context, req *audio.SpeechRequest) (*SpeechStream, error) {
req.EnableStream()
resp, err := c.doStream(ctx, "POST", "/v1/audio/speech", req)
if err != nil {
return nil, err
}
return newSpeechStream(resp.Body), nil
}
// SpeechStream wraps the generic Stream for speech streaming events.
type SpeechStream struct {
stream *Stream[json.RawMessage]
event audio.SpeechStreamEvent
err error
}
func newSpeechStream(body readCloser) *SpeechStream {
return &SpeechStream{
stream: newStream[json.RawMessage](body),
}
}
// Next advances to the next event. Returns false when done or on error.
func (s *SpeechStream) Next() bool {
if s.err != nil {
return false
}
if !s.stream.Next() {
s.err = s.stream.Err()
return false
}
event, err := audio.UnmarshalSpeechStreamEvent(s.stream.Current())
if err != nil {
s.err = err
return false
}
s.event = event
return true
}
// Current returns the most recently read event.
func (s *SpeechStream) Current() audio.SpeechStreamEvent { return s.event }
// Err returns any error encountered during streaming.
func (s *SpeechStream) Err() error { return s.err }
// Close releases the underlying connection.
func (s *SpeechStream) Close() error { return s.stream.Close() }
// ListVoices returns available voices.
func (c *Client) ListVoices(ctx context.Context) (*audio.VoiceListResponse, error) {
var resp audio.VoiceListResponse
if err := c.doJSON(ctx, "GET", "/v1/audio/voices", nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// CreateVoice creates a custom voice.
func (c *Client) CreateVoice(ctx context.Context, req *audio.VoiceCreateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "POST", "/v1/audio/voices", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// GetVoice retrieves a voice by ID.
func (c *Client) GetVoice(ctx context.Context, voiceID string) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// UpdateVoice updates a voice.
func (c *Client) UpdateVoice(ctx context.Context, voiceID string, req *audio.VoiceUpdateRequest) (*audio.VoiceResponse, error) {
var resp audio.VoiceResponse
if err := c.doJSON(ctx, "PATCH", fmt.Sprintf("/v1/audio/voices/%s", voiceID), req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// DeleteVoice deletes a voice.
func (c *Client) DeleteVoice(ctx context.Context, voiceID string) error {
resp, err := c.do(ctx, "DELETE", fmt.Sprintf("/v1/audio/voices/%s", voiceID), nil)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode >= 400 {
return parseAPIError(resp)
}
return nil
}
// GetVoiceSampleAudio retrieves the sample audio for a voice.
// Returns the raw HTTP response; the caller must close the body.
func (c *Client) GetVoiceSampleAudio(ctx context.Context, voiceID string) (*http.Response, error) {
resp, err := c.do(ctx, "GET", fmt.Sprintf("/v1/audio/voices/%s/sample", voiceID), nil)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
defer resp.Body.Close()
return nil, parseAPIError(resp)
}
return resp, nil
}