never executed always true always false
1 module PureClaw.Providers.OpenAI
2 ( -- * Provider type
3 OpenAIProvider
4 , mkOpenAIProvider
5 -- * Errors
6 , OpenAIError (..)
7 -- * Request/response encoding (exported for testing)
8 , encodeRequest
9 , decodeResponse
10 ) where
11
12 import Control.Exception
13 import Data.Aeson
14 import Data.Aeson.Types
15 import Data.ByteString (ByteString)
16 import Data.ByteString.Lazy qualified as BL
17 import Data.Maybe
18 import Data.Text (Text)
19 import Data.Text qualified as T
20 import Data.Text.Encoding qualified as TE
21 import Network.HTTP.Client qualified as HTTP
22 import Network.HTTP.Types.Status qualified as Status
23
24 import PureClaw.Core.Errors
25 import PureClaw.Core.Types
26 import PureClaw.Providers.Class
27 import PureClaw.Security.Secrets
28
29 -- | OpenAI API provider.
30 data OpenAIProvider = OpenAIProvider
31 { _oai_manager :: HTTP.Manager
32 , _oai_apiKey :: ApiKey
33 , _oai_baseUrl :: String
34 }
35
36 -- | Create an OpenAI provider. Uses the standard OpenAI API base URL.
37 mkOpenAIProvider :: HTTP.Manager -> ApiKey -> OpenAIProvider
38 mkOpenAIProvider mgr key = OpenAIProvider mgr key "https://api.openai.com/v1/chat/completions"
39
40 instance Provider OpenAIProvider where
41 complete = openAIComplete
42
43 -- | Errors from the OpenAI API.
44 data OpenAIError
45 = OpenAIAPIError Int ByteString
46 | OpenAIParseError Text
47 deriving stock (Show)
48
49 instance Exception OpenAIError
50
51 instance ToPublicError OpenAIError where
52 toPublicError (OpenAIAPIError 429 _) = RateLimitError
53 toPublicError (OpenAIAPIError 401 _) = NotAllowedError
54 toPublicError _ = TemporaryError "Provider error"
55
56 openAIComplete :: OpenAIProvider -> CompletionRequest -> IO CompletionResponse
57 openAIComplete provider req = do
58 initReq <- HTTP.parseRequest (_oai_baseUrl provider)
59 let httpReq = initReq
60 { HTTP.method = "POST"
61 , HTTP.requestBody = HTTP.RequestBodyLBS (encodeRequest req)
62 , HTTP.requestHeaders =
63 [ ("Authorization", "Bearer " <> withApiKey (_oai_apiKey provider) id)
64 , ("content-type", "application/json")
65 ]
66 }
67 resp <- HTTP.httpLbs httpReq (_oai_manager provider)
68 let status = Status.statusCode (HTTP.responseStatus resp)
69 if status /= 200
70 then throwIO (OpenAIAPIError status (BL.toStrict (HTTP.responseBody resp)))
71 else case decodeResponse (HTTP.responseBody resp) of
72 Left err -> throwIO (OpenAIParseError (T.pack err))
73 Right response -> pure response
74
75 -- | Encode a completion request as OpenAI Chat Completions JSON.
76 encodeRequest :: CompletionRequest -> BL.ByteString
77 encodeRequest req = encode $ object $
78 [ "model" .= unModelId (_cr_model req)
79 , "messages" .= encodeMessages req
80 ]
81 ++ maybe [] (\mt -> ["max_tokens" .= mt]) (_cr_maxTokens req)
82 ++ ["tools" .= map encodeTool (_cr_tools req) | not (null (_cr_tools req))]
83 ++ maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
84
85 -- | OpenAI puts system prompt as a system message in the messages array.
86 encodeMessages :: CompletionRequest -> [Value]
87 encodeMessages req =
88 maybe [] (\s -> [object ["role" .= ("system" :: Text), "content" .= s]]) (_cr_systemPrompt req)
89 ++ map encodeMsg (_cr_messages req)
90
91 encodeMsg :: Message -> Value
92 encodeMsg msg = case _msg_content msg of
93 [TextBlock t] ->
94 -- Simple text message — use string content for compatibility
95 object ["role" .= roleToText (_msg_role msg), "content" .= t]
96 blocks ->
97 object [ "role" .= roleToText (_msg_role msg)
98 , "content" .= map encodeContentBlock blocks
99 ]
100
101 encodeContentBlock :: ContentBlock -> Value
102 encodeContentBlock (TextBlock t) = object
103 [ "type" .= ("text" :: Text), "text" .= t ]
104 encodeContentBlock (ImageBlock mediaType imageData) = object
105 [ "type" .= ("image_url" :: Text)
106 , "image_url" .= object
107 [ "url" .= ("data:" <> mediaType <> ";base64," <> TE.decodeUtf8 imageData) ]
108 ]
109 encodeContentBlock (ToolUseBlock callId name input) = object
110 [ "type" .= ("function" :: Text)
111 , "id" .= unToolCallId callId
112 , "function" .= object ["name" .= name, "arguments" .= TE.decodeUtf8 (BL.toStrict (encode input))]
113 ]
114 encodeContentBlock (ToolResultBlock callId parts _) = object
115 [ "type" .= ("tool_result" :: Text)
116 , "tool_call_id" .= unToolCallId callId
117 , "content" .= T.intercalate "\n" [t | TRPText t <- parts]
118 ]
119
120 encodeTool :: ToolDefinition -> Value
121 encodeTool td = object
122 [ "type" .= ("function" :: Text)
123 , "function" .= object
124 [ "name" .= _td_name td
125 , "description" .= _td_description td
126 , "parameters" .= _td_inputSchema td
127 ]
128 ]
129
130 encodeToolChoice :: ToolChoice -> Value
131 encodeToolChoice AutoTool = String "auto"
132 encodeToolChoice AnyTool = String "required"
133 encodeToolChoice (SpecificTool name) = object
134 [ "type" .= ("function" :: Text)
135 , "function" .= object ["name" .= name]
136 ]
137
138 -- | Decode an OpenAI Chat Completions response.
139 decodeResponse :: BL.ByteString -> Either String CompletionResponse
140 decodeResponse bs = eitherDecode bs >>= parseEither parseResp
141 where
142 parseResp :: Value -> Parser CompletionResponse
143 parseResp = withObject "OpenAIResponse" $ \o -> do
144 choices <- o .: "choices"
145 case choices of
146 [] -> fail "No choices in response"
147 (firstChoice : _) -> do
148 msg <- firstChoice .: "message"
149 blocks <- parseMessage msg
150 modelText <- o .: "model"
151 usageObj <- o .:? "usage"
152 usage <- case usageObj of
153 Nothing -> pure Nothing
154 Just u -> do
155 inToks <- u .: "prompt_tokens"
156 outToks <- u .: "completion_tokens"
157 pure (Just (Usage inToks outToks))
158 pure CompletionResponse
159 { _crsp_content = blocks
160 , _crsp_model = ModelId modelText
161 , _crsp_usage = usage
162 }
163
164 parseMessage :: Value -> Parser [ContentBlock]
165 parseMessage = withObject "Message" $ \m -> do
166 contentVal <- m .:? "content"
167 toolCalls <- m .:? "tool_calls" .!= ([] :: [Value])
168 let textBlocks = case contentVal of
169 Just (String t) | not (T.null t) -> [TextBlock t]
170 _ -> []
171 toolBlocks <- mapM parseToolCall toolCalls
172 pure (textBlocks ++ toolBlocks)
173
174 parseToolCall :: Value -> Parser ContentBlock
175 parseToolCall = withObject "ToolCall" $ \tc -> do
176 callId <- tc .: "id"
177 fn <- tc .: "function"
178 name <- fn .: "name"
179 argsStr <- fn .: "arguments"
180 let input = fromMaybe (object []) (decode (BL.fromStrict (TE.encodeUtf8 argsStr)))
181 pure (ToolUseBlock (ToolCallId callId) name input)