never executed always true always false
    1 module PureClaw.Providers.Anthropic
    2   ( -- * Provider type (constructor intentionally NOT exported)
    3     AnthropicProvider
    4   , mkAnthropicProvider
    5   , mkAnthropicProviderOAuth
    6     -- * Auth headers (exported for testing)
    7   , buildAuthHeaders
    8     -- * Errors
    9   , AnthropicError (..)
   10     -- * Request/response encoding (exported for testing)
   11   , encodeRequest
   12   , decodeResponse
   13     -- * SSE parsing (exported for testing)
   14   , parseSSELine
   15   ) where
   16 
   17 import Control.Exception
   18 import Data.Aeson
   19 import Data.Aeson.Types
   20 import Data.ByteString (ByteString)
   21 import Data.ByteString qualified as BS
   22 import Data.ByteString.Lazy qualified as BL
   23 import Data.IORef
   24 import Data.Maybe
   25 import Data.Text (Text)
   26 import Data.Text qualified as T
   27 import Data.Text.Encoding qualified as TE
   28 import Data.Time.Clock (getCurrentTime)
   29 import Network.HTTP.Client qualified as HTTP
   30 import Network.HTTP.Types.Header qualified as Header
   31 import Network.HTTP.Types.Status qualified as Status
   32 
   33 import PureClaw.Auth.AnthropicOAuth
   34 import PureClaw.Core.Errors
   35 import PureClaw.Core.Types
   36 import PureClaw.Providers.Class
   37 import PureClaw.Security.Secrets
   38 
   39 -- | Authentication method for the Anthropic provider.
   40 data AnthropicAuth
   41   = ApiKeyAuth ApiKey
   42     -- ^ Authenticate with a static API key via the @x-api-key@ header.
   43   | OAuthAuth OAuthHandle
   44     -- ^ Authenticate via OAuth 2.0; access tokens are refreshed automatically.
   45 
   46 -- | Anthropic API provider. Constructor is not exported — use
   47 -- 'mkAnthropicProvider' or 'mkAnthropicProviderOAuth'.
   48 data AnthropicProvider = AnthropicProvider
   49   { _ap_manager :: HTTP.Manager
   50   , _ap_auth    :: AnthropicAuth
   51   }
   52 
   53 -- | Create an Anthropic provider authenticated with an API key.
   54 mkAnthropicProvider :: HTTP.Manager -> ApiKey -> AnthropicProvider
   55 mkAnthropicProvider m k = AnthropicProvider m (ApiKeyAuth k)
   56 
   57 -- | Create an Anthropic provider authenticated via OAuth.
   58 -- Tokens are refreshed automatically when they expire.
   59 mkAnthropicProviderOAuth :: HTTP.Manager -> OAuthHandle -> AnthropicProvider
   60 mkAnthropicProviderOAuth m h = AnthropicProvider m (OAuthAuth h)
   61 
   62 -- | Build the authentication headers for a request.
   63 -- For OAuth, checks token expiry and refreshes if needed.
   64 buildAuthHeaders :: AnthropicProvider -> IO [Header.Header]
   65 buildAuthHeaders p = case _ap_auth p of
   66   ApiKeyAuth key ->
   67     pure [ ("x-api-key",          withApiKey key id)
   68          , ("anthropic-version",  "2023-06-01")
   69          , ("content-type",       "application/json")
   70          ]
   71   OAuthAuth oauthHandle -> do
   72     tokens <- readIORef (_oah_tokensRef oauthHandle)
   73     now    <- getCurrentTime
   74     freshTokens <-
   75       if _oat_expiresAt tokens <= now
   76         then do
   77           newT <- _oah_refresh oauthHandle (_oat_refreshToken tokens)
   78           writeIORef (_oah_tokensRef oauthHandle) newT
   79           pure newT
   80         else pure tokens
   81     pure [ ("authorization",     "Bearer " <> withBearerToken (_oat_accessToken freshTokens) id)
   82          , ("anthropic-version", "2023-06-01")
   83          , ("anthropic-beta",    "oauth-2025-04-20")
   84          , ("content-type",      "application/json")
   85          ]
   86 
   87 instance Provider AnthropicProvider where
   88   complete = anthropicComplete
   89   completeStream = anthropicCompleteStream
   90 
   91 -- | Errors from the Anthropic API.
   92 data AnthropicError
   93   = AnthropicAPIError Int ByteString   -- ^ HTTP status code + response body
   94   | AnthropicParseError Text           -- ^ JSON parse/decode error
   95   deriving stock (Show)
   96 
   97 instance Exception AnthropicError
   98 
   99 instance ToPublicError AnthropicError where
  100   toPublicError (AnthropicAPIError 429 _) = RateLimitError
  101   toPublicError (AnthropicAPIError 401 _) = NotAllowedError
  102   toPublicError _                         = TemporaryError "Provider error"
  103 
  104 -- | Anthropic Messages API base URL.
  105 anthropicBaseUrl :: String
  106 anthropicBaseUrl = "https://api.anthropic.com/v1/messages"
  107 
  108 -- | Call the Anthropic Messages API.
  109 anthropicComplete :: AnthropicProvider -> CompletionRequest -> IO CompletionResponse
  110 anthropicComplete provider req = do
  111   authHeaders <- buildAuthHeaders provider
  112   initReq <- HTTP.parseRequest anthropicBaseUrl
  113   let httpReq = initReq
  114         { HTTP.method         = "POST"
  115         , HTTP.requestBody    = HTTP.RequestBodyLBS (encodeRequest req)
  116         , HTTP.requestHeaders = authHeaders
  117         }
  118   resp <- HTTP.httpLbs httpReq (_ap_manager provider)
  119   let status = Status.statusCode (HTTP.responseStatus resp)
  120   if status /= 200
  121     then throwIO (AnthropicAPIError status (BL.toStrict (HTTP.responseBody resp)))
  122     else case decodeResponse (HTTP.responseBody resp) of
  123       Left err -> throwIO (AnthropicParseError (T.pack err))
  124       Right response -> pure response
  125 
  126 -- | Encode a completion request as JSON for the Anthropic API.
  127 encodeRequest :: CompletionRequest -> BL.ByteString
  128 encodeRequest req = encode $ object $
  129   [ "model"      .= unModelId (_cr_model req)
  130   , "max_tokens" .= fromMaybe 4096 (_cr_maxTokens req)
  131   , "messages"   .= map encodeMsg (_cr_messages req)
  132   ]
  133   ++ maybe [] (\s -> ["system" .= s]) (_cr_systemPrompt req)
  134   ++ if null (_cr_tools req)
  135      then maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  136      else ("tools" .= map encodeTool (_cr_tools req))
  137         : maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  138 
  139 encodeMsg :: Message -> Value
  140 encodeMsg msg = object
  141   [ "role"    .= roleToText (_msg_role msg)
  142   , "content" .= map encodeContentBlock (_msg_content msg)
  143   ]
  144 
  145 encodeContentBlock :: ContentBlock -> Value
  146 encodeContentBlock (TextBlock t) = object
  147   [ "type" .= ("text" :: Text)
  148   , "text" .= t
  149   ]
  150 encodeContentBlock (ToolUseBlock callId name input) = object
  151   [ "type"  .= ("tool_use" :: Text)
  152   , "id"    .= unToolCallId callId
  153   , "name"  .= name
  154   , "input" .= input
  155   ]
  156 encodeContentBlock (ImageBlock mediaType imageData) = object
  157   [ "type" .= ("image" :: Text)
  158   , "source" .= object
  159       [ "type"       .= ("base64" :: Text)
  160       , "media_type" .= mediaType
  161       , "data"       .= TE.decodeUtf8 imageData
  162       ]
  163   ]
  164 encodeContentBlock (ToolResultBlock callId parts isErr) = object $
  165   [ "type"        .= ("tool_result" :: Text)
  166   , "tool_use_id" .= unToolCallId callId
  167   , "content"     .= map encodeToolResultPart parts
  168   ]
  169   ++ ["is_error" .= True | isErr]
  170 
  171 encodeToolResultPart :: ToolResultPart -> Value
  172 encodeToolResultPart (TRPText t) = object
  173   [ "type" .= ("text" :: Text), "text" .= t ]
  174 encodeToolResultPart (TRPImage mediaType imageData) = object
  175   [ "type" .= ("image" :: Text)
  176   , "source" .= object
  177       [ "type"       .= ("base64" :: Text)
  178       , "media_type" .= mediaType
  179       , "data"       .= TE.decodeUtf8 imageData
  180       ]
  181   ]
  182 
  183 encodeTool :: ToolDefinition -> Value
  184 encodeTool td = object
  185   [ "name"         .= _td_name td
  186   , "description"  .= _td_description td
  187   , "input_schema" .= _td_inputSchema td
  188   ]
  189 
  190 encodeToolChoice :: ToolChoice -> Value
  191 encodeToolChoice AutoTool = object ["type" .= ("auto" :: Text)]
  192 encodeToolChoice AnyTool = object ["type" .= ("any" :: Text)]
  193 encodeToolChoice (SpecificTool name) = object
  194   [ "type" .= ("tool" :: Text)
  195   , "name" .= name
  196   ]
  197 
  198 -- | Decode an Anthropic API response into a 'CompletionResponse'.
  199 decodeResponse :: BL.ByteString -> Either String CompletionResponse
  200 decodeResponse bs = eitherDecode bs >>= parseEither parseResp
  201   where
  202     parseResp :: Value -> Parser CompletionResponse
  203     parseResp = withObject "AnthropicResponse" $ \o -> do
  204       contentArr <- o .: "content"
  205       blocks <- mapM parseBlock contentArr
  206       modelText <- o .: "model"
  207       usageObj <- o .: "usage"
  208       inToks <- usageObj .: "input_tokens"
  209       outToks <- usageObj .: "output_tokens"
  210       pure CompletionResponse
  211         { _crsp_content = blocks
  212         , _crsp_model   = ModelId modelText
  213         , _crsp_usage   = Just (Usage inToks outToks)
  214         }
  215 
  216     parseBlock :: Value -> Parser ContentBlock
  217     parseBlock = withObject "ContentBlock" $ \b -> do
  218       bType <- b .: "type"
  219       case (bType :: Text) of
  220         "text" -> TextBlock <$> b .: "text"
  221         "tool_use" -> do
  222           callId <- b .: "id"
  223           name <- b .: "name"
  224           input <- b .: "input"
  225           pure (ToolUseBlock (ToolCallId callId) name input)
  226         other -> fail $ "Unknown content block type: " <> T.unpack other
  227 
  228 -- | Encode a streaming completion request (adds "stream": true).
  229 encodeStreamRequest :: CompletionRequest -> BL.ByteString
  230 encodeStreamRequest req = encode $ object $
  231   [ "model"      .= unModelId (_cr_model req)
  232   , "max_tokens" .= fromMaybe 4096 (_cr_maxTokens req)
  233   , "messages"   .= map encodeMsg (_cr_messages req)
  234   , "stream"     .= True
  235   ]
  236   ++ maybe [] (\s -> ["system" .= s]) (_cr_systemPrompt req)
  237   ++ if null (_cr_tools req)
  238      then maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  239      else ("tools" .= map encodeTool (_cr_tools req))
  240         : maybe [] (\tc -> ["tool_choice" .= encodeToolChoice tc]) (_cr_toolChoice req)
  241 
  242 -- | Stream a completion from the Anthropic API.
  243 -- Processes SSE events and emits StreamEvent callbacks. Accumulates
  244 -- the full response for the final StreamDone event.
  245 anthropicCompleteStream :: AnthropicProvider -> CompletionRequest -> (StreamEvent -> IO ()) -> IO ()
  246 anthropicCompleteStream provider req callback = do
  247   authHeaders <- buildAuthHeaders provider
  248   initReq <- HTTP.parseRequest anthropicBaseUrl
  249   let httpReq = initReq
  250         { HTTP.method         = "POST"
  251         , HTTP.requestBody    = HTTP.RequestBodyLBS (encodeStreamRequest req)
  252         , HTTP.requestHeaders = authHeaders
  253         }
  254   HTTP.withResponse httpReq (_ap_manager provider) $ \resp -> do
  255     let status = Status.statusCode (HTTP.responseStatus resp)
  256     if status /= 200
  257       then do
  258         body <- BL.toStrict <$> HTTP.brReadSome (HTTP.responseBody resp) (1024 * 1024)
  259         throwIO (AnthropicAPIError status body)
  260       else do
  261         -- Accumulate content blocks and usage as events arrive
  262         blocksRef <- newIORef ([] :: [ContentBlock])
  263         modelRef <- newIORef (ModelId "")
  264         usageRef <- newIORef (Nothing :: Maybe Usage)
  265         bufRef <- newIORef BS.empty
  266         let readChunks = do
  267               chunk <- HTTP.brRead (HTTP.responseBody resp)
  268               if BS.null chunk
  269                 then do
  270                   -- Stream ended — emit final response
  271                   blocks <- readIORef blocksRef
  272                   model <- readIORef modelRef
  273                   usage <- readIORef usageRef
  274                   callback $ StreamDone CompletionResponse
  275                     { _crsp_content = reverse blocks
  276                     , _crsp_model   = model
  277                     , _crsp_usage   = usage
  278                     }
  279                 else do
  280                   buf <- readIORef bufRef
  281                   let fullBuf = buf <> chunk
  282                       (lines', remaining) = splitSSELines fullBuf
  283                   writeIORef bufRef remaining
  284                   mapM_ (processSSELine blocksRef modelRef usageRef callback) lines'
  285                   readChunks
  286         readChunks
  287 
  288 -- | Split a buffer into complete SSE lines and remaining partial data.
  289 splitSSELines :: ByteString -> ([ByteString], ByteString)
  290 splitSSELines bs =
  291   let parts = BS.splitWith (== 0x0a) bs  -- split on newline
  292   in case parts of
  293     [] -> ([], BS.empty)
  294     _ -> (init parts, last parts)
  295 
  296 -- | Process a single SSE line.
  297 processSSELine :: IORef [ContentBlock] -> IORef ModelId -> IORef (Maybe Usage) -> (StreamEvent -> IO ()) -> ByteString -> IO ()
  298 processSSELine blocksRef modelRef usageRef callback line =
  299   case parseSSELine line of
  300     Nothing -> pure ()
  301     Just json -> case parseEither parseStreamEvent json of
  302       Left _ -> pure ()
  303       Right evt -> case evt of
  304         SSEContentText t -> do
  305           callback (StreamText t)
  306           -- Accumulate streamed text into a single TextBlock rather than
  307           -- creating one per chunk (which would insert spurious newlines
  308           -- when responseText joins them with "\n").
  309           modifyIORef blocksRef $ \blocks -> case blocks of
  310             (TextBlock prev : rest) -> TextBlock (prev <> t) : rest
  311             _                       -> TextBlock t : blocks
  312         SSEToolStart callId name ->
  313           callback (StreamToolUse callId name)
  314         SSEToolDelta t ->
  315           callback (StreamToolInput t)
  316         SSEMessageStart model -> writeIORef modelRef model
  317         SSEUsage usage -> writeIORef usageRef (Just usage)
  318         SSEMessageStop -> pure ()
  319 
  320 -- | Parse an SSE "data: ..." line into a JSON value.
  321 parseSSELine :: ByteString -> Maybe Value
  322 parseSSELine bs
  323   | BS.isPrefixOf "data: " bs =
  324       let jsonBs = BS.drop 6 bs
  325       in decode (BL.fromStrict jsonBs)
  326   | otherwise = Nothing
  327 
  328 -- | Internal SSE event types.
  329 data SSEEvent
  330   = SSEContentText Text
  331   | SSEToolStart ToolCallId Text
  332   | SSEToolDelta Text
  333   | SSEMessageStart ModelId
  334   | SSEUsage Usage
  335   | SSEMessageStop
  336 
  337 -- | Parse a JSON SSE event.
  338 parseStreamEvent :: Value -> Parser SSEEvent
  339 parseStreamEvent = withObject "SSEEvent" $ \o -> do
  340   eventType <- o .: "type"
  341   case (eventType :: Text) of
  342     "message_start" -> do
  343       msg <- o .: "message"
  344       model <- msg .: "model"
  345       pure (SSEMessageStart (ModelId model))
  346     "content_block_delta" -> do
  347       delta <- o .: "delta"
  348       deltaType <- delta .: "type"
  349       case (deltaType :: Text) of
  350         "text_delta" -> SSEContentText <$> delta .: "text"
  351         "input_json_delta" -> SSEToolDelta <$> delta .: "partial_json"
  352         _ -> fail $ "Unknown delta type: " <> T.unpack deltaType
  353     "content_block_start" -> do
  354       block <- o .: "content_block"
  355       blockType <- block .: "type"
  356       case (blockType :: Text) of
  357         "tool_use" -> do
  358           callId <- block .: "id"
  359           name <- block .: "name"
  360           pure (SSEToolStart (ToolCallId callId) name)
  361         _ -> fail "Ignored block start"
  362     "message_delta" -> do
  363       usage <- o .:? "usage"
  364       case usage of
  365         Just u -> do
  366           outToks <- u .: "output_tokens"
  367           pure (SSEUsage (Usage 0 outToks))
  368         Nothing -> pure SSEMessageStop
  369     "message_stop" -> pure SSEMessageStop
  370     _ -> fail $ "Unknown event type: " <> T.unpack eventType