never executed always true always false
1 module PureClaw.Providers.OpenAI
2 ( -- * Provider type
3 OpenAIProvider
4 , mkOpenAIProvider
5 -- * Errors
6 , OpenAIError (..)
7 -- * Request/response encoding (exported for testing)
8 , encodeRequest
9 , decodeResponse
10 ) where
11
12 import Control.Exception
13 import Data.Aeson
14 import Data.Aeson.Types
15 import Data.ByteString (ByteString)
16 import Data.ByteString.Lazy qualified as BL
17 import Data.Maybe
18 import Data.Text (Text)
19 import Data.Text qualified as T
20 import Data.Text.Encoding qualified as TE
21 import Network.HTTP.Client qualified as HTTP
22 import Network.HTTP.Types.Status qualified as Status
23
24 import PureClaw.Core.Errors
25 import PureClaw.Core.Types
26 import PureClaw.Providers.Class
27 import PureClaw.Security.Secrets
28
29 -- | OpenAI API provider.
30 data OpenAIProvider = OpenAIProvider
31 { _oai_manager :: HTTP.Manager
32 , _oai_apiKey :: ApiKey
33 , _oai_baseUrl :: String
34 }
35
36 -- | Create an OpenAI provider. Uses the standard OpenAI API base URL.
37 mkOpenAIProvider :: HTTP.Manager -> ApiKey -> OpenAIProvider
38 mkOpenAIProvider mgr key = OpenAIProvider mgr key "https://api.openai.com/v1/chat/completions"
39
40 instance Provider OpenAIProvider where
41 complete = openAIComplete
42
43 -- | Errors from the OpenAI API.
44 data OpenAIError
45 = OpenAIAPIError Int ByteString
46 | OpenAIParseError Text
47 deriving stock (Show)
48
49 instance Exception OpenAIError
50
51 instance ToPublicError OpenAIError where
52 toPublicError (OpenAIAPIError 429 _) = RateLimitError
53 toPublicError (OpenAIAPIError 401 _) = NotAllowedError
54 toPublicError _ = TemporaryError "Provider error"
55
56 openAIComplete :: OpenAIProvider -> CompletionRequest -> IO CompletionResponse
57 openAIComplete provider req = do
58 initReq <- HTTP.parseRequest (_oai_baseUrl provider)
59 let httpReq = initReq
60 { HTTP.method = "POST"
61 , HTTP.requestBody = HTTP.RequestBodyLBS (encodeRequest req)
62 , HTTP.requestHeaders =
63 [ ("Authorization", "Bearer " <> withApiKey (_oai_apiKey provider) id)
64 , ("content-type", "application/json")
65 ]
66 }
67 resp <- HTTP.httpLbs httpReq (_oai_manager provider)
68 let status = Status.statusCode (HTTP.responseStatus resp)
69 if status /= 200
70 then throwIO (OpenAIAPIError status (BL.toStrict (HTTP.responseBody resp)))
71 else case decodeResponse (HTTP.responseBody resp) of
72 Left err -> throwIO (OpenAIParseError (T.pack err))
73 Right response -> pure response
74
75 -- | Encode a completion request as OpenAI Chat Completions JSON.
76 encodeRequest :: CompletionRequest -> BL.ByteString
77 encodeRequest req = encode $ object $
78 [ "model" .= unModelId (_cr_model req)
79 , "messages" .= encodeMessages req
80 ]
81 ++ maybe [] (\mt -> ["max_tokens" .= mt]) (_cr_maxTokens req)
82 ++ ["tools" .= map encodeTool (_cr_tools req) | not (null (_cr_tools req))]
83 ++ maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
84
85 -- | OpenAI puts system prompt as a system message in the messages array.
86 encodeMessages :: CompletionRequest -> [Value]
87 encodeMessages req =
88 maybe [] (\s -> [object ["role" .= ("system" :: Text), "content" .= s]]) (_cr_systemPrompt req)
89 ++ map encodeMsg (_cr_messages req)
90
91 encodeMsg :: Message -> Value
92 encodeMsg msg = case _msg_content msg of
93 [TextBlock t] ->
94 -- Simple text message — use string content for compatibility
95 object ["role" .= roleToText (_msg_role msg), "content" .= t]
96 blocks ->
97 object [ "role" .= roleToText (_msg_role msg)
98 , "content" .= map encodeContentBlock blocks
99 ]
100
101 encodeContentBlock :: ContentBlock -> Value
102 encodeContentBlock (TextBlock t) = object
103 [ "type" .= ("text" :: Text), "text" .= t ]
104 encodeContentBlock (ToolUseBlock callId name input) = object
105 [ "type" .= ("function" :: Text)
106 , "id" .= unToolCallId callId
107 , "function" .= object ["name" .= name, "arguments" .= decodeUtf8Lenient (BL.toStrict (encode input))]
108 ]
109 where
110 decodeUtf8Lenient = TE.decodeUtf8
111 encodeContentBlock (ToolResultBlock callId content _) = object
112 [ "type" .= ("tool_result" :: Text)
113 , "tool_call_id" .= unToolCallId callId
114 , "content" .= content
115 ]
116
117 encodeTool :: ToolDefinition -> Value
118 encodeTool td = object
119 [ "type" .= ("function" :: Text)
120 , "function" .= object
121 [ "name" .= _td_name td
122 , "description" .= _td_description td
123 , "parameters" .= _td_inputSchema td
124 ]
125 ]
126
127 encodeToolChoice :: ToolChoice -> Value
128 encodeToolChoice AutoTool = String "auto"
129 encodeToolChoice AnyTool = String "required"
130 encodeToolChoice (SpecificTool name) = object
131 [ "type" .= ("function" :: Text)
132 , "function" .= object ["name" .= name]
133 ]
134
135 -- | Decode an OpenAI Chat Completions response.
136 decodeResponse :: BL.ByteString -> Either String CompletionResponse
137 decodeResponse bs = eitherDecode bs >>= parseEither parseResp
138 where
139 parseResp :: Value -> Parser CompletionResponse
140 parseResp = withObject "OpenAIResponse" $ \o -> do
141 choices <- o .: "choices"
142 case choices of
143 [] -> fail "No choices in response"
144 (firstChoice : _) -> do
145 msg <- firstChoice .: "message"
146 blocks <- parseMessage msg
147 modelText <- o .: "model"
148 usageObj <- o .:? "usage"
149 usage <- case usageObj of
150 Nothing -> pure Nothing
151 Just u -> do
152 inToks <- u .: "prompt_tokens"
153 outToks <- u .: "completion_tokens"
154 pure (Just (Usage inToks outToks))
155 pure CompletionResponse
156 { _crsp_content = blocks
157 , _crsp_model = ModelId modelText
158 , _crsp_usage = usage
159 }
160
161 parseMessage :: Value -> Parser [ContentBlock]
162 parseMessage = withObject "Message" $ \m -> do
163 contentVal <- m .:? "content"
164 toolCalls <- m .:? "tool_calls" .!= ([] :: [Value])
165 let textBlocks = case contentVal of
166 Just (String t) | not (T.null t) -> [TextBlock t]
167 _ -> []
168 toolBlocks <- mapM parseToolCall toolCalls
169 pure (textBlocks ++ toolBlocks)
170
171 parseToolCall :: Value -> Parser ContentBlock
172 parseToolCall = withObject "ToolCall" $ \tc -> do
173 callId <- tc .: "id"
174 fn <- tc .: "function"
175 name <- fn .: "name"
176 argsStr <- fn .: "arguments"
177 let input = fromMaybe (object []) (decode (BL.fromStrict (TE.encodeUtf8 argsStr)))
178 pure (ToolUseBlock (ToolCallId callId) name input)