never executed always true always false
    1 module PureClaw.Providers.OpenAI
    2   ( -- * Provider type
    3     OpenAIProvider
    4   , mkOpenAIProvider
    5     -- * Errors
    6   , OpenAIError (..)
    7     -- * Request/response encoding (exported for testing)
    8   , encodeRequest
    9   , decodeResponse
   10   ) where
   11 
   12 import Control.Exception
   13 import Data.Aeson
   14 import Data.Aeson.Types
   15 import Data.ByteString (ByteString)
   16 import Data.ByteString.Lazy qualified as BL
   17 import Data.Maybe
   18 import Data.Text (Text)
   19 import Data.Text qualified as T
   20 import Data.Text.Encoding qualified as TE
   21 import Network.HTTP.Client qualified as HTTP
   22 import Network.HTTP.Types.Status qualified as Status
   23 
   24 import PureClaw.Core.Errors
   25 import PureClaw.Core.Types
   26 import PureClaw.Providers.Class
   27 import PureClaw.Security.Secrets
   28 
   29 -- | OpenAI API provider.
   30 data OpenAIProvider = OpenAIProvider
   31   { _oai_manager :: HTTP.Manager
   32   , _oai_apiKey  :: ApiKey
   33   , _oai_baseUrl :: String
   34   }
   35 
   36 -- | Create an OpenAI provider. Uses the standard OpenAI API base URL.
   37 mkOpenAIProvider :: HTTP.Manager -> ApiKey -> OpenAIProvider
   38 mkOpenAIProvider mgr key = OpenAIProvider mgr key "https://api.openai.com/v1/chat/completions"
   39 
   40 instance Provider OpenAIProvider where
   41   complete = openAIComplete
   42 
   43 -- | Errors from the OpenAI API.
   44 data OpenAIError
   45   = OpenAIAPIError Int ByteString
   46   | OpenAIParseError Text
   47   deriving stock (Show)
   48 
   49 instance Exception OpenAIError
   50 
   51 instance ToPublicError OpenAIError where
   52   toPublicError (OpenAIAPIError 429 _) = RateLimitError
   53   toPublicError (OpenAIAPIError 401 _) = NotAllowedError
   54   toPublicError _                       = TemporaryError "Provider error"
   55 
   56 openAIComplete :: OpenAIProvider -> CompletionRequest -> IO CompletionResponse
   57 openAIComplete provider req = do
   58   initReq <- HTTP.parseRequest (_oai_baseUrl provider)
   59   let httpReq = initReq
   60         { HTTP.method = "POST"
   61         , HTTP.requestBody = HTTP.RequestBodyLBS (encodeRequest req)
   62         , HTTP.requestHeaders =
   63             [ ("Authorization", "Bearer " <> withApiKey (_oai_apiKey provider) id)
   64             , ("content-type", "application/json")
   65             ]
   66         }
   67   resp <- HTTP.httpLbs httpReq (_oai_manager provider)
   68   let status = Status.statusCode (HTTP.responseStatus resp)
   69   if status /= 200
   70     then throwIO (OpenAIAPIError status (BL.toStrict (HTTP.responseBody resp)))
   71     else case decodeResponse (HTTP.responseBody resp) of
   72       Left err -> throwIO (OpenAIParseError (T.pack err))
   73       Right response -> pure response
   74 
   75 -- | Encode a completion request as OpenAI Chat Completions JSON.
   76 encodeRequest :: CompletionRequest -> BL.ByteString
   77 encodeRequest req = encode $ object $
   78   [ "model"    .= unModelId (_cr_model req)
   79   , "messages" .= encodeMessages req
   80   ]
   81   ++ maybe [] (\mt -> ["max_tokens" .= mt]) (_cr_maxTokens req)
   82   ++ ["tools" .= map encodeTool (_cr_tools req) | not (null (_cr_tools req))]
   83   ++ maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
   84 
   85 -- | OpenAI puts system prompt as a system message in the messages array.
   86 encodeMessages :: CompletionRequest -> [Value]
   87 encodeMessages req =
   88   maybe [] (\s -> [object ["role" .= ("system" :: Text), "content" .= s]]) (_cr_systemPrompt req)
   89   ++ map encodeMsg (_cr_messages req)
   90 
   91 encodeMsg :: Message -> Value
   92 encodeMsg msg = case _msg_content msg of
   93   [TextBlock t] ->
   94     -- Simple text message — use string content for compatibility
   95     object ["role" .= roleToText (_msg_role msg), "content" .= t]
   96   blocks ->
   97     object [ "role"    .= roleToText (_msg_role msg)
   98            , "content" .= map encodeContentBlock blocks
   99            ]
  100 
  101 encodeContentBlock :: ContentBlock -> Value
  102 encodeContentBlock (TextBlock t) = object
  103   [ "type" .= ("text" :: Text), "text" .= t ]
  104 encodeContentBlock (ImageBlock mediaType imageData) = object
  105   [ "type" .= ("image_url" :: Text)
  106   , "image_url" .= object
  107       [ "url" .= ("data:" <> mediaType <> ";base64," <> TE.decodeUtf8 imageData) ]
  108   ]
  109 encodeContentBlock (ToolUseBlock callId name input) = object
  110   [ "type" .= ("function" :: Text)
  111   , "id"   .= unToolCallId callId
  112   , "function" .= object ["name" .= name, "arguments" .= TE.decodeUtf8 (BL.toStrict (encode input))]
  113   ]
  114 encodeContentBlock (ToolResultBlock callId parts _) = object
  115   [ "type"         .= ("tool_result" :: Text)
  116   , "tool_call_id" .= unToolCallId callId
  117   , "content"      .= T.intercalate "\n" [t | TRPText t <- parts]
  118   ]
  119 
  120 encodeTool :: ToolDefinition -> Value
  121 encodeTool td = object
  122   [ "type" .= ("function" :: Text)
  123   , "function" .= object
  124       [ "name"        .= _td_name td
  125       , "description" .= _td_description td
  126       , "parameters"  .= _td_inputSchema td
  127       ]
  128   ]
  129 
  130 encodeToolChoice :: ToolChoice -> Value
  131 encodeToolChoice AutoTool = String "auto"
  132 encodeToolChoice AnyTool = String "required"
  133 encodeToolChoice (SpecificTool name) = object
  134   [ "type" .= ("function" :: Text)
  135   , "function" .= object ["name" .= name]
  136   ]
  137 
  138 -- | Decode an OpenAI Chat Completions response.
  139 decodeResponse :: BL.ByteString -> Either String CompletionResponse
  140 decodeResponse bs = eitherDecode bs >>= parseEither parseResp
  141   where
  142     parseResp :: Value -> Parser CompletionResponse
  143     parseResp = withObject "OpenAIResponse" $ \o -> do
  144       choices <- o .: "choices"
  145       case choices of
  146         [] -> fail "No choices in response"
  147         (firstChoice : _) -> do
  148           msg <- firstChoice .: "message"
  149           blocks <- parseMessage msg
  150           modelText <- o .: "model"
  151           usageObj <- o .:? "usage"
  152           usage <- case usageObj of
  153             Nothing -> pure Nothing
  154             Just u -> do
  155               inToks <- u .: "prompt_tokens"
  156               outToks <- u .: "completion_tokens"
  157               pure (Just (Usage inToks outToks))
  158           pure CompletionResponse
  159             { _crsp_content = blocks
  160             , _crsp_model   = ModelId modelText
  161             , _crsp_usage   = usage
  162             }
  163 
  164     parseMessage :: Value -> Parser [ContentBlock]
  165     parseMessage = withObject "Message" $ \m -> do
  166       contentVal <- m .:? "content"
  167       toolCalls <- m .:? "tool_calls" .!= ([] :: [Value])
  168       let textBlocks = case contentVal of
  169             Just (String t) | not (T.null t) -> [TextBlock t]
  170             _ -> []
  171       toolBlocks <- mapM parseToolCall toolCalls
  172       pure (textBlocks ++ toolBlocks)
  173 
  174     parseToolCall :: Value -> Parser ContentBlock
  175     parseToolCall = withObject "ToolCall" $ \tc -> do
  176       callId <- tc .: "id"
  177       fn <- tc .: "function"
  178       name <- fn .: "name"
  179       argsStr <- fn .: "arguments"
  180       let input = fromMaybe (object []) (decode (BL.fromStrict (TE.encodeUtf8 argsStr)))
  181       pure (ToolUseBlock (ToolCallId callId) name input)